基于ResNet50的植物病害识别训练过程

深入剖析 PyTorch 训练脚本

一个用于图像分类(以水稻图像为例)的 PyTorch 训练脚本。虽然它看起来可能是一个标准的迁移学习流程,但仔细观察会发现其中蕴含着一些值得学习的设计理念和实用技巧。本文将带你解读这份代码的独特之处、关键原理以及可能的“创新点”(或者说,是值得借鉴的优秀实践)。

核心流程概览

在深入细节之前,我们先快速浏览一下脚本的核心功能:

  1. 环境设置:设置随机种子以保证实验的可复现性。
  2. 数据准备:自动将原始数据集 (rice_images) 按比例(默认 80/20)分割成训练集 (train) 和验证集 (val),并存放到 data 目录下。如果 data 目录已存在且非空,则跳过复制。
  3. 数据加载与增强
    • 使用 torchvision.datasets.ImageFolder 加载数据。
    • 定义了强大的数据增强 (data_transforms) 策略,特别是对训练集应用了多种随机变换(裁剪、翻转、旋转、仿射、颜色抖动、透视变换),以提高模型的泛化能力。验证集则使用标准的中心裁剪和缩放。
    • 使用 torch.utils.data.DataLoader 创建数据加载器。
  4. 模型准备
    • 加载预训练的 ResNet50 模型 (pretrained=True),利用迁移学习加速训练并提升性能。
    • 解冻所有模型参数 (param.requires_grad = True),允许整个网络进行微调。
    • 修改最后的全连接层 (model.fc) 以匹配目标数据集的类别数量。
  5. 训练环境:检测 GPU 是否可用,并选择相应的设备 (cudacpu)。
  6. 核心训练循环 (train_model):这是脚本的“心脏”,包含了训练和验证的主要逻辑,我们稍后会详细分析。
  7. 优化器与学习率调度器:选择了 AdamW 优化器,并搭配了 CosineAnnealingWarmRestarts 学习率调度策略。
  8. 执行训练与保存:调用 train_model 函数执行训练,并在结束后保存最终模型 (final_model.pth) 和训练过程中的最佳模型 (checkpoints/best_checkpoint.pth)。

亮点解读:是什么让这份代码与众不同?

现在,让我们聚焦于那些让这份代码脱颖而出的关键特性。

亮点 1:强大的数据增强组合拳 🥊

1
2
3
4
5
6
7
8
9
10
11
12
13
# ... 部分数据增强代码 ...
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(), # 增加了垂直翻转
transforms.RandomRotation(degrees=30),
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=5), # 仿射变换
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 颜色抖动
transforms.RandomPerspective(distortion_scale=0.3, p=0.5), # 透视变换
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
# ...

脚本采用了非常丰富的数据增强技术组合。除了常见的随机裁剪和水平翻转,还加入了垂直翻转、随机旋转、仿射变换(平移、缩放、剪切)、颜色抖动(亮度、对比度、饱和度、色调)以及随机透视变换。

  • 原理:数据增强通过在训练过程中对输入图像应用各种随机变换,人为地增加了训练数据的多样性。这有助于模型学习到对位置、角度、光照、颜色等变化更不敏感的特征,从而提高模型的泛化能力,有效防止过拟合
  • 独特之处:这种组合式的、较为激进的增强策略,我希望模型能应对各种复杂的实际场景变化,尤其适用于数据量不是特别巨大或者图像本身形态、角度、光照变化较大的场景。

亮点 2:高级学习率调度策略 - CosineAnnealingWarmRestarts 📈📉

1
2
3
4
5
6
7
8
9
10
11
# (。・∀・)ノ゙ 优化器和学习率调度器
learning_rate = 1e-4
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4) # 使用 AdamW

# (。・∀・)ノ゙ 使用 CosineAnnealingWarmRestarts
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=10, # 第一次重启前的 epoch 数
T_mult=2, # 每次重启后,周期变为原来的 T_mult 倍
eta_min=learning_rate / 100 # 学习率下限
)

脚本没有使用简单的固定学习率或 StepLR,而是选择了 CosineAnnealingWarmRestarts 学习率调度器,并搭配了 AdamW 优化器。

  • AdamW 原理:AdamW 是 Adam 优化器的一个改进版本,它将权重衰减(Weight Decay)的处理方式从 L2 正则化中分离出来,直接在梯度更新步骤中减去权重的一小部分。实践证明,这种方式对于 Adam 这类自适应学习率算法通常效果更好。
  • CosineAnnealingWarmRestarts 原理
    • Cosine Annealing(余弦退火): 在一个周期(T_0 个 epochs)内,学习率从初始值按照余弦函数平滑地下降到最小值 (eta_min)。这种平滑下降有助于模型在训练后期更稳定地收敛到最优解附近。
    • Warm Restarts(热重启): 在每个周期结束时,学习率突然“重启”回初始值(或一个较高的值),然后开始下一个余弦退火周期。并且,每个新周期的长度可以是上一个周期的 T_mult 倍。
  • 独特之处/优势:这种策略结合了平滑下降周期性重启的优点。平滑下降有助于精细搜索,而周期性的“热重启”可以帮助模型跳出可能陷入的局部最优解,重新探索参数空间。T_mult 参数使得后续周期的探索时间更长,可能找到更好的解。这是一种相对高级且有效的学习率调整策略。

以500epoch为例的训练过程图像
alt 训练曲线

亮点 3:精心设计的训练循环与强大的检查点机制 💾🔄

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# 训练函数 train_model 内部关键逻辑

# ... 加载上一个检查点 ...
checkpoint_path = 'checkpoints/last_checkpoint.pth'
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # 加载优化器状态
scheduler.load_state_dict(checkpoint['scheduler_state_dict']) # 加载调度器状态
best_acc = checkpoint['best_acc']
# ...

for epoch in range(start_epoch, num_epochs):
# ... 训练和验证 ...

# ... (在验证阶段 val) ...
if phase == 'val':
if epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
# (。・∀・)ノ゙ 保存最佳检查点
torch.save({
'epoch': epoch + 1,
'model_state_dict': best_model_wts,
'optimizer_state_dict': optimizer.state_dict(), # 保存优化器
'scheduler_state_dict': scheduler.state_dict(), # 保存调度器
'best_acc': best_acc,
}, 'checkpoints/best_checkpoint.pth')

# (在每个 epoch 结束时)
# 保存最后一个检查点
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), # 保存优化器
'scheduler_state_dict': scheduler.state_dict(), # 保存调度器
'best_acc': best_acc, # 保存当前的最佳准确率
}, 'checkpoints/last_checkpoint.pth')
print('检查点已保存在 checkpoints/last_checkpoint.pth!')

# ... (训练结束后) ...
# 加载最佳权重
model.load_state_dict(best_model_wts)
return model

train_model 函数的设计体现了良好的工程实践,特别是在检查点(Checkpointing) 方面:

  1. 支持从中断处恢复:脚本在开始训练前会检查 checkpoints/last_checkpoint.pth 是否存在。如果存在,它会加载模型的权重、优化器的状态、学习率调度器的状态以及上一个 epoch 的编号和当时记录的最佳准确率。这使得训练可以在意外中断(如断电、程序崩溃)后无缝恢复,继续之前的进度。加载优化器和调度器状态对于保持训练动态(如 Adam 的动量、CosineAnnealing 的周期状态)至关重要。
  2. 保存最佳模型:在每个 epoch 的验证阶段结束后,如果当前模型的验证准确率 (epoch_acc) 超过了历史最佳准确率 (best_acc),脚本会深度拷贝 (copy.deepcopy) 当前模型的权重,并将其连同优化器、调度器状态等信息保存为 checkpoints/best_checkpoint.pth。这确保了即使训练后期模型性能有所下降,你也能找回整个训练过程中表现最好的那个模型状态。
  3. 保存最新模型:在每个 epoch 结束时,脚本都会无条件地保存当前的模型、优化器、调度器状态到 checkpoints/last_checkpoint.pth。这保证了无论训练何时停止,你总能从最后一个完成的 epoch 状态恢复。
  4. 详细的 Batch 级别日志:在每个 batch 处理后打印 Loss 和 Acc,提供了非常细致的训练过程监控,有助于快速发现训练初期的异常。虽然可能输出信息较多,但对于调试很有帮助。
  5. TensorBoard 集成:使用 SummaryWriter 将每个 epoch 的训练和验证损失、准确率写入 TensorBoard 日志,方便后续可视化分析训练过程。
  • 独特之处/优势:这种双检查点策略(最佳 + 最新) 结合了鲁棒性和最优性。它既能保证训练的可恢复性,又能确保最终得到验证集上表现最好的模型。同时,完整地保存和加载优化器/调度器状态是实现真正无缝恢复的关键,这一点做得非常好。

总结

这份 PyTorch 训练脚本虽然基于常见的 ResNet50 迁移学习范式,但其在以下几个方面展现了优秀的实践和值得借鉴的设计:

  1. 强大的数据增强:通过组合多种变换提升模型泛化能力。
  2. 先进的学习率策略:采用 AdamWCosineAnnealingWarmRestarts 优化训练动态,有助于跳出局部最优。
  3. 鲁棒的检查点机制:同时保存最佳和最新状态(包含优化器和调度器),确保训练的可恢复性和最终模型的最优性。
  4. 明确的设计选择:移除 Early Stopping,选择固定 Epoch 数训练。
  5. 良好的监控与日志:集成了 TensorBoard,并提供了详细的 Batch 级别日志。

以下为训练脚本源代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import shutil
from torch.utils.tensorboard import SummaryWriter

def set_seed(seed=42):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

set_seed()

original_data_dir = 'rice_images'
data_dir = 'data'

def copy_data(original_dir, target_dir, train_ratio=0.8):
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)

os.makedirs(os.path.join(target_dir, 'train'), exist_ok=True)
os.makedirs(os.path.join(target_dir, 'val'), exist_ok=True)

for class_name in os.listdir(original_dir):
class_dir = os.path.join(original_dir, class_name)
if os.path.isdir(class_dir):
os.makedirs(os.path.join(target_dir, 'train', class_name), exist_ok=True)
os.makedirs(os.path.join(target_dir, 'val', class_name), exist_ok=True)

images = [f for f in os.listdir(class_dir) if os.path.isfile(os.path.join(class_dir, f)) and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))]
num_images = len(images)
num_train = int(train_ratio * num_images)
print(f"(✪ω✪) 正在处理类别 '{class_name}',共有 {num_images} 张图片,将复制 {num_train} 张到训练集...")

np.random.shuffle(images)
train_images = images[:num_train]
val_images = images[num_train:]

for i, image in enumerate(train_images):
src = os.path.join(class_dir, image)
dst = os.path.join(target_dir, 'train', class_name, image)
shutil.copyfile(src, dst)
if (i + 1) % 100 == 0:
print(f" 已复制 {i + 1}/{num_train} 张训练图片...")

for i, image in enumerate(val_images):
src = os.path.join(class_dir, image)
dst = os.path.join(target_dir, 'val', class_name, image)
shutil.copyfile(src, dst)
if (i + 1) % 50 == 0:
print(f" 已复制 {i + 1}/{len(val_images)} 张验证图片...")

if not os.path.exists(data_dir) or not os.listdir(data_dir):
print("(。♥‿♥。) 正在从 'rice_images' 复制数据到 'data'...")
copy_data(original_data_dir, data_dir)
else:
print(" (´。• ω •。`) 'data' 文件夹已存在且不为空,跳过数据复制…")

data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(degrees=30),
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=5),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomPerspective(distortion_scale=0.3, p=0.5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=16, shuffle=True, num_workers=0) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

if torch.cuda.is_available():
device = torch.device("cuda:0")
print(f"(✪ω✪) 正在使用 GPU: {torch.cuda.get_device_name(0)} 进行训练!")
if torch.cuda.get_device_capability(device)[0] < 7:
print(" (´・ω・`) GPU 不支持 FP16,将不会进行混合精度运算")
else:
device = torch.device("cpu")
print("(´。• ω •。`) 没有可用的 GPU,将在 CPU 上训练…呜喵")

print(f"(๑•̀ㅂ•́)و✧ 训练集大小: {dataset_sizes['train']}, 验证集大小: {dataset_sizes['val']}")
print("类别:", class_names)

def train_model(model, criterion, optimizer, scheduler, num_epochs=25, start_epoch=0, best_acc=0.0, best_model_wts=None):
since = time.time()
if best_model_wts is not None:
model.load_state_dict(best_model_wts)

writer = SummaryWriter('runs/rice_experiment')

checkpoint_path = 'checkpoints/last_checkpoint.pth'
if os.path.exists(checkpoint_path):
print(f"(๑•̀ㅂ•́)و✧ 正在从检查点恢复: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
best_acc = checkpoint['best_acc']
best_model_wts = model.state_dict() # 确保恢复时更新 best_model_wts
print(f"已从第 {start_epoch} 个 epoch 恢复, 当前最佳准确率: {best_acc:.4f}")
else:
print("(´・ω・`) 未找到检查点, 从头开始训练.")
best_model_wts = copy.deepcopy(model.state_dict()) # 从头训练时初始化 best_model_wts


for epoch in range(start_epoch, num_epochs):
print(f'Epoch {epoch + 1}/{num_epochs}')
print('-' * 10)

for phase in ['train', 'val']:
if phase == 'train':
print('(ง •̀_•́)ง 开始训练...')
model.train()
else:
print('(✪ω✪) 开始验证...')
model.eval()

running_loss = 0.0
running_corrects = 0
batch_num = 0

for inputs, labels in dataloaders[phase]:
batch_num += 1
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()

with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)

if phase == 'train':
loss.backward()
optimizer.step()

running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)

print(f' [{phase}] Batch {batch_num}/{len(dataloaders[phase])}, Loss: {loss.item():.4f}, Acc: {torch.sum(preds == labels.data).double() / inputs.size(0):.4f}')

epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]

print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

writer.add_scalar(f'Loss/{phase}', epoch_loss, epoch)
writer.add_scalar(f'Accuracy/{phase}', epoch_acc, epoch)

if phase == 'train':
scheduler.step()

if phase == 'val':
if epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print(f"(。・∀・)ノ゙ 新的最佳验证准确率: {best_acc:.4f}, 正在保存最佳检查点...")
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
torch.save({
'epoch': epoch + 1,
'model_state_dict': best_model_wts,
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_acc': best_acc,
}, 'checkpoints/best_checkpoint.pth')


if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_acc': best_acc,
}, 'checkpoints/last_checkpoint.pth')
print('检查点已保存在 checkpoints/last_checkpoint.pth!')

time_elapsed = time.time() - since
print(f'(。・∀・)ノ゙训练完成,用时 {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'历史最佳验证准确率: {best_acc:4f}')

writer.close()
if best_model_wts:
model.load_state_dict(best_model_wts)
return model

print('( •̀ ω •́ )y 加载预训练的 ResNet50 模型...')
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

for param in model.parameters():
param.requires_grad = True

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))
model = model.to(device)
criterion = nn.CrossEntropyLoss()

learning_rate = 1e-4
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)

scheduler = lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=10,
T_mult=2,
eta_min=learning_rate / 100
)

num_epochs = 100

start_epoch = 0
best_acc = 0.0

print('(≧∇≦)/ 开始训练!')
model_ft = train_model(model, criterion, optimizer, scheduler, num_epochs=num_epochs, start_epoch=start_epoch, best_acc=best_acc)


torch.save(model_ft.state_dict(), 'final_model.pth')
print('训练结束,最终模型已保存!')