feat: enhance RAG functionality with query rewriting, result reranking, and additional metadata in responses
Build and Release / release (push) Successful in 1m28s

This commit is contained in:
2026-06-08 13:01:55 +07:00
parent 15e81ac12b
commit 40fced75d9
6 changed files with 540 additions and 35 deletions
+126 -5
View File
@@ -8,6 +8,7 @@ import (
"history-api/internal/models"
"history-api/internal/repositories"
"history-api/pkg/ai"
"history-api/pkg/config"
"history-api/pkg/constants"
"history-api/pkg/convert"
"strings"
@@ -52,25 +53,49 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str
return "", fmt.Errorf("invalid user id: %w", err)
}
qVector, err := s.ragUtils.EmbedQuery(ctx, question)
searchQuery := strings.TrimSpace(question)
rewriteHistory := s.getRewriteHistory(ctx, pgUserID)
rewrittenQuery, err := s.ragUtils.RewriteSearchQuery(ctx, question, rewriteHistory)
if err != nil {
log.Warn().Err(err).Msg("failed to rewrite RAG search query")
} else if strings.TrimSpace(rewrittenQuery) != "" {
searchQuery = strings.TrimSpace(rewrittenQuery)
}
candidateLimit := config.GetIntConfigWithDefault("RAG_RETRIEVAL_CANDIDATES", 30)
if candidateLimit < 8 {
candidateLimit = 8
}
contextLimit := config.GetIntConfigWithDefault("RAG_CONTEXT_TOP_N", 8)
if contextLimit <= 0 {
contextLimit = 8
}
if contextLimit > candidateLimit {
contextLimit = candidateLimit
}
qVector, err := s.ragUtils.EmbedQuery(ctx, searchQuery)
if err != nil {
return "", fmt.Errorf("failed to embed question: %w", err)
}
results, err := s.repo.SearchSimilar(ctx, projectID, qVector, 8, 0.50)
results, err := s.repo.SearchSimilar(ctx, projectID, qVector, candidateLimit, 0.45)
if err != nil {
return "", fmt.Errorf("failed to search similar content: %w", err)
}
if len(results) < 3 {
broadResults, err := s.repo.SearchSimilar(ctx, projectID, qVector, 8, 0.35)
if len(results) < contextLimit {
broadResults, err := s.repo.SearchSimilar(ctx, projectID, qVector, candidateLimit, 0.30)
if err == nil && len(broadResults) > len(results) {
results = broadResults
}
}
results = s.rerankResults(ctx, searchQuery, results, contextLimit)
var contextBuilder strings.Builder
contextBuilder.Grow(len(results) * 96)
contextBuilder.Grow(len(results) * 128)
for i, res := range results {
contextBuilder.WriteString(fmt.Sprintf("<doc id=\"%d\" score=\"%.2f\">\n%s\n</doc>\n\n", i+1, res.Similarity, res.Content))
}
@@ -115,6 +140,8 @@ Rules:
- Do not cite documents unless the user asks.
- Your final response MUST be wrapped inside <answer> tags.
- Do not output anything outside <answer> tags.
- Before writing the final answer, silently verify that each factual sentence is supported by Context.
- Keep the answer focused on the user's question. Do not add background information that is not needed.
- Answer in complete, natural, grammatically correct sentences.`, contextStr, question)
}
@@ -141,6 +168,99 @@ Rules:
return response, nil
}
func (s *chatbotService) getRewriteHistory(ctx context.Context, userID pgtype.UUID) []ai.ChatTurn {
limit := config.GetIntConfigWithDefault("RAG_REWRITE_HISTORY_TURNS", 3)
if limit <= 0 {
return nil
}
if limit > 5 {
limit = 5
}
history, err := s.chatRepo.GetChatbotHistory(ctx, sqlc.GetChatbotHistoryParams{
UserID: userID,
Limit: int32(limit),
})
if err != nil {
log.Warn().Err(err).Msg("failed to load chatbot history for RAG query rewrite")
return nil
}
turns := make([]ai.ChatTurn, 0, len(history))
for _, item := range history {
if item == nil {
continue
}
turns = append(turns, ai.ChatTurn{
Question: item.Question,
Answer: item.Answer,
})
}
return turns
}
func (s *chatbotService) rerankResults(ctx context.Context, query string, results []*models.RagChunk, limit int) []*models.RagChunk {
if len(results) == 0 {
return results
}
if limit <= 0 {
limit = len(results)
}
if limit > len(results) {
limit = len(results)
}
documents := make([]string, len(results))
for i, result := range results {
documents[i] = result.Content
}
ranked, err := s.ragUtils.RerankDocuments(ctx, query, documents, limit)
if err != nil {
log.Warn().Err(err).Msg("failed to rerank RAG results")
return limitRagResults(results, limit)
}
if len(ranked) == 0 {
return limitRagResults(results, limit)
}
selected := make([]*models.RagChunk, 0, limit)
seen := make(map[int]struct{}, limit)
for _, item := range ranked {
if item.Index < 0 || item.Index >= len(results) {
continue
}
if _, exists := seen[item.Index]; exists {
continue
}
seen[item.Index] = struct{}{}
selected = append(selected, results[item.Index])
if len(selected) >= limit {
return selected
}
}
for i, result := range results {
if _, exists := seen[i]; exists {
continue
}
selected = append(selected, result)
if len(selected) >= limit {
break
}
}
return selected
}
func limitRagResults(results []*models.RagChunk, limit int) []*models.RagChunk {
if limit <= 0 || len(results) <= limit {
return results
}
return results[:limit]
}
func normalizeAnswer(s string) string {
s = strings.TrimSpace(s)
@@ -157,6 +277,7 @@ func normalizeAnswer(s string) string {
return s
}
func (s *chatbotService) GetHistory(ctx context.Context, userID string, dto *request.GetChatbotHistoryDto) ([]*models.ChatbotHistoryEntity, error) {
pgUserID, err := convert.StringToUUID(userID)
if err != nil {