LOADING

加载过慢请开启缓存 浏览器默认开启

ML-house_price_predict-day4

2025/1/20 ml

直接上kaggle进行实战,选了个比较好入门的比赛题目,房价预测

House Prices - Advanced Regression Techniques | Kaggle

1038/4635

Image

代码

还得是fastai 简简单单就能做到一个不错的结果

from fastai.tabular.all import *
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from fastai.metrics import R2Score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载数据
train_df = pd.read_csv('train.csv')
test_df = pd.read_csv('test.csv')

# 定义目标列和特征列
target = 'SalePrice'
num_cols = train_df.select_dtypes(include='number').columns.tolist()
cat_cols = train_df.select_dtypes(exclude='number').columns.tolist()

# 缺失值处理
for col in num_cols:
    if col != target:
        train_df[col].fillna(train_df[col].median(), inplace=True)
        test_df[col].fillna(train_df[col].median(), inplace=True)

# 对数转换目标列
train_df[target] = np.log1p(train_df[target])

# 特征分离 把房价和Id 2个无关数据剔除
X_train = train_df.drop(columns=[target, 'Id'])
X_test = test_df.drop(columns=['Id'])

# 数据预处理
cont_names = [col for col in X_train.columns if col not in cat_cols]
cat_names = [col for col in X_train.columns if col in cat_cols]

# 创建DataLoader
train_dl = TabularDataLoaders.from_df(train_df, path='.', y_names=target, cat_names=cat_names, cont_names=cont_names,
                                      procs=[Categorify, FillMissing, Normalize], bs=64)

# 训练模型 采用多个评估方法
learn = tabular_learner(train_dl, layers=[4000, 2000, 1000, 500], metrics=[rmse, mae, R2Score()], wd=1e-2)

# 将模型转移到GPU
learn.to(device)

# 查找最适合的学习率
learn.lr_find()

# 学习率曲线
plt.figure()
learn.recorder.plot_lr_find()
plt.savefig('lr_find_plot.png')

# 使用OneCycleLR进行训练
learn.fit_one_cycle(500, lr_max=1e-3)

# 绘制损失曲线
plt.figure()
learn.recorder.plot_loss()
plt.savefig('loss_plot.png')

# 预测测试集
test_dl = learn.dls.test_dl(X_test)
test_preds, _ = learn.get_preds(dl=test_dl)

# 反转对数转换
test_preds = np.expm1(test_preds)

# 保存预测结果
test_df['SalePrice'] = test_preds
test_df[['Id', 'SalePrice']].to_csv('predictions.csv', index=False)

print("预测结果已保存到 predictions.csv")

预测CSV点此下载

predictions.csv