Files
History_Api/pkg/ai/rag.go
T
2026-06-05 14:18:55 +07:00

157 lines
3.6 KiB
Go

package ai
import (
"context"
"fmt"
"history-api/pkg/config"
"html"
"regexp"
"strings"
"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/textsplitter"
)
type RagUtils struct {
llm llms.Model
embedder *embeddings.EmbedderImpl
}
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
func NewRagUtils() (*RagUtils, error) {
openRouterAPIKey, err := config.GetConfig("OPEN_ROUTER_API")
if err != nil {
return nil, err
}
model, err := config.GetConfig("OPEN_ROUTER_MODEL")
if err != nil {
model = "qwen/qwen3.5-flash-02-23"
}
embeddingModel, err := config.GetConfig("OPEN_ROUTER_EMBEDDING_MODEL")
if err != nil {
embeddingModel = "qwen/qwen3-embedding-8b"
}
llm, err := openai.New(
openai.WithToken(openRouterAPIKey),
openai.WithBaseURL("https://openrouter.ai/api/v1"),
openai.WithModel(model),
openai.WithEmbeddingModel(embeddingModel),
)
if err != nil {
return nil, fmt.Errorf("failed to init openrouter ai: %w", err)
}
embedder, err := embeddings.NewEmbedder(llm)
if err != nil {
return nil, fmt.Errorf("failed to init embedder: %w", err)
}
return &RagUtils{
llm: llm,
embedder: embedder,
}, nil
}
func (u *RagUtils) StripHTML(text string) string {
text = htmlTagRegex.ReplaceAllString(text, " ")
return html.UnescapeString(text)
}
func (u *RagUtils) PrepareChunks(ctx context.Context, text string) ([]string, [][]float32, error) {
splitter := textsplitter.NewRecursiveCharacter(
textsplitter.WithChunkSize(1000),
textsplitter.WithChunkOverlap(200),
)
chunks, err := splitter.SplitText(text)
if err != nil || len(chunks) == 0 {
return nil, nil, err
}
vectors, err := u.embedder.EmbedDocuments(ctx, chunks)
if err != nil {
return nil, nil, err
}
// Truncate to 1536 dimensions for pgvector compatibility (HNSW index limit is 2000)
for i := range vectors {
if len(vectors[i]) > 1536 {
vectors[i] = vectors[i][:1536]
}
}
return chunks, vectors, nil
}
func (u *RagUtils) EmbedQuery(ctx context.Context, query string) ([]float32, error) {
vectors, err := u.embedder.EmbedDocuments(ctx, []string{query})
if err != nil || len(vectors) == 0 {
return nil, err
}
vector := vectors[0]
if len(vector) > 1536 {
vector = vector[:1536]
}
return vector, nil
}
func (u *RagUtils) GenerateResponse(ctx context.Context, prompt string) (string, error) {
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt)
if err != nil {
return "", err
}
return stripThinking(raw), nil
}
func stripThinking(raw string) string {
startTag := "<answer>"
endTag := "</answer>"
lastStart := strings.LastIndex(raw, startTag)
if lastStart != -1 {
content := raw[lastStart+len(startTag):]
if endIdx := strings.Index(content, endTag); endIdx != -1 {
return strings.TrimSpace(content[:endIdx])
}
return strings.TrimSpace(content)
}
if !strings.Contains(raw, "* ") {
return strings.TrimSpace(raw)
}
lines := strings.Split(raw, "\n")
answerStart := len(lines)
for i := len(lines) - 1; i >= 0; i-- {
trimmed := strings.TrimSpace(lines[i])
if trimmed == "" || strings.HasPrefix(trimmed, "*") || strings.HasPrefix(trimmed, "- ") {
break
}
answerStart = i
}
if answerStart < len(lines) {
answer := strings.TrimSpace(strings.Join(lines[answerStart:], "\n"))
if answer != "" {
return answer
}
}
lastLine := lines[len(lines)-1]
if idx := strings.LastIndex(lastLine, `"`); idx >= 0 && idx < len(lastLine)-1 {
answer := strings.TrimSpace(lastLine[idx+1:])
if answer != "" {
return answer
}
}
return strings.TrimSpace(raw)
}