VectorizationServiceImpl.java 2.54 KB
package com.xly.milvus.service.impl;

import com.xly.milvus.service.VectorizationService;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2.AllMiniLmL6V2EmbeddingModel;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

/**
 * 向量化服务实现 - 使用LangChain4j的All-MiniLM-L6-v2模型
 */
@Slf4j
@Service
public class VectorizationServiceImpl implements VectorizationService {

    private final EmbeddingModel embeddingModel;

    public VectorizationServiceImpl() {
        // 初始化嵌入模型
        this.embeddingModel = new AllMiniLmL6V2EmbeddingModel();
        log.info("向量化模型初始化成功");
    }

    @Override
    public List<Float> textToVector(String text) {
        if (text == null || text.trim().isEmpty()) {
            return new ArrayList<>();
        }

        try {
            // 使用LangChain4j生成向量
            dev.langchain4j.data.embedding.Embedding embedding = embeddingModel.embed(text).content();
            float[] vectorArray = embedding.vector();

            // L2归一化:将向量长度归一化为1
            float[] normalizedArray = normalizeVector(vectorArray);

            // 转换为List<Float>
            List<Float> vector = new ArrayList<>();
            for (float v : normalizedArray) {
                vector.add(v);
            }

            return vector;
        } catch (Exception e) {
            log.error("文本向量化失败: {}", e.getMessage(), e);
            throw new RuntimeException("文本向量化失败", e);
        }
    }

    /**
     * L2归一化:使向量的模长为1
     */
    private float[] normalizeVector(float[] vector) {
        // 计算L2范数
        double norm = 0.0;
        for (float v : vector) {
            norm += v * v;
        }
        norm = Math.sqrt(norm);

        // 归一化
        if (norm > 0) {
            float[] normalized = new float[vector.length];
            for (int i = 0; i < vector.length; i++) {
                normalized[i] = (float)(vector[i] / norm);
            }
            return normalized;
        }
        return vector;
    }

    @Override
    public List<List<Float>> batchTextToVector(List<String> texts) {
        if (texts == null || texts.isEmpty()) {
            return new ArrayList<>();
        }

        return texts.stream()
                .map(this::textToVector)
                .collect(Collectors.toList());
    }
}