feat: enhance RAG functionality with query rewriting, result reranking, and additional metadata in responses
Build and Release / release (push) Successful in 1m28s
Build and Release / release (push) Successful in 1m28s
This commit is contained in:
+377
-17
@@ -1,12 +1,17 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"history-api/pkg/config"
|
||||
json "history-api/pkg/jsonx"
|
||||
"html"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tmc/langchaingo/embeddings"
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
@@ -15,31 +20,49 @@ import (
|
||||
)
|
||||
|
||||
type RagUtils struct {
|
||||
llm llms.Model
|
||||
embedder *embeddings.EmbedderImpl
|
||||
llm llms.Model
|
||||
embedder *embeddings.EmbedderImpl
|
||||
httpClient *http.Client
|
||||
openRouterAPIKey string
|
||||
model string
|
||||
fallbackModel string
|
||||
generationMaxRetries int
|
||||
generationRetryDelay time.Duration
|
||||
rerankEnabled bool
|
||||
rerankModel string
|
||||
rerankFallbackModel string
|
||||
rerankMaxRetries int
|
||||
rerankRetryDelay time.Duration
|
||||
queryRewriteEnabled bool
|
||||
}
|
||||
|
||||
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
|
||||
|
||||
const openRouterBaseURL = "https://openrouter.ai/api/v1"
|
||||
|
||||
type RerankResult struct {
|
||||
Index int
|
||||
Score float64
|
||||
}
|
||||
|
||||
type ChatTurn struct {
|
||||
Question string
|
||||
Answer string
|
||||
}
|
||||
|
||||
func NewRagUtils() (*RagUtils, error) {
|
||||
openRouterAPIKey, err := config.GetConfig("OPEN_ROUTER_API")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
openRouterAPIKey := config.GetConfigWithDefault("OPEN_ROUTER_API", "")
|
||||
if openRouterAPIKey == "" {
|
||||
return nil, fmt.Errorf("OPEN_ROUTER_API is not set")
|
||||
}
|
||||
|
||||
model, err := config.GetConfig("OPEN_ROUTER_MODEL")
|
||||
if err != nil {
|
||||
model = "qwen/qwen3.5-flash-02-23"
|
||||
}
|
||||
model := config.GetConfigWithDefault("OPEN_ROUTER_MODEL", "qwen/qwen3.5-flash-02-23")
|
||||
|
||||
embeddingModel, err := config.GetConfig("OPEN_ROUTER_EMBEDDING_MODEL")
|
||||
if err != nil {
|
||||
embeddingModel = "qwen/qwen3-embedding-8b"
|
||||
}
|
||||
embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b")
|
||||
|
||||
llm, err := openai.New(
|
||||
openai.WithToken(openRouterAPIKey),
|
||||
openai.WithBaseURL("https://openrouter.ai/api/v1"),
|
||||
openai.WithBaseURL(openRouterBaseURL),
|
||||
openai.WithModel(model),
|
||||
openai.WithEmbeddingModel(embeddingModel),
|
||||
)
|
||||
@@ -52,9 +75,52 @@ func NewRagUtils() (*RagUtils, error) {
|
||||
return nil, fmt.Errorf("failed to init embedder: %w", err)
|
||||
}
|
||||
|
||||
timeoutSeconds := config.GetIntConfigWithDefault("RAG_RERANK_TIMEOUT_SECONDS", 10)
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = 10
|
||||
}
|
||||
|
||||
rerankMaxRetries := config.GetIntConfigWithDefault("RAG_RERANK_MAX_RETRIES", 2)
|
||||
if rerankMaxRetries < 0 {
|
||||
rerankMaxRetries = 0
|
||||
}
|
||||
if rerankMaxRetries > 5 {
|
||||
rerankMaxRetries = 5
|
||||
}
|
||||
|
||||
rerankRetryDelayMs := config.GetIntConfigWithDefault("RAG_RERANK_RETRY_DELAY_MS", 250)
|
||||
if rerankRetryDelayMs <= 0 {
|
||||
rerankRetryDelayMs = 250
|
||||
}
|
||||
|
||||
generationMaxRetries := config.GetIntConfigWithDefault("RAG_GENERATION_MAX_RETRIES", 2)
|
||||
if generationMaxRetries < 0 {
|
||||
generationMaxRetries = 0
|
||||
}
|
||||
if generationMaxRetries > 5 {
|
||||
generationMaxRetries = 5
|
||||
}
|
||||
|
||||
generationRetryDelayMs := config.GetIntConfigWithDefault("RAG_GENERATION_RETRY_DELAY_MS", 500)
|
||||
if generationRetryDelayMs <= 0 {
|
||||
generationRetryDelayMs = 500
|
||||
}
|
||||
|
||||
return &RagUtils{
|
||||
llm: llm,
|
||||
embedder: embedder,
|
||||
llm: llm,
|
||||
embedder: embedder,
|
||||
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
|
||||
openRouterAPIKey: openRouterAPIKey,
|
||||
model: model,
|
||||
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
|
||||
generationMaxRetries: generationMaxRetries,
|
||||
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
|
||||
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
|
||||
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
|
||||
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
|
||||
rerankMaxRetries: rerankMaxRetries,
|
||||
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
|
||||
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -103,14 +169,308 @@ func (u *RagUtils) EmbedQuery(ctx context.Context, query string) ([]float32, err
|
||||
return vector, nil
|
||||
}
|
||||
|
||||
func (u *RagUtils) GenerateResponse(ctx context.Context, prompt string) (string, error) {
|
||||
func (u *RagUtils) RewriteSearchQuery(ctx context.Context, question string, history []ChatTurn) (string, error) {
|
||||
question = normalizeWhitespace(question)
|
||||
if question == "" || !u.queryRewriteEnabled {
|
||||
return question, nil
|
||||
}
|
||||
|
||||
historyText := buildHistoryForRewrite(history)
|
||||
if historyText == "" {
|
||||
historyText = "No previous chat turns."
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`You rewrite user questions into precise retrieval queries for a history RAG system.
|
||||
|
||||
Previous chat turns:
|
||||
%s
|
||||
|
||||
Current user question:
|
||||
%s
|
||||
|
||||
Rules:
|
||||
- Output exactly one retrieval query inside <query> tags.
|
||||
- Do not answer the question.
|
||||
- Do not add names, dates, places, causes, results, or facts unless they are explicitly present in the current question or previous chat turns.
|
||||
- Use previous chat turns only to resolve clear references such as "đó", "ông ấy", "bà ấy", "sự kiện này", "that event", or "he/she".
|
||||
- Preserve proper nouns, dates, quoted text, and historical terms exactly when possible.
|
||||
- If the current question is already clear, return it unchanged.
|
||||
- If a reference is ambiguous, keep the ambiguous wording instead of guessing.
|
||||
- Use the same language as the current question.
|
||||
- No markdown, no bullet points, no explanation.
|
||||
|
||||
<query>`, historyText, question)
|
||||
|
||||
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
query := normalizeGeneratedQuery(raw)
|
||||
if query == "" {
|
||||
return question, nil
|
||||
}
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func (u *RagUtils) RerankDocuments(ctx context.Context, query string, documents []string, topN int) ([]RerankResult, error) {
|
||||
if !u.rerankEnabled || len(documents) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if topN <= 0 || topN > len(documents) {
|
||||
topN = len(documents)
|
||||
}
|
||||
|
||||
model := strings.TrimSpace(u.rerankModel)
|
||||
if model == "" {
|
||||
model = "cohere/rerank-4-pro"
|
||||
}
|
||||
|
||||
results, err := u.rerankDocumentsWithModel(ctx, model, query, documents, topN)
|
||||
if err == nil {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
fallbackModel := strings.TrimSpace(u.rerankFallbackModel)
|
||||
if fallbackModel == "" || fallbackModel == model {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fallbackResults, fallbackErr := u.rerankDocumentsWithModel(ctx, fallbackModel, query, documents, topN)
|
||||
if fallbackErr == nil {
|
||||
return fallbackResults, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("rerank failed with model %s: %w; fallback model %s failed: %v", model, err, fallbackModel, fallbackErr)
|
||||
}
|
||||
|
||||
func (u *RagUtils) rerankDocumentsWithModel(ctx context.Context, model, query string, documents []string, topN int) ([]RerankResult, error) {
|
||||
reqBody := rerankRequest{
|
||||
Model: model,
|
||||
Query: query,
|
||||
Documents: documents,
|
||||
TopN: topN,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= u.rerankMaxRetries; attempt++ {
|
||||
results, retryable, err := u.sendRerankRequest(ctx, payload, documents)
|
||||
if err == nil {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
if !retryable || attempt == u.rerankMaxRetries {
|
||||
break
|
||||
}
|
||||
|
||||
delay := u.rerankRetryDelay * time.Duration(1<<attempt)
|
||||
if delay > 2*time.Second {
|
||||
delay = 2 * time.Second
|
||||
}
|
||||
if sleepErr := sleepWithContext(ctx, delay); sleepErr != nil {
|
||||
return nil, sleepErr
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("rerank failed with model %s: %w", model, lastErr)
|
||||
}
|
||||
|
||||
func (u *RagUtils) sendRerankRequest(ctx context.Context, payload []byte, documents []string) ([]RerankResult, bool, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, openRouterBaseURL+"/rerank", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+u.openRouterAPIKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
res, err := u.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return nil, false, ctx.Err()
|
||||
}
|
||||
return nil, true, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices {
|
||||
body, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
|
||||
return nil, isRetryableRerankStatus(res.StatusCode), fmt.Errorf("rerank request failed with status %d: %s", res.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
|
||||
var rerankRes rerankResponse
|
||||
if err := json.Unmarshal(body, &rerankRes); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
results := make([]RerankResult, 0, len(rerankRes.Results))
|
||||
for _, item := range rerankRes.Results {
|
||||
if item.Index < 0 || item.Index >= len(documents) {
|
||||
continue
|
||||
}
|
||||
results = append(results, RerankResult{
|
||||
Index: item.Index,
|
||||
Score: item.RelevanceScore,
|
||||
})
|
||||
}
|
||||
|
||||
return results, false, nil
|
||||
}
|
||||
|
||||
func (u *RagUtils) GenerateResponse(ctx context.Context, prompt string) (string, error) {
|
||||
raw, err := u.generateSinglePromptWithRetry(ctx, prompt, u.model)
|
||||
if err != nil && u.fallbackModel != "" && u.fallbackModel != u.model {
|
||||
raw, err = u.generateSinglePromptWithRetry(ctx, prompt, u.fallbackModel)
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return stripThinking(raw), nil
|
||||
}
|
||||
|
||||
func (u *RagUtils) generateSinglePromptWithRetry(ctx context.Context, prompt, model string) (string, error) {
|
||||
var lastErr error
|
||||
options := make([]llms.CallOption, 0, 1)
|
||||
if strings.TrimSpace(model) != "" {
|
||||
options = append(options, llms.WithModel(model))
|
||||
}
|
||||
|
||||
for attempt := 0; attempt <= u.generationMaxRetries; attempt++ {
|
||||
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt, options...)
|
||||
if err == nil {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return "", ctx.Err()
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
if attempt == u.generationMaxRetries {
|
||||
break
|
||||
}
|
||||
|
||||
delay := u.generationRetryDelay * time.Duration(1<<attempt)
|
||||
if delay > 3*time.Second {
|
||||
delay = 3 * time.Second
|
||||
}
|
||||
if sleepErr := sleepWithContext(ctx, delay); sleepErr != nil {
|
||||
return "", sleepErr
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("generation failed with model %s: %w", model, lastErr)
|
||||
}
|
||||
|
||||
type rerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN int `json:"top_n"`
|
||||
}
|
||||
|
||||
type rerankResponse struct {
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
func buildHistoryForRewrite(history []ChatTurn) string {
|
||||
if len(history) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
for i, turn := range history {
|
||||
question := normalizeWhitespace(turn.Question)
|
||||
answer := normalizeWhitespace(turn.Answer)
|
||||
if question == "" && answer == "" {
|
||||
continue
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("Turn %d\nUser: %s\nAssistant: %s\n", i+1, question, answer))
|
||||
}
|
||||
|
||||
return strings.TrimSpace(builder.String())
|
||||
}
|
||||
|
||||
func normalizeGeneratedQuery(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
|
||||
if query := extractTag(raw, "query"); query != "" {
|
||||
return normalizeWhitespace(query)
|
||||
}
|
||||
|
||||
raw = strings.TrimSpace(strings.TrimPrefix(raw, "Query:"))
|
||||
raw = strings.TrimSpace(strings.TrimPrefix(raw, "Retrieval query:"))
|
||||
raw = strings.Trim(raw, "\"'`")
|
||||
|
||||
return normalizeWhitespace(raw)
|
||||
}
|
||||
|
||||
func extractTag(raw, tag string) string {
|
||||
startTag := "<" + tag + ">"
|
||||
endTag := "</" + tag + ">"
|
||||
|
||||
start := strings.LastIndex(raw, startTag)
|
||||
if start == -1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
content := raw[start+len(startTag):]
|
||||
end := strings.Index(content, endTag)
|
||||
if end == -1 {
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(content[:end])
|
||||
}
|
||||
|
||||
func normalizeWhitespace(s string) string {
|
||||
return strings.Join(strings.Fields(strings.TrimSpace(s)), " ")
|
||||
}
|
||||
|
||||
func isRetryableRerankStatus(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case http.StatusRequestTimeout,
|
||||
http.StatusTooManyRequests,
|
||||
http.StatusInternalServerError,
|
||||
http.StatusBadGateway,
|
||||
http.StatusServiceUnavailable,
|
||||
http.StatusGatewayTimeout:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func sleepWithContext(ctx context.Context, delay time.Duration) error {
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func stripThinking(raw string) string {
|
||||
startTag := "<answer>"
|
||||
endTag := "</answer>"
|
||||
|
||||
Reference in New Issue
Block a user