feat: enhance RAG functionality with query rewriting, result reranking, and additional metadata in responses
Build and Release / release (push) Successful in 1m28s
Build and Release / release (push) Successful in 1m28s
This commit is contained in:
@@ -134,8 +134,22 @@ history-api/
|
|||||||
|
|
||||||
OPEN_ROUTER_API=
|
OPEN_ROUTER_API=
|
||||||
OPEN_ROUTER_MODEL=
|
OPEN_ROUTER_MODEL=
|
||||||
|
OPEN_ROUTER_FALLBACK_MODEL=qwen/qwen3-30b-a3b-instruct-2507
|
||||||
OPEN_ROUTER_EMBEDDING_MODEL=
|
OPEN_ROUTER_EMBEDDING_MODEL=
|
||||||
|
|
||||||
|
RAG_QUERY_REWRITE_ENABLED=true
|
||||||
|
RAG_REWRITE_HISTORY_TURNS=3
|
||||||
|
RAG_RETRIEVAL_CANDIDATES=30
|
||||||
|
RAG_CONTEXT_TOP_N=8
|
||||||
|
RAG_GENERATION_MAX_RETRIES=2
|
||||||
|
RAG_GENERATION_RETRY_DELAY_MS=500
|
||||||
|
RAG_RERANK_ENABLED=true
|
||||||
|
RAG_RERANK_MODEL=cohere/rerank-4-pro
|
||||||
|
RAG_RERANK_FALLBACK_MODEL=cohere/rerank-4-fast
|
||||||
|
RAG_RERANK_TIMEOUT_SECONDS=10
|
||||||
|
RAG_RERANK_MAX_RETRIES=2
|
||||||
|
RAG_RERANK_RETRY_DELAY_MS=250
|
||||||
|
|
||||||
GOONG_API_KEY_MAP=
|
GOONG_API_KEY_MAP=
|
||||||
GOONG_API_KEY_REQ=
|
GOONG_API_KEY_REQ=
|
||||||
|
|
||||||
@@ -214,4 +228,4 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
|||||||
- [pgvector](https://github.com/pgvector/pgvector) - Vector similarity search for PostgreSQL
|
- [pgvector](https://github.com/pgvector/pgvector) - Vector similarity search for PostgreSQL
|
||||||
- [LangChain Go](https://github.com/tmc/langchaingo) - Framework for LLM applications
|
- [LangChain Go](https://github.com/tmc/langchaingo) - Framework for LLM applications
|
||||||
- [Swagger UI](https://swagger.io/tools/swagger-ui/) - Interactive API documentation
|
- [Swagger UI](https://swagger.io/tools/swagger-ui/) - Interactive API documentation
|
||||||
- [Rustfs](https://github.com/rustfs/rustfs) - High performance S3 compatible object storage
|
- [Rustfs](https://github.com/rustfs/rustfs) - High performance S3 compatible object storage
|
||||||
|
|||||||
+1
-1
@@ -9,7 +9,7 @@ RETURNING *;
|
|||||||
|
|
||||||
-- name: SearchRagChunks :many
|
-- name: SearchRagChunks :many
|
||||||
SELECT
|
SELECT
|
||||||
id, source_type, source_id, project_id, chunk_index, content,
|
id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at,
|
||||||
(1 - (embedding <=> sqlc.arg('embedding')))::float8 AS similarity
|
(1 - (embedding <=> sqlc.arg('embedding')))::float8 AS similarity
|
||||||
FROM rag_chunks
|
FROM rag_chunks
|
||||||
WHERE 1=1
|
WHERE 1=1
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func (q *Queries) DeleteRagChunksBySourceIDs(ctx context.Context, arg DeleteRagC
|
|||||||
|
|
||||||
const searchRagChunks = `-- name: SearchRagChunks :many
|
const searchRagChunks = `-- name: SearchRagChunks :many
|
||||||
SELECT
|
SELECT
|
||||||
id, source_type, source_id, project_id, chunk_index, content,
|
id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at,
|
||||||
(1 - (embedding <=> $1))::float8 AS similarity
|
(1 - (embedding <=> $1))::float8 AS similarity
|
||||||
FROM rag_chunks
|
FROM rag_chunks
|
||||||
WHERE 1=1
|
WHERE 1=1
|
||||||
@@ -94,13 +94,15 @@ type SearchRagChunksParams struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SearchRagChunksRow struct {
|
type SearchRagChunksRow struct {
|
||||||
ID pgtype.UUID `json:"id"`
|
ID pgtype.UUID `json:"id"`
|
||||||
SourceType string `json:"source_type"`
|
SourceType string `json:"source_type"`
|
||||||
SourceID pgtype.UUID `json:"source_id"`
|
SourceID pgtype.UUID `json:"source_id"`
|
||||||
ProjectID pgtype.UUID `json:"project_id"`
|
ProjectID pgtype.UUID `json:"project_id"`
|
||||||
ChunkIndex int32 `json:"chunk_index"`
|
ChunkIndex int32 `json:"chunk_index"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Similarity float64 `json:"similarity"`
|
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||||
|
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||||
|
Similarity float64 `json:"similarity"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) SearchRagChunks(ctx context.Context, arg SearchRagChunksParams) ([]SearchRagChunksRow, error) {
|
func (q *Queries) SearchRagChunks(ctx context.Context, arg SearchRagChunksParams) ([]SearchRagChunksRow, error) {
|
||||||
@@ -125,6 +127,8 @@ func (q *Queries) SearchRagChunks(ctx context.Context, arg SearchRagChunksParams
|
|||||||
&i.ProjectID,
|
&i.ProjectID,
|
||||||
&i.ChunkIndex,
|
&i.ChunkIndex,
|
||||||
&i.Content,
|
&i.Content,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
&i.Similarity,
|
&i.Similarity,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -67,8 +67,14 @@ func (r *ragRepository) SearchSimilar(ctx context.Context, projectID *string, ve
|
|||||||
for i, row := range rows {
|
for i, row := range rows {
|
||||||
res[i] = &models.RagChunk{
|
res[i] = &models.RagChunk{
|
||||||
ID: convert.UUIDToString(row.ID),
|
ID: convert.UUIDToString(row.ID),
|
||||||
|
SourceType: row.SourceType,
|
||||||
|
SourceID: convert.UUIDToString(row.SourceID),
|
||||||
|
ProjectID: convert.UUIDToString(row.ProjectID),
|
||||||
|
ChunkIndex: row.ChunkIndex,
|
||||||
Content: row.Content,
|
Content: row.Content,
|
||||||
Similarity: row.Similarity,
|
Similarity: row.Similarity,
|
||||||
|
CreatedAt: row.CreatedAt.Time,
|
||||||
|
UpdatedAt: row.UpdatedAt.Time,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return res, nil
|
return res, nil
|
||||||
@@ -78,7 +84,7 @@ func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string
|
|||||||
if len(sourceIDs) == 0 {
|
if len(sourceIDs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
uids := make([]pgtype.UUID, 0, len(sourceIDs))
|
uids := make([]pgtype.UUID, 0, len(sourceIDs))
|
||||||
for _, id := range sourceIDs {
|
for _, id := range sourceIDs {
|
||||||
uid, err := convert.StringToUUID(id)
|
uid, err := convert.StringToUUID(id)
|
||||||
@@ -86,9 +92,9 @@ func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string
|
|||||||
uids = append(uids, uid)
|
uids = append(uids, uid)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.q.DeleteRagChunksBySourceIDs(ctx, sqlc.DeleteRagChunksBySourceIDsParams{
|
return r.q.DeleteRagChunksBySourceIDs(ctx, sqlc.DeleteRagChunksBySourceIDsParams{
|
||||||
SourceType: sourceType,
|
SourceType: sourceType,
|
||||||
Column2: uids,
|
Column2: uids,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"history-api/internal/models"
|
"history-api/internal/models"
|
||||||
"history-api/internal/repositories"
|
"history-api/internal/repositories"
|
||||||
"history-api/pkg/ai"
|
"history-api/pkg/ai"
|
||||||
|
"history-api/pkg/config"
|
||||||
"history-api/pkg/constants"
|
"history-api/pkg/constants"
|
||||||
"history-api/pkg/convert"
|
"history-api/pkg/convert"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -52,25 +53,49 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str
|
|||||||
return "", fmt.Errorf("invalid user id: %w", err)
|
return "", fmt.Errorf("invalid user id: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
qVector, err := s.ragUtils.EmbedQuery(ctx, question)
|
searchQuery := strings.TrimSpace(question)
|
||||||
|
rewriteHistory := s.getRewriteHistory(ctx, pgUserID)
|
||||||
|
rewrittenQuery, err := s.ragUtils.RewriteSearchQuery(ctx, question, rewriteHistory)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Msg("failed to rewrite RAG search query")
|
||||||
|
} else if strings.TrimSpace(rewrittenQuery) != "" {
|
||||||
|
searchQuery = strings.TrimSpace(rewrittenQuery)
|
||||||
|
}
|
||||||
|
|
||||||
|
candidateLimit := config.GetIntConfigWithDefault("RAG_RETRIEVAL_CANDIDATES", 30)
|
||||||
|
if candidateLimit < 8 {
|
||||||
|
candidateLimit = 8
|
||||||
|
}
|
||||||
|
|
||||||
|
contextLimit := config.GetIntConfigWithDefault("RAG_CONTEXT_TOP_N", 8)
|
||||||
|
if contextLimit <= 0 {
|
||||||
|
contextLimit = 8
|
||||||
|
}
|
||||||
|
if contextLimit > candidateLimit {
|
||||||
|
contextLimit = candidateLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
qVector, err := s.ragUtils.EmbedQuery(ctx, searchQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to embed question: %w", err)
|
return "", fmt.Errorf("failed to embed question: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := s.repo.SearchSimilar(ctx, projectID, qVector, 8, 0.50)
|
results, err := s.repo.SearchSimilar(ctx, projectID, qVector, candidateLimit, 0.45)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to search similar content: %w", err)
|
return "", fmt.Errorf("failed to search similar content: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(results) < 3 {
|
if len(results) < contextLimit {
|
||||||
broadResults, err := s.repo.SearchSimilar(ctx, projectID, qVector, 8, 0.35)
|
broadResults, err := s.repo.SearchSimilar(ctx, projectID, qVector, candidateLimit, 0.30)
|
||||||
if err == nil && len(broadResults) > len(results) {
|
if err == nil && len(broadResults) > len(results) {
|
||||||
results = broadResults
|
results = broadResults
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
results = s.rerankResults(ctx, searchQuery, results, contextLimit)
|
||||||
|
|
||||||
var contextBuilder strings.Builder
|
var contextBuilder strings.Builder
|
||||||
contextBuilder.Grow(len(results) * 96)
|
contextBuilder.Grow(len(results) * 128)
|
||||||
for i, res := range results {
|
for i, res := range results {
|
||||||
contextBuilder.WriteString(fmt.Sprintf("<doc id=\"%d\" score=\"%.2f\">\n%s\n</doc>\n\n", i+1, res.Similarity, res.Content))
|
contextBuilder.WriteString(fmt.Sprintf("<doc id=\"%d\" score=\"%.2f\">\n%s\n</doc>\n\n", i+1, res.Similarity, res.Content))
|
||||||
}
|
}
|
||||||
@@ -115,6 +140,8 @@ Rules:
|
|||||||
- Do not cite documents unless the user asks.
|
- Do not cite documents unless the user asks.
|
||||||
- Your final response MUST be wrapped inside <answer> tags.
|
- Your final response MUST be wrapped inside <answer> tags.
|
||||||
- Do not output anything outside <answer> tags.
|
- Do not output anything outside <answer> tags.
|
||||||
|
- Before writing the final answer, silently verify that each factual sentence is supported by Context.
|
||||||
|
- Keep the answer focused on the user's question. Do not add background information that is not needed.
|
||||||
- Answer in complete, natural, grammatically correct sentences.`, contextStr, question)
|
- Answer in complete, natural, grammatically correct sentences.`, contextStr, question)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,6 +168,99 @@ Rules:
|
|||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *chatbotService) getRewriteHistory(ctx context.Context, userID pgtype.UUID) []ai.ChatTurn {
|
||||||
|
limit := config.GetIntConfigWithDefault("RAG_REWRITE_HISTORY_TURNS", 3)
|
||||||
|
if limit <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if limit > 5 {
|
||||||
|
limit = 5
|
||||||
|
}
|
||||||
|
|
||||||
|
history, err := s.chatRepo.GetChatbotHistory(ctx, sqlc.GetChatbotHistoryParams{
|
||||||
|
UserID: userID,
|
||||||
|
Limit: int32(limit),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Msg("failed to load chatbot history for RAG query rewrite")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
turns := make([]ai.ChatTurn, 0, len(history))
|
||||||
|
for _, item := range history {
|
||||||
|
if item == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
turns = append(turns, ai.ChatTurn{
|
||||||
|
Question: item.Question,
|
||||||
|
Answer: item.Answer,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return turns
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *chatbotService) rerankResults(ctx context.Context, query string, results []*models.RagChunk, limit int) []*models.RagChunk {
|
||||||
|
if len(results) == 0 {
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = len(results)
|
||||||
|
}
|
||||||
|
if limit > len(results) {
|
||||||
|
limit = len(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
documents := make([]string, len(results))
|
||||||
|
for i, result := range results {
|
||||||
|
documents[i] = result.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
ranked, err := s.ragUtils.RerankDocuments(ctx, query, documents, limit)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Msg("failed to rerank RAG results")
|
||||||
|
return limitRagResults(results, limit)
|
||||||
|
}
|
||||||
|
if len(ranked) == 0 {
|
||||||
|
return limitRagResults(results, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
selected := make([]*models.RagChunk, 0, limit)
|
||||||
|
seen := make(map[int]struct{}, limit)
|
||||||
|
for _, item := range ranked {
|
||||||
|
if item.Index < 0 || item.Index >= len(results) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := seen[item.Index]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[item.Index] = struct{}{}
|
||||||
|
selected = append(selected, results[item.Index])
|
||||||
|
if len(selected) >= limit {
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, result := range results {
|
||||||
|
if _, exists := seen[i]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
selected = append(selected, result)
|
||||||
|
if len(selected) >= limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
func limitRagResults(results []*models.RagChunk, limit int) []*models.RagChunk {
|
||||||
|
if limit <= 0 || len(results) <= limit {
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
return results[:limit]
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeAnswer(s string) string {
|
func normalizeAnswer(s string) string {
|
||||||
s = strings.TrimSpace(s)
|
s = strings.TrimSpace(s)
|
||||||
|
|
||||||
@@ -157,6 +277,7 @@ func normalizeAnswer(s string) string {
|
|||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *chatbotService) GetHistory(ctx context.Context, userID string, dto *request.GetChatbotHistoryDto) ([]*models.ChatbotHistoryEntity, error) {
|
func (s *chatbotService) GetHistory(ctx context.Context, userID string, dto *request.GetChatbotHistoryDto) ([]*models.ChatbotHistoryEntity, error) {
|
||||||
pgUserID, err := convert.StringToUUID(userID)
|
pgUserID, err := convert.StringToUUID(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
+377
-17
@@ -1,12 +1,17 @@
|
|||||||
package ai
|
package ai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"history-api/pkg/config"
|
"history-api/pkg/config"
|
||||||
|
json "history-api/pkg/jsonx"
|
||||||
"html"
|
"html"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/tmc/langchaingo/embeddings"
|
"github.com/tmc/langchaingo/embeddings"
|
||||||
"github.com/tmc/langchaingo/llms"
|
"github.com/tmc/langchaingo/llms"
|
||||||
@@ -15,31 +20,49 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RagUtils struct {
|
type RagUtils struct {
|
||||||
llm llms.Model
|
llm llms.Model
|
||||||
embedder *embeddings.EmbedderImpl
|
embedder *embeddings.EmbedderImpl
|
||||||
|
httpClient *http.Client
|
||||||
|
openRouterAPIKey string
|
||||||
|
model string
|
||||||
|
fallbackModel string
|
||||||
|
generationMaxRetries int
|
||||||
|
generationRetryDelay time.Duration
|
||||||
|
rerankEnabled bool
|
||||||
|
rerankModel string
|
||||||
|
rerankFallbackModel string
|
||||||
|
rerankMaxRetries int
|
||||||
|
rerankRetryDelay time.Duration
|
||||||
|
queryRewriteEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
|
var htmlTagRegex = regexp.MustCompile(`<[^>]*>`)
|
||||||
|
|
||||||
|
const openRouterBaseURL = "https://openrouter.ai/api/v1"
|
||||||
|
|
||||||
|
type RerankResult struct {
|
||||||
|
Index int
|
||||||
|
Score float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatTurn struct {
|
||||||
|
Question string
|
||||||
|
Answer string
|
||||||
|
}
|
||||||
|
|
||||||
func NewRagUtils() (*RagUtils, error) {
|
func NewRagUtils() (*RagUtils, error) {
|
||||||
openRouterAPIKey, err := config.GetConfig("OPEN_ROUTER_API")
|
openRouterAPIKey := config.GetConfigWithDefault("OPEN_ROUTER_API", "")
|
||||||
if err != nil {
|
if openRouterAPIKey == "" {
|
||||||
return nil, err
|
return nil, fmt.Errorf("OPEN_ROUTER_API is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := config.GetConfig("OPEN_ROUTER_MODEL")
|
model := config.GetConfigWithDefault("OPEN_ROUTER_MODEL", "qwen/qwen3.5-flash-02-23")
|
||||||
if err != nil {
|
|
||||||
model = "qwen/qwen3.5-flash-02-23"
|
|
||||||
}
|
|
||||||
|
|
||||||
embeddingModel, err := config.GetConfig("OPEN_ROUTER_EMBEDDING_MODEL")
|
embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b")
|
||||||
if err != nil {
|
|
||||||
embeddingModel = "qwen/qwen3-embedding-8b"
|
|
||||||
}
|
|
||||||
|
|
||||||
llm, err := openai.New(
|
llm, err := openai.New(
|
||||||
openai.WithToken(openRouterAPIKey),
|
openai.WithToken(openRouterAPIKey),
|
||||||
openai.WithBaseURL("https://openrouter.ai/api/v1"),
|
openai.WithBaseURL(openRouterBaseURL),
|
||||||
openai.WithModel(model),
|
openai.WithModel(model),
|
||||||
openai.WithEmbeddingModel(embeddingModel),
|
openai.WithEmbeddingModel(embeddingModel),
|
||||||
)
|
)
|
||||||
@@ -52,9 +75,52 @@ func NewRagUtils() (*RagUtils, error) {
|
|||||||
return nil, fmt.Errorf("failed to init embedder: %w", err)
|
return nil, fmt.Errorf("failed to init embedder: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
timeoutSeconds := config.GetIntConfigWithDefault("RAG_RERANK_TIMEOUT_SECONDS", 10)
|
||||||
|
if timeoutSeconds <= 0 {
|
||||||
|
timeoutSeconds = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
rerankMaxRetries := config.GetIntConfigWithDefault("RAG_RERANK_MAX_RETRIES", 2)
|
||||||
|
if rerankMaxRetries < 0 {
|
||||||
|
rerankMaxRetries = 0
|
||||||
|
}
|
||||||
|
if rerankMaxRetries > 5 {
|
||||||
|
rerankMaxRetries = 5
|
||||||
|
}
|
||||||
|
|
||||||
|
rerankRetryDelayMs := config.GetIntConfigWithDefault("RAG_RERANK_RETRY_DELAY_MS", 250)
|
||||||
|
if rerankRetryDelayMs <= 0 {
|
||||||
|
rerankRetryDelayMs = 250
|
||||||
|
}
|
||||||
|
|
||||||
|
generationMaxRetries := config.GetIntConfigWithDefault("RAG_GENERATION_MAX_RETRIES", 2)
|
||||||
|
if generationMaxRetries < 0 {
|
||||||
|
generationMaxRetries = 0
|
||||||
|
}
|
||||||
|
if generationMaxRetries > 5 {
|
||||||
|
generationMaxRetries = 5
|
||||||
|
}
|
||||||
|
|
||||||
|
generationRetryDelayMs := config.GetIntConfigWithDefault("RAG_GENERATION_RETRY_DELAY_MS", 500)
|
||||||
|
if generationRetryDelayMs <= 0 {
|
||||||
|
generationRetryDelayMs = 500
|
||||||
|
}
|
||||||
|
|
||||||
return &RagUtils{
|
return &RagUtils{
|
||||||
llm: llm,
|
llm: llm,
|
||||||
embedder: embedder,
|
embedder: embedder,
|
||||||
|
httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second},
|
||||||
|
openRouterAPIKey: openRouterAPIKey,
|
||||||
|
model: model,
|
||||||
|
fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"),
|
||||||
|
generationMaxRetries: generationMaxRetries,
|
||||||
|
generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond,
|
||||||
|
rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true),
|
||||||
|
rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"),
|
||||||
|
rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"),
|
||||||
|
rerankMaxRetries: rerankMaxRetries,
|
||||||
|
rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond,
|
||||||
|
queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,14 +169,308 @@ func (u *RagUtils) EmbedQuery(ctx context.Context, query string) ([]float32, err
|
|||||||
return vector, nil
|
return vector, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RagUtils) GenerateResponse(ctx context.Context, prompt string) (string, error) {
|
func (u *RagUtils) RewriteSearchQuery(ctx context.Context, question string, history []ChatTurn) (string, error) {
|
||||||
|
question = normalizeWhitespace(question)
|
||||||
|
if question == "" || !u.queryRewriteEnabled {
|
||||||
|
return question, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
historyText := buildHistoryForRewrite(history)
|
||||||
|
if historyText == "" {
|
||||||
|
historyText = "No previous chat turns."
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := fmt.Sprintf(`You rewrite user questions into precise retrieval queries for a history RAG system.
|
||||||
|
|
||||||
|
Previous chat turns:
|
||||||
|
%s
|
||||||
|
|
||||||
|
Current user question:
|
||||||
|
%s
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- Output exactly one retrieval query inside <query> tags.
|
||||||
|
- Do not answer the question.
|
||||||
|
- Do not add names, dates, places, causes, results, or facts unless they are explicitly present in the current question or previous chat turns.
|
||||||
|
- Use previous chat turns only to resolve clear references such as "đó", "ông ấy", "bà ấy", "sự kiện này", "that event", or "he/she".
|
||||||
|
- Preserve proper nouns, dates, quoted text, and historical terms exactly when possible.
|
||||||
|
- If the current question is already clear, return it unchanged.
|
||||||
|
- If a reference is ambiguous, keep the ambiguous wording instead of guessing.
|
||||||
|
- Use the same language as the current question.
|
||||||
|
- No markdown, no bullet points, no explanation.
|
||||||
|
|
||||||
|
<query>`, historyText, question)
|
||||||
|
|
||||||
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt)
|
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
query := normalizeGeneratedQuery(raw)
|
||||||
|
if query == "" {
|
||||||
|
return question, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return query, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RagUtils) RerankDocuments(ctx context.Context, query string, documents []string, topN int) ([]RerankResult, error) {
|
||||||
|
if !u.rerankEnabled || len(documents) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if topN <= 0 || topN > len(documents) {
|
||||||
|
topN = len(documents)
|
||||||
|
}
|
||||||
|
|
||||||
|
model := strings.TrimSpace(u.rerankModel)
|
||||||
|
if model == "" {
|
||||||
|
model = "cohere/rerank-4-pro"
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := u.rerankDocumentsWithModel(ctx, model, query, documents, topN)
|
||||||
|
if err == nil {
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fallbackModel := strings.TrimSpace(u.rerankFallbackModel)
|
||||||
|
if fallbackModel == "" || fallbackModel == model {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fallbackResults, fallbackErr := u.rerankDocumentsWithModel(ctx, fallbackModel, query, documents, topN)
|
||||||
|
if fallbackErr == nil {
|
||||||
|
return fallbackResults, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("rerank failed with model %s: %w; fallback model %s failed: %v", model, err, fallbackModel, fallbackErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RagUtils) rerankDocumentsWithModel(ctx context.Context, model, query string, documents []string, topN int) ([]RerankResult, error) {
|
||||||
|
reqBody := rerankRequest{
|
||||||
|
Model: model,
|
||||||
|
Query: query,
|
||||||
|
Documents: documents,
|
||||||
|
TopN: topN,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 0; attempt <= u.rerankMaxRetries; attempt++ {
|
||||||
|
results, retryable, err := u.sendRerankRequest(ctx, payload, documents)
|
||||||
|
if err == nil {
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lastErr = err
|
||||||
|
if !retryable || attempt == u.rerankMaxRetries {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
delay := u.rerankRetryDelay * time.Duration(1<<attempt)
|
||||||
|
if delay > 2*time.Second {
|
||||||
|
delay = 2 * time.Second
|
||||||
|
}
|
||||||
|
if sleepErr := sleepWithContext(ctx, delay); sleepErr != nil {
|
||||||
|
return nil, sleepErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("rerank failed with model %s: %w", model, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RagUtils) sendRerankRequest(ctx context.Context, payload []byte, documents []string) ([]RerankResult, bool, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, openRouterBaseURL+"/rerank", bytes.NewReader(payload))
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+u.openRouterAPIKey)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
res, err := u.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, false, ctx.Err()
|
||||||
|
}
|
||||||
|
return nil, true, err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
|
||||||
|
return nil, isRetryableRerankStatus(res.StatusCode), fmt.Errorf("rerank request failed with status %d: %s", res.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, true, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var rerankRes rerankResponse
|
||||||
|
if err := json.Unmarshal(body, &rerankRes); err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make([]RerankResult, 0, len(rerankRes.Results))
|
||||||
|
for _, item := range rerankRes.Results {
|
||||||
|
if item.Index < 0 || item.Index >= len(documents) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
results = append(results, RerankResult{
|
||||||
|
Index: item.Index,
|
||||||
|
Score: item.RelevanceScore,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RagUtils) GenerateResponse(ctx context.Context, prompt string) (string, error) {
|
||||||
|
raw, err := u.generateSinglePromptWithRetry(ctx, prompt, u.model)
|
||||||
|
if err != nil && u.fallbackModel != "" && u.fallbackModel != u.model {
|
||||||
|
raw, err = u.generateSinglePromptWithRetry(ctx, prompt, u.fallbackModel)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
return stripThinking(raw), nil
|
return stripThinking(raw), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *RagUtils) generateSinglePromptWithRetry(ctx context.Context, prompt, model string) (string, error) {
|
||||||
|
var lastErr error
|
||||||
|
options := make([]llms.CallOption, 0, 1)
|
||||||
|
if strings.TrimSpace(model) != "" {
|
||||||
|
options = append(options, llms.WithModel(model))
|
||||||
|
}
|
||||||
|
|
||||||
|
for attempt := 0; attempt <= u.generationMaxRetries; attempt++ {
|
||||||
|
raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt, options...)
|
||||||
|
if err == nil {
|
||||||
|
return raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return "", ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
lastErr = err
|
||||||
|
if attempt == u.generationMaxRetries {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
delay := u.generationRetryDelay * time.Duration(1<<attempt)
|
||||||
|
if delay > 3*time.Second {
|
||||||
|
delay = 3 * time.Second
|
||||||
|
}
|
||||||
|
if sleepErr := sleepWithContext(ctx, delay); sleepErr != nil {
|
||||||
|
return "", sleepErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(model) == "" {
|
||||||
|
return "", lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("generation failed with model %s: %w", model, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
type rerankRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
Documents []string `json:"documents"`
|
||||||
|
TopN int `json:"top_n"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type rerankResponse struct {
|
||||||
|
Results []struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
RelevanceScore float64 `json:"relevance_score"`
|
||||||
|
} `json:"results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildHistoryForRewrite(history []ChatTurn) string {
|
||||||
|
if len(history) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var builder strings.Builder
|
||||||
|
for i, turn := range history {
|
||||||
|
question := normalizeWhitespace(turn.Question)
|
||||||
|
answer := normalizeWhitespace(turn.Answer)
|
||||||
|
if question == "" && answer == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
builder.WriteString(fmt.Sprintf("Turn %d\nUser: %s\nAssistant: %s\n", i+1, question, answer))
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(builder.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeGeneratedQuery(raw string) string {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
|
||||||
|
if query := extractTag(raw, "query"); query != "" {
|
||||||
|
return normalizeWhitespace(query)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw = strings.TrimSpace(strings.TrimPrefix(raw, "Query:"))
|
||||||
|
raw = strings.TrimSpace(strings.TrimPrefix(raw, "Retrieval query:"))
|
||||||
|
raw = strings.Trim(raw, "\"'`")
|
||||||
|
|
||||||
|
return normalizeWhitespace(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractTag(raw, tag string) string {
|
||||||
|
startTag := "<" + tag + ">"
|
||||||
|
endTag := "</" + tag + ">"
|
||||||
|
|
||||||
|
start := strings.LastIndex(raw, startTag)
|
||||||
|
if start == -1 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
content := raw[start+len(startTag):]
|
||||||
|
end := strings.Index(content, endTag)
|
||||||
|
if end == -1 {
|
||||||
|
return strings.TrimSpace(content)
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(content[:end])
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeWhitespace(s string) string {
|
||||||
|
return strings.Join(strings.Fields(strings.TrimSpace(s)), " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRetryableRerankStatus(statusCode int) bool {
|
||||||
|
switch statusCode {
|
||||||
|
case http.StatusRequestTimeout,
|
||||||
|
http.StatusTooManyRequests,
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
http.StatusBadGateway,
|
||||||
|
http.StatusServiceUnavailable,
|
||||||
|
http.StatusGatewayTimeout:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sleepWithContext(ctx context.Context, delay time.Duration) error {
|
||||||
|
timer := time.NewTimer(delay)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func stripThinking(raw string) string {
|
func stripThinking(raw string) string {
|
||||||
startTag := "<answer>"
|
startTag := "<answer>"
|
||||||
endTag := "</answer>"
|
endTag := "</answer>"
|
||||||
|
|||||||
Reference in New Issue
Block a user