diff --git a/README.md b/README.md index 2f31512..dae9a60 100644 --- a/README.md +++ b/README.md @@ -134,8 +134,22 @@ history-api/ OPEN_ROUTER_API= OPEN_ROUTER_MODEL= + OPEN_ROUTER_FALLBACK_MODEL=qwen/qwen3-30b-a3b-instruct-2507 OPEN_ROUTER_EMBEDDING_MODEL= + RAG_QUERY_REWRITE_ENABLED=true + RAG_REWRITE_HISTORY_TURNS=3 + RAG_RETRIEVAL_CANDIDATES=30 + RAG_CONTEXT_TOP_N=8 + RAG_GENERATION_MAX_RETRIES=2 + RAG_GENERATION_RETRY_DELAY_MS=500 + RAG_RERANK_ENABLED=true + RAG_RERANK_MODEL=cohere/rerank-4-pro + RAG_RERANK_FALLBACK_MODEL=cohere/rerank-4-fast + RAG_RERANK_TIMEOUT_SECONDS=10 + RAG_RERANK_MAX_RETRIES=2 + RAG_RERANK_RETRY_DELAY_MS=250 + GOONG_API_KEY_MAP= GOONG_API_KEY_REQ= @@ -214,4 +228,4 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file - [pgvector](https://github.com/pgvector/pgvector) - Vector similarity search for PostgreSQL - [LangChain Go](https://github.com/tmc/langchaingo) - Framework for LLM applications - [Swagger UI](https://swagger.io/tools/swagger-ui/) - Interactive API documentation -- [Rustfs](https://github.com/rustfs/rustfs) - High performance S3 compatible object storage \ No newline at end of file +- [Rustfs](https://github.com/rustfs/rustfs) - High performance S3 compatible object storage diff --git a/db/query/rag.sql b/db/query/rag.sql index 311e0bf..016745a 100644 --- a/db/query/rag.sql +++ b/db/query/rag.sql @@ -9,7 +9,7 @@ RETURNING *; -- name: SearchRagChunks :many SELECT - id, source_type, source_id, project_id, chunk_index, content, + id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at, (1 - (embedding <=> sqlc.arg('embedding')))::float8 AS similarity FROM rag_chunks WHERE 1=1 diff --git a/internal/gen/sqlc/rag.sql.go b/internal/gen/sqlc/rag.sql.go index dcc6a48..05b1ebd 100644 --- a/internal/gen/sqlc/rag.sql.go +++ b/internal/gen/sqlc/rag.sql.go @@ -74,7 +74,7 @@ func (q *Queries) DeleteRagChunksBySourceIDs(ctx context.Context, arg DeleteRagC const searchRagChunks = `-- name: SearchRagChunks :many SELECT - id, source_type, source_id, project_id, chunk_index, content, + id, source_type, source_id, project_id, chunk_index, content, created_at, updated_at, (1 - (embedding <=> $1))::float8 AS similarity FROM rag_chunks WHERE 1=1 @@ -94,13 +94,15 @@ type SearchRagChunksParams struct { } type SearchRagChunksRow struct { - ID pgtype.UUID `json:"id"` - SourceType string `json:"source_type"` - SourceID pgtype.UUID `json:"source_id"` - ProjectID pgtype.UUID `json:"project_id"` - ChunkIndex int32 `json:"chunk_index"` - Content string `json:"content"` - Similarity float64 `json:"similarity"` + ID pgtype.UUID `json:"id"` + SourceType string `json:"source_type"` + SourceID pgtype.UUID `json:"source_id"` + ProjectID pgtype.UUID `json:"project_id"` + ChunkIndex int32 `json:"chunk_index"` + Content string `json:"content"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` + Similarity float64 `json:"similarity"` } func (q *Queries) SearchRagChunks(ctx context.Context, arg SearchRagChunksParams) ([]SearchRagChunksRow, error) { @@ -125,6 +127,8 @@ func (q *Queries) SearchRagChunks(ctx context.Context, arg SearchRagChunksParams &i.ProjectID, &i.ChunkIndex, &i.Content, + &i.CreatedAt, + &i.UpdatedAt, &i.Similarity, ); err != nil { return nil, err diff --git a/internal/repositories/ragRepository.go b/internal/repositories/ragRepository.go index 259ff2a..c649f2e 100644 --- a/internal/repositories/ragRepository.go +++ b/internal/repositories/ragRepository.go @@ -67,8 +67,14 @@ func (r *ragRepository) SearchSimilar(ctx context.Context, projectID *string, ve for i, row := range rows { res[i] = &models.RagChunk{ ID: convert.UUIDToString(row.ID), + SourceType: row.SourceType, + SourceID: convert.UUIDToString(row.SourceID), + ProjectID: convert.UUIDToString(row.ProjectID), + ChunkIndex: row.ChunkIndex, Content: row.Content, Similarity: row.Similarity, + CreatedAt: row.CreatedAt.Time, + UpdatedAt: row.UpdatedAt.Time, } } return res, nil @@ -78,7 +84,7 @@ func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string if len(sourceIDs) == 0 { return nil } - + uids := make([]pgtype.UUID, 0, len(sourceIDs)) for _, id := range sourceIDs { uid, err := convert.StringToUUID(id) @@ -86,9 +92,9 @@ func (r *ragRepository) DeleteBySourceIDs(ctx context.Context, sourceType string uids = append(uids, uid) } } - + return r.q.DeleteRagChunksBySourceIDs(ctx, sqlc.DeleteRagChunksBySourceIDsParams{ SourceType: sourceType, - Column2: uids, + Column2: uids, }) } diff --git a/internal/services/chatbotService.go b/internal/services/chatbotService.go index 3f6c993..ff8875f 100644 --- a/internal/services/chatbotService.go +++ b/internal/services/chatbotService.go @@ -8,6 +8,7 @@ import ( "history-api/internal/models" "history-api/internal/repositories" "history-api/pkg/ai" + "history-api/pkg/config" "history-api/pkg/constants" "history-api/pkg/convert" "strings" @@ -52,25 +53,49 @@ func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *str return "", fmt.Errorf("invalid user id: %w", err) } - qVector, err := s.ragUtils.EmbedQuery(ctx, question) + searchQuery := strings.TrimSpace(question) + rewriteHistory := s.getRewriteHistory(ctx, pgUserID) + rewrittenQuery, err := s.ragUtils.RewriteSearchQuery(ctx, question, rewriteHistory) + if err != nil { + log.Warn().Err(err).Msg("failed to rewrite RAG search query") + } else if strings.TrimSpace(rewrittenQuery) != "" { + searchQuery = strings.TrimSpace(rewrittenQuery) + } + + candidateLimit := config.GetIntConfigWithDefault("RAG_RETRIEVAL_CANDIDATES", 30) + if candidateLimit < 8 { + candidateLimit = 8 + } + + contextLimit := config.GetIntConfigWithDefault("RAG_CONTEXT_TOP_N", 8) + if contextLimit <= 0 { + contextLimit = 8 + } + if contextLimit > candidateLimit { + contextLimit = candidateLimit + } + + qVector, err := s.ragUtils.EmbedQuery(ctx, searchQuery) if err != nil { return "", fmt.Errorf("failed to embed question: %w", err) } - results, err := s.repo.SearchSimilar(ctx, projectID, qVector, 8, 0.50) + results, err := s.repo.SearchSimilar(ctx, projectID, qVector, candidateLimit, 0.45) if err != nil { return "", fmt.Errorf("failed to search similar content: %w", err) } - if len(results) < 3 { - broadResults, err := s.repo.SearchSimilar(ctx, projectID, qVector, 8, 0.35) + if len(results) < contextLimit { + broadResults, err := s.repo.SearchSimilar(ctx, projectID, qVector, candidateLimit, 0.30) if err == nil && len(broadResults) > len(results) { results = broadResults } } + results = s.rerankResults(ctx, searchQuery, results, contextLimit) + var contextBuilder strings.Builder - contextBuilder.Grow(len(results) * 96) + contextBuilder.Grow(len(results) * 128) for i, res := range results { contextBuilder.WriteString(fmt.Sprintf("\n%s\n\n\n", i+1, res.Similarity, res.Content)) } @@ -115,6 +140,8 @@ Rules: - Do not cite documents unless the user asks. - Your final response MUST be wrapped inside tags. - Do not output anything outside tags. +- Before writing the final answer, silently verify that each factual sentence is supported by Context. +- Keep the answer focused on the user's question. Do not add background information that is not needed. - Answer in complete, natural, grammatically correct sentences.`, contextStr, question) } @@ -141,6 +168,99 @@ Rules: return response, nil } +func (s *chatbotService) getRewriteHistory(ctx context.Context, userID pgtype.UUID) []ai.ChatTurn { + limit := config.GetIntConfigWithDefault("RAG_REWRITE_HISTORY_TURNS", 3) + if limit <= 0 { + return nil + } + if limit > 5 { + limit = 5 + } + + history, err := s.chatRepo.GetChatbotHistory(ctx, sqlc.GetChatbotHistoryParams{ + UserID: userID, + Limit: int32(limit), + }) + if err != nil { + log.Warn().Err(err).Msg("failed to load chatbot history for RAG query rewrite") + return nil + } + + turns := make([]ai.ChatTurn, 0, len(history)) + for _, item := range history { + if item == nil { + continue + } + turns = append(turns, ai.ChatTurn{ + Question: item.Question, + Answer: item.Answer, + }) + } + + return turns +} + +func (s *chatbotService) rerankResults(ctx context.Context, query string, results []*models.RagChunk, limit int) []*models.RagChunk { + if len(results) == 0 { + return results + } + if limit <= 0 { + limit = len(results) + } + if limit > len(results) { + limit = len(results) + } + + documents := make([]string, len(results)) + for i, result := range results { + documents[i] = result.Content + } + + ranked, err := s.ragUtils.RerankDocuments(ctx, query, documents, limit) + if err != nil { + log.Warn().Err(err).Msg("failed to rerank RAG results") + return limitRagResults(results, limit) + } + if len(ranked) == 0 { + return limitRagResults(results, limit) + } + + selected := make([]*models.RagChunk, 0, limit) + seen := make(map[int]struct{}, limit) + for _, item := range ranked { + if item.Index < 0 || item.Index >= len(results) { + continue + } + if _, exists := seen[item.Index]; exists { + continue + } + seen[item.Index] = struct{}{} + selected = append(selected, results[item.Index]) + if len(selected) >= limit { + return selected + } + } + + for i, result := range results { + if _, exists := seen[i]; exists { + continue + } + selected = append(selected, result) + if len(selected) >= limit { + break + } + } + + return selected +} + +func limitRagResults(results []*models.RagChunk, limit int) []*models.RagChunk { + if limit <= 0 || len(results) <= limit { + return results + } + return results[:limit] +} + func normalizeAnswer(s string) string { s = strings.TrimSpace(s) @@ -157,6 +277,7 @@ func normalizeAnswer(s string) string { return s } + func (s *chatbotService) GetHistory(ctx context.Context, userID string, dto *request.GetChatbotHistoryDto) ([]*models.ChatbotHistoryEntity, error) { pgUserID, err := convert.StringToUUID(userID) if err != nil { diff --git a/pkg/ai/rag.go b/pkg/ai/rag.go index f34c063..8687154 100644 --- a/pkg/ai/rag.go +++ b/pkg/ai/rag.go @@ -1,12 +1,17 @@ package ai import ( + "bytes" "context" "fmt" "history-api/pkg/config" + json "history-api/pkg/jsonx" "html" + "io" + "net/http" "regexp" "strings" + "time" "github.com/tmc/langchaingo/embeddings" "github.com/tmc/langchaingo/llms" @@ -15,31 +20,49 @@ import ( ) type RagUtils struct { - llm llms.Model - embedder *embeddings.EmbedderImpl + llm llms.Model + embedder *embeddings.EmbedderImpl + httpClient *http.Client + openRouterAPIKey string + model string + fallbackModel string + generationMaxRetries int + generationRetryDelay time.Duration + rerankEnabled bool + rerankModel string + rerankFallbackModel string + rerankMaxRetries int + rerankRetryDelay time.Duration + queryRewriteEnabled bool } var htmlTagRegex = regexp.MustCompile(`<[^>]*>`) +const openRouterBaseURL = "https://openrouter.ai/api/v1" + +type RerankResult struct { + Index int + Score float64 +} + +type ChatTurn struct { + Question string + Answer string +} + func NewRagUtils() (*RagUtils, error) { - openRouterAPIKey, err := config.GetConfig("OPEN_ROUTER_API") - if err != nil { - return nil, err + openRouterAPIKey := config.GetConfigWithDefault("OPEN_ROUTER_API", "") + if openRouterAPIKey == "" { + return nil, fmt.Errorf("OPEN_ROUTER_API is not set") } - model, err := config.GetConfig("OPEN_ROUTER_MODEL") - if err != nil { - model = "qwen/qwen3.5-flash-02-23" - } + model := config.GetConfigWithDefault("OPEN_ROUTER_MODEL", "qwen/qwen3.5-flash-02-23") - embeddingModel, err := config.GetConfig("OPEN_ROUTER_EMBEDDING_MODEL") - if err != nil { - embeddingModel = "qwen/qwen3-embedding-8b" - } + embeddingModel := config.GetConfigWithDefault("OPEN_ROUTER_EMBEDDING_MODEL", "qwen/qwen3-embedding-8b") llm, err := openai.New( openai.WithToken(openRouterAPIKey), - openai.WithBaseURL("https://openrouter.ai/api/v1"), + openai.WithBaseURL(openRouterBaseURL), openai.WithModel(model), openai.WithEmbeddingModel(embeddingModel), ) @@ -52,9 +75,52 @@ func NewRagUtils() (*RagUtils, error) { return nil, fmt.Errorf("failed to init embedder: %w", err) } + timeoutSeconds := config.GetIntConfigWithDefault("RAG_RERANK_TIMEOUT_SECONDS", 10) + if timeoutSeconds <= 0 { + timeoutSeconds = 10 + } + + rerankMaxRetries := config.GetIntConfigWithDefault("RAG_RERANK_MAX_RETRIES", 2) + if rerankMaxRetries < 0 { + rerankMaxRetries = 0 + } + if rerankMaxRetries > 5 { + rerankMaxRetries = 5 + } + + rerankRetryDelayMs := config.GetIntConfigWithDefault("RAG_RERANK_RETRY_DELAY_MS", 250) + if rerankRetryDelayMs <= 0 { + rerankRetryDelayMs = 250 + } + + generationMaxRetries := config.GetIntConfigWithDefault("RAG_GENERATION_MAX_RETRIES", 2) + if generationMaxRetries < 0 { + generationMaxRetries = 0 + } + if generationMaxRetries > 5 { + generationMaxRetries = 5 + } + + generationRetryDelayMs := config.GetIntConfigWithDefault("RAG_GENERATION_RETRY_DELAY_MS", 500) + if generationRetryDelayMs <= 0 { + generationRetryDelayMs = 500 + } + return &RagUtils{ - llm: llm, - embedder: embedder, + llm: llm, + embedder: embedder, + httpClient: &http.Client{Timeout: time.Duration(timeoutSeconds) * time.Second}, + openRouterAPIKey: openRouterAPIKey, + model: model, + fallbackModel: config.GetConfigWithDefault("OPEN_ROUTER_FALLBACK_MODEL", "qwen/qwen3-30b-a3b-instruct-2507"), + generationMaxRetries: generationMaxRetries, + generationRetryDelay: time.Duration(generationRetryDelayMs) * time.Millisecond, + rerankEnabled: config.GetBoolConfigWithDefault("RAG_RERANK_ENABLED", true), + rerankModel: config.GetConfigWithDefault("RAG_RERANK_MODEL", "cohere/rerank-4-pro"), + rerankFallbackModel: config.GetConfigWithDefault("RAG_RERANK_FALLBACK_MODEL", "cohere/rerank-4-fast"), + rerankMaxRetries: rerankMaxRetries, + rerankRetryDelay: time.Duration(rerankRetryDelayMs) * time.Millisecond, + queryRewriteEnabled: config.GetBoolConfigWithDefault("RAG_QUERY_REWRITE_ENABLED", true), }, nil } @@ -103,14 +169,308 @@ func (u *RagUtils) EmbedQuery(ctx context.Context, query string) ([]float32, err return vector, nil } -func (u *RagUtils) GenerateResponse(ctx context.Context, prompt string) (string, error) { +func (u *RagUtils) RewriteSearchQuery(ctx context.Context, question string, history []ChatTurn) (string, error) { + question = normalizeWhitespace(question) + if question == "" || !u.queryRewriteEnabled { + return question, nil + } + + historyText := buildHistoryForRewrite(history) + if historyText == "" { + historyText = "No previous chat turns." + } + + prompt := fmt.Sprintf(`You rewrite user questions into precise retrieval queries for a history RAG system. + +Previous chat turns: +%s + +Current user question: +%s + +Rules: +- Output exactly one retrieval query inside tags. +- Do not answer the question. +- Do not add names, dates, places, causes, results, or facts unless they are explicitly present in the current question or previous chat turns. +- Use previous chat turns only to resolve clear references such as "đó", "ông ấy", "bà ấy", "sự kiện này", "that event", or "he/she". +- Preserve proper nouns, dates, quoted text, and historical terms exactly when possible. +- If the current question is already clear, return it unchanged. +- If a reference is ambiguous, keep the ambiguous wording instead of guessing. +- Use the same language as the current question. +- No markdown, no bullet points, no explanation. + +`, historyText, question) + raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt) if err != nil { return "", err } + + query := normalizeGeneratedQuery(raw) + if query == "" { + return question, nil + } + + return query, nil +} + +func (u *RagUtils) RerankDocuments(ctx context.Context, query string, documents []string, topN int) ([]RerankResult, error) { + if !u.rerankEnabled || len(documents) == 0 { + return nil, nil + } + if topN <= 0 || topN > len(documents) { + topN = len(documents) + } + + model := strings.TrimSpace(u.rerankModel) + if model == "" { + model = "cohere/rerank-4-pro" + } + + results, err := u.rerankDocumentsWithModel(ctx, model, query, documents, topN) + if err == nil { + return results, nil + } + + fallbackModel := strings.TrimSpace(u.rerankFallbackModel) + if fallbackModel == "" || fallbackModel == model { + return nil, err + } + + fallbackResults, fallbackErr := u.rerankDocumentsWithModel(ctx, fallbackModel, query, documents, topN) + if fallbackErr == nil { + return fallbackResults, nil + } + + return nil, fmt.Errorf("rerank failed with model %s: %w; fallback model %s failed: %v", model, err, fallbackModel, fallbackErr) +} + +func (u *RagUtils) rerankDocumentsWithModel(ctx context.Context, model, query string, documents []string, topN int) ([]RerankResult, error) { + reqBody := rerankRequest{ + Model: model, + Query: query, + Documents: documents, + TopN: topN, + } + + payload, err := json.Marshal(reqBody) + if err != nil { + return nil, err + } + + var lastErr error + for attempt := 0; attempt <= u.rerankMaxRetries; attempt++ { + results, retryable, err := u.sendRerankRequest(ctx, payload, documents) + if err == nil { + return results, nil + } + + lastErr = err + if !retryable || attempt == u.rerankMaxRetries { + break + } + + delay := u.rerankRetryDelay * time.Duration(1< 2*time.Second { + delay = 2 * time.Second + } + if sleepErr := sleepWithContext(ctx, delay); sleepErr != nil { + return nil, sleepErr + } + } + + return nil, fmt.Errorf("rerank failed with model %s: %w", model, lastErr) +} + +func (u *RagUtils) sendRerankRequest(ctx context.Context, payload []byte, documents []string) ([]RerankResult, bool, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, openRouterBaseURL+"/rerank", bytes.NewReader(payload)) + if err != nil { + return nil, false, err + } + req.Header.Set("Authorization", "Bearer "+u.openRouterAPIKey) + req.Header.Set("Content-Type", "application/json") + + res, err := u.httpClient.Do(req) + if err != nil { + if ctx.Err() != nil { + return nil, false, ctx.Err() + } + return nil, true, err + } + defer res.Body.Close() + + if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices { + body, _ := io.ReadAll(io.LimitReader(res.Body, 4096)) + return nil, isRetryableRerankStatus(res.StatusCode), fmt.Errorf("rerank request failed with status %d: %s", res.StatusCode, strings.TrimSpace(string(body))) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, true, err + } + + var rerankRes rerankResponse + if err := json.Unmarshal(body, &rerankRes); err != nil { + return nil, false, err + } + + results := make([]RerankResult, 0, len(rerankRes.Results)) + for _, item := range rerankRes.Results { + if item.Index < 0 || item.Index >= len(documents) { + continue + } + results = append(results, RerankResult{ + Index: item.Index, + Score: item.RelevanceScore, + }) + } + + return results, false, nil +} + +func (u *RagUtils) GenerateResponse(ctx context.Context, prompt string) (string, error) { + raw, err := u.generateSinglePromptWithRetry(ctx, prompt, u.model) + if err != nil && u.fallbackModel != "" && u.fallbackModel != u.model { + raw, err = u.generateSinglePromptWithRetry(ctx, prompt, u.fallbackModel) + } + if err != nil { + return "", err + } return stripThinking(raw), nil } +func (u *RagUtils) generateSinglePromptWithRetry(ctx context.Context, prompt, model string) (string, error) { + var lastErr error + options := make([]llms.CallOption, 0, 1) + if strings.TrimSpace(model) != "" { + options = append(options, llms.WithModel(model)) + } + + for attempt := 0; attempt <= u.generationMaxRetries; attempt++ { + raw, err := llms.GenerateFromSinglePrompt(ctx, u.llm, prompt, options...) + if err == nil { + return raw, nil + } + + if ctx.Err() != nil { + return "", ctx.Err() + } + + lastErr = err + if attempt == u.generationMaxRetries { + break + } + + delay := u.generationRetryDelay * time.Duration(1< 3*time.Second { + delay = 3 * time.Second + } + if sleepErr := sleepWithContext(ctx, delay); sleepErr != nil { + return "", sleepErr + } + } + + if strings.TrimSpace(model) == "" { + return "", lastErr + } + + return "", fmt.Errorf("generation failed with model %s: %w", model, lastErr) +} + +type rerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + Documents []string `json:"documents"` + TopN int `json:"top_n"` +} + +type rerankResponse struct { + Results []struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` + } `json:"results"` +} + +func buildHistoryForRewrite(history []ChatTurn) string { + if len(history) == 0 { + return "" + } + + var builder strings.Builder + for i, turn := range history { + question := normalizeWhitespace(turn.Question) + answer := normalizeWhitespace(turn.Answer) + if question == "" && answer == "" { + continue + } + builder.WriteString(fmt.Sprintf("Turn %d\nUser: %s\nAssistant: %s\n", i+1, question, answer)) + } + + return strings.TrimSpace(builder.String()) +} + +func normalizeGeneratedQuery(raw string) string { + raw = strings.TrimSpace(raw) + + if query := extractTag(raw, "query"); query != "" { + return normalizeWhitespace(query) + } + + raw = strings.TrimSpace(strings.TrimPrefix(raw, "Query:")) + raw = strings.TrimSpace(strings.TrimPrefix(raw, "Retrieval query:")) + raw = strings.Trim(raw, "\"'`") + + return normalizeWhitespace(raw) +} + +func extractTag(raw, tag string) string { + startTag := "<" + tag + ">" + endTag := "" + + start := strings.LastIndex(raw, startTag) + if start == -1 { + return "" + } + + content := raw[start+len(startTag):] + end := strings.Index(content, endTag) + if end == -1 { + return strings.TrimSpace(content) + } + + return strings.TrimSpace(content[:end]) +} + +func normalizeWhitespace(s string) string { + return strings.Join(strings.Fields(strings.TrimSpace(s)), " ") +} + +func isRetryableRerankStatus(statusCode int) bool { + switch statusCode { + case http.StatusRequestTimeout, + http.StatusTooManyRequests, + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout: + return true + default: + return false + } +} + +func sleepWithContext(ctx context.Context, delay time.Duration) error { + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + func stripThinking(raw string) string { startTag := "" endTag := ""