diff --git a/internal/services/chatbotService.go b/internal/services/chatbotService.go index caa8698..042652a 100644 --- a/internal/services/chatbotService.go +++ b/internal/services/chatbotService.go @@ -2,7 +2,6 @@ package services import ( "context" - "errors" "fmt" "history-api/internal/dtos/request" "history-api/internal/gen/sqlc" @@ -11,8 +10,10 @@ import ( "history-api/pkg/ai" "history-api/pkg/constants" "history-api/pkg/convert" + "strings" "github.com/jackc/pgx/v5/pgtype" + "github.com/rs/zerolog/log" ) type ChatbotService interface { @@ -43,23 +44,32 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str } if usage >= constants.MaxDailyAIUsage { - return "", errors.New("you have reached your daily limit of 10 questions. Please come back tomorrow") + return "", fmt.Errorf("you have reached your daily limit of %d questions. Please come back tomorrow", constants.MaxDailyAIUsage) } qVector, err := s.ragUtils.EmbedQuery(ctx, question) if err != nil { return "", fmt.Errorf("failed to embed question: %w", err) } - results, err := s.repo.SearchSimilar(ctx, projectID, qVector, 5, 0.65) + + results, err := s.repo.SearchSimilar(ctx, projectID, qVector, 8, 0.50) if err != nil { return "", fmt.Errorf("failed to search similar content: %w", err) } - contextStr := "" - for i, res := range results { - contextStr += fmt.Sprintf("[Document %d]: %s\n", i+1, res.Content) + if len(results) < 3 { + broadResults, err := s.repo.SearchSimilar(ctx, projectID, qVector, 8, 0.35) + if err == nil && len(broadResults) > len(results) { + results = broadResults + } } + var contextBuilder strings.Builder + for i, res := range results { + contextBuilder.WriteString(fmt.Sprintf("[Document %d (score: %.2f)]: %s\n", i+1, res.Similarity, res.Content)) + } + contextStr := contextBuilder.String() + pgUserID, err := convert.StringToUUID(userID) if err != nil { return "", fmt.Errorf("invalid user id: %w", err) @@ -70,13 +80,14 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str Limit: 10, }) if err != nil { - fmt.Printf("Warning: failed to get chatbot history: %v\n", err) + log.Warn().Err(err).Msg("failed to get chatbot history") } - historyStr := "" + var historyBuilder strings.Builder for _, h := range histories { - historyStr += fmt.Sprintf("User: %s\nAssistant: %s\n\n", h.Question, h.Answer) + historyBuilder.WriteString(fmt.Sprintf("User: %s\nAssistant: %s\n\n", h.Question, h.Answer)) } + historyStr := historyBuilder.String() var prompt string if contextStr == "" { @@ -115,7 +126,9 @@ Question: %s`, contextStr, historyStr, question) if err != nil { return "", err } - _, _ = s.usageRepo.IncrementAIUsage(ctx, userID) + if _, err := s.usageRepo.IncrementAIUsage(ctx, userID); err != nil { + log.Warn().Err(err).Str("userID", userID).Msg("failed to increment AI usage") + } _, err = s.chatRepo.CreateChatbotHistory(ctx, sqlc.CreateChatbotHistoryParams{ UserID: pgUserID, @@ -123,7 +136,7 @@ Question: %s`, contextStr, historyStr, question) Answer: response, }) if err != nil { - fmt.Printf("Warning: failed to save chatbot history: %v\n", err) + log.Warn().Err(err).Msg("failed to save chatbot history") } return response, nil