feat: implement lexical search functionality in RAG, including new SQL queries and service methods
Build and Release / release (push) Failing after 1m34s

This commit is contained in:
2026-06-08 14:00:20 +07:00
parent a77b856973
commit 53e1e4b8ea
7 changed files with 284 additions and 2 deletions
+4 -2
View File
@@ -144,8 +144,10 @@ history-api/
RAG_QUERY_REWRITE_MAX_TOKENS=96 RAG_QUERY_REWRITE_MAX_TOKENS=96
RAG_REWRITE_HISTORY_TURNS=3 RAG_REWRITE_HISTORY_TURNS=3
RAG_RETRIEVAL_CANDIDATES=30 RAG_RETRIEVAL_CANDIDATES=30
RAG_CONTEXT_TOP_N=5 RAG_LEXICAL_SEARCH_ENABLED=true
RAG_CONTEXT_MAX_CHARS=8000 RAG_LEXICAL_CANDIDATES=20
RAG_CONTEXT_TOP_N=6
RAG_CONTEXT_MAX_CHARS=12000
RAG_GENERATION_MAX_RETRIES=1 RAG_GENERATION_MAX_RETRIES=1
RAG_GENERATION_RETRY_DELAY_MS=500 RAG_GENERATION_RETRY_DELAY_MS=500
RAG_RERANK_ENABLED=true RAG_RERANK_ENABLED=true
@@ -0,0 +1 @@
DROP INDEX IF EXISTS idx_rag_chunks_content_trgm;
@@ -0,0 +1,4 @@
CREATE EXTENSION IF NOT EXISTS pg_trgm;
CREATE INDEX IF NOT EXISTS idx_rag_chunks_content_trgm
ON rag_chunks USING GIN (content gin_trgm_ops);
+36
View File
@@ -19,6 +19,42 @@ WHERE 1=1
ORDER BY embedding <=> sqlc.arg('embedding') ORDER BY embedding <=> sqlc.arg('embedding')
LIMIT sqlc.arg('match_count'); LIMIT sqlc.arg('match_count');
-- name: SearchRagChunksLexical :many
WITH terms AS (
SELECT DISTINCT trim(term) AS term
FROM unnest(sqlc.arg('terms')::text[]) AS term
WHERE trim(term) <> ''
),
matched AS (
SELECT
r.id,
r.source_type,
r.source_id,
r.project_id,
r.chunk_index,
r.content,
r.created_at,
r.updated_at,
COUNT(*)::float8 AS similarity
FROM rag_chunks r
JOIN terms t ON r.content ILIKE '%' || t.term || '%'
WHERE (sqlc.narg('project_id')::uuid IS NULL OR r.project_id = sqlc.narg('project_id')::uuid)
GROUP BY
r.id,
r.source_type,
r.source_id,
r.project_id,
r.chunk_index,
r.content,
r.created_at,
r.updated_at
)
SELECT
id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at, similarity
FROM matched
ORDER BY similarity DESC, chunk_index ASC
LIMIT sqlc.arg('match_count');
-- name: DeleteRagChunksBySourceIDs :exec -- name: DeleteRagChunksBySourceIDs :exec
DELETE FROM rag_chunks DELETE FROM rag_chunks
WHERE source_type = $1 AND source_id = ANY($2::uuid[]); WHERE source_type = $1 AND source_id = ANY($2::uuid[]);
+85
View File
@@ -140,3 +140,88 @@ func (q *Queries) SearchRagChunks(ctx context.Context, arg SearchRagChunksParams
} }
return items, nil return items, nil
} }
const searchRagChunksLexical = `-- name: SearchRagChunksLexical :many
WITH terms AS (
SELECT DISTINCT trim(term) AS term
FROM unnest($2::text[]) AS term
WHERE trim(term) <> ''
),
matched AS (
SELECT
r.id,
r.source_type,
r.source_id,
r.project_id,
r.chunk_index,
r.content,
r.created_at,
r.updated_at,
COUNT(*)::float8 AS similarity
FROM rag_chunks r
JOIN terms t ON r.content ILIKE '%' || t.term || '%'
WHERE ($3::uuid IS NULL OR r.project_id = $3::uuid)
GROUP BY
r.id,
r.source_type,
r.source_id,
r.project_id,
r.chunk_index,
r.content,
r.created_at,
r.updated_at
)
SELECT
id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at, similarity
FROM matched
ORDER BY similarity DESC, chunk_index ASC
LIMIT $1
`
type SearchRagChunksLexicalParams struct {
MatchCount int32 `json:"match_count"`
Terms []string `json:"terms"`
ProjectID pgtype.UUID `json:"project_id"`
}
type SearchRagChunksLexicalRow struct {
ID pgtype.UUID `json:"id"`
SourceType string `json:"source_type"`
SourceID pgtype.UUID `json:"source_id"`
ProjectID pgtype.UUID `json:"project_id"`
ChunkIndex int32 `json:"chunk_index"`
Content string `json:"content"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
Similarity float64 `json:"similarity"`
}
func (q *Queries) SearchRagChunksLexical(ctx context.Context, arg SearchRagChunksLexicalParams) ([]SearchRagChunksLexicalRow, error) {
rows, err := q.db.Query(ctx, searchRagChunksLexical, arg.MatchCount, arg.Terms, arg.ProjectID)
if err != nil {
return nil, err
}
defer rows.Close()
items := []SearchRagChunksLexicalRow{}
for rows.Next() {
var i SearchRagChunksLexicalRow
if err := rows.Scan(
&i.ID,
&i.SourceType,
&i.SourceID,
&i.ProjectID,
&i.ChunkIndex,
&i.Content,
&i.CreatedAt,
&i.UpdatedAt,
&i.Similarity,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
+33
View File
@@ -15,6 +15,7 @@ import (
type RagRepository interface { type RagRepository interface {
SaveChunk(ctx context.Context, sourceType string, sourceID string, projectID string, index int, content string, vector []float32) error SaveChunk(ctx context.Context, sourceType string, sourceID string, projectID string, index int, content string, vector []float32) error
SearchSimilar(ctx context.Context, projectID *string, vector []float32, limit int, threshold float64) ([]*models.RagChunk, error) SearchSimilar(ctx context.Context, projectID *string, vector []float32, limit int, threshold float64) ([]*models.RagChunk, error)
SearchLexical(ctx context.Context, projectID *string, terms []string, limit int) ([]*models.RagChunk, error)
DeleteBySourceIDs(ctx context.Context, sourceType string, sourceIDs []string) error DeleteBySourceIDs(ctx context.Context, sourceType string, sourceIDs []string) error
WithTx(tx pgx.Tx) RagRepository WithTx(tx pgx.Tx) RagRepository
} }
@@ -80,6 +81,38 @@ func (r *ragRepository) SearchSimilar(ctx context.Context, projectID *string, ve
return res, nil return res, nil
} }
func (r *ragRepository) SearchLexical(ctx context.Context, projectID *string, terms []string, limit int) ([]*models.RagChunk, error) {
params := sqlc.SearchRagChunksLexicalParams{
Terms: terms,
MatchCount: int32(limit),
}
if projectID != nil && *projectID != "" {
pID, _ := convert.StringToUUID(*projectID)
params.ProjectID = pID
}
rows, err := r.q.SearchRagChunksLexical(ctx, params)
if err != nil {
return nil, err
}
res := make([]*models.RagChunk, len(rows))
for i, row := range rows {
res[i] = &models.RagChunk{
ID: convert.UUIDToString(row.ID),
SourceType: row.SourceType,
SourceID: convert.UUIDToString(row.SourceID),
ProjectID: convert.UUIDToString(row.ProjectID),
ChunkIndex: row.ChunkIndex,
Content: row.Content,
Similarity: row.Similarity,
CreatedAt: row.CreatedAt.Time,
UpdatedAt: row.UpdatedAt.Time,
}
}
return res, nil
}
func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string, sourceIDs []string) error { func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string, sourceIDs []string) error {
if len(sourceIDs) == 0 { if len(sourceIDs) == 0 {
return nil return nil
+121
View File
@@ -11,6 +11,7 @@ import (
"history-api/pkg/config" "history-api/pkg/config"
"history-api/pkg/constants" "history-api/pkg/constants"
"history-api/pkg/convert" "history-api/pkg/convert"
"regexp"
"strings" "strings"
"time" "time"
@@ -30,6 +31,8 @@ type chatbotService struct {
ragUtils *ai.RagUtils ragUtils *ai.RagUtils
} }
var lexicalTokenRegex = regexp.MustCompile(`[\p{L}\p{N}]+`)
func NewChatbotService(repo repositories.RagRepository, usageRepo repositories.UsageRepository, chatRepo repositories.ChatRepository, ragUtils *ai.RagUtils) ChatbotService { func NewChatbotService(repo repositories.RagRepository, usageRepo repositories.UsageRepository, chatRepo repositories.ChatRepository, ragUtils *ai.RagUtils) ChatbotService {
return &chatbotService{ return &chatbotService{
repo: repo, repo: repo,
@@ -173,6 +176,37 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str
broadResultCount = len(broadResults) broadResultCount = len(broadResults)
} }
var lexicalSearchDuration time.Duration
lexicalResultCount := 0
lexicalTerms := extractLexicalSearchTerms(searchQuery)
if config.GetBoolConfigWithDefault("RAG_LEXICAL_SEARCH_ENABLED", true) && len(lexicalTerms) > 0 {
lexicalLimit := config.GetIntConfigWithDefault("RAG_LEXICAL_CANDIDATES", 20)
if lexicalLimit < 0 {
lexicalLimit = 0
}
if lexicalLimit > candidateLimit {
lexicalLimit = candidateLimit
}
if lexicalLimit > 0 {
lexicalSearchStart := time.Now()
lexicalResults, err := s.repo.SearchLexical(ctx, projectID, lexicalTerms, lexicalLimit)
lexicalSearchDuration = time.Since(lexicalSearchStart)
if err == nil {
lexicalResultCount = len(lexicalResults)
results = mergeRagResults(results, lexicalResults, candidateLimit+lexicalLimit)
} else {
log.Warn().
Err(err).
Str("userID", userID).
Str("projectID", projectIDLog).
Int("lexical_terms", len(lexicalTerms)).
Int("lexical_limit", lexicalLimit).
Dur("lexical_search_duration", lexicalSearchDuration).
Msg("rag lexical search failed")
}
}
}
rerankStart := time.Now() rerankStart := time.Now()
results = s.rerankResults(ctx, searchQuery, results, contextLimit) results = s.rerankResults(ctx, searchQuery, results, contextLimit)
rerankDuration := time.Since(rerankStart) rerankDuration := time.Since(rerankStart)
@@ -276,6 +310,8 @@ Rules:
Int("context_limit", contextLimit). Int("context_limit", contextLimit).
Int("initial_results", initialResultCount). Int("initial_results", initialResultCount).
Int("broad_results", broadResultCount). Int("broad_results", broadResultCount).
Int("lexical_terms", len(lexicalTerms)).
Int("lexical_results", lexicalResultCount).
Int("final_results", len(results)). Int("final_results", len(results)).
Int("context_chars", len(contextStr)). Int("context_chars", len(contextStr)).
Int("prompt_chars", len(prompt)). Int("prompt_chars", len(prompt)).
@@ -287,6 +323,7 @@ Rules:
Dur("embed_duration", embedDuration). Dur("embed_duration", embedDuration).
Dur("vector_search_duration", searchDuration). Dur("vector_search_duration", searchDuration).
Dur("broad_search_duration", broadSearchDuration). Dur("broad_search_duration", broadSearchDuration).
Dur("lexical_search_duration", lexicalSearchDuration).
Dur("rerank_duration", rerankDuration). Dur("rerank_duration", rerankDuration).
Dur("prompt_build_duration", promptBuildDuration). Dur("prompt_build_duration", promptBuildDuration).
Dur("generate_duration", generateDuration). Dur("generate_duration", generateDuration).
@@ -330,6 +367,90 @@ func (s *chatbotService) getRewriteHistory(ctx context.Context, userID pgtype.UU
return turns return turns
} }
func extractLexicalSearchTerms(query string) []string {
tokens := lexicalTokenRegex.FindAllString(strings.ToLower(query), -1)
if len(tokens) == 0 {
return nil
}
terms := make([]string, 0, len(tokens))
seen := make(map[string]struct{}, len(tokens))
addTerm := func(term string) {
term = strings.TrimSpace(term)
if term == "" {
return
}
if _, exists := seen[term]; exists {
return
}
seen[term] = struct{}{}
terms = append(terms, term)
}
for _, token := range tokens {
if hasASCIIDigit(token) || len([]rune(token)) >= 4 {
addTerm(token)
}
}
for i := 0; i < len(tokens)-1; i++ {
left := tokens[i]
right := tokens[i+1]
if hasASCIIDigit(left) || hasASCIIDigit(right) || len([]rune(left))+len([]rune(right)) >= 5 {
addTerm(left + " " + right)
}
}
if len(terms) > 12 {
return terms[:12]
}
return terms
}
func hasASCIIDigit(s string) bool {
for _, r := range s {
if r >= '0' && r <= '9' {
return true
}
}
return false
}
func mergeRagResults(primary, extra []*models.RagChunk, limit int) []*models.RagChunk {
if len(primary) == 0 {
return limitRagResults(extra, limit)
}
if len(extra) == 0 {
return limitRagResults(primary, limit)
}
if limit <= 0 {
limit = len(primary) + len(extra)
}
merged := make([]*models.RagChunk, 0, min(limit, len(primary)+len(extra)))
seen := make(map[string]struct{}, len(primary)+len(extra))
appendChunk := func(chunk *models.RagChunk) {
if chunk == nil || len(merged) >= limit {
return
}
if _, exists := seen[chunk.ID]; exists {
return
}
seen[chunk.ID] = struct{}{}
merged = append(merged, chunk)
}
for _, chunk := range primary {
appendChunk(chunk)
}
for _, chunk := range extra {
appendChunk(chunk)
}
return merged
}
func (s *chatbotService) rerankResults(ctx context.Context, query string, results []*models.RagChunk, limit int) []*models.RagChunk { func (s *chatbotService) rerankResults(ctx context.Context, query string, results []*models.RagChunk, limit int) []*models.RagChunk {
if len(results) == 0 { if len(results) == 0 {
return results return results