EmbeddingServiceImpl.java 2.74 KB
package com.xly.milvus.service.impl;

import com.xly.milvus.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

@Slf4j
@Service("embeddingService")
@RequiredArgsConstructor
public class EmbeddingServiceImpl implements EmbeddingService {

    private final EmbeddingModel embeddingModel;

    /**
     * 生成单个文本的向量
     */
    public List<Float> generateEmbedding(String text) {
        if (text == null || text.trim().isEmpty()) {
            log.warn("Input text is empty");
            return null;
        }

        try {
            // 0.35.0 API: embed 方法返回 Response<Embedding>
            Embedding embedding = embeddingModel.embed(text).content();

            // 将 float[] 转换为 List<Float> 供 Milvus 使用
            float[] vectorArray = embedding.vector();
            List<Float> vectorList = new ArrayList<>(vectorArray.length);
            for (float v : vectorArray) {
                vectorList.add(v);
            }
            return vectorList;
        } catch (Exception e) {
            log.error("Error generating embedding for text: {}", text, e);
            return null;
        }
    }

    /**
     * 批量生成向量(高效版)
     * 利用 LangChain4j 内置的并行化能力,显著提升性能
     */
    public List<List<Float>> generateEmbeddings(List<String> texts) {
        if (texts == null || texts.isEmpty()) {
            return Collections.emptyList();
        }

        try {
            // 将文本转换为 TextSegment 列表
            List<TextSegment> segments = texts.stream()
                    .map(TextSegment::from)
                    .collect(Collectors.toList());

            // 批量嵌入,内部自动并行化
            List<Embedding> embeddings = embeddingModel.embedAll(segments).content();

            // 转换格式
            return embeddings.stream()
                    .map(embedding -> {
                        float[] array = embedding.vector();
                        List<Float> list = new ArrayList<>(array.length);
                        for (float v : array) {
                            list.add(v);
                        }
                        return list;
                    })
                    .collect(Collectors.toList());

        } catch (Exception e) {
            log.error("Error generating embeddings in batch", e);
            return Collections.emptyList();
        }
    }
}