feat: enhance RAG functionality with context character limits, query rewriting options, and improved timeout settings
Build and Release / release (push) Successful in 1m31s

This commit is contained in:
2026-06-08 13:46:32 +07:00
parent 872692d8d2
commit a77b856973
3 changed files with 117 additions and 33 deletions
+9 -4
View File
@@ -133,15 +133,20 @@ history-api/
GOOGLE_AI_EMBEDDING_MODEL= GOOGLE_AI_EMBEDDING_MODEL=
OPEN_ROUTER_API= OPEN_ROUTER_API=
OPEN_ROUTER_MODEL= OPEN_ROUTER_MODEL=qwen/qwen3-30b-a3b-instruct-2507
OPEN_ROUTER_FALLBACK_MODEL=qwen/qwen3-30b-a3b-instruct-2507 OPEN_ROUTER_FALLBACK_MODEL=google/gemini-2.5-flash-lite
OPEN_ROUTER_EMBEDDING_MODEL= OPEN_ROUTER_EMBEDDING_MODEL=
RAG_LLM_TIMEOUT_SECONDS=20
RAG_QUERY_REWRITE_ENABLED=true RAG_QUERY_REWRITE_ENABLED=true
RAG_QUERY_REWRITE_MODEL=google/gemini-2.5-flash-lite
RAG_QUERY_REWRITE_TIMEOUT_SECONDS=5
RAG_QUERY_REWRITE_MAX_TOKENS=96
RAG_REWRITE_HISTORY_TURNS=3 RAG_REWRITE_HISTORY_TURNS=3
RAG_RETRIEVAL_CANDIDATES=30 RAG_RETRIEVAL_CANDIDATES=30
RAG_CONTEXT_TOP_N=8 RAG_CONTEXT_TOP_N=5
RAG_GENERATION_MAX_RETRIES=2 RAG_CONTEXT_MAX_CHARS=8000
RAG_GENERATION_MAX_RETRIES=1
RAG_GENERATION_RETRY_DELAY_MS=500 RAG_GENERATION_RETRY_DELAY_MS=500
RAG_RERANK_ENABLED=true RAG_RERANK_ENABLED=true
RAG_RERANK_MODEL=cohere/rerank-4-pro RAG_RERANK_MODEL=cohere/rerank-4-pro
+29
View File
@@ -176,6 +176,7 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str
rerankStart := time.Now() rerankStart := time.Now()
results = s.rerankResults(ctx, searchQuery, results, contextLimit) results = s.rerankResults(ctx, searchQuery, results, contextLimit)
rerankDuration := time.Since(rerankStart) rerankDuration := time.Since(rerankStart)
results = limitRagResultsByContextChars(results, config.GetIntConfigWithDefault("RAG_CONTEXT_MAX_CHARS", 8000))
promptBuildStart := time.Now() promptBuildStart := time.Now()
var contextBuilder strings.Builder var contextBuilder strings.Builder
@@ -390,6 +391,34 @@ func limitRagResults(results []*models.RagChunk, limit int) []*models.RagChunk {
return results[:limit] return results[:limit]
} }
func limitRagResultsByContextChars(results []*models.RagChunk, maxChars int) []*models.RagChunk {
if maxChars <= 0 || len(results) == 0 {
return results
}
selected := make([]*models.RagChunk, 0, len(results))
used := 0
for _, result := range results {
if result == nil {
continue
}
nextLen := len(result.Content) + 64
if len(selected) > 0 && used+nextLen > maxChars {
break
}
selected = append(selected, result)
used += nextLen
}
if len(selected) == 0 {
return results[:1]
}
return selected
}
func normalizeAnswer(s string) string { func normalizeAnswer(s string) string {
s = strings.TrimSpace(s) s = strings.TrimSpace(s)
+79 -29
View File
@@ -21,20 +21,23 @@ import (
) )
type RagUtils struct { type RagUtils struct {
llm llms.Model llm llms.Model
embedder *embeddings.EmbedderImpl embedder *embeddings.EmbedderImpl
httpClient *http.Client httpClient *http.Client
openRouterAPIKey string openRouterAPIKey string
model string model string
fallbackModel string fallbackModel string
generationMaxRetries int generationMaxRetries int
generationRetryDelay time.Duration generationRetryDelay time.Duration
rerankEnabled bool rerankEnabled bool
rerankModel string rerankModel string
rerankFallbackModel string rerankFallbackModel string
rerankMaxRetries int rerankMaxRetries int
rerankRetryDelay time.Duration rerankRetryDelay time.Duration
queryRewriteEnabled bool queryRewriteEnabled bool
queryRewriteModel string
queryRewriteTimeout time.Duration
queryRewriteMaxTokens int
} }
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`) var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
@@ -61,11 +64,17 @@ func NewRagUtils() (*RagUtils, error) {
embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b") embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b")
llmTimeoutSeconds := config.GetIntConfigWithDefault("RAG_LLM_TIMEOUT_SECONDS", 20)
if llmTimeoutSeconds <= 0 {
llmTimeoutSeconds = 20
}
llm, err := openai.New( llm, err := openai.New(
openai.WithToken(openRouterAPIKey), openai.WithToken(openRouterAPIKey),
openai.WithBaseURL(openRouterBaseURL), openai.WithBaseURL(openRouterBaseURL),
openai.WithModel(model), openai.WithModel(model),
openai.WithEmbeddingModel(embeddingModel), openai.WithEmbeddingModel(embeddingModel),
openai.WithHTTPClient(&http.Client{Timeout: time.Duration(llmTimeoutSeconds) * time.Second}),
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to init openrouter ai: %w", err) return nil, fmt.Errorf("failed to init openrouter ai: %w", err)
@@ -107,21 +116,42 @@ func NewRagUtils() (*RagUtils, error) {
generationRetryDelayMs = 500 generationRetryDelayMs = 500
} }
queryRewriteModel := strings.TrimSpace(config.GetConfigWithDefault("RAG_QUERY_REWRITE_MODEL", model))
if queryRewriteModel == "" {
queryRewriteModel = model
}
queryRewriteTimeoutSeconds := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_TIMEOUT_SECONDS", 5)
if queryRewriteTimeoutSeconds < 0 {
queryRewriteTimeoutSeconds = 0
}
queryRewriteMaxTokens := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_MAX_TOKENS", 96)
if queryRewriteMaxTokens <= 0 {
queryRewriteMaxTokens = 96
}
if queryRewriteMaxTokens > 512 {
queryRewriteMaxTokens = 512
}
return &RagUtils{ return &RagUtils{
llm: llm, llm: llm,
embedder: embedder, embedder: embedder,
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second}, httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
openRouterAPIKey: openRouterAPIKey, openRouterAPIKey: openRouterAPIKey,
model: model, model: model,
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"), fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
generationMaxRetries: generationMaxRetries, generationMaxRetries: generationMaxRetries,
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond, generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true), rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"), rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"), rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
rerankMaxRetries: rerankMaxRetries, rerankMaxRetries: rerankMaxRetries,
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond, rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true), queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
queryRewriteModel: queryRewriteModel,
queryRewriteTimeout: time.Duration(queryRewriteTimeoutSeconds) * time.Second,
queryRewriteMaxTokens: queryRewriteMaxTokens,
}, nil }, nil
} }
@@ -202,10 +232,30 @@ Rules:
<query>`, historyText, question) <query>`, historyText, question)
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt) rewriteCtx := ctx
cancel := func() {}
if u.queryRewriteTimeout > 0 {
rewriteCtx, cancel = context.WithTimeout(ctx, u.queryRewriteTimeout)
}
defer cancel()
options := []llms.CallOption{
llms.WithModel(u.queryRewriteModel),
llms.WithMaxTokens(u.queryRewriteMaxTokens),
llms.WithTemperature(0),
}
rewriteStart := time.Now()
raw, err := llms.GenerateFromSinglePrompt(rewriteCtx, u.llm, prompt, options...)
if err != nil { if err != nil {
return "", err return "", err
} }
log.Info().
Str("model", u.queryRewriteModel).
Int("prompt_chars", len(prompt)).
Int("response_chars", len(raw)).
Dur("duration", time.Since(rewriteStart)).
Msg("rag query rewrite succeeded")
query := normalizeGeneratedQuery(raw) query := normalizeGeneratedQuery(raw)
if query == "" { if query == "" {