feat: enhance RAG functionality with context character limits, query rewriting options, and improved timeout settings
Build and Release / release (push) Successful in 1m31s
Build and Release / release (push) Successful in 1m31s
This commit is contained in:
@@ -133,15 +133,20 @@ history-api/
|
||||
GOOGLE_AI_EMBEDDING_MODEL=
|
||||
|
||||
OPEN_ROUTER_API=
|
||||
OPEN_ROUTER_MODEL=
|
||||
OPEN_ROUTER_FALLBACK_MODEL=qwen/qwen3-30b-a3b-instruct-2507
|
||||
OPEN_ROUTER_MODEL=qwen/qwen3-30b-a3b-instruct-2507
|
||||
OPEN_ROUTER_FALLBACK_MODEL=google/gemini-2.5-flash-lite
|
||||
OPEN_ROUTER_EMBEDDING_MODEL=
|
||||
|
||||
RAG_LLM_TIMEOUT_SECONDS=20
|
||||
RAG_QUERY_REWRITE_ENABLED=true
|
||||
RAG_QUERY_REWRITE_MODEL=google/gemini-2.5-flash-lite
|
||||
RAG_QUERY_REWRITE_TIMEOUT_SECONDS=5
|
||||
RAG_QUERY_REWRITE_MAX_TOKENS=96
|
||||
RAG_REWRITE_HISTORY_TURNS=3
|
||||
RAG_RETRIEVAL_CANDIDATES=30
|
||||
RAG_CONTEXT_TOP_N=8
|
||||
RAG_GENERATION_MAX_RETRIES=2
|
||||
RAG_CONTEXT_TOP_N=5
|
||||
RAG_CONTEXT_MAX_CHARS=8000
|
||||
RAG_GENERATION_MAX_RETRIES=1
|
||||
RAG_GENERATION_RETRY_DELAY_MS=500
|
||||
RAG_RERANK_ENABLED=true
|
||||
RAG_RERANK_MODEL=cohere/rerank-4-pro
|
||||
|
||||
@@ -176,6 +176,7 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str
|
||||
rerankStart := time.Now()
|
||||
results = s.rerankResults(ctx, searchQuery, results, contextLimit)
|
||||
rerankDuration := time.Since(rerankStart)
|
||||
results = limitRagResultsByContextChars(results, config.GetIntConfigWithDefault("RAG_CONTEXT_MAX_CHARS", 8000))
|
||||
|
||||
promptBuildStart := time.Now()
|
||||
var contextBuilder strings.Builder
|
||||
@@ -390,6 +391,34 @@ func limitRagResults(results []*models.RagChunk, limit int) []*models.RagChunk {
|
||||
return results[:limit]
|
||||
}
|
||||
|
||||
func limitRagResultsByContextChars(results []*models.RagChunk, maxChars int) []*models.RagChunk {
|
||||
if maxChars <= 0 || len(results) == 0 {
|
||||
return results
|
||||
}
|
||||
|
||||
selected := make([]*models.RagChunk, 0, len(results))
|
||||
used := 0
|
||||
for _, result := range results {
|
||||
if result == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
nextLen := len(result.Content) + 64
|
||||
if len(selected) > 0 && used+nextLen > maxChars {
|
||||
break
|
||||
}
|
||||
|
||||
selected = append(selected, result)
|
||||
used += nextLen
|
||||
}
|
||||
|
||||
if len(selected) == 0 {
|
||||
return results[:1]
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
func normalizeAnswer(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
|
||||
+79
-29
@@ -21,20 +21,23 @@ import (
|
||||
)
|
||||
|
||||
type RagUtils struct {
|
||||
llm llms.Model
|
||||
embedder *embeddings.EmbedderImpl
|
||||
httpClient *http.Client
|
||||
openRouterAPIKey string
|
||||
model string
|
||||
fallbackModel string
|
||||
generationMaxRetries int
|
||||
generationRetryDelay time.Duration
|
||||
rerankEnabled bool
|
||||
rerankModel string
|
||||
rerankFallbackModel string
|
||||
rerankMaxRetries int
|
||||
rerankRetryDelay time.Duration
|
||||
queryRewriteEnabled bool
|
||||
llm llms.Model
|
||||
embedder *embeddings.EmbedderImpl
|
||||
httpClient *http.Client
|
||||
openRouterAPIKey string
|
||||
model string
|
||||
fallbackModel string
|
||||
generationMaxRetries int
|
||||
generationRetryDelay time.Duration
|
||||
rerankEnabled bool
|
||||
rerankModel string
|
||||
rerankFallbackModel string
|
||||
rerankMaxRetries int
|
||||
rerankRetryDelay time.Duration
|
||||
queryRewriteEnabled bool
|
||||
queryRewriteModel string
|
||||
queryRewriteTimeout time.Duration
|
||||
queryRewriteMaxTokens int
|
||||
}
|
||||
|
||||
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
|
||||
@@ -61,11 +64,17 @@ func NewRagUtils() (*RagUtils, error) {
|
||||
|
||||
embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b")
|
||||
|
||||
llmTimeoutSeconds := config.GetIntConfigWithDefault("RAG_LLM_TIMEOUT_SECONDS", 20)
|
||||
if llmTimeoutSeconds <= 0 {
|
||||
llmTimeoutSeconds = 20
|
||||
}
|
||||
|
||||
llm, err := openai.New(
|
||||
openai.WithToken(openRouterAPIKey),
|
||||
openai.WithBaseURL(openRouterBaseURL),
|
||||
openai.WithModel(model),
|
||||
openai.WithEmbeddingModel(embeddingModel),
|
||||
openai.WithHTTPClient(&http.Client{Timeout: time.Duration(llmTimeoutSeconds) * time.Second}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init openrouter ai: %w", err)
|
||||
@@ -107,21 +116,42 @@ func NewRagUtils() (*RagUtils, error) {
|
||||
generationRetryDelayMs = 500
|
||||
}
|
||||
|
||||
queryRewriteModel := strings.TrimSpace(config.GetConfigWithDefault("RAG_QUERY_REWRITE_MODEL", model))
|
||||
if queryRewriteModel == "" {
|
||||
queryRewriteModel = model
|
||||
}
|
||||
|
||||
queryRewriteTimeoutSeconds := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_TIMEOUT_SECONDS", 5)
|
||||
if queryRewriteTimeoutSeconds < 0 {
|
||||
queryRewriteTimeoutSeconds = 0
|
||||
}
|
||||
|
||||
queryRewriteMaxTokens := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_MAX_TOKENS", 96)
|
||||
if queryRewriteMaxTokens <= 0 {
|
||||
queryRewriteMaxTokens = 96
|
||||
}
|
||||
if queryRewriteMaxTokens > 512 {
|
||||
queryRewriteMaxTokens = 512
|
||||
}
|
||||
|
||||
return &RagUtils{
|
||||
llm: llm,
|
||||
embedder: embedder,
|
||||
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
|
||||
openRouterAPIKey: openRouterAPIKey,
|
||||
model: model,
|
||||
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
|
||||
generationMaxRetries: generationMaxRetries,
|
||||
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
|
||||
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
|
||||
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
|
||||
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
|
||||
rerankMaxRetries: rerankMaxRetries,
|
||||
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
|
||||
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
|
||||
llm: llm,
|
||||
embedder: embedder,
|
||||
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
|
||||
openRouterAPIKey: openRouterAPIKey,
|
||||
model: model,
|
||||
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
|
||||
generationMaxRetries: generationMaxRetries,
|
||||
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
|
||||
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
|
||||
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
|
||||
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
|
||||
rerankMaxRetries: rerankMaxRetries,
|
||||
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
|
||||
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
|
||||
queryRewriteModel: queryRewriteModel,
|
||||
queryRewriteTimeout: time.Duration(queryRewriteTimeoutSeconds) * time.Second,
|
||||
queryRewriteMaxTokens: queryRewriteMaxTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -202,10 +232,30 @@ Rules:
|
||||
|
||||
<query>`, historyText, question)
|
||||
|
||||
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt)
|
||||
rewriteCtx := ctx
|
||||
cancel := func() {}
|
||||
if u.queryRewriteTimeout > 0 {
|
||||
rewriteCtx, cancel = context.WithTimeout(ctx, u.queryRewriteTimeout)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
options := []llms.CallOption{
|
||||
llms.WithModel(u.queryRewriteModel),
|
||||
llms.WithMaxTokens(u.queryRewriteMaxTokens),
|
||||
llms.WithTemperature(0),
|
||||
}
|
||||
|
||||
rewriteStart := time.Now()
|
||||
raw, err := llms.GenerateFromSinglePrompt(rewriteCtx, u.llm, prompt, options...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
log.Info().
|
||||
Str("model", u.queryRewriteModel).
|
||||
Int("prompt_chars", len(prompt)).
|
||||
Int("response_chars", len(raw)).
|
||||
Dur("duration", time.Since(rewriteStart)).
|
||||
Msg("rag query rewrite succeeded")
|
||||
|
||||
query := normalizeGeneratedQuery(raw)
|
||||
if query == "" {
|
||||
|
||||
Reference in New Issue
Block a user