From a77b856973ac249fe577ce6911d5547780029f4a Mon Sep 17 00:00:00 2001 From: AzenKain Date: Mon, 8 Jun 2026 13:46:32 +0700 Subject: [PATCH] feat: enhance RAG functionality with context character limits, query rewriting options, and improved timeout settings --- README.md | 13 ++-- internal/services/chatbotService.go | 29 ++++++++ pkg/ai/rag.go | 108 ++++++++++++++++++++-------- 3 files changed, 117 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index dae9a60..2140801 100644 --- a/README.md +++ b/README.md @@ -133,15 +133,20 @@ history-api/ GOOGLE_AI_EMBEDDING_MODEL= OPEN_ROUTER_API= - OPEN_ROUTER_MODEL= - OPEN_ROUTER_FALLBACK_MODEL=qwen/qwen3-30b-a3b-instruct-2507 + OPEN_ROUTER_MODEL=qwen/qwen3-30b-a3b-instruct-2507 + OPEN_ROUTER_FALLBACK_MODEL=google/gemini-2.5-flash-lite OPEN_ROUTER_EMBEDDING_MODEL= + RAG_LLM_TIMEOUT_SECONDS=20 RAG_QUERY_REWRITE_ENABLED=true + RAG_QUERY_REWRITE_MODEL=google/gemini-2.5-flash-lite + RAG_QUERY_REWRITE_TIMEOUT_SECONDS=5 + RAG_QUERY_REWRITE_MAX_TOKENS=96 RAG_REWRITE_HISTORY_TURNS=3 RAG_RETRIEVAL_CANDIDATES=30 - RAG_CONTEXT_TOP_N=8 - RAG_GENERATION_MAX_RETRIES=2 + RAG_CONTEXT_TOP_N=5 + RAG_CONTEXT_MAX_CHARS=8000 + RAG_GENERATION_MAX_RETRIES=1 RAG_GENERATION_RETRY_DELAY_MS=500 RAG_RERANK_ENABLED=true RAG_RERANK_MODEL=cohere/rerank-4-pro diff --git a/internal/services/chatbotService.go b/internal/services/chatbotService.go index 0d36056..1e29d6e 100644 --- a/internal/services/chatbotService.go +++ b/internal/services/chatbotService.go @@ -176,6 +176,7 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str rerankStart := time.Now() results = s.rerankResults(ctx, searchQuery, results, contextLimit) rerankDuration := time.Since(rerankStart) + results = limitRagResultsByContextChars(results, config.GetIntConfigWithDefault("RAG_CONTEXT_MAX_CHARS", 8000)) promptBuildStart := time.Now() var contextBuilder strings.Builder @@ -390,6 +391,34 @@ func limitRagResults(results []*models.RagChunk, limit int) []*models.RagChunk { return results[:limit] } +func limitRagResultsByContextChars(results []*models.RagChunk, maxChars int) []*models.RagChunk { + if maxChars <= 0 || len(results) == 0 { + return results + } + + selected := make([]*models.RagChunk, 0, len(results)) + used := 0 + for _, result := range results { + if result == nil { + continue + } + + nextLen := len(result.Content) + 64 + if len(selected) > 0 && used+nextLen > maxChars { + break + } + + selected = append(selected, result) + used += nextLen + } + + if len(selected) == 0 { + return results[:1] + } + + return selected +} + func normalizeAnswer(s string) string { s = strings.TrimSpace(s) diff --git a/pkg/ai/rag.go b/pkg/ai/rag.go index a2592c2..976473a 100644 --- a/pkg/ai/rag.go +++ b/pkg/ai/rag.go @@ -21,20 +21,23 @@ import ( ) type RagUtils struct { - llm llms.Model - embedder *embeddings.EmbedderImpl - httpClient *http.Client - openRouterAPIKey string - model string - fallbackModel string - generationMaxRetries int - generationRetryDelay time.Duration - rerankEnabled bool - rerankModel string - rerankFallbackModel string - rerankMaxRetries int - rerankRetryDelay time.Duration - queryRewriteEnabled bool + llm llms.Model + embedder *embeddings.EmbedderImpl + httpClient *http.Client + openRouterAPIKey string + model string + fallbackModel string + generationMaxRetries int + generationRetryDelay time.Duration + rerankEnabled bool + rerankModel string + rerankFallbackModel string + rerankMaxRetries int + rerankRetryDelay time.Duration + queryRewriteEnabled bool + queryRewriteModel string + queryRewriteTimeout time.Duration + queryRewriteMaxTokens int } var htmlTagRegex = regexp.MustCompile(`<[^>]*>`) @@ -61,11 +64,17 @@ func NewRagUtils() (*RagUtils, error) { embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b") + llmTimeoutSeconds := config.GetIntConfigWithDefault("RAG_LLM_TIMEOUT_SECONDS", 20) + if llmTimeoutSeconds <= 0 { + llmTimeoutSeconds = 20 + } + llm, err := openai.New( openai.WithToken(openRouterAPIKey), openai.WithBaseURL(openRouterBaseURL), openai.WithModel(model), openai.WithEmbeddingModel(embeddingModel), + openai.WithHTTPClient(&http.Client{Timeout: time.Duration(llmTimeoutSeconds) * time.Second}), ) if err != nil { return nil, fmt.Errorf("failed to init openrouter ai: %w", err) @@ -107,21 +116,42 @@ func NewRagUtils() (*RagUtils, error) { generationRetryDelayMs = 500 } + queryRewriteModel := strings.TrimSpace(config.GetConfigWithDefault("RAG_QUERY_REWRITE_MODEL", model)) + if queryRewriteModel == "" { + queryRewriteModel = model + } + + queryRewriteTimeoutSeconds := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_TIMEOUT_SECONDS", 5) + if queryRewriteTimeoutSeconds < 0 { + queryRewriteTimeoutSeconds = 0 + } + + queryRewriteMaxTokens := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_MAX_TOKENS", 96) + if queryRewriteMaxTokens <= 0 { + queryRewriteMaxTokens = 96 + } + if queryRewriteMaxTokens > 512 { + queryRewriteMaxTokens = 512 + } + return &RagUtils{ - llm: llm, - embedder: embedder, - httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second}, - openRouterAPIKey: openRouterAPIKey, - model: model, - fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"), - generationMaxRetries: generationMaxRetries, - generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond, - rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true), - rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"), - rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"), - rerankMaxRetries: rerankMaxRetries, - rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond, - queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true), + llm: llm, + embedder: embedder, + httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second}, + openRouterAPIKey: openRouterAPIKey, + model: model, + fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"), + generationMaxRetries: generationMaxRetries, + generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond, + rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true), + rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"), + rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"), + rerankMaxRetries: rerankMaxRetries, + rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond, + queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true), + queryRewriteModel: queryRewriteModel, + queryRewriteTimeout: time.Duration(queryRewriteTimeoutSeconds) * time.Second, + queryRewriteMaxTokens: queryRewriteMaxTokens, }, nil } @@ -202,10 +232,30 @@ Rules: `, historyText, question) - raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt) + rewriteCtx := ctx + cancel := func() {} + if u.queryRewriteTimeout > 0 { + rewriteCtx, cancel = context.WithTimeout(ctx, u.queryRewriteTimeout) + } + defer cancel() + + options := []llms.CallOption{ + llms.WithModel(u.queryRewriteModel), + llms.WithMaxTokens(u.queryRewriteMaxTokens), + llms.WithTemperature(0), + } + + rewriteStart := time.Now() + raw, err := llms.GenerateFromSinglePrompt(rewriteCtx, u.llm, prompt, options...) if err != nil { return "", err } + log.Info(). + Str("model", u.queryRewriteModel). + Int("prompt_chars", len(prompt)). + Int("response_chars", len(raw)). + Dur("duration", time.Since(rewriteStart)). + Msg("rag query rewrite succeeded") query := normalizeGeneratedQuery(raw) if query == "" {