DJL-实现手写数字识别

DJL-实现手写数字识别

起男 44 2025-04-11

DJL-实现手写数字识别

导入依赖

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>0.26.0</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl.opencv</groupId>
            <artifactId>opencv</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-engine</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-native-auto</artifactId>
            <version>1.8.0</version>
        </dependency>
    </dependencies>

训练模型

        //获取数据集,用自己的数据就自定义Dataset
        RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN);
        RandomAccessDataset validateSet = getDataset(Dataset.Usage.TRAIN);
        //构建神经网络
        Mlp mlp = new Mlp(Mnist.IMAGE_WIDTH * Mnist.IMAGE_HEIGHT,
                Mnist.NUM_CLASSES, new int[]{128, 64});
        //创建模型,加载block
        try (Model model = Model.newInstance("mlp")){
            model.setBlock(mlp);
            //训练配置
            String outputDir = "build/mlp";
            DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                    .addEvaluator(new Accuracy())
                    .addTrainingListeners(TrainingListener.Defaults.logging(outputDir));
            //训练
            try (Trainer trainer = model.newTrainer(config)){//训练器
                //打印指标数据
                trainer.setMetrics(new Metrics());
                //初始化
                trainer.initialize(new Shape(1,Mnist.IMAGE_HEIGHT*Mnist.IMAGE_WIDTH));
                //训练5次
                EasyTrain.fit(trainer,5,trainingSet,validateSet);
                //训练结果
                TrainingResult result = trainer.getTrainingResult();
                System.out.println(result);
                //保存模型
                model.save(Paths.get(outputDir),"mlp");
            }
        }

数据集

    private RandomAccessDataset getDataset(Dataset.Usage usage) {
        //内置数据集
        Mnist mnist = Mnist.builder()
                .setSampling(64, true)//采样
                .optUsage(usage)
                .build();
        mnist.prepare(new ProgressBar());//进度条
        return mnist;
    }

使用

注意:使用前需要创建synset.txt文件,来告知模型可能的结果

        //测试数据
        Image img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");
//        Image img2 = ImageFactory.getInstance().fromFile(Paths.get("bulid/img/demo.png"));
        //加载模型
        Path modelDir = Paths.get("build/mlp");
        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(Mnist.IMAGE_WIDTH * Mnist.IMAGE_HEIGHT,
                Mnist.NUM_CLASSES, new int[]{128, 64}));
        model.load(modelDir);
        //使用模型
        //转换器
        ImageClassificationTranslator translator = ImageClassificationTranslator
                .builder()
                .addTransform(new Resize(28,28))
                .addTransform(new ToTensor())
                .build();
        //预测器
        Predictor<Image, Classifications> predictor = model.newPredictor(translator);
        //预测图片
        Classifications predict = predictor.predict(img);
        System.out.println(predict);