LOADING

加载过慢请开启缓存 浏览器默认开启

ML-day3

2025/1/17 ml

经过前面2天的学习,对基本的ML有了了解,但是为了更好的学习和应用,我决定转化到fastai的课程上去学习,他们的思维模式和教学模式更符合我的需求。也就是不是告诉你XX是YY,而且告诉你 what can u do & what u need do。

Practical Deep Learning for Coders - Practical Deep Learning

下面的第一节课的代码,我自己重写了下,思路不变还是利用fastai的便利,自带学习机和模型,只需要提供数据集即可。

用Unsplash来获取图片数据集。

我非常认可fastai的理念,训练并不是需要大量的数据集和超硬核的机器。

import os
import requests
from fastai.vision.all import *
from PIL import Image
from pathlib import Path

# 设置 Unsplash API Access Key
UNSPLASH_ACCESS_KEY = 'x'

# 1. 使用 Unsplash API 获取图片
def search_images(query, max_images=10):
    """
    使用 Unsplash API 搜索图像并返回图片 URL。
    """
    url = f"https://api.unsplash.com/photos/random?query={query}&count={max_images}&client_id={UNSPLASH_ACCESS_KEY}"
    response = requests.get(url)
    
    if response.status_code == 200:
        results = response.json()
        image_urls = [result['urls']['regular'] for result in results]
        return image_urls
    else:
        print(f"Error: {response.status_code}")
        return []

# 2. 下载图像到指定文件夹
def download_images(urls, folder_name):
    """
    下载图像到指定文件夹。
    """
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    for i, url in enumerate(urls):
        try:
            img_data = requests.get(url).content
            with open(f"{folder_name}/{i+1}.jpg", "wb") as f:
                f.write(img_data)
            print(f"Downloaded image {i+1}")
        except Exception as e:
            print(f"Error downloading image {i+1}: {e}")

# 3. 下载鸟类和森林图片
bird_urls = search_images("bird", max_images=10)
forest_urls = search_images("forest", max_images=10)

download_images(bird_urls, 'bird_images')
download_images(forest_urls, 'forest_images')

# 4. 准备数据集
def prepare_dataset(bird_folder, forest_folder):
    """
    将下载的图片移动到分类文件夹中,生成所需的文件夹结构。
    """
    # 创建目标文件夹
    target_folder = 'images'
    if not os.path.exists(target_folder):
        os.makedirs(target_folder)

    # 将鸟类图片移动到 "bird" 文件夹
    bird_target = os.path.join(target_folder, 'bird')
    if not os.path.exists(bird_target):
        os.makedirs(bird_target)

    # 将森林类图片移动到 "forest" 文件夹
    forest_target = os.path.join(target_folder, 'forest')
    if not os.path.exists(forest_target):
        os.makedirs(forest_target)

    # 移动文件
    for img in os.listdir(bird_folder):
        img_path = os.path.join(bird_folder, img)
        if img.endswith('.jpg'):
            os.rename(img_path, os.path.join(bird_target, img))

    for img in os.listdir(forest_folder):
        img_path = os.path.join(forest_folder, img)
        if img.endswith('.jpg'):
            os.rename(img_path, os.path.join(forest_target, img))

prepare_dataset('bird_images', 'forest_images')

# 5. 加载数据集
path = Path('images')
dls = ImageDataLoaders.from_folder(
    path, 
    valid_pct=0.2,  # 使用 20% 的数据进行验证
    seed=42,  # 固定随机种子,确保结果一致
    item_tfms=Resize(224),  # 图像预处理,将图片大小调整为224x224
    batch_tfms=aug_transforms()  # 数据增强
)

# 6. 创建 CNN 模型
learn = cnn_learner(dls, resnet34, metrics=accuracy)

# 7. 训练模型
learn.fine_tune(1)

# 8. 评估模型
learn.show_results()

# 9. 进行预测
img_path = 'xxxx.jpg'  # 替换为你要进行预测的图像路径
img = PILImage.create(img_path)
pred, pred_idx, probs = learn.predict(img)

# 打印预测结果
print(f'预测标签: {pred}; 预测概率: {probs}')

# 10. 保存和加载模型
learn.save('bird_forest_model')  # 保存模型

# 加载保存的模型
learn.load('bird_forest_model')