feat: implement lexical search functionality in RAG, including new SQL queries and service methods
Build and Release / release (push) Failing after 1m34s
Build and Release / release (push) Failing after 1m34s
This commit is contained in:
@@ -144,8 +144,10 @@ history-api/
|
||||
RAG_QUERY_REWRITE_MAX_TOKENS=96
|
||||
RAG_REWRITE_HISTORY_TURNS=3
|
||||
RAG_RETRIEVAL_CANDIDATES=30
|
||||
RAG_CONTEXT_TOP_N=5
|
||||
RAG_CONTEXT_MAX_CHARS=8000
|
||||
RAG_LEXICAL_SEARCH_ENABLED=true
|
||||
RAG_LEXICAL_CANDIDATES=20
|
||||
RAG_CONTEXT_TOP_N=6
|
||||
RAG_CONTEXT_MAX_CHARS=12000
|
||||
RAG_GENERATION_MAX_RETRIES=1
|
||||
RAG_GENERATION_RETRY_DELAY_MS=500
|
||||
RAG_RERANK_ENABLED=true
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
DROP INDEX IF EXISTS idx_rag_chunks_content_trgm;
|
||||
@@ -0,0 +1,4 @@
|
||||
CREATE EXTENSION IF NOT EXISTS pg_trgm;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_rag_chunks_content_trgm
|
||||
ON rag_chunks USING GIN (content gin_trgm_ops);
|
||||
@@ -19,6 +19,42 @@ WHERE 1=1
|
||||
ORDER BY embedding <=> sqlc.arg('embedding')
|
||||
LIMIT sqlc.arg('match_count');
|
||||
|
||||
-- name: SearchRagChunksLexical :many
|
||||
WITH terms AS (
|
||||
SELECT DISTINCT trim(term) AS term
|
||||
FROM unnest(sqlc.arg('terms')::text[]) AS term
|
||||
WHERE trim(term) <> ''
|
||||
),
|
||||
matched AS (
|
||||
SELECT
|
||||
r.id,
|
||||
r.source_type,
|
||||
r.source_id,
|
||||
r.project_id,
|
||||
r.chunk_index,
|
||||
r.content,
|
||||
r.created_at,
|
||||
r.updated_at,
|
||||
COUNT(*)::float8 AS similarity
|
||||
FROM rag_chunks r
|
||||
JOIN terms t ON r.content ILIKE '%' || t.term || '%'
|
||||
WHERE (sqlc.narg('project_id')::uuid IS NULL OR r.project_id = sqlc.narg('project_id')::uuid)
|
||||
GROUP BY
|
||||
r.id,
|
||||
r.source_type,
|
||||
r.source_id,
|
||||
r.project_id,
|
||||
r.chunk_index,
|
||||
r.content,
|
||||
r.created_at,
|
||||
r.updated_at
|
||||
)
|
||||
SELECT
|
||||
id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at, similarity
|
||||
FROM matched
|
||||
ORDER BY similarity DESC, chunk_index ASC
|
||||
LIMIT sqlc.arg('match_count');
|
||||
|
||||
-- name: DeleteRagChunksBySourceIDs :exec
|
||||
DELETE FROM rag_chunks
|
||||
WHERE source_type = $1 AND source_id = ANY($2::uuid[]);
|
||||
|
||||
@@ -140,3 +140,88 @@ func (q *Queries) SearchRagChunks(ctx context.Context, arg SearchRagChunksParams
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const searchRagChunksLexical = `-- name: SearchRagChunksLexical :many
|
||||
WITH terms AS (
|
||||
SELECT DISTINCT trim(term) AS term
|
||||
FROM unnest($2::text[]) AS term
|
||||
WHERE trim(term) <> ''
|
||||
),
|
||||
matched AS (
|
||||
SELECT
|
||||
r.id,
|
||||
r.source_type,
|
||||
r.source_id,
|
||||
r.project_id,
|
||||
r.chunk_index,
|
||||
r.content,
|
||||
r.created_at,
|
||||
r.updated_at,
|
||||
COUNT(*)::float8 AS similarity
|
||||
FROM rag_chunks r
|
||||
JOIN terms t ON r.content ILIKE '%' || t.term || '%'
|
||||
WHERE ($3::uuid IS NULL OR r.project_id = $3::uuid)
|
||||
GROUP BY
|
||||
r.id,
|
||||
r.source_type,
|
||||
r.source_id,
|
||||
r.project_id,
|
||||
r.chunk_index,
|
||||
r.content,
|
||||
r.created_at,
|
||||
r.updated_at
|
||||
)
|
||||
SELECT
|
||||
id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at, similarity
|
||||
FROM matched
|
||||
ORDER BY similarity DESC, chunk_index ASC
|
||||
LIMIT $1
|
||||
`
|
||||
|
||||
type SearchRagChunksLexicalParams struct {
|
||||
MatchCount int32 `json:"match_count"`
|
||||
Terms []string `json:"terms"`
|
||||
ProjectID pgtype.UUID `json:"project_id"`
|
||||
}
|
||||
|
||||
type SearchRagChunksLexicalRow struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
SourceType string `json:"source_type"`
|
||||
SourceID pgtype.UUID `json:"source_id"`
|
||||
ProjectID pgtype.UUID `json:"project_id"`
|
||||
ChunkIndex int32 `json:"chunk_index"`
|
||||
Content string `json:"content"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
Similarity float64 `json:"similarity"`
|
||||
}
|
||||
|
||||
func (q *Queries) SearchRagChunksLexical(ctx context.Context, arg SearchRagChunksLexicalParams) ([]SearchRagChunksLexicalRow, error) {
|
||||
rows, err := q.db.Query(ctx, searchRagChunksLexical, arg.MatchCount, arg.Terms, arg.ProjectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []SearchRagChunksLexicalRow{}
|
||||
for rows.Next() {
|
||||
var i SearchRagChunksLexicalRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SourceType,
|
||||
&i.SourceID,
|
||||
&i.ProjectID,
|
||||
&i.ChunkIndex,
|
||||
&i.Content,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Similarity,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
type RagRepository interface {
|
||||
SaveChunk(ctx context.Context, sourceType string, sourceID string, projectID string, index int, content string, vector []float32) error
|
||||
SearchSimilar(ctx context.Context, projectID *string, vector []float32, limit int, threshold float64) ([]*models.RagChunk, error)
|
||||
SearchLexical(ctx context.Context, projectID *string, terms []string, limit int) ([]*models.RagChunk, error)
|
||||
DeleteBySourceIDs(ctx context.Context, sourceType string, sourceIDs []string) error
|
||||
WithTx(tx pgx.Tx) RagRepository
|
||||
}
|
||||
@@ -80,6 +81,38 @@ func (r *ragRepository) SearchSimilar(ctx context.Context, projectID *string, ve
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (r *ragRepository) SearchLexical(ctx context.Context, projectID *string, terms []string, limit int) ([]*models.RagChunk, error) {
|
||||
params := sqlc.SearchRagChunksLexicalParams{
|
||||
Terms: terms,
|
||||
MatchCount: int32(limit),
|
||||
}
|
||||
if projectID != nil && *projectID != "" {
|
||||
pID, _ := convert.StringToUUID(*projectID)
|
||||
params.ProjectID = pID
|
||||
}
|
||||
|
||||
rows, err := r.q.SearchRagChunksLexical(ctx, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := make([]*models.RagChunk, len(rows))
|
||||
for i, row := range rows {
|
||||
res[i] = &models.RagChunk{
|
||||
ID: convert.UUIDToString(row.ID),
|
||||
SourceType: row.SourceType,
|
||||
SourceID: convert.UUIDToString(row.SourceID),
|
||||
ProjectID: convert.UUIDToString(row.ProjectID),
|
||||
ChunkIndex: row.ChunkIndex,
|
||||
Content: row.Content,
|
||||
Similarity: row.Similarity,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
UpdatedAt: row.UpdatedAt.Time,
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string, sourceIDs []string) error {
|
||||
if len(sourceIDs) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"history-api/pkg/config"
|
||||
"history-api/pkg/constants"
|
||||
"history-api/pkg/convert"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -30,6 +31,8 @@ type chatbotService struct {
|
||||
ragUtils *ai.RagUtils
|
||||
}
|
||||
|
||||
var lexicalTokenRegex = regexp.MustCompile(`[\p{L}\p{N}]+`)
|
||||
|
||||
func NewChatbotService(repo repositories.RagRepository, usageRepo repositories.UsageRepository, chatRepo repositories.ChatRepository, ragUtils *ai.RagUtils) ChatbotService {
|
||||
return &chatbotService{
|
||||
repo: repo,
|
||||
@@ -173,6 +176,37 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str
|
||||
broadResultCount = len(broadResults)
|
||||
}
|
||||
|
||||
var lexicalSearchDuration time.Duration
|
||||
lexicalResultCount := 0
|
||||
lexicalTerms := extractLexicalSearchTerms(searchQuery)
|
||||
if config.GetBoolConfigWithDefault("RAG_LEXICAL_SEARCH_ENABLED", true) && len(lexicalTerms) > 0 {
|
||||
lexicalLimit := config.GetIntConfigWithDefault("RAG_LEXICAL_CANDIDATES", 20)
|
||||
if lexicalLimit < 0 {
|
||||
lexicalLimit = 0
|
||||
}
|
||||
if lexicalLimit > candidateLimit {
|
||||
lexicalLimit = candidateLimit
|
||||
}
|
||||
if lexicalLimit > 0 {
|
||||
lexicalSearchStart := time.Now()
|
||||
lexicalResults, err := s.repo.SearchLexical(ctx, projectID, lexicalTerms, lexicalLimit)
|
||||
lexicalSearchDuration = time.Since(lexicalSearchStart)
|
||||
if err == nil {
|
||||
lexicalResultCount = len(lexicalResults)
|
||||
results = mergeRagResults(results, lexicalResults, candidateLimit+lexicalLimit)
|
||||
} else {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Str("userID", userID).
|
||||
Str("projectID", projectIDLog).
|
||||
Int("lexical_terms", len(lexicalTerms)).
|
||||
Int("lexical_limit", lexicalLimit).
|
||||
Dur("lexical_search_duration", lexicalSearchDuration).
|
||||
Msg("rag lexical search failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rerankStart := time.Now()
|
||||
results = s.rerankResults(ctx, searchQuery, results, contextLimit)
|
||||
rerankDuration := time.Since(rerankStart)
|
||||
@@ -276,6 +310,8 @@ Rules:
|
||||
Int("context_limit", contextLimit).
|
||||
Int("initial_results", initialResultCount).
|
||||
Int("broad_results", broadResultCount).
|
||||
Int("lexical_terms", len(lexicalTerms)).
|
||||
Int("lexical_results", lexicalResultCount).
|
||||
Int("final_results", len(results)).
|
||||
Int("context_chars", len(contextStr)).
|
||||
Int("prompt_chars", len(prompt)).
|
||||
@@ -287,6 +323,7 @@ Rules:
|
||||
Dur("embed_duration", embedDuration).
|
||||
Dur("vector_search_duration", searchDuration).
|
||||
Dur("broad_search_duration", broadSearchDuration).
|
||||
Dur("lexical_search_duration", lexicalSearchDuration).
|
||||
Dur("rerank_duration", rerankDuration).
|
||||
Dur("prompt_build_duration", promptBuildDuration).
|
||||
Dur("generate_duration", generateDuration).
|
||||
@@ -330,6 +367,90 @@ func (s *chatbotService) getRewriteHistory(ctx context.Context, userID pgtype.UU
|
||||
return turns
|
||||
}
|
||||
|
||||
func extractLexicalSearchTerms(query string) []string {
|
||||
tokens := lexicalTokenRegex.FindAllString(strings.ToLower(query), -1)
|
||||
if len(tokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
terms := make([]string, 0, len(tokens))
|
||||
seen := make(map[string]struct{}, len(tokens))
|
||||
addTerm := func(term string) {
|
||||
term = strings.TrimSpace(term)
|
||||
if term == "" {
|
||||
return
|
||||
}
|
||||
if _, exists := seen[term]; exists {
|
||||
return
|
||||
}
|
||||
seen[term] = struct{}{}
|
||||
terms = append(terms, term)
|
||||
}
|
||||
|
||||
for _, token := range tokens {
|
||||
if hasASCIIDigit(token) || len([]rune(token)) >= 4 {
|
||||
addTerm(token)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(tokens)-1; i++ {
|
||||
left := tokens[i]
|
||||
right := tokens[i+1]
|
||||
if hasASCIIDigit(left) || hasASCIIDigit(right) || len([]rune(left))+len([]rune(right)) >= 5 {
|
||||
addTerm(left + " " + right)
|
||||
}
|
||||
}
|
||||
|
||||
if len(terms) > 12 {
|
||||
return terms[:12]
|
||||
}
|
||||
|
||||
return terms
|
||||
}
|
||||
|
||||
func hasASCIIDigit(s string) bool {
|
||||
for _, r := range s {
|
||||
if r >= '0' && r <= '9' {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mergeRagResults(primary, extra []*models.RagChunk, limit int) []*models.RagChunk {
|
||||
if len(primary) == 0 {
|
||||
return limitRagResults(extra, limit)
|
||||
}
|
||||
if len(extra) == 0 {
|
||||
return limitRagResults(primary, limit)
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = len(primary) + len(extra)
|
||||
}
|
||||
|
||||
merged := make([]*models.RagChunk, 0, min(limit, len(primary)+len(extra)))
|
||||
seen := make(map[string]struct{}, len(primary)+len(extra))
|
||||
appendChunk := func(chunk *models.RagChunk) {
|
||||
if chunk == nil || len(merged) >= limit {
|
||||
return
|
||||
}
|
||||
if _, exists := seen[chunk.ID]; exists {
|
||||
return
|
||||
}
|
||||
seen[chunk.ID] = struct{}{}
|
||||
merged = append(merged, chunk)
|
||||
}
|
||||
|
||||
for _, chunk := range primary {
|
||||
appendChunk(chunk)
|
||||
}
|
||||
for _, chunk := range extra {
|
||||
appendChunk(chunk)
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
func (s *chatbotService) rerankResults(ctx context.Context, query string, results []*models.RagChunk, limit int) []*models.RagChunk {
|
||||
if len(results) == 0 {
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user