feat: enhance RAG functionality with context character limits, query rewriting options, and improved timeout settings
Build and Release / release (push) Successful in 1m31s

This commit is contained in:
2026-06-08 13:46:32 +07:00
parent 872692d8d2
commit a77b856973
3 changed files with 117 additions and 33 deletions
+79 -29
View File
@@ -21,20 +21,23 @@ import (
)
type RagUtils struct {
llm llms.Model
embedder *embeddings.EmbedderImpl
httpClient *http.Client
openRouterAPIKey string
model string
fallbackModel string
generationMaxRetries int
generationRetryDelay time.Duration
rerankEnabled bool
rerankModel string
rerankFallbackModel string
rerankMaxRetries int
rerankRetryDelay time.Duration
queryRewriteEnabled bool
llm llms.Model
embedder *embeddings.EmbedderImpl
httpClient *http.Client
openRouterAPIKey string
model string
fallbackModel string
generationMaxRetries int
generationRetryDelay time.Duration
rerankEnabled bool
rerankModel string
rerankFallbackModel string
rerankMaxRetries int
rerankRetryDelay time.Duration
queryRewriteEnabled bool
queryRewriteModel string
queryRewriteTimeout time.Duration
queryRewriteMaxTokens int
}
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
@@ -61,11 +64,17 @@ func NewRagUtils() (*RagUtils, error) {
embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b")
llmTimeoutSeconds := config.GetIntConfigWithDefault("RAG_LLM_TIMEOUT_SECONDS", 20)
if llmTimeoutSeconds <= 0 {
llmTimeoutSeconds = 20
}
llm, err := openai.New(
openai.WithToken(openRouterAPIKey),
openai.WithBaseURL(openRouterBaseURL),
openai.WithModel(model),
openai.WithEmbeddingModel(embeddingModel),
openai.WithHTTPClient(&http.Client{Timeout: time.Duration(llmTimeoutSeconds) * time.Second}),
)
if err != nil {
return nil, fmt.Errorf("failed to init openrouter ai: %w", err)
@@ -107,21 +116,42 @@ func NewRagUtils() (*RagUtils, error) {
generationRetryDelayMs = 500
}
queryRewriteModel := strings.TrimSpace(config.GetConfigWithDefault("RAG_QUERY_REWRITE_MODEL", model))
if queryRewriteModel == "" {
queryRewriteModel = model
}
queryRewriteTimeoutSeconds := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_TIMEOUT_SECONDS", 5)
if queryRewriteTimeoutSeconds < 0 {
queryRewriteTimeoutSeconds = 0
}
queryRewriteMaxTokens := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_MAX_TOKENS", 96)
if queryRewriteMaxTokens <= 0 {
queryRewriteMaxTokens = 96
}
if queryRewriteMaxTokens > 512 {
queryRewriteMaxTokens = 512
}
return &RagUtils{
llm: llm,
embedder: embedder,
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
openRouterAPIKey: openRouterAPIKey,
model: model,
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
generationMaxRetries: generationMaxRetries,
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
rerankMaxRetries: rerankMaxRetries,
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
llm: llm,
embedder: embedder,
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
openRouterAPIKey: openRouterAPIKey,
model: model,
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
generationMaxRetries: generationMaxRetries,
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
rerankMaxRetries: rerankMaxRetries,
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
queryRewriteModel: queryRewriteModel,
queryRewriteTimeout: time.Duration(queryRewriteTimeoutSeconds) * time.Second,
queryRewriteMaxTokens: queryRewriteMaxTokens,
}, nil
}
@@ -202,10 +232,30 @@ Rules:
<query>`, historyText, question)
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt)
rewriteCtx := ctx
cancel := func() {}
if u.queryRewriteTimeout > 0 {
rewriteCtx, cancel = context.WithTimeout(ctx, u.queryRewriteTimeout)
}
defer cancel()
options := []llms.CallOption{
llms.WithModel(u.queryRewriteModel),
llms.WithMaxTokens(u.queryRewriteMaxTokens),
llms.WithTemperature(0),
}
rewriteStart := time.Now()
raw, err := llms.GenerateFromSinglePrompt(rewriteCtx, u.llm, prompt, options...)
if err != nil {
return "", err
}
log.Info().
Str("model", u.queryRewriteModel).
Int("prompt_chars", len(prompt)).
Int("response_chars", len(raw)).
Dur("duration", time.Since(rewriteStart)).
Msg("rag query rewrite succeeded")
query := normalizeGeneratedQuery(raw)
if query == "" {