diff --git a/README.md b/README.md
index 2f31512..dae9a60 100644
--- a/README.md
+++ b/README.md
@@ -134,8 +134,22 @@ history-api/
OPEN_ROUTER_API=
OPEN_ROUTER_MODEL=
+ OPEN_ROUTER_FALLBACK_MODEL=qwen/qwen3-30b-a3b-instruct-2507
OPEN_ROUTER_EMBEDDING_MODEL=
+ RAG_QUERY_REWRITE_ENABLED=true
+ RAG_REWRITE_HISTORY_TURNS=3
+ RAG_RETRIEVAL_CANDIDATES=30
+ RAG_CONTEXT_TOP_N=8
+ RAG_GENERATION_MAX_RETRIES=2
+ RAG_GENERATION_RETRY_DELAY_MS=500
+ RAG_RERANK_ENABLED=true
+ RAG_RERANK_MODEL=cohere/rerank-4-pro
+ RAG_RERANK_FALLBACK_MODEL=cohere/rerank-4-fast
+ RAG_RERANK_TIMEOUT_SECONDS=10
+ RAG_RERANK_MAX_RETRIES=2
+ RAG_RERANK_RETRY_DELAY_MS=250
+
GOONG_API_KEY_MAP=
GOONG_API_KEY_REQ=
@@ -214,4 +228,4 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
- [pgvector](https://github.com/pgvector/pgvector) - Vector similarity search for PostgreSQL
- [LangChain Go](https://github.com/tmc/langchaingo) - Framework for LLM applications
- [Swagger UI](https://swagger.io/tools/swagger-ui/) - Interactive API documentation
-- [Rustfs](https://github.com/rustfs/rustfs) - High performance S3 compatible object storage
\ No newline at end of file
+- [Rustfs](https://github.com/rustfs/rustfs) - High performance S3 compatible object storage
diff --git a/db/query/rag.sql b/db/query/rag.sql
index 311e0bf..016745a 100644
--- a/db/query/rag.sql
+++ b/db/query/rag.sql
@@ -9,7 +9,7 @@ RETURNING *;
-- name: SearchRagChunks :many
SELECT
- id, source_type, source_id, project_id, chunk_index, content,
+ id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at,
(1 - (embedding <=> sqlc.arg('embedding')))::float8 AS similarity
FROM rag_chunks
WHERE 1=1
diff --git a/internal/gen/sqlc/rag.sql.go b/internal/gen/sqlc/rag.sql.go
index dcc6a48..05b1ebd 100644
--- a/internal/gen/sqlc/rag.sql.go
+++ b/internal/gen/sqlc/rag.sql.go
@@ -74,7 +74,7 @@ func (q *Queries) DeleteRagChunksBySourceIDs(ctx context.Context, arg DeleteRagC
const searchRagChunks = `-- name: SearchRagChunks :many
SELECT
- id, source_type, source_id, project_id, chunk_index, content,
+ id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at,
(1 - (embedding <=> $1))::float8 AS similarity
FROM rag_chunks
WHERE 1=1
@@ -94,13 +94,15 @@ type SearchRagChunksParams struct {
}
type SearchRagChunksRow struct {
- ID pgtype.UUID `json:"id"`
- SourceType string `json:"source_type"`
- SourceID pgtype.UUID `json:"source_id"`
- ProjectID pgtype.UUID `json:"project_id"`
- ChunkIndex int32 `json:"chunk_index"`
- Content string `json:"content"`
- Similarity float64 `json:"similarity"`
+ ID pgtype.UUID `json:"id"`
+ SourceType string `json:"source_type"`
+ SourceID pgtype.UUID `json:"source_id"`
+ ProjectID pgtype.UUID `json:"project_id"`
+ ChunkIndex int32 `json:"chunk_index"`
+ Content string `json:"content"`
+ CreatedAt pgtype.Timestamptz `json:"created_at"`
+ UpdatedAt pgtype.Timestamptz `json:"updated_at"`
+ Similarity float64 `json:"similarity"`
}
func (q *Queries) SearchRagChunks(ctx context.Context, arg SearchRagChunksParams) ([]SearchRagChunksRow, error) {
@@ -125,6 +127,8 @@ func (q *Queries) SearchRagChunks(ctx context.Context, arg SearchRagChunksParams
&i.ProjectID,
&i.ChunkIndex,
&i.Content,
+ &i.CreatedAt,
+ &i.UpdatedAt,
&i.Similarity,
); err != nil {
return nil, err
diff --git a/internal/repositories/ragRepository.go b/internal/repositories/ragRepository.go
index 259ff2a..c649f2e 100644
--- a/internal/repositories/ragRepository.go
+++ b/internal/repositories/ragRepository.go
@@ -67,8 +67,14 @@ func (r *ragRepository) SearchSimilar(ctx context.Context, projectID *string, ve
for i, row := range rows {
res[i] = &models.RagChunk{
ID: convert.UUIDToString(row.ID),
+ SourceType: row.SourceType,
+ SourceID: convert.UUIDToString(row.SourceID),
+ ProjectID: convert.UUIDToString(row.ProjectID),
+ ChunkIndex: row.ChunkIndex,
Content: row.Content,
Similarity: row.Similarity,
+ CreatedAt: row.CreatedAt.Time,
+ UpdatedAt: row.UpdatedAt.Time,
}
}
return res, nil
@@ -78,7 +84,7 @@ func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string
if len(sourceIDs) == 0 {
return nil
}
-
+
uids := make([]pgtype.UUID, 0, len(sourceIDs))
for _, id := range sourceIDs {
uid, err := convert.StringToUUID(id)
@@ -86,9 +92,9 @@ func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string
uids = append(uids, uid)
}
}
-
+
return r.q.DeleteRagChunksBySourceIDs(ctx, sqlc.DeleteRagChunksBySourceIDsParams{
SourceType: sourceType,
- Column2: uids,
+ Column2: uids,
})
}
diff --git a/internal/services/chatbotService.go b/internal/services/chatbotService.go
index 3f6c993..ff8875f 100644
--- a/internal/services/chatbotService.go
+++ b/internal/services/chatbotService.go
@@ -8,6 +8,7 @@ import (
"history-api/internal/models"
"history-api/internal/repositories"
"history-api/pkg/ai"
+ "history-api/pkg/config"
"history-api/pkg/constants"
"history-api/pkg/convert"
"strings"
@@ -52,25 +53,49 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str
return "", fmt.Errorf("invalid user id: %w", err)
}
- qVector, err := s.ragUtils.EmbedQuery(ctx, question)
+ searchQuery := strings.TrimSpace(question)
+ rewriteHistory := s.getRewriteHistory(ctx, pgUserID)
+ rewrittenQuery, err := s.ragUtils.RewriteSearchQuery(ctx, question, rewriteHistory)
+ if err != nil {
+ log.Warn().Err(err).Msg("failed to rewrite RAG search query")
+ } else if strings.TrimSpace(rewrittenQuery) != "" {
+ searchQuery = strings.TrimSpace(rewrittenQuery)
+ }
+
+ candidateLimit := config.GetIntConfigWithDefault("RAG_RETRIEVAL_CANDIDATES", 30)
+ if candidateLimit < 8 {
+ candidateLimit = 8
+ }
+
+ contextLimit := config.GetIntConfigWithDefault("RAG_CONTEXT_TOP_N", 8)
+ if contextLimit <= 0 {
+ contextLimit = 8
+ }
+ if contextLimit > candidateLimit {
+ contextLimit = candidateLimit
+ }
+
+ qVector, err := s.ragUtils.EmbedQuery(ctx, searchQuery)
if err != nil {
return "", fmt.Errorf("failed to embed question: %w", err)
}
- results, err := s.repo.SearchSimilar(ctx, projectID, qVector, 8, 0.50)
+ results, err := s.repo.SearchSimilar(ctx, projectID, qVector, candidateLimit, 0.45)
if err != nil {
return "", fmt.Errorf("failed to search similar content: %w", err)
}
- if len(results) < 3 {
- broadResults, err := s.repo.SearchSimilar(ctx, projectID, qVector, 8, 0.35)
+ if len(results) < contextLimit {
+ broadResults, err := s.repo.SearchSimilar(ctx, projectID, qVector, candidateLimit, 0.30)
if err == nil && len(broadResults) > len(results) {
results = broadResults
}
}
+ results = s.rerankResults(ctx, searchQuery, results, contextLimit)
+
var contextBuilder strings.Builder
- contextBuilder.Grow(len(results) * 96)
+ contextBuilder.Grow(len(results) * 128)
for i, res := range results {
contextBuilder.WriteString(fmt.Sprintf("\n%s\n\n\n", i+1, res.Similarity, res.Content))
}
@@ -115,6 +140,8 @@ Rules:
- Do not cite documents unless the user asks.
- Your final response MUST be wrapped inside tags.
- Do not output anything outside tags.
+- Before writing the final answer, silently verify that each factual sentence is supported by Context.
+- Keep the answer focused on the user's question. Do not add background information that is not needed.
- Answer in complete, natural, grammatically correct sentences.`, contextStr, question)
}
@@ -141,6 +168,99 @@ Rules:
return response, nil
}
+func (s *chatbotService) getRewriteHistory(ctx context.Context, userID pgtype.UUID) []ai.ChatTurn {
+ limit := config.GetIntConfigWithDefault("RAG_REWRITE_HISTORY_TURNS", 3)
+ if limit <= 0 {
+ return nil
+ }
+ if limit > 5 {
+ limit = 5
+ }
+
+ history, err := s.chatRepo.GetChatbotHistory(ctx, sqlc.GetChatbotHistoryParams{
+ UserID: userID,
+ Limit: int32(limit),
+ })
+ if err != nil {
+ log.Warn().Err(err).Msg("failed to load chatbot history for RAG query rewrite")
+ return nil
+ }
+
+ turns := make([]ai.ChatTurn, 0, len(history))
+ for _, item := range history {
+ if item == nil {
+ continue
+ }
+ turns = append(turns, ai.ChatTurn{
+ Question: item.Question,
+ Answer: item.Answer,
+ })
+ }
+
+ return turns
+}
+
+func (s *chatbotService) rerankResults(ctx context.Context, query string, results []*models.RagChunk, limit int) []*models.RagChunk {
+ if len(results) == 0 {
+ return results
+ }
+ if limit <= 0 {
+ limit = len(results)
+ }
+ if limit > len(results) {
+ limit = len(results)
+ }
+
+ documents := make([]string, len(results))
+ for i, result := range results {
+ documents[i] = result.Content
+ }
+
+ ranked, err := s.ragUtils.RerankDocuments(ctx, query, documents, limit)
+ if err != nil {
+ log.Warn().Err(err).Msg("failed to rerank RAG results")
+ return limitRagResults(results, limit)
+ }
+ if len(ranked) == 0 {
+ return limitRagResults(results, limit)
+ }
+
+ selected := make([]*models.RagChunk, 0, limit)
+ seen := make(map[int]struct{}, limit)
+ for _, item := range ranked {
+ if item.Index < 0 || item.Index >= len(results) {
+ continue
+ }
+ if _, exists := seen[item.Index]; exists {
+ continue
+ }
+ seen[item.Index] = struct{}{}
+ selected = append(selected, results[item.Index])
+ if len(selected) >= limit {
+ return selected
+ }
+ }
+
+ for i, result := range results {
+ if _, exists := seen[i]; exists {
+ continue
+ }
+ selected = append(selected, result)
+ if len(selected) >= limit {
+ break
+ }
+ }
+
+ return selected
+}
+
+func limitRagResults(results []*models.RagChunk, limit int) []*models.RagChunk {
+ if limit <= 0 || len(results) <= limit {
+ return results
+ }
+ return results[:limit]
+}
+
func normalizeAnswer(s string) string {
s = strings.TrimSpace(s)
@@ -157,6 +277,7 @@ func normalizeAnswer(s string) string {
return s
}
+
func (s *chatbotService) GetHistory(ctx context.Context, userID string, dto *request.GetChatbotHistoryDto) ([]*models.ChatbotHistoryEntity, error) {
pgUserID, err := convert.StringToUUID(userID)
if err != nil {
diff --git a/pkg/ai/rag.go b/pkg/ai/rag.go
index f34c063..8687154 100644
--- a/pkg/ai/rag.go
+++ b/pkg/ai/rag.go
@@ -1,12 +1,17 @@
package ai
import (
+ "bytes"
"context"
"fmt"
"history-api/pkg/config"
+ json "history-api/pkg/jsonx"
"html"
+ "io"
+ "net/http"
"regexp"
"strings"
+ "time"
"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms"
@@ -15,31 +20,49 @@ import (
)
type RagUtils struct {
- llm llms.Model
- embedder *embeddings.EmbedderImpl
+ llm llms.Model
+ embedder *embeddings.EmbedderImpl
+ httpClient *http.Client
+ openRouterAPIKey string
+ model string
+ fallbackModel string
+ generationMaxRetries int
+ generationRetryDelay time.Duration
+ rerankEnabled bool
+ rerankModel string
+ rerankFallbackModel string
+ rerankMaxRetries int
+ rerankRetryDelay time.Duration
+ queryRewriteEnabled bool
}
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
+const openRouterBaseURL = "https://openrouter.ai/api/v1"
+
+type RerankResult struct {
+ Index int
+ Score float64
+}
+
+type ChatTurn struct {
+ Question string
+ Answer string
+}
+
func NewRagUtils() (*RagUtils, error) {
- openRouterAPIKey, err := config.GetConfig("OPEN_ROUTER_API")
- if err != nil {
- return nil, err
+ openRouterAPIKey := config.GetConfigWithDefault("OPEN_ROUTER_API", "")
+ if openRouterAPIKey == "" {
+ return nil, fmt.Errorf("OPEN_ROUTER_API is not set")
}
- model, err := config.GetConfig("OPEN_ROUTER_MODEL")
- if err != nil {
- model = "qwen/qwen3.5-flash-02-23"
- }
+ model := config.GetConfigWithDefault("OPEN_ROUTER_MODEL", "qwen/qwen3.5-flash-02-23")
- embeddingModel, err := config.GetConfig("OPEN_ROUTER_EMBEDDING_MODEL")
- if err != nil {
- embeddingModel = "qwen/qwen3-embedding-8b"
- }
+ embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b")
llm, err := openai.New(
openai.WithToken(openRouterAPIKey),
- openai.WithBaseURL("https://openrouter.ai/api/v1"),
+ openai.WithBaseURL(openRouterBaseURL),
openai.WithModel(model),
openai.WithEmbeddingModel(embeddingModel),
)
@@ -52,9 +75,52 @@ func NewRagUtils() (*RagUtils, error) {
return nil, fmt.Errorf("failed to init embedder: %w", err)
}
+ timeoutSeconds := config.GetIntConfigWithDefault("RAG_RERANK_TIMEOUT_SECONDS", 10)
+ if timeoutSeconds <= 0 {
+ timeoutSeconds = 10
+ }
+
+ rerankMaxRetries := config.GetIntConfigWithDefault("RAG_RERANK_MAX_RETRIES", 2)
+ if rerankMaxRetries < 0 {
+ rerankMaxRetries = 0
+ }
+ if rerankMaxRetries > 5 {
+ rerankMaxRetries = 5
+ }
+
+ rerankRetryDelayMs := config.GetIntConfigWithDefault("RAG_RERANK_RETRY_DELAY_MS", 250)
+ if rerankRetryDelayMs <= 0 {
+ rerankRetryDelayMs = 250
+ }
+
+ generationMaxRetries := config.GetIntConfigWithDefault("RAG_GENERATION_MAX_RETRIES", 2)
+ if generationMaxRetries < 0 {
+ generationMaxRetries = 0
+ }
+ if generationMaxRetries > 5 {
+ generationMaxRetries = 5
+ }
+
+ generationRetryDelayMs := config.GetIntConfigWithDefault("RAG_GENERATION_RETRY_DELAY_MS", 500)
+ if generationRetryDelayMs <= 0 {
+ generationRetryDelayMs = 500
+ }
+
return &RagUtils{
- llm: llm,
- embedder: embedder,
+ llm: llm,
+ embedder: embedder,
+ httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
+ openRouterAPIKey: openRouterAPIKey,
+ model: model,
+ fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
+ generationMaxRetries: generationMaxRetries,
+ generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
+ rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
+ rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
+ rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
+ rerankMaxRetries: rerankMaxRetries,
+ rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
+ queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
}, nil
}
@@ -103,14 +169,308 @@ func (u *RagUtils) EmbedQuery(ctx context.Context, query string) ([]float32, err
return vector, nil
}
-func (u *RagUtils) GenerateResponse(ctx context.Context, prompt string) (string, error) {
+func (u *RagUtils) RewriteSearchQuery(ctx context.Context, question string, history []ChatTurn) (string, error) {
+ question = normalizeWhitespace(question)
+ if question == "" || !u.queryRewriteEnabled {
+ return question, nil
+ }
+
+ historyText := buildHistoryForRewrite(history)
+ if historyText == "" {
+ historyText = "No previous chat turns."
+ }
+
+ prompt := fmt.Sprintf(`You rewrite user questions into precise retrieval queries for a history RAG system.
+
+Previous chat turns:
+%s
+
+Current user question:
+%s
+
+Rules:
+- Output exactly one retrieval query inside tags.
+- Do not answer the question.
+- Do not add names, dates, places, causes, results, or facts unless they are explicitly present in the current question or previous chat turns.
+- Use previous chat turns only to resolve clear references such as "đó", "ông ấy", "bà ấy", "sự kiện này", "that event", or "he/she".
+- Preserve proper nouns, dates, quoted text, and historical terms exactly when possible.
+- If the current question is already clear, return it unchanged.
+- If a reference is ambiguous, keep the ambiguous wording instead of guessing.
+- Use the same language as the current question.
+- No markdown, no bullet points, no explanation.
+
+`, historyText, question)
+
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt)
if err != nil {
return "", err
}
+
+ query := normalizeGeneratedQuery(raw)
+ if query == "" {
+ return question, nil
+ }
+
+ return query, nil
+}
+
+func (u *RagUtils) RerankDocuments(ctx context.Context, query string, documents []string, topN int) ([]RerankResult, error) {
+ if !u.rerankEnabled || len(documents) == 0 {
+ return nil, nil
+ }
+ if topN <= 0 || topN > len(documents) {
+ topN = len(documents)
+ }
+
+ model := strings.TrimSpace(u.rerankModel)
+ if model == "" {
+ model = "cohere/rerank-4-pro"
+ }
+
+ results, err := u.rerankDocumentsWithModel(ctx, model, query, documents, topN)
+ if err == nil {
+ return results, nil
+ }
+
+ fallbackModel := strings.TrimSpace(u.rerankFallbackModel)
+ if fallbackModel == "" || fallbackModel == model {
+ return nil, err
+ }
+
+ fallbackResults, fallbackErr := u.rerankDocumentsWithModel(ctx, fallbackModel, query, documents, topN)
+ if fallbackErr == nil {
+ return fallbackResults, nil
+ }
+
+ return nil, fmt.Errorf("rerank failed with model %s: %w; fallback model %s failed: %v", model, err, fallbackModel, fallbackErr)
+}
+
+func (u *RagUtils) rerankDocumentsWithModel(ctx context.Context, model, query string, documents []string, topN int) ([]RerankResult, error) {
+ reqBody := rerankRequest{
+ Model: model,
+ Query: query,
+ Documents: documents,
+ TopN: topN,
+ }
+
+ payload, err := json.Marshal(reqBody)
+ if err != nil {
+ return nil, err
+ }
+
+ var lastErr error
+ for attempt := 0; attempt <= u.rerankMaxRetries; attempt++ {
+ results, retryable, err := u.sendRerankRequest(ctx, payload, documents)
+ if err == nil {
+ return results, nil
+ }
+
+ lastErr = err
+ if !retryable || attempt == u.rerankMaxRetries {
+ break
+ }
+
+ delay := u.rerankRetryDelay * time.Duration(1< 2*time.Second {
+ delay = 2 * time.Second
+ }
+ if sleepErr := sleepWithContext(ctx, delay); sleepErr != nil {
+ return nil, sleepErr
+ }
+ }
+
+ return nil, fmt.Errorf("rerank failed with model %s: %w", model, lastErr)
+}
+
+func (u *RagUtils) sendRerankRequest(ctx context.Context, payload []byte, documents []string) ([]RerankResult, bool, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, openRouterBaseURL+"/rerank", bytes.NewReader(payload))
+ if err != nil {
+ return nil, false, err
+ }
+ req.Header.Set("Authorization", "Bearer "+u.openRouterAPIKey)
+ req.Header.Set("Content-Type", "application/json")
+
+ res, err := u.httpClient.Do(req)
+ if err != nil {
+ if ctx.Err() != nil {
+ return nil, false, ctx.Err()
+ }
+ return nil, true, err
+ }
+ defer res.Body.Close()
+
+ if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices {
+ body, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
+ return nil, isRetryableRerankStatus(res.StatusCode), fmt.Errorf("rerank request failed with status %d: %s", res.StatusCode, strings.TrimSpace(string(body)))
+ }
+
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ return nil, true, err
+ }
+
+ var rerankRes rerankResponse
+ if err := json.Unmarshal(body, &rerankRes); err != nil {
+ return nil, false, err
+ }
+
+ results := make([]RerankResult, 0, len(rerankRes.Results))
+ for _, item := range rerankRes.Results {
+ if item.Index < 0 || item.Index >= len(documents) {
+ continue
+ }
+ results = append(results, RerankResult{
+ Index: item.Index,
+ Score: item.RelevanceScore,
+ })
+ }
+
+ return results, false, nil
+}
+
+func (u *RagUtils) GenerateResponse(ctx context.Context, prompt string) (string, error) {
+ raw, err := u.generateSinglePromptWithRetry(ctx, prompt, u.model)
+ if err != nil && u.fallbackModel != "" && u.fallbackModel != u.model {
+ raw, err = u.generateSinglePromptWithRetry(ctx, prompt, u.fallbackModel)
+ }
+ if err != nil {
+ return "", err
+ }
return stripThinking(raw), nil
}
+func (u *RagUtils) generateSinglePromptWithRetry(ctx context.Context, prompt, model string) (string, error) {
+ var lastErr error
+ options := make([]llms.CallOption, 0, 1)
+ if strings.TrimSpace(model) != "" {
+ options = append(options, llms.WithModel(model))
+ }
+
+ for attempt := 0; attempt <= u.generationMaxRetries; attempt++ {
+ raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt, options...)
+ if err == nil {
+ return raw, nil
+ }
+
+ if ctx.Err() != nil {
+ return "", ctx.Err()
+ }
+
+ lastErr = err
+ if attempt == u.generationMaxRetries {
+ break
+ }
+
+ delay := u.generationRetryDelay * time.Duration(1< 3*time.Second {
+ delay = 3 * time.Second
+ }
+ if sleepErr := sleepWithContext(ctx, delay); sleepErr != nil {
+ return "", sleepErr
+ }
+ }
+
+ if strings.TrimSpace(model) == "" {
+ return "", lastErr
+ }
+
+ return "", fmt.Errorf("generation failed with model %s: %w", model, lastErr)
+}
+
+type rerankRequest struct {
+ Model string `json:"model"`
+ Query string `json:"query"`
+ Documents []string `json:"documents"`
+ TopN int `json:"top_n"`
+}
+
+type rerankResponse struct {
+ Results []struct {
+ Index int `json:"index"`
+ RelevanceScore float64 `json:"relevance_score"`
+ } `json:"results"`
+}
+
+func buildHistoryForRewrite(history []ChatTurn) string {
+ if len(history) == 0 {
+ return ""
+ }
+
+ var builder strings.Builder
+ for i, turn := range history {
+ question := normalizeWhitespace(turn.Question)
+ answer := normalizeWhitespace(turn.Answer)
+ if question == "" && answer == "" {
+ continue
+ }
+ builder.WriteString(fmt.Sprintf("Turn %d\nUser: %s\nAssistant: %s\n", i+1, question, answer))
+ }
+
+ return strings.TrimSpace(builder.String())
+}
+
+func normalizeGeneratedQuery(raw string) string {
+ raw = strings.TrimSpace(raw)
+
+ if query := extractTag(raw, "query"); query != "" {
+ return normalizeWhitespace(query)
+ }
+
+ raw = strings.TrimSpace(strings.TrimPrefix(raw, "Query:"))
+ raw = strings.TrimSpace(strings.TrimPrefix(raw, "Retrieval query:"))
+ raw = strings.Trim(raw, "\"'`")
+
+ return normalizeWhitespace(raw)
+}
+
+func extractTag(raw, tag string) string {
+ startTag := "<" + tag + ">"
+ endTag := "" + tag + ">"
+
+ start := strings.LastIndex(raw, startTag)
+ if start == -1 {
+ return ""
+ }
+
+ content := raw[start+len(startTag):]
+ end := strings.Index(content, endTag)
+ if end == -1 {
+ return strings.TrimSpace(content)
+ }
+
+ return strings.TrimSpace(content[:end])
+}
+
+func normalizeWhitespace(s string) string {
+ return strings.Join(strings.Fields(strings.TrimSpace(s)), " ")
+}
+
+func isRetryableRerankStatus(statusCode int) bool {
+ switch statusCode {
+ case http.StatusRequestTimeout,
+ http.StatusTooManyRequests,
+ http.StatusInternalServerError,
+ http.StatusBadGateway,
+ http.StatusServiceUnavailable,
+ http.StatusGatewayTimeout:
+ return true
+ default:
+ return false
+ }
+}
+
+func sleepWithContext(ctx context.Context, delay time.Duration) error {
+ timer := time.NewTimer(delay)
+ defer timer.Stop()
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-timer.C:
+ return nil
+ }
+}
+
func stripThinking(raw string) string {
startTag := ""
endTag := ""