feat: implement lexical search functionality in RAG, including new SQL queries and service methods
Build and Release / release (push) Failing after 1m34s
Build and Release / release (push) Failing after 1m34s
This commit is contained in:
@@ -144,8 +144,10 @@ history-api/
|
|||||||
RAG_QUERY_REWRITE_MAX_TOKENS=96
|
RAG_QUERY_REWRITE_MAX_TOKENS=96
|
||||||
RAG_REWRITE_HISTORY_TURNS=3
|
RAG_REWRITE_HISTORY_TURNS=3
|
||||||
RAG_RETRIEVAL_CANDIDATES=30
|
RAG_RETRIEVAL_CANDIDATES=30
|
||||||
RAG_CONTEXT_TOP_N=5
|
RAG_LEXICAL_SEARCH_ENABLED=true
|
||||||
RAG_CONTEXT_MAX_CHARS=8000
|
RAG_LEXICAL_CANDIDATES=20
|
||||||
|
RAG_CONTEXT_TOP_N=6
|
||||||
|
RAG_CONTEXT_MAX_CHARS=12000
|
||||||
RAG_GENERATION_MAX_RETRIES=1
|
RAG_GENERATION_MAX_RETRIES=1
|
||||||
RAG_GENERATION_RETRY_DELAY_MS=500
|
RAG_GENERATION_RETRY_DELAY_MS=500
|
||||||
RAG_RERANK_ENABLED=true
|
RAG_RERANK_ENABLED=true
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
DROP INDEX IF EXISTS idx_rag_chunks_content_trgm;
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
CREATE EXTENSION IF NOT EXISTS pg_trgm;
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_rag_chunks_content_trgm
|
||||||
|
ON rag_chunks USING GIN (content gin_trgm_ops);
|
||||||
@@ -19,6 +19,42 @@ WHERE 1=1
|
|||||||
ORDER BY embedding <=> sqlc.arg('embedding')
|
ORDER BY embedding <=> sqlc.arg('embedding')
|
||||||
LIMIT sqlc.arg('match_count');
|
LIMIT sqlc.arg('match_count');
|
||||||
|
|
||||||
|
-- name: SearchRagChunksLexical :many
|
||||||
|
WITH terms AS (
|
||||||
|
SELECT DISTINCT trim(term) AS term
|
||||||
|
FROM unnest(sqlc.arg('terms')::text[]) AS term
|
||||||
|
WHERE trim(term) <> ''
|
||||||
|
),
|
||||||
|
matched AS (
|
||||||
|
SELECT
|
||||||
|
r.id,
|
||||||
|
r.source_type,
|
||||||
|
r.source_id,
|
||||||
|
r.project_id,
|
||||||
|
r.chunk_index,
|
||||||
|
r.content,
|
||||||
|
r.created_at,
|
||||||
|
r.updated_at,
|
||||||
|
COUNT(*)::float8 AS similarity
|
||||||
|
FROM rag_chunks r
|
||||||
|
JOIN terms t ON r.content ILIKE '%' || t.term || '%'
|
||||||
|
WHERE (sqlc.narg('project_id')::uuid IS NULL OR r.project_id = sqlc.narg('project_id')::uuid)
|
||||||
|
GROUP BY
|
||||||
|
r.id,
|
||||||
|
r.source_type,
|
||||||
|
r.source_id,
|
||||||
|
r.project_id,
|
||||||
|
r.chunk_index,
|
||||||
|
r.content,
|
||||||
|
r.created_at,
|
||||||
|
r.updated_at
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at, similarity
|
||||||
|
FROM matched
|
||||||
|
ORDER BY similarity DESC, chunk_index ASC
|
||||||
|
LIMIT sqlc.arg('match_count');
|
||||||
|
|
||||||
-- name: DeleteRagChunksBySourceIDs :exec
|
-- name: DeleteRagChunksBySourceIDs :exec
|
||||||
DELETE FROM rag_chunks
|
DELETE FROM rag_chunks
|
||||||
WHERE source_type = $1 AND source_id = ANY($2::uuid[]);
|
WHERE source_type = $1 AND source_id = ANY($2::uuid[]);
|
||||||
|
|||||||
@@ -140,3 +140,88 @@ func (q *Queries) SearchRagChunks(ctx context.Context, arg SearchRagChunksParams
|
|||||||
}
|
}
|
||||||
return items, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const searchRagChunksLexical = `-- name: SearchRagChunksLexical :many
|
||||||
|
WITH terms AS (
|
||||||
|
SELECT DISTINCT trim(term) AS term
|
||||||
|
FROM unnest($2::text[]) AS term
|
||||||
|
WHERE trim(term) <> ''
|
||||||
|
),
|
||||||
|
matched AS (
|
||||||
|
SELECT
|
||||||
|
r.id,
|
||||||
|
r.source_type,
|
||||||
|
r.source_id,
|
||||||
|
r.project_id,
|
||||||
|
r.chunk_index,
|
||||||
|
r.content,
|
||||||
|
r.created_at,
|
||||||
|
r.updated_at,
|
||||||
|
COUNT(*)::float8 AS similarity
|
||||||
|
FROM rag_chunks r
|
||||||
|
JOIN terms t ON r.content ILIKE '%' || t.term || '%'
|
||||||
|
WHERE ($3::uuid IS NULL OR r.project_id = $3::uuid)
|
||||||
|
GROUP BY
|
||||||
|
r.id,
|
||||||
|
r.source_type,
|
||||||
|
r.source_id,
|
||||||
|
r.project_id,
|
||||||
|
r.chunk_index,
|
||||||
|
r.content,
|
||||||
|
r.created_at,
|
||||||
|
r.updated_at
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at, similarity
|
||||||
|
FROM matched
|
||||||
|
ORDER BY similarity DESC, chunk_index ASC
|
||||||
|
LIMIT $1
|
||||||
|
`
|
||||||
|
|
||||||
|
type SearchRagChunksLexicalParams struct {
|
||||||
|
MatchCount int32 `json:"match_count"`
|
||||||
|
Terms []string `json:"terms"`
|
||||||
|
ProjectID pgtype.UUID `json:"project_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SearchRagChunksLexicalRow struct {
|
||||||
|
ID pgtype.UUID `json:"id"`
|
||||||
|
SourceType string `json:"source_type"`
|
||||||
|
SourceID pgtype.UUID `json:"source_id"`
|
||||||
|
ProjectID pgtype.UUID `json:"project_id"`
|
||||||
|
ChunkIndex int32 `json:"chunk_index"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||||
|
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||||
|
Similarity float64 `json:"similarity"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) SearchRagChunksLexical(ctx context.Context, arg SearchRagChunksLexicalParams) ([]SearchRagChunksLexicalRow, error) {
|
||||||
|
rows, err := q.db.Query(ctx, searchRagChunksLexical, arg.MatchCount, arg.Terms, arg.ProjectID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
items := []SearchRagChunksLexicalRow{}
|
||||||
|
for rows.Next() {
|
||||||
|
var i SearchRagChunksLexicalRow
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.SourceType,
|
||||||
|
&i.SourceID,
|
||||||
|
&i.ProjectID,
|
||||||
|
&i.ChunkIndex,
|
||||||
|
&i.Content,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
&i.Similarity,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
type RagRepository interface {
|
type RagRepository interface {
|
||||||
SaveChunk(ctx context.Context, sourceType string, sourceID string, projectID string, index int, content string, vector []float32) error
|
SaveChunk(ctx context.Context, sourceType string, sourceID string, projectID string, index int, content string, vector []float32) error
|
||||||
SearchSimilar(ctx context.Context, projectID *string, vector []float32, limit int, threshold float64) ([]*models.RagChunk, error)
|
SearchSimilar(ctx context.Context, projectID *string, vector []float32, limit int, threshold float64) ([]*models.RagChunk, error)
|
||||||
|
SearchLexical(ctx context.Context, projectID *string, terms []string, limit int) ([]*models.RagChunk, error)
|
||||||
DeleteBySourceIDs(ctx context.Context, sourceType string, sourceIDs []string) error
|
DeleteBySourceIDs(ctx context.Context, sourceType string, sourceIDs []string) error
|
||||||
WithTx(tx pgx.Tx) RagRepository
|
WithTx(tx pgx.Tx) RagRepository
|
||||||
}
|
}
|
||||||
@@ -80,6 +81,38 @@ func (r *ragRepository) SearchSimilar(ctx context.Context, projectID *string, ve
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *ragRepository) SearchLexical(ctx context.Context, projectID *string, terms []string, limit int) ([]*models.RagChunk, error) {
|
||||||
|
params := sqlc.SearchRagChunksLexicalParams{
|
||||||
|
Terms: terms,
|
||||||
|
MatchCount: int32(limit),
|
||||||
|
}
|
||||||
|
if projectID != nil && *projectID != "" {
|
||||||
|
pID, _ := convert.StringToUUID(*projectID)
|
||||||
|
params.ProjectID = pID
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := r.q.SearchRagChunksLexical(ctx, params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
res := make([]*models.RagChunk, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
res[i] = &models.RagChunk{
|
||||||
|
ID: convert.UUIDToString(row.ID),
|
||||||
|
SourceType: row.SourceType,
|
||||||
|
SourceID: convert.UUIDToString(row.SourceID),
|
||||||
|
ProjectID: convert.UUIDToString(row.ProjectID),
|
||||||
|
ChunkIndex: row.ChunkIndex,
|
||||||
|
Content: row.Content,
|
||||||
|
Similarity: row.Similarity,
|
||||||
|
CreatedAt: row.CreatedAt.Time,
|
||||||
|
UpdatedAt: row.UpdatedAt.Time,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string, sourceIDs []string) error {
|
func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string, sourceIDs []string) error {
|
||||||
if len(sourceIDs) == 0 {
|
if len(sourceIDs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"history-api/pkg/config"
|
"history-api/pkg/config"
|
||||||
"history-api/pkg/constants"
|
"history-api/pkg/constants"
|
||||||
"history-api/pkg/convert"
|
"history-api/pkg/convert"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -30,6 +31,8 @@ type chatbotService struct {
|
|||||||
ragUtils *ai.RagUtils
|
ragUtils *ai.RagUtils
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var lexicalTokenRegex = regexp.MustCompile(`[\p{L}\p{N}]+`)
|
||||||
|
|
||||||
func NewChatbotService(repo repositories.RagRepository, usageRepo repositories.UsageRepository, chatRepo repositories.ChatRepository, ragUtils *ai.RagUtils) ChatbotService {
|
func NewChatbotService(repo repositories.RagRepository, usageRepo repositories.UsageRepository, chatRepo repositories.ChatRepository, ragUtils *ai.RagUtils) ChatbotService {
|
||||||
return &chatbotService{
|
return &chatbotService{
|
||||||
repo: repo,
|
repo: repo,
|
||||||
@@ -173,6 +176,37 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str
|
|||||||
broadResultCount = len(broadResults)
|
broadResultCount = len(broadResults)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var lexicalSearchDuration time.Duration
|
||||||
|
lexicalResultCount := 0
|
||||||
|
lexicalTerms := extractLexicalSearchTerms(searchQuery)
|
||||||
|
if config.GetBoolConfigWithDefault("RAG_LEXICAL_SEARCH_ENABLED", true) && len(lexicalTerms) > 0 {
|
||||||
|
lexicalLimit := config.GetIntConfigWithDefault("RAG_LEXICAL_CANDIDATES", 20)
|
||||||
|
if lexicalLimit < 0 {
|
||||||
|
lexicalLimit = 0
|
||||||
|
}
|
||||||
|
if lexicalLimit > candidateLimit {
|
||||||
|
lexicalLimit = candidateLimit
|
||||||
|
}
|
||||||
|
if lexicalLimit > 0 {
|
||||||
|
lexicalSearchStart := time.Now()
|
||||||
|
lexicalResults, err := s.repo.SearchLexical(ctx, projectID, lexicalTerms, lexicalLimit)
|
||||||
|
lexicalSearchDuration = time.Since(lexicalSearchStart)
|
||||||
|
if err == nil {
|
||||||
|
lexicalResultCount = len(lexicalResults)
|
||||||
|
results = mergeRagResults(results, lexicalResults, candidateLimit+lexicalLimit)
|
||||||
|
} else {
|
||||||
|
log.Warn().
|
||||||
|
Err(err).
|
||||||
|
Str("userID", userID).
|
||||||
|
Str("projectID", projectIDLog).
|
||||||
|
Int("lexical_terms", len(lexicalTerms)).
|
||||||
|
Int("lexical_limit", lexicalLimit).
|
||||||
|
Dur("lexical_search_duration", lexicalSearchDuration).
|
||||||
|
Msg("rag lexical search failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
rerankStart := time.Now()
|
rerankStart := time.Now()
|
||||||
results = s.rerankResults(ctx, searchQuery, results, contextLimit)
|
results = s.rerankResults(ctx, searchQuery, results, contextLimit)
|
||||||
rerankDuration := time.Since(rerankStart)
|
rerankDuration := time.Since(rerankStart)
|
||||||
@@ -276,6 +310,8 @@ Rules:
|
|||||||
Int("context_limit", contextLimit).
|
Int("context_limit", contextLimit).
|
||||||
Int("initial_results", initialResultCount).
|
Int("initial_results", initialResultCount).
|
||||||
Int("broad_results", broadResultCount).
|
Int("broad_results", broadResultCount).
|
||||||
|
Int("lexical_terms", len(lexicalTerms)).
|
||||||
|
Int("lexical_results", lexicalResultCount).
|
||||||
Int("final_results", len(results)).
|
Int("final_results", len(results)).
|
||||||
Int("context_chars", len(contextStr)).
|
Int("context_chars", len(contextStr)).
|
||||||
Int("prompt_chars", len(prompt)).
|
Int("prompt_chars", len(prompt)).
|
||||||
@@ -287,6 +323,7 @@ Rules:
|
|||||||
Dur("embed_duration", embedDuration).
|
Dur("embed_duration", embedDuration).
|
||||||
Dur("vector_search_duration", searchDuration).
|
Dur("vector_search_duration", searchDuration).
|
||||||
Dur("broad_search_duration", broadSearchDuration).
|
Dur("broad_search_duration", broadSearchDuration).
|
||||||
|
Dur("lexical_search_duration", lexicalSearchDuration).
|
||||||
Dur("rerank_duration", rerankDuration).
|
Dur("rerank_duration", rerankDuration).
|
||||||
Dur("prompt_build_duration", promptBuildDuration).
|
Dur("prompt_build_duration", promptBuildDuration).
|
||||||
Dur("generate_duration", generateDuration).
|
Dur("generate_duration", generateDuration).
|
||||||
@@ -330,6 +367,90 @@ func (s *chatbotService) getRewriteHistory(ctx context.Context, userID pgtype.UU
|
|||||||
return turns
|
return turns
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractLexicalSearchTerms(query string) []string {
|
||||||
|
tokens := lexicalTokenRegex.FindAllString(strings.ToLower(query), -1)
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
terms := make([]string, 0, len(tokens))
|
||||||
|
seen := make(map[string]struct{}, len(tokens))
|
||||||
|
addTerm := func(term string) {
|
||||||
|
term = strings.TrimSpace(term)
|
||||||
|
if term == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, exists := seen[term]; exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[term] = struct{}{}
|
||||||
|
terms = append(terms, term)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, token := range tokens {
|
||||||
|
if hasASCIIDigit(token) || len([]rune(token)) >= 4 {
|
||||||
|
addTerm(token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(tokens)-1; i++ {
|
||||||
|
left := tokens[i]
|
||||||
|
right := tokens[i+1]
|
||||||
|
if hasASCIIDigit(left) || hasASCIIDigit(right) || len([]rune(left))+len([]rune(right)) >= 5 {
|
||||||
|
addTerm(left + " " + right)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(terms) > 12 {
|
||||||
|
return terms[:12]
|
||||||
|
}
|
||||||
|
|
||||||
|
return terms
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasASCIIDigit(s string) bool {
|
||||||
|
for _, r := range s {
|
||||||
|
if r >= '0' && r <= '9' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeRagResults(primary, extra []*models.RagChunk, limit int) []*models.RagChunk {
|
||||||
|
if len(primary) == 0 {
|
||||||
|
return limitRagResults(extra, limit)
|
||||||
|
}
|
||||||
|
if len(extra) == 0 {
|
||||||
|
return limitRagResults(primary, limit)
|
||||||
|
}
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = len(primary) + len(extra)
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := make([]*models.RagChunk, 0, min(limit, len(primary)+len(extra)))
|
||||||
|
seen := make(map[string]struct{}, len(primary)+len(extra))
|
||||||
|
appendChunk := func(chunk *models.RagChunk) {
|
||||||
|
if chunk == nil || len(merged) >= limit {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, exists := seen[chunk.ID]; exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[chunk.ID] = struct{}{}
|
||||||
|
merged = append(merged, chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chunk := range primary {
|
||||||
|
appendChunk(chunk)
|
||||||
|
}
|
||||||
|
for _, chunk := range extra {
|
||||||
|
appendChunk(chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
func (s *chatbotService) rerankResults(ctx context.Context, query string, results []*models.RagChunk, limit int) []*models.RagChunk {
|
func (s *chatbotService) rerankResults(ctx context.Context, query string, results []*models.RagChunk, limit int) []*models.RagChunk {
|
||||||
if len(results) == 0 {
|
if len(results) == 0 {
|
||||||
return results
|
return results
|
||||||
|
|||||||
Reference in New Issue
Block a user