feat: enhance RAG functionality with query rewriting, result reranking, and additional metadata in responses
Build and Release / release (push) Successful in 1m28s
Build and Release / release (push) Successful in 1m28s
This commit is contained in:
@@ -74,7 +74,7 @@ func (q *Queries) DeleteRagChunksBySourceIDs(ctx context.Context, arg DeleteRagC
|
||||
|
||||
const searchRagChunks = `-- name: SearchRagChunks :many
|
||||
SELECT
|
||||
id, source_type, source_id, project_id, chunk_index, content,
|
||||
id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at,
|
||||
(1 - (embedding <=> $1))::float8 AS similarity
|
||||
FROM rag_chunks
|
||||
WHERE 1=1
|
||||
@@ -94,13 +94,15 @@ type SearchRagChunksParams struct {
|
||||
}
|
||||
|
||||
type SearchRagChunksRow struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
SourceType string `json:"source_type"`
|
||||
SourceID pgtype.UUID `json:"source_id"`
|
||||
ProjectID pgtype.UUID `json:"project_id"`
|
||||
ChunkIndex int32 `json:"chunk_index"`
|
||||
Content string `json:"content"`
|
||||
Similarity float64 `json:"similarity"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
SourceType string `json:"source_type"`
|
||||
SourceID pgtype.UUID `json:"source_id"`
|
||||
ProjectID pgtype.UUID `json:"project_id"`
|
||||
ChunkIndex int32 `json:"chunk_index"`
|
||||
Content string `json:"content"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
Similarity float64 `json:"similarity"`
|
||||
}
|
||||
|
||||
func (q *Queries) SearchRagChunks(ctx context.Context, arg SearchRagChunksParams) ([]SearchRagChunksRow, error) {
|
||||
@@ -125,6 +127,8 @@ func (q *Queries) SearchRagChunks(ctx context.Context, arg SearchRagChunksParams
|
||||
&i.ProjectID,
|
||||
&i.ChunkIndex,
|
||||
&i.Content,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Similarity,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -67,8 +67,14 @@ func (r *ragRepository) SearchSimilar(ctx context.Context, projectID *string, ve
|
||||
for i, row := range rows {
|
||||
res[i] = &models.RagChunk{
|
||||
ID: convert.UUIDToString(row.ID),
|
||||
SourceType: row.SourceType,
|
||||
SourceID: convert.UUIDToString(row.SourceID),
|
||||
ProjectID: convert.UUIDToString(row.ProjectID),
|
||||
ChunkIndex: row.ChunkIndex,
|
||||
Content: row.Content,
|
||||
Similarity: row.Similarity,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
UpdatedAt: row.UpdatedAt.Time,
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
@@ -78,7 +84,7 @@ func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string
|
||||
if len(sourceIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
uids := make([]pgtype.UUID, 0, len(sourceIDs))
|
||||
for _, id := range sourceIDs {
|
||||
uid, err := convert.StringToUUID(id)
|
||||
@@ -86,9 +92,9 @@ func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string
|
||||
uids = append(uids, uid)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return r.q.DeleteRagChunksBySourceIDs(ctx, sqlc.DeleteRagChunksBySourceIDsParams{
|
||||
SourceType: sourceType,
|
||||
Column2: uids,
|
||||
Column2: uids,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"history-api/internal/models"
|
||||
"history-api/internal/repositories"
|
||||
"history-api/pkg/ai"
|
||||
"history-api/pkg/config"
|
||||
"history-api/pkg/constants"
|
||||
"history-api/pkg/convert"
|
||||
"strings"
|
||||
@@ -52,25 +53,49 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str
|
||||
return "", fmt.Errorf("invalid user id: %w", err)
|
||||
}
|
||||
|
||||
qVector, err := s.ragUtils.EmbedQuery(ctx, question)
|
||||
searchQuery := strings.TrimSpace(question)
|
||||
rewriteHistory := s.getRewriteHistory(ctx, pgUserID)
|
||||
rewrittenQuery, err := s.ragUtils.RewriteSearchQuery(ctx, question, rewriteHistory)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("failed to rewrite RAG search query")
|
||||
} else if strings.TrimSpace(rewrittenQuery) != "" {
|
||||
searchQuery = strings.TrimSpace(rewrittenQuery)
|
||||
}
|
||||
|
||||
candidateLimit := config.GetIntConfigWithDefault("RAG_RETRIEVAL_CANDIDATES", 30)
|
||||
if candidateLimit < 8 {
|
||||
candidateLimit = 8
|
||||
}
|
||||
|
||||
contextLimit := config.GetIntConfigWithDefault("RAG_CONTEXT_TOP_N", 8)
|
||||
if contextLimit <= 0 {
|
||||
contextLimit = 8
|
||||
}
|
||||
if contextLimit > candidateLimit {
|
||||
contextLimit = candidateLimit
|
||||
}
|
||||
|
||||
qVector, err := s.ragUtils.EmbedQuery(ctx, searchQuery)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to embed question: %w", err)
|
||||
}
|
||||
|
||||
results, err := s.repo.SearchSimilar(ctx, projectID, qVector, 8, 0.50)
|
||||
results, err := s.repo.SearchSimilar(ctx, projectID, qVector, candidateLimit, 0.45)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to search similar content: %w", err)
|
||||
}
|
||||
|
||||
if len(results) < 3 {
|
||||
broadResults, err := s.repo.SearchSimilar(ctx, projectID, qVector, 8, 0.35)
|
||||
if len(results) < contextLimit {
|
||||
broadResults, err := s.repo.SearchSimilar(ctx, projectID, qVector, candidateLimit, 0.30)
|
||||
if err == nil && len(broadResults) > len(results) {
|
||||
results = broadResults
|
||||
}
|
||||
}
|
||||
|
||||
results = s.rerankResults(ctx, searchQuery, results, contextLimit)
|
||||
|
||||
var contextBuilder strings.Builder
|
||||
contextBuilder.Grow(len(results) * 96)
|
||||
contextBuilder.Grow(len(results) * 128)
|
||||
for i, res := range results {
|
||||
contextBuilder.WriteString(fmt.Sprintf("<doc id=\"%d\" score=\"%.2f\">\n%s\n</doc>\n\n", i+1, res.Similarity, res.Content))
|
||||
}
|
||||
@@ -115,6 +140,8 @@ Rules:
|
||||
- Do not cite documents unless the user asks.
|
||||
- Your final response MUST be wrapped inside <answer> tags.
|
||||
- Do not output anything outside <answer> tags.
|
||||
- Before writing the final answer, silently verify that each factual sentence is supported by Context.
|
||||
- Keep the answer focused on the user's question. Do not add background information that is not needed.
|
||||
- Answer in complete, natural, grammatically correct sentences.`, contextStr, question)
|
||||
}
|
||||
|
||||
@@ -141,6 +168,99 @@ Rules:
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *chatbotService) getRewriteHistory(ctx context.Context, userID pgtype.UUID) []ai.ChatTurn {
|
||||
limit := config.GetIntConfigWithDefault("RAG_REWRITE_HISTORY_TURNS", 3)
|
||||
if limit <= 0 {
|
||||
return nil
|
||||
}
|
||||
if limit > 5 {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
history, err := s.chatRepo.GetChatbotHistory(ctx, sqlc.GetChatbotHistoryParams{
|
||||
UserID: userID,
|
||||
Limit: int32(limit),
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("failed to load chatbot history for RAG query rewrite")
|
||||
return nil
|
||||
}
|
||||
|
||||
turns := make([]ai.ChatTurn, 0, len(history))
|
||||
for _, item := range history {
|
||||
if item == nil {
|
||||
continue
|
||||
}
|
||||
turns = append(turns, ai.ChatTurn{
|
||||
Question: item.Question,
|
||||
Answer: item.Answer,
|
||||
})
|
||||
}
|
||||
|
||||
return turns
|
||||
}
|
||||
|
||||
func (s *chatbotService) rerankResults(ctx context.Context, query string, results []*models.RagChunk, limit int) []*models.RagChunk {
|
||||
if len(results) == 0 {
|
||||
return results
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = len(results)
|
||||
}
|
||||
if limit > len(results) {
|
||||
limit = len(results)
|
||||
}
|
||||
|
||||
documents := make([]string, len(results))
|
||||
for i, result := range results {
|
||||
documents[i] = result.Content
|
||||
}
|
||||
|
||||
ranked, err := s.ragUtils.RerankDocuments(ctx, query, documents, limit)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("failed to rerank RAG results")
|
||||
return limitRagResults(results, limit)
|
||||
}
|
||||
if len(ranked) == 0 {
|
||||
return limitRagResults(results, limit)
|
||||
}
|
||||
|
||||
selected := make([]*models.RagChunk, 0, limit)
|
||||
seen := make(map[int]struct{}, limit)
|
||||
for _, item := range ranked {
|
||||
if item.Index < 0 || item.Index >= len(results) {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[item.Index]; exists {
|
||||
continue
|
||||
}
|
||||
seen[item.Index] = struct{}{}
|
||||
selected = append(selected, results[item.Index])
|
||||
if len(selected) >= limit {
|
||||
return selected
|
||||
}
|
||||
}
|
||||
|
||||
for i, result := range results {
|
||||
if _, exists := seen[i]; exists {
|
||||
continue
|
||||
}
|
||||
selected = append(selected, result)
|
||||
if len(selected) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
func limitRagResults(results []*models.RagChunk, limit int) []*models.RagChunk {
|
||||
if limit <= 0 || len(results) <= limit {
|
||||
return results
|
||||
}
|
||||
return results[:limit]
|
||||
}
|
||||
|
||||
func normalizeAnswer(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
@@ -157,6 +277,7 @@ func normalizeAnswer(s string) string {
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *chatbotService) GetHistory(ctx context.Context, userID string, dto *request.GetChatbotHistoryDto) ([]*models.ChatbotHistoryEntity, error) {
|
||||
pgUserID, err := convert.StringToUUID(userID)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user