feat: enhance RAG functionality with context character limits, query rewriting options, and improved timeout settings
Build and Release / release (push) Successful in 1m31s
Build and Release / release (push) Successful in 1m31s
This commit is contained in:
@@ -133,15 +133,20 @@ history-api/
|
|||||||
GOOGLE_AI_EMBEDDING_MODEL=
|
GOOGLE_AI_EMBEDDING_MODEL=
|
||||||
|
|
||||||
OPEN_ROUTER_API=
|
OPEN_ROUTER_API=
|
||||||
OPEN_ROUTER_MODEL=
|
OPEN_ROUTER_MODEL=qwen/qwen3-30b-a3b-instruct-2507
|
||||||
OPEN_ROUTER_FALLBACK_MODEL=qwen/qwen3-30b-a3b-instruct-2507
|
OPEN_ROUTER_FALLBACK_MODEL=google/gemini-2.5-flash-lite
|
||||||
OPEN_ROUTER_EMBEDDING_MODEL=
|
OPEN_ROUTER_EMBEDDING_MODEL=
|
||||||
|
|
||||||
|
RAG_LLM_TIMEOUT_SECONDS=20
|
||||||
RAG_QUERY_REWRITE_ENABLED=true
|
RAG_QUERY_REWRITE_ENABLED=true
|
||||||
|
RAG_QUERY_REWRITE_MODEL=google/gemini-2.5-flash-lite
|
||||||
|
RAG_QUERY_REWRITE_TIMEOUT_SECONDS=5
|
||||||
|
RAG_QUERY_REWRITE_MAX_TOKENS=96
|
||||||
RAG_REWRITE_HISTORY_TURNS=3
|
RAG_REWRITE_HISTORY_TURNS=3
|
||||||
RAG_RETRIEVAL_CANDIDATES=30
|
RAG_RETRIEVAL_CANDIDATES=30
|
||||||
RAG_CONTEXT_TOP_N=8
|
RAG_CONTEXT_TOP_N=5
|
||||||
RAG_GENERATION_MAX_RETRIES=2
|
RAG_CONTEXT_MAX_CHARS=8000
|
||||||
|
RAG_GENERATION_MAX_RETRIES=1
|
||||||
RAG_GENERATION_RETRY_DELAY_MS=500
|
RAG_GENERATION_RETRY_DELAY_MS=500
|
||||||
RAG_RERANK_ENABLED=true
|
RAG_RERANK_ENABLED=true
|
||||||
RAG_RERANK_MODEL=cohere/rerank-4-pro
|
RAG_RERANK_MODEL=cohere/rerank-4-pro
|
||||||
|
|||||||
@@ -176,6 +176,7 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str
|
|||||||
rerankStart := time.Now()
|
rerankStart := time.Now()
|
||||||
results = s.rerankResults(ctx, searchQuery, results, contextLimit)
|
results = s.rerankResults(ctx, searchQuery, results, contextLimit)
|
||||||
rerankDuration := time.Since(rerankStart)
|
rerankDuration := time.Since(rerankStart)
|
||||||
|
results = limitRagResultsByContextChars(results, config.GetIntConfigWithDefault("RAG_CONTEXT_MAX_CHARS", 8000))
|
||||||
|
|
||||||
promptBuildStart := time.Now()
|
promptBuildStart := time.Now()
|
||||||
var contextBuilder strings.Builder
|
var contextBuilder strings.Builder
|
||||||
@@ -390,6 +391,34 @@ func limitRagResults(results []*models.RagChunk, limit int) []*models.RagChunk {
|
|||||||
return results[:limit]
|
return results[:limit]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func limitRagResultsByContextChars(results []*models.RagChunk, maxChars int) []*models.RagChunk {
|
||||||
|
if maxChars <= 0 || len(results) == 0 {
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
selected := make([]*models.RagChunk, 0, len(results))
|
||||||
|
used := 0
|
||||||
|
for _, result := range results {
|
||||||
|
if result == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
nextLen := len(result.Content) + 64
|
||||||
|
if len(selected) > 0 && used+nextLen > maxChars {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
selected = append(selected, result)
|
||||||
|
used += nextLen
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(selected) == 0 {
|
||||||
|
return results[:1]
|
||||||
|
}
|
||||||
|
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeAnswer(s string) string {
|
func normalizeAnswer(s string) string {
|
||||||
s = strings.TrimSpace(s)
|
s = strings.TrimSpace(s)
|
||||||
|
|
||||||
|
|||||||
+79
-29
@@ -21,20 +21,23 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RagUtils struct {
|
type RagUtils struct {
|
||||||
llm llms.Model
|
llm llms.Model
|
||||||
embedder *embeddings.EmbedderImpl
|
embedder *embeddings.EmbedderImpl
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
openRouterAPIKey string
|
openRouterAPIKey string
|
||||||
model string
|
model string
|
||||||
fallbackModel string
|
fallbackModel string
|
||||||
generationMaxRetries int
|
generationMaxRetries int
|
||||||
generationRetryDelay time.Duration
|
generationRetryDelay time.Duration
|
||||||
rerankEnabled bool
|
rerankEnabled bool
|
||||||
rerankModel string
|
rerankModel string
|
||||||
rerankFallbackModel string
|
rerankFallbackModel string
|
||||||
rerankMaxRetries int
|
rerankMaxRetries int
|
||||||
rerankRetryDelay time.Duration
|
rerankRetryDelay time.Duration
|
||||||
queryRewriteEnabled bool
|
queryRewriteEnabled bool
|
||||||
|
queryRewriteModel string
|
||||||
|
queryRewriteTimeout time.Duration
|
||||||
|
queryRewriteMaxTokens int
|
||||||
}
|
}
|
||||||
|
|
||||||
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
|
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
|
||||||
@@ -61,11 +64,17 @@ func NewRagUtils() (*RagUtils, error) {
|
|||||||
|
|
||||||
embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b")
|
embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b")
|
||||||
|
|
||||||
|
llmTimeoutSeconds := config.GetIntConfigWithDefault("RAG_LLM_TIMEOUT_SECONDS", 20)
|
||||||
|
if llmTimeoutSeconds <= 0 {
|
||||||
|
llmTimeoutSeconds = 20
|
||||||
|
}
|
||||||
|
|
||||||
llm, err := openai.New(
|
llm, err := openai.New(
|
||||||
openai.WithToken(openRouterAPIKey),
|
openai.WithToken(openRouterAPIKey),
|
||||||
openai.WithBaseURL(openRouterBaseURL),
|
openai.WithBaseURL(openRouterBaseURL),
|
||||||
openai.WithModel(model),
|
openai.WithModel(model),
|
||||||
openai.WithEmbeddingModel(embeddingModel),
|
openai.WithEmbeddingModel(embeddingModel),
|
||||||
|
openai.WithHTTPClient(&http.Client{Timeout: time.Duration(llmTimeoutSeconds) * time.Second}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to init openrouter ai: %w", err)
|
return nil, fmt.Errorf("failed to init openrouter ai: %w", err)
|
||||||
@@ -107,21 +116,42 @@ func NewRagUtils() (*RagUtils, error) {
|
|||||||
generationRetryDelayMs = 500
|
generationRetryDelayMs = 500
|
||||||
}
|
}
|
||||||
|
|
||||||
|
queryRewriteModel := strings.TrimSpace(config.GetConfigWithDefault("RAG_QUERY_REWRITE_MODEL", model))
|
||||||
|
if queryRewriteModel == "" {
|
||||||
|
queryRewriteModel = model
|
||||||
|
}
|
||||||
|
|
||||||
|
queryRewriteTimeoutSeconds := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_TIMEOUT_SECONDS", 5)
|
||||||
|
if queryRewriteTimeoutSeconds < 0 {
|
||||||
|
queryRewriteTimeoutSeconds = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
queryRewriteMaxTokens := config.GetIntConfigWithDefault("RAG_QUERY_REWRITE_MAX_TOKENS", 96)
|
||||||
|
if queryRewriteMaxTokens <= 0 {
|
||||||
|
queryRewriteMaxTokens = 96
|
||||||
|
}
|
||||||
|
if queryRewriteMaxTokens > 512 {
|
||||||
|
queryRewriteMaxTokens = 512
|
||||||
|
}
|
||||||
|
|
||||||
return &RagUtils{
|
return &RagUtils{
|
||||||
llm: llm,
|
llm: llm,
|
||||||
embedder: embedder,
|
embedder: embedder,
|
||||||
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
|
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
|
||||||
openRouterAPIKey: openRouterAPIKey,
|
openRouterAPIKey: openRouterAPIKey,
|
||||||
model: model,
|
model: model,
|
||||||
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
|
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
|
||||||
generationMaxRetries: generationMaxRetries,
|
generationMaxRetries: generationMaxRetries,
|
||||||
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
|
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
|
||||||
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
|
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
|
||||||
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
|
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
|
||||||
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
|
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
|
||||||
rerankMaxRetries: rerankMaxRetries,
|
rerankMaxRetries: rerankMaxRetries,
|
||||||
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
|
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
|
||||||
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
|
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
|
||||||
|
queryRewriteModel: queryRewriteModel,
|
||||||
|
queryRewriteTimeout: time.Duration(queryRewriteTimeoutSeconds) * time.Second,
|
||||||
|
queryRewriteMaxTokens: queryRewriteMaxTokens,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -202,10 +232,30 @@ Rules:
|
|||||||
|
|
||||||
<query>`, historyText, question)
|
<query>`, historyText, question)
|
||||||
|
|
||||||
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt)
|
rewriteCtx := ctx
|
||||||
|
cancel := func() {}
|
||||||
|
if u.queryRewriteTimeout > 0 {
|
||||||
|
rewriteCtx, cancel = context.WithTimeout(ctx, u.queryRewriteTimeout)
|
||||||
|
}
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
options := []llms.CallOption{
|
||||||
|
llms.WithModel(u.queryRewriteModel),
|
||||||
|
llms.WithMaxTokens(u.queryRewriteMaxTokens),
|
||||||
|
llms.WithTemperature(0),
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriteStart := time.Now()
|
||||||
|
raw, err := llms.GenerateFromSinglePrompt(rewriteCtx, u.llm, prompt, options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
log.Info().
|
||||||
|
Str("model", u.queryRewriteModel).
|
||||||
|
Int("prompt_chars", len(prompt)).
|
||||||
|
Int("response_chars", len(raw)).
|
||||||
|
Dur("duration", time.Since(rewriteStart)).
|
||||||
|
Msg("rag query rewrite succeeded")
|
||||||
|
|
||||||
query := normalizeGeneratedQuery(raw)
|
query := normalizeGeneratedQuery(raw)
|
||||||
if query == "" {
|
if query == "" {
|
||||||
|
|||||||
Reference in New Issue
Block a user