feat: enhance RAG functionality with context character limits, query rewriting options, and improved timeout settings
Build and Release / release (push) Successful in 1m31s
Build and Release / release (push) Successful in 1m31s
This commit is contained in:
+79
-29
@@ -21,20 +21,23 @@ import (
|
||||
)
|
||||
|
||||
type RagUtils struct {
|
||||
llm llms.Model
|
||||
embedder *embeddings.EmbedderImpl
|
||||
httpClient *http.Client
|
||||
openRouterAPIKey string
|
||||
model string
|
||||
fallbackModel string
|
||||
generationMaxRetries int
|
||||
generationRetryDelay time.Duration
|
||||
rerankEnabled bool
|
||||
rerankModel string
|
||||
rerankFallbackModel string
|
||||
rerankMaxRetries int
|
||||
rerankRetryDelay time.Duration
|
||||
queryRewriteEnabled bool
|
||||
llm llms.Model
|
||||
embedder *embeddings.EmbedderImpl
|
||||
httpClient *http.Client
|
||||
openRouterAPIKey string
|
||||
model string
|
||||
fallbackModel string
|
||||
generationMaxRetries int
|
||||
generationRetryDelay time.Duration
|
||||
rerankEnabled bool
|
||||
rerankModel string
|
||||
rerankFallbackModel string
|
||||
rerankMaxRetries int
|
||||
rerankRetryDelay time.Duration
|
||||
queryRewriteEnabled bool
|
||||
queryRewriteModel string
|
||||
queryRewriteTimeout time.Duration
|
||||
queryRewriteMaxTokens int
|
||||
}
|
||||
|
||||
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
|
||||
@@ -61,11 +64,17 @@ func NewRagUtils() (*RagUtils, error) {
|
||||
|
||||
embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b")
|
||||
|
||||
llmTimeoutSeconds := config.GetIntConfigWithDefault("RAG_LLM_TIMEOUT_SECONDS", 20)
|
||||
if llmTimeoutSeconds <= 0 {
|
||||
llmTimeoutSeconds = 20
|
||||
}
|
||||
|
||||
llm, err := openai.New(
|
||||
openai.WithToken(openRouterAPIKey),
|
||||
openai.WithBaseURL(openRouterBaseURL),
|
||||
openai.WithModel(model),
|
||||
openai.WithEmbeddingModel(embeddingModel),
|
||||
openai.WithHTTPClient(&http.Client{Timeout: time.Duration(llmTimeoutSeconds) * time.Second}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init openrouter ai: %w", err)
|
||||
@@ -107,21 +116,42 @@ func NewRagUtils() (*RagUtils, error) {
|
||||
generationRetryDelayMs = 500
|
||||
}
|
||||
|
||||
queryRewriteModel := strings.TrimSpace(config.GetConfigWithDefault("RAG_QUERY_REWRITE_MODEL", model))
|
||||
if queryRewriteModel == "" {
|
||||
queryRewriteModel = model
|
||||
}
|
||||
|
||||
queryRewriteTimeoutSeconds := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_TIMEOUT_SECONDS", 5)
|
||||
if queryRewriteTimeoutSeconds < 0 {
|
||||
queryRewriteTimeoutSeconds = 0
|
||||
}
|
||||
|
||||
queryRewriteMaxTokens := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_MAX_TOKENS", 96)
|
||||
if queryRewriteMaxTokens <= 0 {
|
||||
queryRewriteMaxTokens = 96
|
||||
}
|
||||
if queryRewriteMaxTokens > 512 {
|
||||
queryRewriteMaxTokens = 512
|
||||
}
|
||||
|
||||
return &RagUtils{
|
||||
llm: llm,
|
||||
embedder: embedder,
|
||||
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
|
||||
openRouterAPIKey: openRouterAPIKey,
|
||||
model: model,
|
||||
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
|
||||
generationMaxRetries: generationMaxRetries,
|
||||
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
|
||||
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
|
||||
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
|
||||
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
|
||||
rerankMaxRetries: rerankMaxRetries,
|
||||
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
|
||||
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
|
||||
llm: llm,
|
||||
embedder: embedder,
|
||||
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
|
||||
openRouterAPIKey: openRouterAPIKey,
|
||||
model: model,
|
||||
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
|
||||
generationMaxRetries: generationMaxRetries,
|
||||
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
|
||||
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
|
||||
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
|
||||
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
|
||||
rerankMaxRetries: rerankMaxRetries,
|
||||
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
|
||||
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
|
||||
queryRewriteModel: queryRewriteModel,
|
||||
queryRewriteTimeout: time.Duration(queryRewriteTimeoutSeconds) * time.Second,
|
||||
queryRewriteMaxTokens: queryRewriteMaxTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -202,10 +232,30 @@ Rules:
|
||||
|
||||
<query>`, historyText, question)
|
||||
|
||||
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt)
|
||||
rewriteCtx := ctx
|
||||
cancel := func() {}
|
||||
if u.queryRewriteTimeout > 0 {
|
||||
rewriteCtx, cancel = context.WithTimeout(ctx, u.queryRewriteTimeout)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
options := []llms.CallOption{
|
||||
llms.WithModel(u.queryRewriteModel),
|
||||
llms.WithMaxTokens(u.queryRewriteMaxTokens),
|
||||
llms.WithTemperature(0),
|
||||
}
|
||||
|
||||
rewriteStart := time.Now()
|
||||
raw, err := llms.GenerateFromSinglePrompt(rewriteCtx, u.llm, prompt, options...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
log.Info().
|
||||
Str("model", u.queryRewriteModel).
|
||||
Int("prompt_chars", len(prompt)).
|
||||
Int("response_chars", len(raw)).
|
||||
Dur("duration", time.Since(rewriteStart)).
|
||||
Msg("rag query rewrite succeeded")
|
||||
|
||||
query := normalizeGeneratedQuery(raw)
|
||||
if query == "" {
|
||||
|
||||
Reference in New Issue
Block a user