0
点赞
收藏
分享

微信扫一扫

Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习


Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习

  • ​​一、说明​​
  • ​​二、操作过程​​
  • ​​1. 加载预先训练的ResNet50V1模型​​
  • ​​2. 准备数据集​​

一、说明

Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习_深度学习

本文将使用风格迁移学习模型训练一个图像分类模型。如前文所述,风格迁移学习是训练一个针对某个问题的模型,然后把模型应用在第二个问题上。与直接训练针对特定问题的模型相比,风格迁移学习可以减少学习的特征数量,用更少的时间产生更灵活的模型。

本文使用CIFAR-10数据集训练我们自己的模型,该数据集包含 6万个 32*32的彩色分类图形。

本文的预训练模型使用ResNet50v1 ,它是使用ImageNet训练的有50层的深度学习模型,使用超过120万张图片、拥有1000个分类。本文修改ImageNet,并从CIFAR-10数据集中分类10个类。

本文尚未实验成功,加载预定义模型失败

Java程序员学深度学习 DJL上手9 在CIFAR-10数据集使用风格迁移学习_pytorch_02
CIFAR-10 数据集

二、操作过程

1. 加载预先训练的ResNet50V1模型

ResNet50V1可以在 ModelZoo中找到。此模型是在ImageNet数据集上进行了训练,拥有1000个输出分类。 由于我们要在CIFAR10上重新调整为10个分类,因此我们要删除最后一层,并添加 具有 10个输出通道的新线性层。完成对块的修改后,把块重新放回模型中使用。

// load model and change last layer
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optProgress(new ProgressBar())
.optArtifactId("resnet")
.optFilter("layers", "50")
.optFilter("flavor", "v1").build();
Model model = criteria.loadModel();
SequentialBlock newBlock = new SequentialBlock();
SymbolBlock block = (SymbolBlock) model.getBlock();
block.removeLastBlock();
newBlock.add(block);
newBlock.add(Blocks.batchFlattenBlock());
newBlock.add(Linear.builder().setUnits(10).build());
model.setBlock(newBlock);

2. 准备数据集

在构建数据集时,可以设置训练、测试的数据集大小、批次大小,设置预处理管道。
管道用于对数据进行预处理,例如可以将形状 (32、32、3)和值从0到256的彩色图像NDArray与形状(3、32、32)和值从0转换成1.
另外还可以根据输入数据的均值和标准偏差值使输入数据正常化。

int batchSize = 32;
int limit = Integer.MAX_VALUE; // change this to a small value for a dry run
// int limit = 160; // limit 160 records in the dataset for a dry run
Pipeline pipeline = new Pipeline(
new ToTensor(),
new Normalize(new float[] {0.4914f, 0.4822f, 0.4465f}, new float[] {0.2023f, 0.1994f, 0.2010f}));
Cifar10 trainDataset =
Cifar10.builder()
.setSampling(batchSize, true)
.optUsage(Dataset.Usage.TRAIN)
.optLimit(limit)
.optPipeline(pipeline)
.build();
trainDataset.prepare(new ProgressBar());
```
## 3. 设置训练参数
我们利用预先训练的模型,只进行10次迭代。
```java
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
//softmaxCrossEntropyLoss is a standard loss for classification problems
.addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
.optDevices(Device.getDevices(1)) // Limit your GPU, using more GPU actually will slow down coverging
.addTrainingListeners(TrainingListener.Defaults.logging());

// Now that we have our training configuration, we should create a new trainer for our model
Trainer trainer = model.newTrainer(config);
```
## 4. 训练模型
```java
int epoch = 10;
Shape inputShape = new Shape(1, 3, 32, 32);
trainer.initialize(inputShape);
```
```java
for (int i = 0; i < epoch; ++i) {
int index = 0;
for (Batch batch : trainer.iterateDataset(trainDataset)) {
EasyTrain.trainBatch(trainer, batch);
trainer.step();
batch.close();
}

// reset training and validation evaluators at end of epoch
trainer.notifyListeners(listener -> listener.onEpoch(trainer));
}
```

## 5. 保存模型
```java
Path modelDir = Paths.get("build/resnet");
Files.createDirectories(modelDir);

model.setProperty("Epoch", String.valueOf(epoch));
model.save(modelDir, "resnet");
```

# 源代码
```java
package com.xundh;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.PaddingStackBatchifier;
import ai.djl.translate.TranslateException;
import org.apache.commons.csv.CSVFormat;

import java.io.IOException;
import java.nio.file.Paths;

public class PyTorchLearn {
public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException {
// 根据深度学习引擎,选择要下载的模型
// MXNet base model
String modelUrls = "https://resources.djl.ai/test-models/distilbert.zip";
if ("PyTorch".equals(Engine.getInstance().getEngineName())) {
modelUrls = "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip";
}

Criteria<NDList, NDList> criteria = Criteria.builder()
.optApplication(Application.NLP.WORD_EMBEDDING)
.setTypes(NDList.class, NDList.class)
.optModelUrls(modelUrls)
.optProgress(new ProgressBar())
.build();
ZooModel<NDList, NDList> embedding = criteria.loadModel();
Predictor<NDList, NDList> embedder = embedding.newPredictor();
Block classifier = new SequentialBlock()
// text embedding layer
.add(ndList -> {
NDArray data = ndList.singletonOrThrow();
NDList inputs = new NDList();
long batchSize = data.getShape().get(0);
float maxLength = data.getShape().get(1);

if ("PyTorch".equals(Engine.getInstance().getEngineName())) {
inputs.add(data.toType(DataType.INT64, false));
inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
inputs.add(data.getManager().arange(maxLength)
.toType(DataType.INT64, false)
.broadcast(data.getShape()));
} else {
inputs.add(data);
inputs.add(data.getManager().full(new Shape(batchSize), maxLength));
}
// run embedding
try {
return embedder.predict(inputs);
} catch (TranslateException e) {
throw new IllegalArgumentException("embedding error", e);
}
})
// classification layer
.add(Linear.builder().setUnits(768).build()) // pre classifier
.add(Activation::relu)
.add(Dropout.builder().optRate(0.2f).build())
.add(Linear.builder().setUnits(5).build()) // 5 star rating
.addSingleton(nd -> nd.get(":,0")); // Take [CLS] as the head
Model model = Model.newInstance("AmazonReviewRatingClassification");
model.setBlock(classifier);

// Prepare the vocabulary
SimpleVocabulary vocabulary = SimpleVocabulary.builder()
.optMinFrequency(1)
.addFromTextFile(embedding.getArtifact("vocab.txt"))
.optUnknownToken("[UNK]")
.build();
// Prepare dataset
int maxTokenLength = 64; // cutoff tokens length
int batchSize = 8;
// int limit = Integer.MAX_VALUE;
int limit = 512; // uncomment for quick testing

BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);
// split data with 7:3 train:valid ratio
RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);
RandomAccessDataset trainingSet = datasets[0];
RandomAccessDataset validationSet = datasets[1];
SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
listener.setSaveModelCallback(trainer -> {
TrainingResult result = trainer.getTrainingResult();
Model model1 = trainer.getModel();
// track for accuracy and loss
float accuracy = result.getValidateEvaluation("Accuracy");
model1.setProperty("Accuracy", String.format("%.5f", accuracy));
model1.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
});
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type
.addEvaluator(new Accuracy())
.optDevices(new Device[]{Device.cpu()}) // train using single GPU
.addTrainingListeners(TrainingListener.Defaults.logging("build/model"))
.addTrainingListeners(listener);

int epoch = 2;

Trainer trainer = model.newTrainer(config);
trainer.setMetrics(new Metrics());
Shape encoderInputShape = new Shape(batchSize, maxTokenLength);
// initialize trainer with proper input shape
trainer.initialize(encoderInputShape);
EasyTrain.fit(trainer, epoch, trainingSet, validationSet);
System.out.println(trainer.getTrainingResult());

model.save(Paths.get("build/model"), "amazon-review.param");

String review = "It works great, but it takes too long to update itself and slows the system";
Predictor<String, Classifications> predictor = model.newPredictor(new MyTranslator(tokenizer));
System.out.println(predictor.predict(review));
}

/**
* 下载创建数据集对象
*/
static CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit) {
String amazonReview = "https://s3.amazonaws.com/amazon-reviews-pds/tsv/amazon_reviews_us_Digital_Software_v1_00.tsv.gz";
float paddingToken = tokenizer.getVocabulary().getIndex("[PAD]");
return CsvDataset.builder()
.optCsvUrl(amazonReview) // load from Url
.setCsvFormat(CSVFormat.TDF.withQuote(null).withHeader()) // Setting TSV loading format
.setSampling(batchSize, true) // make sample size and random access
.optLimit(limit)
.addFeature(new CsvDataset.Feature("review_body", new BertFeaturizer(tokenizer, maxLength)))
.addLabel(new CsvDataset.Feature("star_rating", (buf, data) -> buf.put(Float.parseFloat(data) - 1.0f)))
.optDataBatchifier(PaddingStackBatchifier.builder().optIncludeValidLengths(false)
.addPad(0, 0, (m) -> m.ones(new Shape(1)).mul(paddingToken))
.build()) // define how to pad dataset to a fix length
.build();
}
}

```

pom.xml
```xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>com.xundh</groupId>
<artifactId>djl-learning</artifactId>
<version>0.1-SNAPSHOT</version>

<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<java.version>8</java.version>
<djl.version>0.12.0</djl.version>
</properties>

<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>${djl.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.9.0</version>
</dependency>
</dependencies>
</project>

```


举报

相关推荐

0 条评论