AI知识分享AI知识分享
✿导航
  • 人工智能
  • 神经网络
  • 机器学习
  • 深度学习
  • 强化学习
  • 自然语言处理
  • 计算机视觉
  • 大模型基础
  • 动手学深度学习
  • 理论理解
  • 工程实践
  • 应用开发
  • AI For Everyone
  • AIGC_2024大会
  • AIGC_2025大会
  • Transformer
  • Pytorch
  • HuggingFace
  • 蒸馏
  • RAG
  • 目标检测
  • MCP
  • 概念
  • 意图识别
  • 工具
✿导航
  • 人工智能
  • 神经网络
  • 机器学习
  • 深度学习
  • 强化学习
  • 自然语言处理
  • 计算机视觉
  • 大模型基础
  • 动手学深度学习
  • 理论理解
  • 工程实践
  • 应用开发
  • AI For Everyone
  • AIGC_2024大会
  • AIGC_2025大会
  • Transformer
  • Pytorch
  • HuggingFace
  • 蒸馏
  • RAG
  • 目标检测
  • MCP
  • 概念
  • 意图识别
  • 工具
  • 大模型基础

    • 语言模型基础

      • 概述
      • 基于统计方法的语言模型
      • 基于神经网络的语言模型
      • 语言模型的采样方法
      • 语言模型的评测
    • 大语言模型架构

      • 概述
      • 主流模型架构
      • Encoder-only
      • Encoder-Decoder
      • Decoder-only
      • 非Transformer 架构
    • Prompt工程

      • 工程简介
      • 上下文学习
      • 思维链
      • 技巧
    • 参数高效微调

      • 概述
      • 参数附加方法
      • 参数选择方法
      • 低秩适配方法
      • 实践与应用
    • 模型编辑

      • 简介
      • 方法
      • 附加参数法
      • 定位编辑法
    • RAG

      • 基础
      • 架构
      • 知识检索
      • 生成增强
  • 动手学深度学习

    • 深度学习基础

      • 引言
      • 数据操作
      • 数据预处理
      • 数学知识(线代、矩阵计算、求导)
      • 线性回归
      • 基础优化方法
      • Softmax回归
      • 感知机
      • 模型选择
      • 过拟合和欠拟合
      • 环境和分布偏移
      • 权重衰减
      • Dropout
      • 数值稳定性
    • 卷积神经网络

      • 模型基本操作
      • 从全连接层到卷积
      • 填充和步长
      • 多个输入和输出通道
      • 池化层
      • LeNet
      • AlexNet
      • VGG
      • NiN网络
      • GoogleNet
      • 批量归一化
      • ResNet
    • 计算机视觉

      • 图像增广
      • 微调
      • 目标检测
      • 锚框
      • 区域卷积神经网络
      • 单发多框检测
      • 一次看完
      • 语义分割
      • 转置卷积
      • 全连接卷积神经网络
      • 样式迁移
    • 循环神经网络

      • 序列模型
      • 语言模型
      • 循环神经网络
      • 序列到序列学习
      • 搜索策略
    • 注意力机制

      • 优化算法

图像增广

参考文章

  • PyTorch(五):图像增广(image augmentation)

概念

图像增广是增加数据多样性的技术,属于广义上的图像数据增强范畴。图像增广(image augmentation)技术通过对训练图像做一系列随机改变,来产生相似但又不同的训练样本,从而扩大训练数据集的规模。图像增广的另一种解释是,随机改变训练样本可以降低模型 对某些属性的依赖,从而提高模型的泛化能力。

为了在预测时得到确定的结果,我们通常只将图像增广应用在训练样本上,而不在预测时使用含随机操作的图像增广。

作用:在数据驱动的深度学习模型中,加入合适的图像增广是提升模型准确率的有效途径,往往能使模型准确率提升 2%-5%甚至更多。

技术方法分类

文章将图像增广技术分为两大类:

类别主要方法/思想具体技术举例
图像变换对单张图像进行颜色或几何上的调整几何变换:平移、旋转、裁剪、缩放、翻转、错切、仿射变换
颜色变换:色调分离,以及锐度、亮度、饱和度、对比度调整
图像合成组合多张图像或其特征来生成新样本简单合成:Mixup, CutMix, AugMix
自动生成:生成对抗网络(GAN)、对抗学习

PyTorch 框架下实现图像增广

  • 翻转
    • 左右翻转
    • 上下翻转
  • 颜色调整
    • 亮度(Brightness)
    • 对比度(Contrast)
    • 饱和度(Saturation)
    • 色调(Hue)
  • 旋转
  • 裁剪
    • 随机裁剪 (Random Crop)
    • 中心裁剪 (Center Crop)
  • 变换
    • 随机仿射变换 (Random Affine)
    • 随机透视变换 (Random Perspective)
    • 随机灰度化 (Random Grayscale)
    • 随机擦除 (Random Erasing)
    • 弹性形变 (Elastic Transform)
    • 高斯模糊 (Gaussian Blur)
  • 组合变换 (Compose)
    • 翻转、颜色、旋转、变换等

实践

详情
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch

# 加载图像
img_path = 'cat.png'  # 替换为你的图像路径
img = Image.open(img_path)

# 定义各个增广函数
def plot_transformed_images(original_img, transform, title):
    """
    对给定的图像应用变换并绘制原始和变换后的图像。
    :param original_img: 原始图像
    :param transform: 要应用的变换
    :param title: 变换的名称(用于图表标题)
    """
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(original_img)
    axs[0].set_title('Original Image')
    axs[0].axis('off')

    transformed_img = transform(original_img)
    axs[1].imshow(transformed_img)
    axs[1].set_title(title)
    axs[1].axis('off')

    plt.show()

# 左右翻转
transform_hflip = transforms.RandomHorizontalFlip(p=1)  # p=1 表示总是翻转
plot_transformed_images(img, transform_hflip, "Horizontal Flip")

# 上下翻转
transform_vflip = transforms.RandomVerticalFlip(p=1)  # p=1 表示总是翻转
plot_transformed_images(img, transform_vflip, "Vertical Flip")

# 颜色调整 - 亮度
transform_brightness = transforms.ColorJitter(brightness=1.5, contrast=0, saturation=0, hue=0)
plot_transformed_images(img, transform_brightness, "Brightness Adjusted")

# 颜色调整 - 对比度
transform_contrast = transforms.ColorJitter(brightness=0, contrast=1.5, saturation=0, hue=0)
plot_transformed_images(img, transform_contrast, "Contrast Adjusted")

# 颜色调整 - 饱和度
transform_saturation = transforms.ColorJitter(brightness=0, contrast=0, saturation=1.5, hue=0)
plot_transformed_images(img, transform_saturation, "Saturation Adjusted")

# 颜色调整 - 色调
transform_hue = transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0.1)
plot_transformed_images(img, transform_hue, "Hue Adjusted")

# 旋转
transform_rotate = transforms.RandomRotation(degrees=45)  # 旋转角度范围为 [-45, 45]
plot_transformed_images(img, transform_rotate, "Rotated")

png

png

png

png

png

png

png

# ---------------------------------------------------------
# 1. 随机裁剪 (Random Crop)
# 从图像中随机裁剪出指定大小的区域,常用于让模型关注局部特征
# ---------------------------------------------------------
transform_random_crop = transforms.RandomCrop(size=(200, 200)) 
# 注意:如果原图小于200x200,需要先用 Pad 或 Resize 处理,这里假设原图足够大
plot_transformed_images(img, transform_random_crop, "Random Crop (200x200)")

# ---------------------------------------------------------
# 2. 中心裁剪 (Center Crop)
# 从图像中心裁剪出指定大小,常用于测试阶段或去除边缘无关信息
# ---------------------------------------------------------
transform_center_crop = transforms.CenterCrop(size=(200, 200))
plot_transformed_images(img, transform_center_crop, "Center Crop (200x200)")

# ---------------------------------------------------------
# 3. 随机仿射变换 (Random Affine)
# 包含旋转、平移、缩放和剪切的组合变换,模拟更复杂的视角变化
# degrees: 旋转角度范围
# translate: 平移比例 (宽, 高)
# scale: 缩放比例范围
# shear: 剪切角度范围
# ---------------------------------------------------------
transform_affine = transforms.RandomAffine(
    degrees=30, 
    translate=(0.1, 0.1), 
    scale=(0.8, 1.2), 
    shear=10
)
plot_transformed_images(img, transform_affine, "Random Affine")

# ---------------------------------------------------------
# 4. 随机透视变换 (Random Perspective)
# 模拟透视变形,让图像看起来像是从不同角度拍摄的平面物体
# distortion_scale: 控制变形程度 (0-1)
# ---------------------------------------------------------
transform_perspective = transforms.RandomPerspective(distortion_scale=0.5, p=1.0)
plot_transformed_images(img, transform_perspective, "Random Perspective")

# ---------------------------------------------------------
# 5. 弹性形变 (Elastic Transform) - 需较新版本 torchvision
# 模拟局部非线性变形,常用于医学图像或增加纹理多样性
# alpha: 变形强度
# sigma: 高斯核标准差,控制平滑度
# ---------------------------------------------------------
try:
    transform_elastic = transforms.ElasticTransform(alpha=50.0, sigma=5.0)
    plot_transformed_images(img, transform_elastic, "Elastic Transform")
except AttributeError:
    print("当前 torchvision 版本可能不支持 ElasticTransform,请升级库。")

# ---------------------------------------------------------
# 6. 高斯模糊 (Gaussian Blur)
# 模拟相机失焦或运动模糊,提高模型对模糊图像的鲁棒性
# kernel_size: 模糊核大小 (必须是奇数或序列)
# sigma: 标准差范围
# ---------------------------------------------------------
transform_blur = transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 2.0))
plot_transformed_images(img, transform_blur, "Gaussian Blur")

# ---------------------------------------------------------
# 7. 随机灰度化 (Random Grayscale)
# 以一定概率将图像转换为灰度图,迫使模型不依赖颜色特征
# p: 转换为灰度的概率
# ---------------------------------------------------------
transform_grayscale = transforms.RandomGrayscale(p=1.0) # p=1 强制变灰度
plot_transformed_images(img, transform_grayscale, "Random Grayscale")

# ---------------------------------------------------------
# 8. 随机擦除 (Random Erasing) - 修复警告版本
# ---------------------------------------------------------
transform_to_tensor = transforms.ToTensor()
tensor_img = transform_to_tensor(img)

transform_erasing = transforms.RandomErasing(p=1.0, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random')
erased_tensor = transform_erasing(tensor_img)

# 【关键修复步骤】
# 1. 使用 clamp 确保数据严格在 [0, 1] 之间
erased_tensor_clamped = torch.clamp(erased_tensor, 0.0, 1.0)

# 2. 将 Tensor (C, H, W) 转换为 numpy (H, W, C) 供 matplotlib 使用
# permute(1, 2, 0) 改变维度顺序
original_np = tensor_img.permute(1, 2, 0).numpy()
erased_np = erased_tensor_clamped.permute(1, 2, 0).numpy()

# 绘制 Random Erasing
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(original_np) 
axs[0].set_title('Original (Tensor)')
axs[0].axis('off')

axs[1].imshow(erased_np)
axs[1].set_title('Random Erasing (Clamped)')
axs[1].axis('off')

plt.tight_layout()
plt.show()

# ---------------------------------------------------------
# 9. 组合变换 (Compose) - 实际训练中的常用方式
# 将多个变换串联在一起,一次性应用
# ---------------------------------------------------------
combined_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(degrees=15),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
])
plot_transformed_images(img, combined_transform, "Combined Augmentations")

png

png

png

png

png

png

png

png

png

最近更新: 2026/3/25 06:52
Contributors: klc407073648
Next
微调