学习视频:https://www.bilibili.com/video/BV1hE411t7RN?p=1,内含环境搭建
torchvision.datasets使用
Pytorch有自带许多数据集可供学习使用,因此,本次挑选一个来进行学习,官方地址
CIFAR10
本次使用的是CIFAR10数据集。
import torchvision
train_set = torchvision.datasets.CIFAR10(root = "./CIFAR10_Dataset",train = True,download = True)
test_set = torchvision.datasets.CIFAR10(root = "./CIFAR10_Dataset",train = False,download = True)
# root:下载的根目录
# train:true为训练集,false为测试集
# download:是否从网上下载,若下载速度慢可以用迅雷下载:
# https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
查看一下数据集的数据是什么类型的
print(test_set[0])
由输出结果可以看出是一个PIL Image类型,"3"代表类别时三,CIFAR10有10个类别,具体类别参考官方文档
(<PIL.Image.Image image mode=RGB size=32x32 at 0x281591063D0>, 3)
转化为Tensor类型
在调用参数中多加一个自带的transform的参数:
test_data = torchvision.datasets.CIFAR10(root = "./CIFAR10_Dataset",
train = False,
transform = torchvision.transforms.ToTensor())
print(test_data[0])
得到的就是tensor数据类型
(tensor([[[0.6196, 0.6235, 0.6471, ..., 0.5373, 0.4941, 0.4549],
[0.5961, 0.5922, 0.6235, ..., 0.5333, 0.4902, 0.4667],
[0.5922, 0.5922, 0.6196, ..., 0.5451, 0.5098, 0.4706],
...,
[0.2667, 0.1647, 0.1216, ..., 0.1490, 0.0510, 0.1569],
[0.2392, 0.1922, 0.1373, ..., 0.1020, 0.1137, 0.0784],
[0.2118, 0.2196, 0.1765, ..., 0.0941, 0.1333, 0.0824]],
[[0.4392, 0.4353, 0.4549, ..., 0.3725, 0.3569, 0.3333],
[0.4392, 0.4314, 0.4471, ..., 0.3725, 0.3569, 0.3451],
[0.4314, 0.4275, 0.4353, ..., 0.3843, 0.3725, 0.3490],
...,
[0.4863, 0.3922, 0.3451, ..., 0.3804, 0.2510, 0.3333],
[0.4549, 0.4000, 0.3333, ..., 0.3216, 0.3216, 0.2510],
[0.4196, 0.4118, 0.3490, ..., 0.3020, 0.3294, 0.2627]],
[[0.1922, 0.1843, 0.2000, ..., 0.1412, 0.1412, 0.1294],
[0.2000, 0.1569, 0.1765, ..., 0.1216, 0.1255, 0.1333],
[0.1843, 0.1294, 0.1412, ..., 0.1333, 0.1333, 0.1294],
...,
[0.6941, 0.5804, 0.5373, ..., 0.5725, 0.4235, 0.4980],
[0.6588, 0.5804, 0.5176, ..., 0.5098, 0.4941, 0.4196],
[0.6275, 0.5843, 0.5176, ..., 0.4863, 0.5059, 0.4314]]]), 3)
DataLoader使用
Dataloader就是把数据整理成适合输入到神经网络形式的工具
from torch.utils.data import DataLoader
test_loader = DataLoader(dataset = test_data,
batch_size = 4, # how many samples per batch to load 一次取多少
shuffle = True, # set to True to have the data reshuffled at every epoch 是否打乱
num_workers = 0, # how many subprocesses to use for data loading 多少个子进程进行读取
drop_last = False) # set to True to drop the last incomplete batch,if the dataset size is not divisible by the batch size. 是否丢弃最后不能形成一个batch的数据
for data in test_loader:
imgs,targets = data
print(imgs.shape) # batch_size = 4:四张图,每张图3通道,像素为32*32
print(targets) # 四张图的所对应的类别
一个输出结果
torch.Size([4, 3, 32, 32])
tensor([0, 8, 5, 1])
借助TensorBoard更直观地查看:
from torch.utils.tensorboard import SummaryWriter
test_loader = DataLoader(dataset = test_data,
batch_size = 64,
shuffle = True,
num_workers = 0,
drop_last = False)
writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
imgs,targets = data
writer.add_images("test_data",imgs,step)
step = step + 1
writer.close()
每一步有64张图











