feat: implement lexical search functionality in RAG, including new SQL queries and service methods
Build and Release / release (push) Failing after 1m34s
Build and Release / release (push) Failing after 1m34s
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"history-api/pkg/config"
|
||||
"history-api/pkg/constants"
|
||||
"history-api/pkg/convert"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -30,6 +31,8 @@ type chatbotService struct {
|
||||
ragUtils *ai.RagUtils
|
||||
}
|
||||
|
||||
var lexicalTokenRegex = regexp.MustCompile(`[\p{L}\p{N}]+`)
|
||||
|
||||
func NewChatbotService(repo repositories.RagRepository, usageRepo repositories.UsageRepository, chatRepo repositories.ChatRepository, ragUtils *ai.RagUtils) ChatbotService {
|
||||
return &chatbotService{
|
||||
repo: repo,
|
||||
@@ -173,6 +176,37 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str
|
||||
broadResultCount = len(broadResults)
|
||||
}
|
||||
|
||||
var lexicalSearchDuration time.Duration
|
||||
lexicalResultCount := 0
|
||||
lexicalTerms := extractLexicalSearchTerms(searchQuery)
|
||||
if config.GetBoolConfigWithDefault("RAG_LEXICAL_SEARCH_ENABLED", true) && len(lexicalTerms) > 0 {
|
||||
lexicalLimit := config.GetIntConfigWithDefault("RAG_LEXICAL_CANDIDATES", 20)
|
||||
if lexicalLimit < 0 {
|
||||
lexicalLimit = 0
|
||||
}
|
||||
if lexicalLimit > candidateLimit {
|
||||
lexicalLimit = candidateLimit
|
||||
}
|
||||
if lexicalLimit > 0 {
|
||||
lexicalSearchStart := time.Now()
|
||||
lexicalResults, err := s.repo.SearchLexical(ctx, projectID, lexicalTerms, lexicalLimit)
|
||||
lexicalSearchDuration = time.Since(lexicalSearchStart)
|
||||
if err == nil {
|
||||
lexicalResultCount = len(lexicalResults)
|
||||
results = mergeRagResults(results, lexicalResults, candidateLimit+lexicalLimit)
|
||||
} else {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Str("userID", userID).
|
||||
Str("projectID", projectIDLog).
|
||||
Int("lexical_terms", len(lexicalTerms)).
|
||||
Int("lexical_limit", lexicalLimit).
|
||||
Dur("lexical_search_duration", lexicalSearchDuration).
|
||||
Msg("rag lexical search failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rerankStart := time.Now()
|
||||
results = s.rerankResults(ctx, searchQuery, results, contextLimit)
|
||||
rerankDuration := time.Since(rerankStart)
|
||||
@@ -276,6 +310,8 @@ Rules:
|
||||
Int("context_limit", contextLimit).
|
||||
Int("initial_results", initialResultCount).
|
||||
Int("broad_results", broadResultCount).
|
||||
Int("lexical_terms", len(lexicalTerms)).
|
||||
Int("lexical_results", lexicalResultCount).
|
||||
Int("final_results", len(results)).
|
||||
Int("context_chars", len(contextStr)).
|
||||
Int("prompt_chars", len(prompt)).
|
||||
@@ -287,6 +323,7 @@ Rules:
|
||||
Dur("embed_duration", embedDuration).
|
||||
Dur("vector_search_duration", searchDuration).
|
||||
Dur("broad_search_duration", broadSearchDuration).
|
||||
Dur("lexical_search_duration", lexicalSearchDuration).
|
||||
Dur("rerank_duration", rerankDuration).
|
||||
Dur("prompt_build_duration", promptBuildDuration).
|
||||
Dur("generate_duration", generateDuration).
|
||||
@@ -330,6 +367,90 @@ func (s *chatbotService) getRewriteHistory(ctx context.Context, userID pgtype.UU
|
||||
return turns
|
||||
}
|
||||
|
||||
func extractLexicalSearchTerms(query string) []string {
|
||||
tokens := lexicalTokenRegex.FindAllString(strings.ToLower(query), -1)
|
||||
if len(tokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
terms := make([]string, 0, len(tokens))
|
||||
seen := make(map[string]struct{}, len(tokens))
|
||||
addTerm := func(term string) {
|
||||
term = strings.TrimSpace(term)
|
||||
if term == "" {
|
||||
return
|
||||
}
|
||||
if _, exists := seen[term]; exists {
|
||||
return
|
||||
}
|
||||
seen[term] = struct{}{}
|
||||
terms = append(terms, term)
|
||||
}
|
||||
|
||||
for _, token := range tokens {
|
||||
if hasASCIIDigit(token) || len([]rune(token)) >= 4 {
|
||||
addTerm(token)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(tokens)-1; i++ {
|
||||
left := tokens[i]
|
||||
right := tokens[i+1]
|
||||
if hasASCIIDigit(left) || hasASCIIDigit(right) || len([]rune(left))+len([]rune(right)) >= 5 {
|
||||
addTerm(left + " " + right)
|
||||
}
|
||||
}
|
||||
|
||||
if len(terms) > 12 {
|
||||
return terms[:12]
|
||||
}
|
||||
|
||||
return terms
|
||||
}
|
||||
|
||||
func hasASCIIDigit(s string) bool {
|
||||
for _, r := range s {
|
||||
if r >= '0' && r <= '9' {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mergeRagResults(primary, extra []*models.RagChunk, limit int) []*models.RagChunk {
|
||||
if len(primary) == 0 {
|
||||
return limitRagResults(extra, limit)
|
||||
}
|
||||
if len(extra) == 0 {
|
||||
return limitRagResults(primary, limit)
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = len(primary) + len(extra)
|
||||
}
|
||||
|
||||
merged := make([]*models.RagChunk, 0, min(limit, len(primary)+len(extra)))
|
||||
seen := make(map[string]struct{}, len(primary)+len(extra))
|
||||
appendChunk := func(chunk *models.RagChunk) {
|
||||
if chunk == nil || len(merged) >= limit {
|
||||
return
|
||||
}
|
||||
if _, exists := seen[chunk.ID]; exists {
|
||||
return
|
||||
}
|
||||
seen[chunk.ID] = struct{}{}
|
||||
merged = append(merged, chunk)
|
||||
}
|
||||
|
||||
for _, chunk := range primary {
|
||||
appendChunk(chunk)
|
||||
}
|
||||
for _, chunk := range extra {
|
||||
appendChunk(chunk)
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
func (s *chatbotService) rerankResults(ctx context.Context, query string, results []*models.RagChunk, limit int) []*models.RagChunk {
|
||||
if len(results) == 0 {
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user