1.下载pytorch官方制作的vs模板
LibTorch Project - Visual Studio Marketplace
2.下载mnist手写体数据集
MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
下载后为压缩包,注意libtorch加载的文件名格式为“t10k-images-idx3-ubyte”(Linux下ungzip命令解压),不是“t10k-images.idx3-ubyte”(win10下zip解压)(注意小点.的区别,否则会产生c10:erro),路径设置格式如下:
std::string data_path = "F:\\C++_Study\\ml\\TorchProject1\\MNIST";
3.libtorch加载数据集代码
auto dataset = torch::data::datasets::MNIST(data_path)
.map(torch::data::transforms::Stack<>());
auto dataloader = torch::data::make_data_loader(std::move(dataset), 64);
或者
auto data_loader = torch::data::make_data_loader(
torch::data::datasets::MNIST(data_path).map(
torch::data::transforms::Stack<>()), 64);
//将60000个3阶张量(1x28x28)转化为4阶张量(60000x1x28x28)
.map(torch::data::transforms::Stack<>())
//batch_size = 64
make_data_loader(std::move(dataset), 64);