0
点赞
收藏
分享

微信扫一扫

Java程序员学深度学习 DJL上手7 使用Pytorch引擎


Java程序员学深度学习 DJL上手7 使用Pytorch引擎

  • ​​一、DJL 项目在maven引用 Pytorch引擎​​
  • ​​1. 引用 pytorch-engin​​
  • ​​2. 引用 pytorch-native-auto库​​
  • ​​二、PyTorch的Model Zoo预训练模型库​​
  • ​​三、PyTorch模型格式转换​​
  • ​​四、加载PyTorch模型​​
  • ​​1. 准备模型​​
  • ​​2. 创建转换器(Translator)​​
  • ​​3. 加载自己的模型​​
  • ​​4. 加载分类器​​
  • ​​5. 执行推理​​
  • ​​五、源代码​​
  • ​​1. pom.xml​​
  • ​​2. java​​
  • ​​六、加载本地模型​​
  • ​​七、模型优化建议​​

本文主要讲解如何在DJL调用Pytorch引擎并使用Pytorch的对象。 由于DJL只支持ScriptTorch格式,所以自己的PyTorch模型需要进行格式转换。本文前节讲了转换的方式 ,后面的演示从网络加载已经转换好的ScriptTorch格式模型。

一、DJL 项目在maven引用 Pytorch引擎

1. 引用 pytorch-engin

<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.13.0-SNAPSHOT</version>
<scope>runtime</scope>
</dependency>

2. 引用 pytorch-native-auto库

目前pytoch-engin的每个版本只能匹配一个pytorch版本,对应关系如下:

PyTorch engine version

PyTorch native library version

pytorch-engine:0.13.0

pytorch-native-auto:1.9.0

pytorch-engine:0.12.0

pytorch-native-auto:1.8.1

pytorch-engine:0.11.0

pytorch-native-auto:1.8.1

pytorch-engine:0.10.0

pytorch-native-auto:1.7.1

pytorch-engine:0.9.0

pytorch-native-auto:1.7.0

pytorch-engine:0.8.0

pytorch-native-auto:1.6.0

pytorch-engine:0.7.0

pytorch-native-auto:1.6.0

pytorch-engine:0.6.0

pytorch-native-auto:1.5.0

pytorch-engine:0.5.0

pytorch-native-auto:1.4.0

pytorch-engine:0.4.0

pytorch-native-auto:1.4.0

使用示例:

<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.9.0</version>
<scope>runtime</scope>
</dependency>

这一步操作与CPU、操作架构、GPU等也有关系,但pytorch-native-auto将自动匹配相应的版本。
如果自适应有问题,可以到 http://docs.djl.ai/engines/pytorch/pytorch-engine/index.html 查询对应架构需要的库进行手工修改。

二、PyTorch的Model Zoo预训练模型库

<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>0.13.0-SNAPSHOT</version>
</dependency>

ModelZoo的预训练的模型主要是机器视觉模型,包括:

  • 图片分类
  • 目标检测
  • 风格迁移
  • 图像生成
    等。

三、PyTorch模型格式转换

需要将PyTorch的模型转为 TorchScript 格式,转换方式主要有两种 :跟踪(Tracing)和脚本(Scripting)。
Tracing的脚本示例:

import torch
import torchvision

# 指向你自己的模型
model = torchvision.models.resnet18(pretrained=True)

# 转为测试模式
model.eval()


# 提供一个示例数据给模型的前向处理(forward)方法
example = torch.rand(1, 3, 224, 224)

# 执行Trace
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

# 保存 TorchScript模型
traced_script_module.save("traced_resnet_model.pt")

四、加载PyTorch模型

1. 准备模型

下面的示例假设已经准备好了TorchScript格式模型, 这里使用预训练的resnet18模型,
DownloadUtils函数用来下载网络上的模型,目标文件夹是build/pytorch_models。

DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
Downloading: 100% |████████████████████████████████████████| resnet18.pt

配合resnet18模型还要有标签文件,同样使用DownloadUtils下载。

DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
Downloading: 100% |████████████████████████████████████████| synset.txt

2. 创建转换器(Translator)

先创建一个管道(每个图像要经过的预处理):

preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

然后创建转换器:

Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
.addTransform(new Resize(256))
.addTransform(new CenterCrop(224, 224))
.addTransform(new ToTensor())
.addTransform(new Normalize(
new float[] {0.485f, 0.456f, 0.406f},
new float[] {0.229f, 0.224f, 0.225f}))
.optApplySoftmax(true)
.build();

3. 加载自己的模型

加载模型时需要一些参数,如 optModelPath 告知模型的位置 。

Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get("build/pytorch_models/resnet18"))
.optTranslator(translator)
.optProgress(new ProgressBar()).build();

ZooModel model = criteria.loadModel();
Loading:     100% |████████████████████████████████████████|

4. 加载分类器

var img = ImageFactory.getInstance().fromUrl("https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg");
img.getWrappedImage()

5. 执行推理

Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications classifications = predictor.predict(img);

打印结果:

classifications

[
class: "n02111889 Samoyed, Samoyede", probability: 0.94256
class: "n02114548 white wolf, Arctic wolf, Canis lupus tundrarum", probability: 0.02820
class: "n02111500 Great Pyrenees", probability: 0.01032
class: "n02120079 Arctic fox, white fox, Alopex lagopus", probability: 0.00412
class: "n02109961 Eskimo dog, husky", probability: 0.00279
]

五、源代码

1. pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>com.xundh</groupId>
<artifactId>djl-learning</artifactId>
<version>0.1-SNAPSHOT</version>

<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<java.version>8</java.version>
<djl.version>0.13.0-SNAPSHOT</djl.version>
</properties>

<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>${djl.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.9.0</version>
</dependency>
</dependencies>
</project>

2. java

package com.xundh;

import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;

import java.io.IOException;
import java.nio.file.Paths;

public class PyTorchLearn {
public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException {
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());

Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
.addTransform(new Resize(256))
.addTransform(new CenterCrop(224, 224))
.addTransform(new ToTensor())
.addTransform(new Normalize(
new float[] {0.485f, 0.456f, 0.406f},
new float[] {0.229f, 0.224f, 0.225f}))
.optApplySoftmax(true)
.build();

Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get("build/pytorch_models/resnet18"))
.optTranslator(translator)
.optProgress(new ProgressBar()).build();

ZooModel model = criteria.loadModel();
Image img = ImageFactory.getInstance().fromUrl("https://img-blog.csdnimg.cn/4c1c40b41c6a49afa69f7ccf96e24ddf.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBA57yW56iL5ZyI5a2Q,size_20,color_FFFFFF,t_70,g_se,x_16#pic_center");
img.getWrappedImage();
Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications classifications = predictor.predict(img);
System.out.println(classifications);
}
}

六、加载本地模型

package com.xundh;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;

public class PyTorchLearn {
public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException {
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());

Path modelDir = Paths.get("build/pytorch_models/resnet18");
Model model = Model.newInstance("resnet");
model.load(modelDir, "resnet18");
Pipeline pipeline = new Pipeline();
pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());

Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
.setPipeline(pipeline)
.optSynsetArtifactName("synset.txt")
.optApplySoftmax(true)
.build();

Image img = ImageFactory.getInstance().fromUrl("https://img-blog.csdnimg.cn/4c1c40b41c6a49afa69f7ccf96e24ddf.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBA57yW56iL5ZyI5a2Q,size_20,color_FFFFFF,t_70,g_se,x_16#pic_center");
img.getWrappedImage();
Predictor<Image, Classifications> predictor = model.newPredictor(translator);
Classifications classifications = predictor.predict(img);
System.out.println(classifications);
}
}

七、模型优化建议

见链接地址:
​​​ https://github.com/deepjavalibrary/djl/blob/master/docs/pytorch/how_to_optimize_inference_performance.md​​


举报

相关推荐

0 条评论