VectorizationServiceImpl.java
2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package com.xly.milvus.service.impl;
import com.xly.milvus.service.VectorizationService;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2.AllMiniLmL6V2EmbeddingModel;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
/**
* 向量化服务实现 - 使用LangChain4j的All-MiniLM-L6-v2模型
*/
@Slf4j
@Service
public class VectorizationServiceImpl implements VectorizationService {
private final EmbeddingModel embeddingModel;
public VectorizationServiceImpl() {
// 初始化嵌入模型
this.embeddingModel = new AllMiniLmL6V2EmbeddingModel();
log.info("向量化模型初始化成功");
}
@Override
public List<Float> textToVector(String text) {
if (text == null || text.trim().isEmpty()) {
return new ArrayList<>();
}
try {
// 使用LangChain4j生成向量
dev.langchain4j.data.embedding.Embedding embedding = embeddingModel.embed(text).content();
float[] vectorArray = embedding.vector();
// L2归一化:将向量长度归一化为1
float[] normalizedArray = normalizeVector(vectorArray);
// 转换为List<Float>
List<Float> vector = new ArrayList<>();
for (float v : normalizedArray) {
vector.add(v);
}
return vector;
} catch (Exception e) {
log.error("文本向量化失败: {}", e.getMessage(), e);
throw new RuntimeException("文本向量化失败", e);
}
}
/**
* L2归一化:使向量的模长为1
*/
private float[] normalizeVector(float[] vector) {
// 计算L2范数
double norm = 0.0;
for (float v : vector) {
norm += v * v;
}
norm = Math.sqrt(norm);
// 归一化
if (norm > 0) {
float[] normalized = new float[vector.length];
for (int i = 0; i < vector.length; i++) {
normalized[i] = (float)(vector[i] / norm);
}
return normalized;
}
return vector;
}
@Override
public List<List<Float>> batchTextToVector(List<String> texts) {
if (texts == null || texts.isEmpty()) {
return new ArrayList<>();
}
return texts.stream()
.map(this::textToVector)
.collect(Collectors.toList());
}
}