AI知识分享AI知识分享
✿导航
  • 人工智能
  • 神经网络
  • 机器学习
  • 深度学习
  • 强化学习
  • 自然语言处理
  • 计算机视觉
  • 大模型基础
  • 动手学深度学习
  • 理论理解
  • 工程实践
  • 应用开发
  • AI For Everyone
  • AIGC_2024大会
  • AIGC_2025大会
  • Transformer
  • Pytorch
  • HuggingFace
  • 蒸馏
  • RAG
  • 目标检测
  • MCP
  • 概念
  • 意图识别
  • 工具
✿导航
  • 人工智能
  • 神经网络
  • 机器学习
  • 深度学习
  • 强化学习
  • 自然语言处理
  • 计算机视觉
  • 大模型基础
  • 动手学深度学习
  • 理论理解
  • 工程实践
  • 应用开发
  • AI For Everyone
  • AIGC_2024大会
  • AIGC_2025大会
  • Transformer
  • Pytorch
  • HuggingFace
  • 蒸馏
  • RAG
  • 目标检测
  • MCP
  • 概念
  • 意图识别
  • 工具
  • Transformer

    • Transformer - 概述
    • Transformer - Encoding and Decoding Context with Attention
    • Transformer - Tokenizers
    • Transformer - 架构
    • Transformer - Block
    • Transformer - 自注意力
    • Transformer - MoE
    • Transformer - Transformer
  • Pytorch

    • Pytorch - Dataset
    • Pytorch - TensorBoard
    • Pytorch - transforms
    • Pytorch - DataLoader
    • Pytorch - nn
    • Pytorch - Model
    • Pytorch - train
    • Pytorch - Practice
    • Pytorch - pytorch
  • HuggingFace

    • HuggingFace - Transformers
    • HuggingFace - Pipeline
    • HuggingFace - Tokenizer
    • HuggingFace - Model
    • HuggingFace - Datasets
    • HuggingFace - Evaluate
    • HuggingFace - Trainer

Pytorch - 数据集加载

  1. 数据集加载学习,运行程序前需要解压dataset.zip
  2. 关注MyData类需要继承Dataset类,重写__getitem__和__len__的方法
详情
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

class MyData(Dataset):

    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.image_path = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.image_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img,label

    def __len__(self):
        return len(self.image_path)

root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset





最近更新: 2025/4/14 07:18
Contributors: klc407073648
Next
Pytorch - TensorBoard