1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
| import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler import numpy as np import torchvision from torchvision import datasets, models, transforms import matplotlib.pyplot as plt import time import os import copy import shutil from torch.utils.tensorboard import SummaryWriter
def set_seed(seed=42): torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
set_seed()
original_data_dir = 'rice_images' data_dir = 'data'
def copy_data(original_dir, target_dir, train_ratio=0.8): if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir)
os.makedirs(os.path.join(target_dir, 'train'), exist_ok=True) os.makedirs(os.path.join(target_dir, 'val'), exist_ok=True)
for class_name in os.listdir(original_dir): class_dir = os.path.join(original_dir, class_name) if os.path.isdir(class_dir): os.makedirs(os.path.join(target_dir, 'train', class_name), exist_ok=True) os.makedirs(os.path.join(target_dir, 'val', class_name), exist_ok=True)
images = [f for f in os.listdir(class_dir) if os.path.isfile(os.path.join(class_dir, f)) and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))] num_images = len(images) num_train = int(train_ratio * num_images) print(f"(✪ω✪) 正在处理类别 '{class_name}',共有 {num_images} 张图片,将复制 {num_train} 张到训练集...")
np.random.shuffle(images) train_images = images[:num_train] val_images = images[num_train:]
for i, image in enumerate(train_images): src = os.path.join(class_dir, image) dst = os.path.join(target_dir, 'train', class_name, image) shutil.copyfile(src, dst) if (i + 1) % 100 == 0: print(f" 已复制 {i + 1}/{num_train} 张训练图片...")
for i, image in enumerate(val_images): src = os.path.join(class_dir, image) dst = os.path.join(target_dir, 'val', class_name, image) shutil.copyfile(src, dst) if (i + 1) % 50 == 0: print(f" 已复制 {i + 1}/{len(val_images)} 张验证图片...")
if not os.path.exists(data_dir) or not os.listdir(data_dir): print("(。♥‿♥。) 正在从 'rice_images' 复制数据到 'data'...") copy_data(original_data_dir, data_dir) else: print(" (´。• ω •。`) 'data' 文件夹已存在且不为空,跳过数据复制…")
data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomRotation(degrees=30), transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=5), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.RandomPerspective(distortion_scale=0.3, p=0.5), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), }
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=16, shuffle=True, num_workers=0) for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes
if torch.cuda.is_available(): device = torch.device("cuda:0") print(f"(✪ω✪) 正在使用 GPU: {torch.cuda.get_device_name(0)} 进行训练!") if torch.cuda.get_device_capability(device)[0] < 7: print(" (´・ω・`) GPU 不支持 FP16,将不会进行混合精度运算") else: device = torch.device("cpu") print("(´。• ω •。`) 没有可用的 GPU,将在 CPU 上训练…呜喵")
print(f"(๑•̀ㅂ•́)و✧ 训练集大小: {dataset_sizes['train']}, 验证集大小: {dataset_sizes['val']}") print("类别:", class_names)
def train_model(model, criterion, optimizer, scheduler, num_epochs=25, start_epoch=0, best_acc=0.0, best_model_wts=None): since = time.time() if best_model_wts is not None: model.load_state_dict(best_model_wts)
writer = SummaryWriter('runs/rice_experiment')
checkpoint_path = 'checkpoints/last_checkpoint.pth' if os.path.exists(checkpoint_path): print(f"(๑•̀ㅂ•́)و✧ 正在从检查点恢复: {checkpoint_path}") checkpoint = torch.load(checkpoint_path) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) best_acc = checkpoint['best_acc'] best_model_wts = model.state_dict() print(f"已从第 {start_epoch} 个 epoch 恢复, 当前最佳准确率: {best_acc:.4f}") else: print("(´・ω・`) 未找到检查点, 从头开始训练.") best_model_wts = copy.deepcopy(model.state_dict())
for epoch in range(start_epoch, num_epochs): print(f'Epoch {epoch + 1}/{num_epochs}') print('-' * 10)
for phase in ['train', 'val']: if phase == 'train': print('(ง •̀_•́)ง 开始训练...') model.train() else: print('(✪ω✪) 开始验证...') model.eval()
running_loss = 0.0 running_corrects = 0 batch_num = 0
for inputs, labels in dataloaders[phase]: batch_num += 1 inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels)
if phase == 'train': loss.backward() optimizer.step()
running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data)
print(f' [{phase}] Batch {batch_num}/{len(dataloaders[phase])}, Loss: {loss.item():.4f}, Acc: {torch.sum(preds == labels.data).double() / inputs.size(0):.4f}')
epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects.double() / dataset_sizes[phase]
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
writer.add_scalar(f'Loss/{phase}', epoch_loss, epoch) writer.add_scalar(f'Accuracy/{phase}', epoch_acc, epoch)
if phase == 'train': scheduler.step()
if phase == 'val': if epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) print(f"(。・∀・)ノ゙ 新的最佳验证准确率: {best_acc:.4f}, 正在保存最佳检查点...") if not os.path.exists('checkpoints'): os.makedirs('checkpoints') torch.save({ 'epoch': epoch + 1, 'model_state_dict': best_model_wts, 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_acc': best_acc, }, 'checkpoints/best_checkpoint.pth')
if not os.path.exists('checkpoints'): os.makedirs('checkpoints') torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_acc': best_acc, }, 'checkpoints/last_checkpoint.pth') print('检查点已保存在 checkpoints/last_checkpoint.pth!')
time_elapsed = time.time() - since print(f'(。・∀・)ノ゙训练完成,用时 {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s') print(f'历史最佳验证准确率: {best_acc:4f}')
writer.close() if best_model_wts: model.load_state_dict(best_model_wts) return model
print('( •̀ ω •́ )y 加载预训练的 ResNet50 模型...') model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
for param in model.parameters(): param.requires_grad = True
num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, len(class_names)) model = model.to(device) criterion = nn.CrossEntropyLoss()
learning_rate = 1e-4 optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=10, T_mult=2, eta_min=learning_rate / 100 )
num_epochs = 100
start_epoch = 0 best_acc = 0.0
print('(≧∇≦)/ 开始训练!') model_ft = train_model(model, criterion, optimizer, scheduler, num_epochs=num_epochs, start_epoch=start_epoch, best_acc=best_acc)
torch.save(model_ft.state_dict(), 'final_model.pth') print('训练结束,最终模型已保存!')
|