Files
History_Api/pkg/ai/rag.go
T

610 lines
16 KiB
Go

package ai
import (
"bytes"
"context"
"fmt"
"history-api/pkg/config"
json "history-api/pkg/jsonx"
"html"
"io"
"net/http"
"regexp"
"strings"
"time"
"github.com/rs/zerolog/log"
"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/textsplitter"
)
type RagUtils struct {
llm llms.Model
embedder *embeddings.EmbedderImpl
httpClient *http.Client
openRouterAPIKey string
model string
fallbackModel string
generationMaxRetries int
generationRetryDelay time.Duration
rerankEnabled bool
rerankModel string
rerankFallbackModel string
rerankMaxRetries int
rerankRetryDelay time.Duration
queryRewriteEnabled bool
queryRewriteModel string
queryRewriteTimeout time.Duration
queryRewriteMaxTokens int
}
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
const openRouterBaseURL = "https://openrouter.ai/api/v1"
type RerankResult struct {
Index int
Score float64
}
type ChatTurn struct {
Question string
Answer string
}
func NewRagUtils() (*RagUtils, error) {
openRouterAPIKey := config.GetConfigWithDefault("OPEN_ROUTER_API", "")
if openRouterAPIKey == "" {
return nil, fmt.Errorf("OPEN_ROUTER_API is not set")
}
model := config.GetConfigWithDefault("OPEN_ROUTER_MODEL", "qwen/qwen3.5-flash-02-23")
embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b")
llmTimeoutSeconds := config.GetIntConfigWithDefault("RAG_LLM_TIMEOUT_SECONDS", 20)
if llmTimeoutSeconds <= 0 {
llmTimeoutSeconds = 20
}
llm, err := openai.New(
openai.WithToken(openRouterAPIKey),
openai.WithBaseURL(openRouterBaseURL),
openai.WithModel(model),
openai.WithEmbeddingModel(embeddingModel),
openai.WithHTTPClient(&http.Client{Timeout: time.Duration(llmTimeoutSeconds) * time.Second}),
)
if err != nil {
return nil, fmt.Errorf("failed to init openrouter ai: %w", err)
}
embedder, err := embeddings.NewEmbedder(llm)
if err != nil {
return nil, fmt.Errorf("failed to init embedder: %w", err)
}
timeoutSeconds := config.GetIntConfigWithDefault("RAG_RERANK_TIMEOUT_SECONDS", 10)
if timeoutSeconds <= 0 {
timeoutSeconds = 10
}
rerankMaxRetries := config.GetIntConfigWithDefault("RAG_RERANK_MAX_RETRIES", 2)
if rerankMaxRetries < 0 {
rerankMaxRetries = 0
}
if rerankMaxRetries > 5 {
rerankMaxRetries = 5
}
rerankRetryDelayMs := config.GetIntConfigWithDefault("RAG_RERANK_RETRY_DELAY_MS", 250)
if rerankRetryDelayMs <= 0 {
rerankRetryDelayMs = 250
}
generationMaxRetries := config.GetIntConfigWithDefault("RAG_GENERATION_MAX_RETRIES", 2)
if generationMaxRetries < 0 {
generationMaxRetries = 0
}
if generationMaxRetries > 5 {
generationMaxRetries = 5
}
generationRetryDelayMs := config.GetIntConfigWithDefault("RAG_GENERATION_RETRY_DELAY_MS", 500)
if generationRetryDelayMs <= 0 {
generationRetryDelayMs = 500
}
queryRewriteModel := strings.TrimSpace(config.GetConfigWithDefault("RAG_QUERY_REWRITE_MODEL", model))
if queryRewriteModel == "" {
queryRewriteModel = model
}
queryRewriteTimeoutSeconds := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_TIMEOUT_SECONDS", 5)
if queryRewriteTimeoutSeconds < 0 {
queryRewriteTimeoutSeconds = 0
}
queryRewriteMaxTokens := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_MAX_TOKENS", 96)
if queryRewriteMaxTokens <= 0 {
queryRewriteMaxTokens = 96
}
if queryRewriteMaxTokens > 512 {
queryRewriteMaxTokens = 512
}
return &RagUtils{
llm: llm,
embedder: embedder,
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
openRouterAPIKey: openRouterAPIKey,
model: model,
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
generationMaxRetries: generationMaxRetries,
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
rerankMaxRetries: rerankMaxRetries,
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
queryRewriteModel: queryRewriteModel,
queryRewriteTimeout: time.Duration(queryRewriteTimeoutSeconds) * time.Second,
queryRewriteMaxTokens: queryRewriteMaxTokens,
}, nil
}
func (u *RagUtils) StripHTML(text string) string {
text = htmlTagRegex.ReplaceAllString(text, " ")
return html.UnescapeString(text)
}
func (u *RagUtils) PrepareChunks(ctx context.Context, text string) ([]string, [][]float32, error) {
splitter := textsplitter.NewRecursiveCharacter(
textsplitter.WithChunkSize(1000),
textsplitter.WithChunkOverlap(200),
)
chunks, err := splitter.SplitText(text)
if err != nil || len(chunks) == 0 {
return nil, nil, err
}
vectors, err := u.embedder.EmbedDocuments(ctx, chunks)
if err != nil {
return nil, nil, err
}
// Truncate to 1536 dimensions for pgvector compatibility (HNSW index limit is 2000)
for i := range vectors {
if len(vectors[i]) > 1536 {
vectors[i] = vectors[i][:1536]
}
}
return chunks, vectors, nil
}
func (u *RagUtils) EmbedQuery(ctx context.Context, query string) ([]float32, error) {
vectors, err := u.embedder.EmbedDocuments(ctx, []string{query})
if err != nil || len(vectors) == 0 {
return nil, err
}
vector := vectors[0]
if len(vector) > 1536 {
vector = vector[:1536]
}
return vector, nil
}
func (u *RagUtils) RewriteSearchQuery(ctx context.Context, question string, history []ChatTurn) (string, error) {
question = normalizeWhitespace(question)
if question == "" || !u.queryRewriteEnabled {
return question, nil
}
historyText := buildHistoryForRewrite(history)
if historyText == "" {
historyText = "No previous chat turns."
}
prompt := fmt.Sprintf(`You rewrite user questions into precise retrieval queries for a history RAG system.
Previous chat turns:
%s
Current user question:
%s
Rules:
- Output exactly one retrieval query inside <query> tags.
- Do not answer the question.
- Do not add names, dates, places, causes, results, or facts unless they are explicitly present in the current question or previous chat turns.
- Use previous chat turns only to resolve clear references such as "đó", "ông ấy", "bà ấy", "sự kiện này", "that event", or "he/she".
- Preserve proper nouns, dates, quoted text, and historical terms exactly when possible.
- If the current question is already clear, return it unchanged.
- If a reference is ambiguous, keep the ambiguous wording instead of guessing.
- Use the same language as the current question.
- No markdown, no bullet points, no explanation.
<query>`, historyText, question)
rewriteCtx := ctx
cancel := func() {}
if u.queryRewriteTimeout > 0 {
rewriteCtx, cancel = context.WithTimeout(ctx, u.queryRewriteTimeout)
}
defer cancel()
options := []llms.CallOption{
llms.WithModel(u.queryRewriteModel),
llms.WithMaxTokens(u.queryRewriteMaxTokens),
llms.WithTemperature(0),
}
rewriteStart := time.Now()
raw, err := llms.GenerateFromSinglePrompt(rewriteCtx, u.llm, prompt, options...)
if err != nil {
return "", err
}
log.Info().
Str("model", u.queryRewriteModel).
Int("prompt_chars", len(prompt)).
Int("response_chars", len(raw)).
Dur("duration", time.Since(rewriteStart)).
Msg("rag query rewrite succeeded")
query := normalizeGeneratedQuery(raw)
if query == "" {
return question, nil
}
return query, nil
}
func (u *RagUtils) RerankDocuments(ctx context.Context, query string, documents []string, topN int) ([]RerankResult, error) {
if !u.rerankEnabled || len(documents) == 0 {
return nil, nil
}
if topN <= 0 || topN > len(documents) {
topN = len(documents)
}
model := strings.TrimSpace(u.rerankModel)
if model == "" {
model = "cohere/rerank-4-pro"
}
results, err := u.rerankDocumentsWithModel(ctx, model, query, documents, topN)
if err == nil {
return results, nil
}
fallbackModel := strings.TrimSpace(u.rerankFallbackModel)
if fallbackModel == "" || fallbackModel == model {
return nil, err
}
fallbackResults, fallbackErr := u.rerankDocumentsWithModel(ctx, fallbackModel, query, documents, topN)
if fallbackErr == nil {
return fallbackResults, nil
}
return nil, fmt.Errorf("rerank failed with model %s: %w; fallback model %s failed: %v", model, err, fallbackModel, fallbackErr)
}
func (u *RagUtils) rerankDocumentsWithModel(ctx context.Context, model, query string, documents []string, topN int) ([]RerankResult, error) {
reqBody := rerankRequest{
Model: model,
Query: query,
Documents: documents,
TopN: topN,
}
payload, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
var lastErr error
for attempt := 0; attempt <= u.rerankMaxRetries; attempt++ {
attemptStart := time.Now()
results, retryable, err := u.sendRerankRequest(ctx, payload, documents)
if err == nil {
log.Info().
Str("model", model).
Int("attempt", attempt+1).
Int("documents", len(documents)).
Int("top_n", topN).
Int("results", len(results)).
Dur("duration", time.Since(attemptStart)).
Msg("rag rerank attempt succeeded")
return results, nil
}
log.Warn().
Err(err).
Str("model", model).
Int("attempt", attempt+1).
Int("documents", len(documents)).
Int("top_n", topN).
Bool("retryable", retryable).
Dur("duration", time.Since(attemptStart)).
Msg("rag rerank attempt failed")
lastErr = err
if !retryable || attempt == u.rerankMaxRetries {
break
}
delay := u.rerankRetryDelay * time.Duration(1<<attempt)
if delay > 2*time.Second {
delay = 2 * time.Second
}
if sleepErr := sleepWithContext(ctx, delay); sleepErr != nil {
return nil, sleepErr
}
}
return nil, fmt.Errorf("rerank failed with model %s: %w", model, lastErr)
}
func (u *RagUtils) sendRerankRequest(ctx context.Context, payload []byte, documents []string) ([]RerankResult, bool, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, openRouterBaseURL+"/rerank", bytes.NewReader(payload))
if err != nil {
return nil, false, err
}
req.Header.Set("Authorization", "Bearer "+u.openRouterAPIKey)
req.Header.Set("Content-Type", "application/json")
res, err := u.httpClient.Do(req)
if err != nil {
if ctx.Err() != nil {
return nil, false, ctx.Err()
}
return nil, true, err
}
defer res.Body.Close()
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices {
body, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
return nil, isRetryableRerankStatus(res.StatusCode), fmt.Errorf("rerank request failed with status %d: %s", res.StatusCode, strings.TrimSpace(string(body)))
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, true, err
}
var rerankRes rerankResponse
if err := json.Unmarshal(body, &rerankRes); err != nil {
return nil, false, err
}
results := make([]RerankResult, 0, len(rerankRes.Results))
for _, item := range rerankRes.Results {
if item.Index < 0 || item.Index >= len(documents) {
continue
}
results = append(results, RerankResult{
Index: item.Index,
Score: item.RelevanceScore,
})
}
return results, false, nil
}
func (u *RagUtils) GenerateResponse(ctx context.Context, prompt string) (string, error) {
raw, err := u.generateSinglePromptWithRetry(ctx, prompt, u.model)
if err != nil && u.fallbackModel != "" && u.fallbackModel != u.model {
raw, err = u.generateSinglePromptWithRetry(ctx, prompt, u.fallbackModel)
}
if err != nil {
return "", err
}
return stripThinking(raw), nil
}
func (u *RagUtils) generateSinglePromptWithRetry(ctx context.Context, prompt, model string) (string, error) {
var lastErr error
options := make([]llms.CallOption, 0, 1)
if strings.TrimSpace(model) != "" {
options = append(options, llms.WithModel(model))
}
for attempt := 0; attempt <= u.generationMaxRetries; attempt++ {
attemptStart := time.Now()
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt, options...)
if err == nil {
log.Info().
Str("model", model).
Int("attempt", attempt+1).
Int("prompt_chars", len(prompt)).
Int("response_chars", len(raw)).
Dur("duration", time.Since(attemptStart)).
Msg("rag generation attempt succeeded")
return raw, nil
}
if ctx.Err() != nil {
log.Warn().
Err(ctx.Err()).
Str("model", model).
Int("attempt", attempt+1).
Int("prompt_chars", len(prompt)).
Dur("duration", time.Since(attemptStart)).
Msg("rag generation attempt canceled")
return "", ctx.Err()
}
log.Warn().
Err(err).
Str("model", model).
Int("attempt", attempt+1).
Int("prompt_chars", len(prompt)).
Dur("duration", time.Since(attemptStart)).
Msg("rag generation attempt failed")
lastErr = err
if attempt == u.generationMaxRetries {
break
}
delay := u.generationRetryDelay * time.Duration(1<<attempt)
if delay > 3*time.Second {
delay = 3 * time.Second
}
if sleepErr := sleepWithContext(ctx, delay); sleepErr != nil {
return "", sleepErr
}
}
if strings.TrimSpace(model) == "" {
return "", lastErr
}
return "", fmt.Errorf("generation failed with model %s: %w", model, lastErr)
}
type rerankRequest struct {
Model string `json:"model"`
Query string `json:"query"`
Documents []string `json:"documents"`
TopN int `json:"top_n"`
}
type rerankResponse struct {
Results []struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
} `json:"results"`
}
func buildHistoryForRewrite(history []ChatTurn) string {
if len(history) == 0 {
return ""
}
var builder strings.Builder
for i, turn := range history {
question := normalizeWhitespace(turn.Question)
answer := normalizeWhitespace(turn.Answer)
if question == "" && answer == "" {
continue
}
builder.WriteString(fmt.Sprintf("Turn %d\nUser: %s\nAssistant: %s\n", i+1, question, answer))
}
return strings.TrimSpace(builder.String())
}
func normalizeGeneratedQuery(raw string) string {
raw = strings.TrimSpace(raw)
if query := extractTag(raw, "query"); query != "" {
return normalizeWhitespace(query)
}
raw = strings.TrimSpace(strings.TrimPrefix(raw, "Query:"))
raw = strings.TrimSpace(strings.TrimPrefix(raw, "Retrieval query:"))
raw = strings.Trim(raw, "\"'`")
return normalizeWhitespace(raw)
}
func extractTag(raw, tag string) string {
startTag := "<" + tag + ">"
endTag := "</" + tag + ">"
start := strings.LastIndex(raw, startTag)
if start == -1 {
return ""
}
content := raw[start+len(startTag):]
end := strings.Index(content, endTag)
if end == -1 {
return strings.TrimSpace(content)
}
return strings.TrimSpace(content[:end])
}
func normalizeWhitespace(s string) string {
return strings.Join(strings.Fields(strings.TrimSpace(s)), " ")
}
func isRetryableRerankStatus(statusCode int) bool {
switch statusCode {
case http.StatusRequestTimeout,
http.StatusTooManyRequests,
http.StatusInternalServerError,
http.StatusBadGateway,
http.StatusServiceUnavailable,
http.StatusGatewayTimeout:
return true
default:
return false
}
}
func sleepWithContext(ctx context.Context, delay time.Duration) error {
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
func stripThinking(raw string) string {
startTag := "<answer>"
endTag := "</answer>"
lastStart := strings.LastIndex(raw, startTag)
if lastStart != -1 {
content := raw[lastStart+len(startTag):]
if endIdx := strings.Index(content, endTag); endIdx != -1 {
return strings.TrimSpace(content[:endIdx])
}
return strings.TrimSpace(content)
}
if !strings.Contains(raw, "* ") {
return strings.TrimSpace(raw)
}
lines := strings.Split(raw, "\n")
answerStart := len(lines)
for i := len(lines) - 1; i >= 0; i-- {
trimmed := strings.TrimSpace(lines[i])
if trimmed == "" || strings.HasPrefix(trimmed, "*") || strings.HasPrefix(trimmed, "- ") {
break
}
answerStart = i
}
if answerStart < len(lines) {
answer := strings.TrimSpace(strings.Join(lines[answerStart:], "\n"))
if answer != "" {
return answer
}
}
lastLine := lines[len(lines)-1]
if idx := strings.LastIndex(lastLine, `"`); idx >= 0 && idx < len(lastLine)-1 {
answer := strings.TrimSpace(lastLine[idx+1:])
if answer != "" {
return answer
}
}
return strings.TrimSpace(raw)
}