经过前面2天的学习,对基本的ML有了了解,但是为了更好的学习和应用,我决定转化到fastai的课程上去学习,他们的思维模式和教学模式更符合我的需求。也就是不是告诉你XX是YY,而且告诉你 what can u do & what u need do。
Practical Deep Learning for Coders - Practical Deep Learning
下面的第一节课的代码,我自己重写了下,思路不变还是利用fastai的便利,自带学习机和模型,只需要提供数据集即可。
用Unsplash来获取图片数据集。
我非常认可fastai的理念,训练并不是需要大量的数据集和超硬核的机器。
import os
import requests
from fastai.vision.all import *
from PIL import Image
from pathlib import Path
# 设置 Unsplash API Access Key
UNSPLASH_ACCESS_KEY = 'x'
# 1. 使用 Unsplash API 获取图片
def search_images(query, max_images=10):
"""
使用 Unsplash API 搜索图像并返回图片 URL。
"""
url = f"https://api.unsplash.com/photos/random?query={query}&count={max_images}&client_id={UNSPLASH_ACCESS_KEY}"
response = requests.get(url)
if response.status_code == 200:
results = response.json()
image_urls = [result['urls']['regular'] for result in results]
return image_urls
else:
print(f"Error: {response.status_code}")
return []
# 2. 下载图像到指定文件夹
def download_images(urls, folder_name):
"""
下载图像到指定文件夹。
"""
if not os.path.exists(folder_name):
os.makedirs(folder_name)
for i, url in enumerate(urls):
try:
img_data = requests.get(url).content
with open(f"{folder_name}/{i+1}.jpg", "wb") as f:
f.write(img_data)
print(f"Downloaded image {i+1}")
except Exception as e:
print(f"Error downloading image {i+1}: {e}")
# 3. 下载鸟类和森林图片
bird_urls = search_images("bird", max_images=10)
forest_urls = search_images("forest", max_images=10)
download_images(bird_urls, 'bird_images')
download_images(forest_urls, 'forest_images')
# 4. 准备数据集
def prepare_dataset(bird_folder, forest_folder):
"""
将下载的图片移动到分类文件夹中,生成所需的文件夹结构。
"""
# 创建目标文件夹
target_folder = 'images'
if not os.path.exists(target_folder):
os.makedirs(target_folder)
# 将鸟类图片移动到 "bird" 文件夹
bird_target = os.path.join(target_folder, 'bird')
if not os.path.exists(bird_target):
os.makedirs(bird_target)
# 将森林类图片移动到 "forest" 文件夹
forest_target = os.path.join(target_folder, 'forest')
if not os.path.exists(forest_target):
os.makedirs(forest_target)
# 移动文件
for img in os.listdir(bird_folder):
img_path = os.path.join(bird_folder, img)
if img.endswith('.jpg'):
os.rename(img_path, os.path.join(bird_target, img))
for img in os.listdir(forest_folder):
img_path = os.path.join(forest_folder, img)
if img.endswith('.jpg'):
os.rename(img_path, os.path.join(forest_target, img))
prepare_dataset('bird_images', 'forest_images')
# 5. 加载数据集
path = Path('images')
dls = ImageDataLoaders.from_folder(
path,
valid_pct=0.2, # 使用 20% 的数据进行验证
seed=42, # 固定随机种子,确保结果一致
item_tfms=Resize(224), # 图像预处理,将图片大小调整为224x224
batch_tfms=aug_transforms() # 数据增强
)
# 6. 创建 CNN 模型
learn = cnn_learner(dls, resnet34, metrics=accuracy)
# 7. 训练模型
learn.fine_tune(1)
# 8. 评估模型
learn.show_results()
# 9. 进行预测
img_path = 'xxxx.jpg' # 替换为你要进行预测的图像路径
img = PILImage.create(img_path)
pred, pred_idx, probs = learn.predict(img)
# 打印预测结果
print(f'预测标签: {pred}; 预测概率: {probs}')
# 10. 保存和加载模型
learn.save('bird_forest_model') # 保存模型
# 加载保存的模型
learn.load('bird_forest_model')