diff --git a/README.md b/README.md index 2140801..0140351 100644 --- a/README.md +++ b/README.md @@ -144,8 +144,10 @@ history-api/ RAG_QUERY_REWRITE_MAX_TOKENS=96 RAG_REWRITE_HISTORY_TURNS=3 RAG_RETRIEVAL_CANDIDATES=30 - RAG_CONTEXT_TOP_N=5 - RAG_CONTEXT_MAX_CHARS=8000 + RAG_LEXICAL_SEARCH_ENABLED=true + RAG_LEXICAL_CANDIDATES=20 + RAG_CONTEXT_TOP_N=6 + RAG_CONTEXT_MAX_CHARS=12000 RAG_GENERATION_MAX_RETRIES=1 RAG_GENERATION_RETRY_DELAY_MS=500 RAG_RERANK_ENABLED=true diff --git a/db/migrations/0000018_rag_chunks_content_trgm.down.sql b/db/migrations/0000018_rag_chunks_content_trgm.down.sql new file mode 100644 index 0000000..91c2e57 --- /dev/null +++ b/db/migrations/0000018_rag_chunks_content_trgm.down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS idx_rag_chunks_content_trgm; diff --git a/db/migrations/0000018_rag_chunks_content_trgm.up.sql b/db/migrations/0000018_rag_chunks_content_trgm.up.sql new file mode 100644 index 0000000..6952397 --- /dev/null +++ b/db/migrations/0000018_rag_chunks_content_trgm.up.sql @@ -0,0 +1,4 @@ +CREATE EXTENSION IF NOT EXISTS pg_trgm; + +CREATE INDEX IF NOT EXISTS idx_rag_chunks_content_trgm +ON rag_chunks USING GIN (content gin_trgm_ops); diff --git a/db/query/rag.sql b/db/query/rag.sql index 016745a..1eb9675 100644 --- a/db/query/rag.sql +++ b/db/query/rag.sql @@ -19,6 +19,42 @@ WHERE 1=1 ORDER BY embedding <=> sqlc.arg('embedding') LIMIT sqlc.arg('match_count'); +-- name: SearchRagChunksLexical :many +WITH terms AS ( + SELECT DISTINCT trim(term) AS term + FROM unnest(sqlc.arg('terms')::text[]) AS term + WHERE trim(term) <> '' +), +matched AS ( + SELECT + r.id, + r.source_type, + r.source_id, + r.project_id, + r.chunk_index, + r.content, + r.created_at, + r.updated_at, + COUNT(*)::float8 AS similarity + FROM rag_chunks r + JOIN terms t ON r.content ILIKE '%' || t.term || '%' + WHERE (sqlc.narg('project_id')::uuid IS NULL OR r.project_id = sqlc.narg('project_id')::uuid) + GROUP BY + r.id, + r.source_type, + r.source_id, + r.project_id, + r.chunk_index, + r.content, + r.created_at, + r.updated_at +) +SELECT + id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at, similarity +FROM matched +ORDER BY similarity DESC, chunk_index ASC +LIMIT sqlc.arg('match_count'); + -- name: DeleteRagChunksBySourceIDs :exec DELETE FROM rag_chunks WHERE source_type = $1 AND source_id = ANY($2::uuid[]); diff --git a/internal/gen/sqlc/rag.sql.go b/internal/gen/sqlc/rag.sql.go index 05b1ebd..2cbba29 100644 --- a/internal/gen/sqlc/rag.sql.go +++ b/internal/gen/sqlc/rag.sql.go @@ -140,3 +140,88 @@ func (q *Queries) SearchRagChunks(ctx context.Context, arg SearchRagChunksParams } return items, nil } + +const searchRagChunksLexical = `-- name: SearchRagChunksLexical :many +WITH terms AS ( + SELECT DISTINCT trim(term) AS term + FROM unnest($2::text[]) AS term + WHERE trim(term) <> '' +), +matched AS ( + SELECT + r.id, + r.source_type, + r.source_id, + r.project_id, + r.chunk_index, + r.content, + r.created_at, + r.updated_at, + COUNT(*)::float8 AS similarity + FROM rag_chunks r + JOIN terms t ON r.content ILIKE '%' || t.term || '%' + WHERE ($3::uuid IS NULL OR r.project_id = $3::uuid) + GROUP BY + r.id, + r.source_type, + r.source_id, + r.project_id, + r.chunk_index, + r.content, + r.created_at, + r.updated_at +) +SELECT + id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at, similarity +FROM matched +ORDER BY similarity DESC, chunk_index ASC +LIMIT $1 +` + +type SearchRagChunksLexicalParams struct { + MatchCount int32 `json:"match_count"` + Terms []string `json:"terms"` + ProjectID pgtype.UUID `json:"project_id"` +} + +type SearchRagChunksLexicalRow struct { + ID pgtype.UUID `json:"id"` + SourceType string `json:"source_type"` + SourceID pgtype.UUID `json:"source_id"` + ProjectID pgtype.UUID `json:"project_id"` + ChunkIndex int32 `json:"chunk_index"` + Content string `json:"content"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` + Similarity float64 `json:"similarity"` +} + +func (q *Queries) SearchRagChunksLexical(ctx context.Context, arg SearchRagChunksLexicalParams) ([]SearchRagChunksLexicalRow, error) { + rows, err := q.db.Query(ctx, searchRagChunksLexical, arg.MatchCount, arg.Terms, arg.ProjectID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []SearchRagChunksLexicalRow{} + for rows.Next() { + var i SearchRagChunksLexicalRow + if err := rows.Scan( + &i.ID, + &i.SourceType, + &i.SourceID, + &i.ProjectID, + &i.ChunkIndex, + &i.Content, + &i.CreatedAt, + &i.UpdatedAt, + &i.Similarity, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/repositories/ragRepository.go b/internal/repositories/ragRepository.go index c649f2e..03e80ac 100644 --- a/internal/repositories/ragRepository.go +++ b/internal/repositories/ragRepository.go @@ -15,6 +15,7 @@ import ( type RagRepository interface { SaveChunk(ctx context.Context, sourceType string, sourceID string, projectID string, index int, content string, vector []float32) error SearchSimilar(ctx context.Context, projectID *string, vector []float32, limit int, threshold float64) ([]*models.RagChunk, error) + SearchLexical(ctx context.Context, projectID *string, terms []string, limit int) ([]*models.RagChunk, error) DeleteBySourceIDs(ctx context.Context, sourceType string, sourceIDs []string) error WithTx(tx pgx.Tx) RagRepository } @@ -80,6 +81,38 @@ func (r *ragRepository) SearchSimilar(ctx context.Context, projectID *string, ve return res, nil } +func (r *ragRepository) SearchLexical(ctx context.Context, projectID *string, terms []string, limit int) ([]*models.RagChunk, error) { + params := sqlc.SearchRagChunksLexicalParams{ + Terms: terms, + MatchCount: int32(limit), + } + if projectID != nil && *projectID != "" { + pID, _ := convert.StringToUUID(*projectID) + params.ProjectID = pID + } + + rows, err := r.q.SearchRagChunksLexical(ctx, params) + if err != nil { + return nil, err + } + + res := make([]*models.RagChunk, len(rows)) + for i, row := range rows { + res[i] = &models.RagChunk{ + ID: convert.UUIDToString(row.ID), + SourceType: row.SourceType, + SourceID: convert.UUIDToString(row.SourceID), + ProjectID: convert.UUIDToString(row.ProjectID), + ChunkIndex: row.ChunkIndex, + Content: row.Content, + Similarity: row.Similarity, + CreatedAt: row.CreatedAt.Time, + UpdatedAt: row.UpdatedAt.Time, + } + } + return res, nil +} + func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string, sourceIDs []string) error { if len(sourceIDs) == 0 { return nil diff --git a/internal/services/chatbotService.go b/internal/services/chatbotService.go index 1e29d6e..5c06ee5 100644 --- a/internal/services/chatbotService.go +++ b/internal/services/chatbotService.go @@ -11,6 +11,7 @@ import ( "history-api/pkg/config" "history-api/pkg/constants" "history-api/pkg/convert" + "regexp" "strings" "time" @@ -30,6 +31,8 @@ type chatbotService struct { ragUtils *ai.RagUtils } +var lexicalTokenRegex = regexp.MustCompile(`[\p{L}\p{N}]+`) + func NewChatbotService(repo repositories.RagRepository, usageRepo repositories.UsageRepository, chatRepo repositories.ChatRepository, ragUtils *ai.RagUtils) ChatbotService { return &chatbotService{ repo: repo, @@ -173,6 +176,37 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str broadResultCount = len(broadResults) } + var lexicalSearchDuration time.Duration + lexicalResultCount := 0 + lexicalTerms := extractLexicalSearchTerms(searchQuery) + if config.GetBoolConfigWithDefault("RAG_LEXICAL_SEARCH_ENABLED", true) && len(lexicalTerms) > 0 { + lexicalLimit := config.GetIntConfigWithDefault("RAG_LEXICAL_CANDIDATES", 20) + if lexicalLimit < 0 { + lexicalLimit = 0 + } + if lexicalLimit > candidateLimit { + lexicalLimit = candidateLimit + } + if lexicalLimit > 0 { + lexicalSearchStart := time.Now() + lexicalResults, err := s.repo.SearchLexical(ctx, projectID, lexicalTerms, lexicalLimit) + lexicalSearchDuration = time.Since(lexicalSearchStart) + if err == nil { + lexicalResultCount = len(lexicalResults) + results = mergeRagResults(results, lexicalResults, candidateLimit+lexicalLimit) + } else { + log.Warn(). + Err(err). + Str("userID", userID). + Str("projectID", projectIDLog). + Int("lexical_terms", len(lexicalTerms)). + Int("lexical_limit", lexicalLimit). + Dur("lexical_search_duration", lexicalSearchDuration). + Msg("rag lexical search failed") + } + } + } + rerankStart := time.Now() results = s.rerankResults(ctx, searchQuery, results, contextLimit) rerankDuration := time.Since(rerankStart) @@ -276,6 +310,8 @@ Rules: Int("context_limit", contextLimit). Int("initial_results", initialResultCount). Int("broad_results", broadResultCount). + Int("lexical_terms", len(lexicalTerms)). + Int("lexical_results", lexicalResultCount). Int("final_results", len(results)). Int("context_chars", len(contextStr)). Int("prompt_chars", len(prompt)). @@ -287,6 +323,7 @@ Rules: Dur("embed_duration", embedDuration). Dur("vector_search_duration", searchDuration). Dur("broad_search_duration", broadSearchDuration). + Dur("lexical_search_duration", lexicalSearchDuration). Dur("rerank_duration", rerankDuration). Dur("prompt_build_duration", promptBuildDuration). Dur("generate_duration", generateDuration). @@ -330,6 +367,90 @@ func (s *chatbotService) getRewriteHistory(ctx context.Context, userID pgtype.UU return turns } +func extractLexicalSearchTerms(query string) []string { + tokens := lexicalTokenRegex.FindAllString(strings.ToLower(query), -1) + if len(tokens) == 0 { + return nil + } + + terms := make([]string, 0, len(tokens)) + seen := make(map[string]struct{}, len(tokens)) + addTerm := func(term string) { + term = strings.TrimSpace(term) + if term == "" { + return + } + if _, exists := seen[term]; exists { + return + } + seen[term] = struct{}{} + terms = append(terms, term) + } + + for _, token := range tokens { + if hasASCIIDigit(token) || len([]rune(token)) >= 4 { + addTerm(token) + } + } + + for i := 0; i < len(tokens)-1; i++ { + left := tokens[i] + right := tokens[i+1] + if hasASCIIDigit(left) || hasASCIIDigit(right) || len([]rune(left))+len([]rune(right)) >= 5 { + addTerm(left + " " + right) + } + } + + if len(terms) > 12 { + return terms[:12] + } + + return terms +} + +func hasASCIIDigit(s string) bool { + for _, r := range s { + if r >= '0' && r <= '9' { + return true + } + } + return false +} + +func mergeRagResults(primary, extra []*models.RagChunk, limit int) []*models.RagChunk { + if len(primary) == 0 { + return limitRagResults(extra, limit) + } + if len(extra) == 0 { + return limitRagResults(primary, limit) + } + if limit <= 0 { + limit = len(primary) + len(extra) + } + + merged := make([]*models.RagChunk, 0, min(limit, len(primary)+len(extra))) + seen := make(map[string]struct{}, len(primary)+len(extra)) + appendChunk := func(chunk *models.RagChunk) { + if chunk == nil || len(merged) >= limit { + return + } + if _, exists := seen[chunk.ID]; exists { + return + } + seen[chunk.ID] = struct{}{} + merged = append(merged, chunk) + } + + for _, chunk := range primary { + appendChunk(chunk) + } + for _, chunk := range extra { + appendChunk(chunk) + } + + return merged +} + func (s *chatbotService) rerankResults(ctx context.Context, query string, results []*models.RagChunk, limit int) []*models.RagChunk { if len(results) == 0 { return results