feat: enhance RAG functionality with context character limits, query rewriting options, and improved timeout settings
Build and Release / release (push) Successful in 1m31s

This commit is contained in:
2026-06-08 13:46:32 +07:00
parent 872692d8d2
commit a77b856973
3 changed files with 117 additions and 33 deletions
+9 -4
View File
@@ -133,15 +133,20 @@ history-api/
GOOGLE_AI_EMBEDDING_MODEL=
OPEN_ROUTER_API=
OPEN_ROUTER_MODEL=
OPEN_ROUTER_FALLBACK_MODEL=qwen/qwen3-30b-a3b-instruct-2507
OPEN_ROUTER_MODEL=qwen/qwen3-30b-a3b-instruct-2507
OPEN_ROUTER_FALLBACK_MODEL=google/gemini-2.5-flash-lite
OPEN_ROUTER_EMBEDDING_MODEL=
RAG_LLM_TIMEOUT_SECONDS=20
RAG_QUERY_REWRITE_ENABLED=true
RAG_QUERY_REWRITE_MODEL=google/gemini-2.5-flash-lite
RAG_QUERY_REWRITE_TIMEOUT_SECONDS=5
RAG_QUERY_REWRITE_MAX_TOKENS=96
RAG_REWRITE_HISTORY_TURNS=3
RAG_RETRIEVAL_CANDIDATES=30
RAG_CONTEXT_TOP_N=8
RAG_GENERATION_MAX_RETRIES=2
RAG_CONTEXT_TOP_N=5
RAG_CONTEXT_MAX_CHARS=8000
RAG_GENERATION_MAX_RETRIES=1
RAG_GENERATION_RETRY_DELAY_MS=500
RAG_RERANK_ENABLED=true
RAG_RERANK_MODEL=cohere/rerank-4-pro
+29
View File
@@ -176,6 +176,7 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str
rerankStart := time.Now()
results = s.rerankResults(ctx, searchQuery, results, contextLimit)
rerankDuration := time.Since(rerankStart)
results = limitRagResultsByContextChars(results, config.GetIntConfigWithDefault("RAG_CONTEXT_MAX_CHARS", 8000))
promptBuildStart := time.Now()
var contextBuilder strings.Builder
@@ -390,6 +391,34 @@ func limitRagResults(results []*models.RagChunk, limit int) []*models.RagChunk {
return results[:limit]
}
func limitRagResultsByContextChars(results []*models.RagChunk, maxChars int) []*models.RagChunk {
if maxChars <= 0 || len(results) == 0 {
return results
}
selected := make([]*models.RagChunk, 0, len(results))
used := 0
for _, result := range results {
if result == nil {
continue
}
nextLen := len(result.Content) + 64
if len(selected) > 0 && used+nextLen > maxChars {
break
}
selected = append(selected, result)
used += nextLen
}
if len(selected) == 0 {
return results[:1]
}
return selected
}
func normalizeAnswer(s string) string {
s = strings.TrimSpace(s)
+79 -29
View File
@@ -21,20 +21,23 @@ import (
)
type RagUtils struct {
llm llms.Model
embedder *embeddings.EmbedderImpl
httpClient *http.Client
openRouterAPIKey string
model string
fallbackModel string
generationMaxRetries int
generationRetryDelay time.Duration
rerankEnabled bool
rerankModel string
rerankFallbackModel string
rerankMaxRetries int
rerankRetryDelay time.Duration
queryRewriteEnabled bool
llm llms.Model
embedder *embeddings.EmbedderImpl
httpClient *http.Client
openRouterAPIKey string
model string
fallbackModel string
generationMaxRetries int
generationRetryDelay time.Duration
rerankEnabled bool
rerankModel string
rerankFallbackModel string
rerankMaxRetries int
rerankRetryDelay time.Duration
queryRewriteEnabled bool
queryRewriteModel string
queryRewriteTimeout time.Duration
queryRewriteMaxTokens int
}
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
@@ -61,11 +64,17 @@ func NewRagUtils() (*RagUtils, error) {
embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b")
llmTimeoutSeconds := config.GetIntConfigWithDefault("RAG_LLM_TIMEOUT_SECONDS", 20)
if llmTimeoutSeconds <= 0 {
llmTimeoutSeconds = 20
}
llm, err := openai.New(
openai.WithToken(openRouterAPIKey),
openai.WithBaseURL(openRouterBaseURL),
openai.WithModel(model),
openai.WithEmbeddingModel(embeddingModel),
openai.WithHTTPClient(&http.Client{Timeout: time.Duration(llmTimeoutSeconds) * time.Second}),
)
if err != nil {
return nil, fmt.Errorf("failed to init openrouter ai: %w", err)
@@ -107,21 +116,42 @@ func NewRagUtils() (*RagUtils, error) {
generationRetryDelayMs = 500
}
queryRewriteModel := strings.TrimSpace(config.GetConfigWithDefault("RAG_QUERY_REWRITE_MODEL", model))
if queryRewriteModel == "" {
queryRewriteModel = model
}
queryRewriteTimeoutSeconds := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_TIMEOUT_SECONDS", 5)
if queryRewriteTimeoutSeconds < 0 {
queryRewriteTimeoutSeconds = 0
}
queryRewriteMaxTokens := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_MAX_TOKENS", 96)
if queryRewriteMaxTokens <= 0 {
queryRewriteMaxTokens = 96
}
if queryRewriteMaxTokens > 512 {
queryRewriteMaxTokens = 512
}
return &RagUtils{
llm: llm,
embedder: embedder,
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
openRouterAPIKey: openRouterAPIKey,
model: model,
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
generationMaxRetries: generationMaxRetries,
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
rerankMaxRetries: rerankMaxRetries,
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
llm: llm,
embedder: embedder,
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
openRouterAPIKey: openRouterAPIKey,
model: model,
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
generationMaxRetries: generationMaxRetries,
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
rerankMaxRetries: rerankMaxRetries,
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
queryRewriteModel: queryRewriteModel,
queryRewriteTimeout: time.Duration(queryRewriteTimeoutSeconds) * time.Second,
queryRewriteMaxTokens: queryRewriteMaxTokens,
}, nil
}
@@ -202,10 +232,30 @@ Rules:
<query>`, historyText, question)
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt)
rewriteCtx := ctx
cancel := func() {}
if u.queryRewriteTimeout > 0 {
rewriteCtx, cancel = context.WithTimeout(ctx, u.queryRewriteTimeout)
}
defer cancel()
options := []llms.CallOption{
llms.WithModel(u.queryRewriteModel),
llms.WithMaxTokens(u.queryRewriteMaxTokens),
llms.WithTemperature(0),
}
rewriteStart := time.Now()
raw, err := llms.GenerateFromSinglePrompt(rewriteCtx, u.llm, prompt, options...)
if err != nil {
return "", err
}
log.Info().
Str("model", u.queryRewriteModel).
Int("prompt_chars", len(prompt)).
Int("response_chars", len(raw)).
Dur("duration", time.Since(rewriteStart)).
Msg("rag query rewrite succeeded")
query := normalizeGeneratedQuery(raw)
if query == "" {