OcrUtil.java 12.4 KB
package com.xly.ocr.util;

import cn.hutool.core.io.FileUtil;
import cn.hutool.core.util.StrUtil;
import com.benjaminwan.ocrlibrary.OcrResult;
import com.benjaminwan.ocrlibrary.TextBlock;
import io.github.mymonstercat.Model;
import io.github.mymonstercat.ocr.InferenceEngine;
import io.github.mymonstercat.ocr.config.ParamConfig;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.multipart.MultipartFile;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.UUID;

@Slf4j
public class OcrUtil {

    // 引擎实例(单例,避免重复初始化)
    private static volatile InferenceEngine engine;
    private static final Object LOCK = new Object();

    /**
     * 获取 OCR 引擎实例(懒加载单例)
     */
    private static InferenceEngine getEngine() {
        if (engine == null) {
            synchronized (LOCK) {
                if (engine == null) {
                    try {
                        log.info("初始化 OCR 引擎 (PP-OCRv4)...");
                        engine = InferenceEngine.getInstance(Model.ONNX_PPOCR_V4);
                        log.info("OCR 引擎初始化成功");
                    } catch (Exception e) {
                        log.error("OCR 引擎初始化失败: {}", e.getMessage(), e);
                        throw new RuntimeException("OCR 引擎初始化失败", e);
                    }
                }
            }
        }
        return engine;
    }

    /**
     * 识别图片中的文字
     * @param imageFile 上传的图片文件
     * @param tempDir 临时目录路径
     * @return 识别出的文字
     */
    public static String ocrFile(MultipartFile imageFile, String tempDir) {
        File tempImageFile = null;
        String processedImagePath = null;

        try {
            log.info("开始 OCR 识别,文件: {}", imageFile.getOriginalFilename());

            // 1. 验证输入
            if (imageFile == null || imageFile.isEmpty()) {
                log.warn("图片文件为空");
                return StrUtil.EMPTY;
            }

            // 2. 创建临时目录
            ensureTempDirExists(tempDir);

            // 3. MultipartFile 转 File
            tempImageFile = multipartFileToFile(imageFile, tempDir);
            if (tempImageFile == null || !tempImageFile.exists()) {
                log.error("转换临时文件失败");
                return StrUtil.EMPTY;
            }

            // 4. 图像预处理
            BufferedImage processedImage = preprocessImage(tempImageFile);
            if (processedImage == null) {
                log.error("图像预处理失败");
                return StrUtil.EMPTY;
            }

            // 5. 保存预处理图片
            processedImagePath = saveProcessedImage(processedImage, tempDir);
            if (processedImagePath == null) {
                log.error("保存预处理图片失败");
                return StrUtil.EMPTY;
            }

            // 6. 执行 OCR 识别
            String text = performOcr(processedImagePath);

            // 7. 记录识别结果
            if (StrUtil.isNotBlank(text)) {
                log.info("OCR 识别成功,文字长度: {} 字符", text.length());
                log.debug("识别结果: {}", text);
            } else {
                log.warn("OCR 识别结果为空");
            }

            return text;

        } catch (Exception e) {
            log.error("OCR 识别失败: {}", e.getMessage(), e);
            return StrUtil.EMPTY;
        } finally {
            // 清理临时文件
            cleanupTempFiles(tempImageFile, processedImagePath);
        }
    }

    /**
     * 确保临时目录存在
     */
    private static void ensureTempDirExists(String tempDir) {
        if (StrUtil.isBlank(tempDir)) {
            tempDir = System.getProperty("java.io.tmpdir");
        }

        File dir = new File(tempDir);
        if (!dir.exists()) {
            boolean created = dir.mkdirs();
            if (created) {
                log.debug("创建临时目录: {}", tempDir);
            } else {
                log.warn("无法创建临时目录: {}", tempDir);
            }
        }
    }

    /**
     * MultipartFile 转 File
     * @param multipartFile 上传文件
     * @param tempDir 临时目录
     * @return File 对象
     */
    public static File multipartFileToFile(MultipartFile multipartFile, String tempDir) throws IOException {
        if (multipartFile == null || multipartFile.isEmpty()) {
            return null;
        }

        // 获取文件扩展名
        String originalFilename = multipartFile.getOriginalFilename();
        String extension = getFileExtension(originalFilename);

        // 生成唯一文件名
        String uniqueFilename = UUID.randomUUID().toString() + extension;
        String filePath = tempDir + File.separator + uniqueFilename;

        File file = new File(filePath);
        multipartFile.transferTo(file);

        log.debug("创建临时文件: {}", filePath);
        return file;
    }

    /**
     * 执行 OCR 识别
     */
    private static String performOcr(String imagePath) {
        try {
            // 获取引擎实例
            InferenceEngine engine = getEngine();

            // 创建参数配置
            ParamConfig config = createOptimizedParamConfig();

            // 执行识别
            long startTime = System.currentTimeMillis();
            OcrResult ocrResult = engine.runOcr(imagePath, config);
            long endTime = System.currentTimeMillis();

            log.info("OCR 识别耗时: {} ms", (endTime - startTime));

            // 输出文本块详情(DEBUG 级别)
            if (log.isDebugEnabled() && ocrResult.getTextBlocks() != null) {
                List<TextBlock> textBlocks = ocrResult.getTextBlocks();
                log.debug("识别到 {} 个文本块", textBlocks.size());
                for (int i = 0; i < textBlocks.size(); i++) {
                    TextBlock block = textBlocks.get(i);
                    log.debug("  块{}: {} (置信度: {})",
                            i + 1, block.getText(), block.getBoxScore());
                }
            }

            return ocrResult.getStrRes().trim();

        } catch (Exception e) {
            log.error("执行 OCR 识别失败: {}", e.getMessage(), e);
            return StrUtil.EMPTY;
        }
    }

    /**
     * 保存预处理后的图片
     */
    private static String saveProcessedImage(BufferedImage image, String tempDir) throws IOException {
        if (image == null) {
            return null;
        }

        String filename = "processed_" + System.currentTimeMillis() + "_" + UUID.randomUUID().toString() + ".png";
        String filePath = tempDir + File.separator + filename;

        File outputFile = new File(filePath);
        ImageIO.write(image, "png", outputFile);

        log.debug("保存预处理图片: {}", filePath);
        return filePath;
    }

    /**
     * 清理临时文件
     */
    private static void cleanupTempFiles(File tempImageFile, String processedImagePath) {
        // 清理原始临时文件
        if (tempImageFile != null && tempImageFile.exists()) {
            boolean deleted = tempImageFile.delete();
            if (deleted) {
                log.debug("删除临时文件: {}", tempImageFile.getPath());
            } else {
                log.warn("删除临时文件失败: {}", tempImageFile.getPath());
                tempImageFile.deleteOnExit();
            }
        }

        // 清理预处理图片
        if (StrUtil.isNotBlank(processedImagePath)) {
            File processedFile = new File(processedImagePath);
            if (processedFile.exists()) {
                boolean deleted = processedFile.delete();
                if (deleted) {
                    log.debug("删除预处理图片: {}", processedImagePath);
                } else {
                    log.warn("删除预处理图片失败: {}", processedImagePath);
                    processedFile.deleteOnExit();
                }
            }
        }
    }

    /**
     * 创建优化的参数配置
     */
    private static ParamConfig createOptimizedParamConfig() {
        ParamConfig config = new ParamConfig();

        // 文本区域扩展
        config.setPadding(50);

        // 最大边长限制(0 表示不限制)
        config.setMaxSideLen(0);

        // 文本块置信度阈值
        config.setBoxScoreThresh(0.4f);
        config.setBoxThresh(0.25f);

        // 文本区域扩展比例
        config.setUnClipRatio(1.8f);

        // 角度检测
        config.setDoAngle(true);
        config.setMostAngle(true);

        log.debug("OCR 参数配置: padding={}, unClipRatio={}",
                config.getPadding(), config.getUnClipRatio());

        return config;
    }

    /**
     * 图像预处理
     */
    private static BufferedImage preprocessImage(File imageFile) throws IOException {
        BufferedImage original = ImageIO.read(imageFile);
        if (original == null) {
            throw new IOException("无法读取图片: " + imageFile.getPath());
        }

        log.debug("原始图片尺寸: {}x{}", original.getWidth(), original.getHeight());

        BufferedImage processed = original;

        // 1. 如果图片太大,缩小尺寸
        if (processed.getWidth() > 2000 || processed.getHeight() > 2000) {
            processed = resizeImage(processed, 1600, 1600);
            log.debug("缩小图片尺寸: {}x{}", processed.getWidth(), processed.getHeight());
        }

        // 2. 增强对比度
        processed = enhanceContrast(processed);

        return processed;
    }

    /**
     * 调整图片大小
     */
    private static BufferedImage resizeImage(BufferedImage image, int maxWidth, int maxHeight) {
        int w = image.getWidth();
        int h = image.getHeight();

        // 计算缩放比例
        double ratio = Math.min((double) maxWidth / w, (double) maxHeight / h);
        if (ratio >= 1.0) {
            return image;
        }

        int newW = (int) (w * ratio);
        int newH = (int) (h * ratio);

        BufferedImage resized = new BufferedImage(newW, newH, BufferedImage.TYPE_INT_RGB);
        Graphics2D g = resized.createGraphics();
        g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BICUBIC);
        g.setRenderingHint(RenderingHints.KEY_RENDERING, RenderingHints.VALUE_RENDER_QUALITY);
        g.drawImage(image, 0, 0, newW, newH, null);
        g.dispose();

        return resized;
    }

    /**
     * 增强对比度
     */
    private static BufferedImage enhanceContrast(BufferedImage image) {
        BufferedImage result = new BufferedImage(image.getWidth(), image.getHeight(), image.getType());

        for (int y = 0; y < image.getHeight(); y++) {
            for (int x = 0; x < image.getWidth(); x++) {
                Color c = new Color(image.getRGB(x, y));
                int r = Math.min(255, (int) (c.getRed() * 1.15));
                int g = Math.min(255, (int) (c.getGreen() * 1.15));
                int b = Math.min(255, (int) (c.getBlue() * 1.15));
                result.setRGB(x, y, new Color(r, g, b).getRGB());
            }
        }

        return result;
    }

    /**
     * 获取文件扩展名
     */
    private static String getFileExtension(String filename) {
        if (StrUtil.isBlank(filename)) {
            return ".jpg";
        }
        int lastDotIndex = filename.lastIndexOf(".");
        if (lastDotIndex == -1) {
            return ".jpg";
        }
        return filename.substring(lastDotIndex);
    }

    /**
     * 测试方法
     */
    public static void main(String[] args) {
        String tempDir = "D:/temp/ocrJava";

        // 测试识别
        try {
            String imagePath = "E:/aa/b.jpg";
            File imageFile = new File(imagePath);
            if (!imageFile.exists()) {
                System.err.println("图片文件不存在: " + imagePath);
                return;
            }

            // 手动测试(实际使用中应该通过 MultipartFile)
            BufferedImage processedImage = preprocessImage(imageFile);
            String processedPath = saveProcessedImage(processedImage, tempDir);
            String result = performOcr(processedPath);

            System.out.println("识别结果: " + result);

            // 清理
            new File(processedPath).delete();

        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}