OcrUtil.java 5.66 KB
package com.xly.ocr.util;

import cn.hutool.core.util.StrUtil;
import com.benjaminwan.ocrlibrary.OcrResult;
import io.github.mymonstercat.Model;
import io.github.mymonstercat.ocr.InferenceEngine;
import io.github.mymonstercat.ocr.config.ParamConfig;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.multipart.MultipartFile;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.util.UUID;

@Slf4j
public class OcrUtil {

    private static volatile InferenceEngine engine;
    private static final Object LOCK = new Object();
    private static final String TEMP_PATH = "D:/ocr_temp";

    static {
        try {
            File tempDir = new File(TEMP_PATH);
            if (!tempDir.exists()) tempDir.mkdirs();
            System.setProperty("java.io.tmpdir", TEMP_PATH);
            System.setProperty("ORT_TMP_DIR", TEMP_PATH);
            log.info("环境初始化完成");
        } catch (Exception e) {
            log.error("初始化失败", e);
        }
    }

    // ✅ 原版 V4,绝不崩溃
    private static InferenceEngine getEngine() {
        if (engine == null) {
            synchronized (LOCK) {
                if (engine == null) {
                    log.info("初始化 OCR 引擎 (PP-OCRv4)...");
                    engine = InferenceEngine.getInstance(Model.ONNX_PPOCR_V4);
                    log.info("OCR 引擎初始化成功");
                }
            }
        }
        return engine;
    }

    public static String ocrFile(MultipartFile imageFile, String tempDir) {
        File tempImageFile = null;
        String processedPath = null;
        try {
            if (imageFile.isEmpty()) return StrUtil.EMPTY;
            tempImageFile = multipartFileToFile(imageFile, TEMP_PATH);
            BufferedImage img = preprocessImage(tempImageFile);
            processedPath = saveProcessedImage(img, TEMP_PATH);
            return performOcr(processedPath);
        } catch (Exception e) {
            log.error("识别失败", e);
            return StrUtil.EMPTY;
        } finally {
            if (tempImageFile != null) tempImageFile.delete();
            if (processedPath != null) new File(processedPath).delete();
        }
    }

    private static File multipartFileToFile(MultipartFile file, String dir) throws IOException {
        File temp = new File(dir, UUID.randomUUID() + "_" + file.getOriginalFilename());
        try (InputStream in = file.getInputStream()) {
            Files.copy(in, temp.toPath());
        }
        return temp;
    }

    // ✅ 安全预处理
    private static BufferedImage preprocessImage(File file) throws IOException {
        BufferedImage original = ImageIO.read(file);
        BufferedImage rgb = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_RGB);
        Graphics2D g = rgb.createGraphics();
        g.drawImage(original, 0, 0, null);
        g.dispose();
        if (rgb.getWidth() > 1600) {
            return resizeImage(rgb, 1280, 1280);
        }
        return rgb;
    }

    private static BufferedImage resizeImage(BufferedImage img, int w, int h) {
        int width = img.getWidth();
        int height = img.getHeight();
        double ratio = Math.min((double) w / width, (double) h / height);
        if (ratio >= 1) return img;
        int newW = (int) (width * ratio);
        int newH = (int) (height * ratio);
        BufferedImage resized = new BufferedImage(newW, newH, BufferedImage.TYPE_INT_RGB);
        Graphics2D g = resized.createGraphics();
        g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BICUBIC);
        g.drawImage(img, 0, 0, newW, newH, null);
        g.dispose();
        return resized;
    }

    private static String saveProcessedImage(BufferedImage img, String dir) throws IOException {
        String name = UUID.randomUUID() + "_proc.png";
        File out = new File(dir, name);
        ImageIO.write(img, "png", out);
        return out.getAbsolutePath();
    }

    // ================================
    // ✅【关键修复】参数永不崩溃
    // ================================
    private static String performOcr(String path) {
        try {
            InferenceEngine engine = getEngine();
            ParamConfig c = new ParamConfig();

            // ❌ 错误的致命参数 👇 已经删除
            // config.setPadding(50);

            // ✅ 安全稳定参数
            c.setPadding(5);           // 很小,不崩溃
            c.setMaxSideLen(1536);
            c.setBoxScoreThresh(0.4f);
            c.setBoxThresh(0.3f);
            c.setUnClipRatio(1.3f);
            c.setDoAngle(true);
            c.setMostAngle(true);

            OcrResult res = engine.runOcr(path.replace("\\", "/"), c);
            return res.getStrRes().trim();
        } catch (Exception e) {
            return StrUtil.EMPTY;
        }
    }

    // 测试
    public static void main(String[] args) {
        try {
            getEngine();
            System.out.println("✅ 引擎启动成功");

            File img = new File("E:/aa/b.jpg");
            if (img.exists()) {
                BufferedImage proc = preprocessImage(img);
                String p = saveProcessedImage(proc, TEMP_PATH);
                String result = performOcr(p);

                System.out.println("=====================================");
                System.out.println(result);
                System.out.println("=====================================");
                new File(p).delete();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}