DJL-实现手写数字识别
导入依赖
<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>0.26.0</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.opencv</groupId>
<artifactId>opencv</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-engine</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-native-auto</artifactId>
<version>1.8.0</version>
</dependency>
</dependencies>
训练模型
//获取数据集,用自己的数据就自定义Dataset
RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN);
RandomAccessDataset validateSet = getDataset(Dataset.Usage.TRAIN);
//构建神经网络
Mlp mlp = new Mlp(Mnist.IMAGE_WIDTH * Mnist.IMAGE_HEIGHT,
Mnist.NUM_CLASSES, new int[]{128, 64});
//创建模型,加载block
try (Model model = Model.newInstance("mlp")){
model.setBlock(mlp);
//训练配置
String outputDir = "build/mlp";
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir));
//训练
try (Trainer trainer = model.newTrainer(config)){//训练器
//打印指标数据
trainer.setMetrics(new Metrics());
//初始化
trainer.initialize(new Shape(1,Mnist.IMAGE_HEIGHT*Mnist.IMAGE_WIDTH));
//训练5次
EasyTrain.fit(trainer,5,trainingSet,validateSet);
//训练结果
TrainingResult result = trainer.getTrainingResult();
System.out.println(result);
//保存模型
model.save(Paths.get(outputDir),"mlp");
}
}
数据集
private RandomAccessDataset getDataset(Dataset.Usage usage) {
//内置数据集
Mnist mnist = Mnist.builder()
.setSampling(64, true)//采样
.optUsage(usage)
.build();
mnist.prepare(new ProgressBar());//进度条
return mnist;
}
使用
注意:使用前需要创建synset.txt文件,来告知模型可能的结果
//测试数据
Image img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");
// Image img2 = ImageFactory.getInstance().fromFile(Paths.get("bulid/img/demo.png"));
//加载模型
Path modelDir = Paths.get("build/mlp");
Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(Mnist.IMAGE_WIDTH * Mnist.IMAGE_HEIGHT,
Mnist.NUM_CLASSES, new int[]{128, 64}));
model.load(modelDir);
//使用模型
//转换器
ImageClassificationTranslator translator = ImageClassificationTranslator
.builder()
.addTransform(new Resize(28,28))
.addTransform(new ToTensor())
.build();
//预测器
Predictor<Image, Classifications> predictor = model.newPredictor(translator);
//预测图片
Classifications predict = predictor.predict(img);
System.out.println(predict);