图像增广
参考文章
概念
图像增广是增加数据多样性的技术,属于广义上的图像数据增强范畴。图像增广(image augmentation)技术通过对训练图像做一系列随机改变,来产生相似但又不同的训练样本,从而扩大训练数据集的规模。图像增广的另一种解释是,随机改变训练样本可以降低模型 对某些属性的依赖,从而提高模型的泛化能力。
为了在预测时得到确定的结果,我们通常只将图像增广应用在训练样本上,而不在预测时使用含随机操作的图像增广。
作用:在数据驱动的深度学习模型中,加入合适的图像增广是提升模型准确率的有效途径,往往能使模型准确率提升 2%-5%甚至更多。
技术方法分类
文章将图像增广技术分为两大类:
| 类别 | 主要方法/思想 | 具体技术举例 |
|---|---|---|
| 图像变换 | 对单张图像进行颜色或几何上的调整 | 几何变换:平移、旋转、裁剪、缩放、翻转、错切、仿射变换 颜色变换:色调分离,以及锐度、亮度、饱和度、对比度调整 |
| 图像合成 | 组合多张图像或其特征来生成新样本 | 简单合成:Mixup, CutMix, AugMix 自动生成:生成对抗网络(GAN)、对抗学习 |
PyTorch 框架下实现图像增广
- 翻转
- 左右翻转
- 上下翻转
- 颜色调整
- 亮度(Brightness)
- 对比度(Contrast)
- 饱和度(Saturation)
- 色调(Hue)
- 旋转
- 裁剪
- 随机裁剪 (Random Crop)
- 中心裁剪 (Center Crop)
- 变换
- 随机仿射变换 (Random Affine)
- 随机透视变换 (Random Perspective)
- 随机灰度化 (Random Grayscale)
- 随机擦除 (Random Erasing)
- 弹性形变 (Elastic Transform)
- 高斯模糊 (Gaussian Blur)
- 组合变换 (Compose)
- 翻转、颜色、旋转、变换等
实践
详情
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch
# 加载图像
img_path = 'cat.png' # 替换为你的图像路径
img = Image.open(img_path)
# 定义各个增广函数
def plot_transformed_images(original_img, transform, title):
"""
对给定的图像应用变换并绘制原始和变换后的图像。
:param original_img: 原始图像
:param transform: 要应用的变换
:param title: 变换的名称(用于图表标题)
"""
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(original_img)
axs[0].set_title('Original Image')
axs[0].axis('off')
transformed_img = transform(original_img)
axs[1].imshow(transformed_img)
axs[1].set_title(title)
axs[1].axis('off')
plt.show()
# 左右翻转
transform_hflip = transforms.RandomHorizontalFlip(p=1) # p=1 表示总是翻转
plot_transformed_images(img, transform_hflip, "Horizontal Flip")
# 上下翻转
transform_vflip = transforms.RandomVerticalFlip(p=1) # p=1 表示总是翻转
plot_transformed_images(img, transform_vflip, "Vertical Flip")
# 颜色调整 - 亮度
transform_brightness = transforms.ColorJitter(brightness=1.5, contrast=0, saturation=0, hue=0)
plot_transformed_images(img, transform_brightness, "Brightness Adjusted")
# 颜色调整 - 对比度
transform_contrast = transforms.ColorJitter(brightness=0, contrast=1.5, saturation=0, hue=0)
plot_transformed_images(img, transform_contrast, "Contrast Adjusted")
# 颜色调整 - 饱和度
transform_saturation = transforms.ColorJitter(brightness=0, contrast=0, saturation=1.5, hue=0)
plot_transformed_images(img, transform_saturation, "Saturation Adjusted")
# 颜色调整 - 色调
transform_hue = transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0.1)
plot_transformed_images(img, transform_hue, "Hue Adjusted")
# 旋转
transform_rotate = transforms.RandomRotation(degrees=45) # 旋转角度范围为 [-45, 45]
plot_transformed_images(img, transform_rotate, "Rotated")







# ---------------------------------------------------------
# 1. 随机裁剪 (Random Crop)
# 从图像中随机裁剪出指定大小的区域,常用于让模型关注局部特征
# ---------------------------------------------------------
transform_random_crop = transforms.RandomCrop(size=(200, 200))
# 注意:如果原图小于200x200,需要先用 Pad 或 Resize 处理,这里假设原图足够大
plot_transformed_images(img, transform_random_crop, "Random Crop (200x200)")
# ---------------------------------------------------------
# 2. 中心裁剪 (Center Crop)
# 从图像中心裁剪出指定大小,常用于测试阶段或去除边缘无关信息
# ---------------------------------------------------------
transform_center_crop = transforms.CenterCrop(size=(200, 200))
plot_transformed_images(img, transform_center_crop, "Center Crop (200x200)")
# ---------------------------------------------------------
# 3. 随机仿射变换 (Random Affine)
# 包含旋转、平移、缩放和剪切的组合变换,模拟更复杂的视角变化
# degrees: 旋转角度范围
# translate: 平移比例 (宽, 高)
# scale: 缩放比例范围
# shear: 剪切角度范围
# ---------------------------------------------------------
transform_affine = transforms.RandomAffine(
degrees=30,
translate=(0.1, 0.1),
scale=(0.8, 1.2),
shear=10
)
plot_transformed_images(img, transform_affine, "Random Affine")
# ---------------------------------------------------------
# 4. 随机透视变换 (Random Perspective)
# 模拟透视变形,让图像看起来像是从不同角度拍摄的平面物体
# distortion_scale: 控制变形程度 (0-1)
# ---------------------------------------------------------
transform_perspective = transforms.RandomPerspective(distortion_scale=0.5, p=1.0)
plot_transformed_images(img, transform_perspective, "Random Perspective")
# ---------------------------------------------------------
# 5. 弹性形变 (Elastic Transform) - 需较新版本 torchvision
# 模拟局部非线性变形,常用于医学图像或增加纹理多样性
# alpha: 变形强度
# sigma: 高斯核标准差,控制平滑度
# ---------------------------------------------------------
try:
transform_elastic = transforms.ElasticTransform(alpha=50.0, sigma=5.0)
plot_transformed_images(img, transform_elastic, "Elastic Transform")
except AttributeError:
print("当前 torchvision 版本可能不支持 ElasticTransform,请升级库。")
# ---------------------------------------------------------
# 6. 高斯模糊 (Gaussian Blur)
# 模拟相机失焦或运动模糊,提高模型对模糊图像的鲁棒性
# kernel_size: 模糊核大小 (必须是奇数或序列)
# sigma: 标准差范围
# ---------------------------------------------------------
transform_blur = transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 2.0))
plot_transformed_images(img, transform_blur, "Gaussian Blur")
# ---------------------------------------------------------
# 7. 随机灰度化 (Random Grayscale)
# 以一定概率将图像转换为灰度图,迫使模型不依赖颜色特征
# p: 转换为灰度的概率
# ---------------------------------------------------------
transform_grayscale = transforms.RandomGrayscale(p=1.0) # p=1 强制变灰度
plot_transformed_images(img, transform_grayscale, "Random Grayscale")
# ---------------------------------------------------------
# 8. 随机擦除 (Random Erasing) - 修复警告版本
# ---------------------------------------------------------
transform_to_tensor = transforms.ToTensor()
tensor_img = transform_to_tensor(img)
transform_erasing = transforms.RandomErasing(p=1.0, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random')
erased_tensor = transform_erasing(tensor_img)
# 【关键修复步骤】
# 1. 使用 clamp 确保数据严格在 [0, 1] 之间
erased_tensor_clamped = torch.clamp(erased_tensor, 0.0, 1.0)
# 2. 将 Tensor (C, H, W) 转换为 numpy (H, W, C) 供 matplotlib 使用
# permute(1, 2, 0) 改变维度顺序
original_np = tensor_img.permute(1, 2, 0).numpy()
erased_np = erased_tensor_clamped.permute(1, 2, 0).numpy()
# 绘制 Random Erasing
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(original_np)
axs[0].set_title('Original (Tensor)')
axs[0].axis('off')
axs[1].imshow(erased_np)
axs[1].set_title('Random Erasing (Clamped)')
axs[1].axis('off')
plt.tight_layout()
plt.show()
# ---------------------------------------------------------
# 9. 组合变换 (Compose) - 实际训练中的常用方式
# 将多个变换串联在一起,一次性应用
# ---------------------------------------------------------
combined_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.RandomRotation(degrees=15),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
])
plot_transformed_images(img, combined_transform, "Combined Augmentations")









