From 99b4e674dbbe8256752f003c7543c489514a7a13 Mon Sep 17 00:00:00 2001 From: Teerth Sharma Date: Fri, 19 Jun 2026 09:56:04 +0530 Subject: [PATCH 1/4] feat: add adaptive rag prefetching Signed-off-by: Teerth Sharma --- agent-schema.json | 42 +++ .../2026-06-18-adaptive-rag-prefetcher.md | 88 +++++ ...26-06-18-adaptive-rag-prefetcher-design.md | 79 +++++ docs/tools/rag/index.md | 26 +- examples/rag/adaptive_prefetch.yaml | 55 ++++ pkg/config/latest/types.go | 13 + pkg/config/schema_test.go | 1 + pkg/rag/builder.go | 16 + pkg/rag/manager.go | 85 +++-- pkg/rag/manager_test.go | 60 ++++ pkg/rag/prefetch/prefetch.go | 300 ++++++++++++++++++ pkg/rag/prefetch/prefetch_test.go | 93 ++++++ 12 files changed, 824 insertions(+), 34 deletions(-) create mode 100644 docs/superpowers/plans/2026-06-18-adaptive-rag-prefetcher.md create mode 100644 docs/superpowers/specs/2026-06-18-adaptive-rag-prefetcher-design.md create mode 100644 examples/rag/adaptive_prefetch.yaml create mode 100644 pkg/rag/prefetch/prefetch.go create mode 100644 pkg/rag/prefetch/prefetch_test.go diff --git a/agent-schema.json b/agent-schema.json index ef22f4332..03f85e5fd 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -2690,6 +2690,48 @@ ], "additionalProperties": false }, + "prefetch": { + "type": "object", + "description": "Optional adaptive query prefetching. When enabled, docker-agent caches repeated RAG queries and warms bounded follow-up candidates in the background.", + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable adaptive RAG query prefetching.", + "default": false + }, + "max_entries": { + "type": "integer", + "description": "Maximum number of cached query result sets.", + "minimum": 1, + "default": 32 + }, + "max_candidates": { + "type": "integer", + "description": "Maximum number of follow-up query candidates to prefetch after a cache miss.", + "minimum": 1, + "default": 2 + }, + "min_similarity": { + "type": "number", + "description": "Minimum result similarity required before a source path can seed a follow-up prefetch candidate.", + "minimum": 0, + "maximum": 1, + "default": 0.5 + }, + "drift_threshold": { + "type": "number", + "description": "Maximum topology drift score that still allows background prefetching.", + "minimum": 0, + "default": 0.8 + }, + "timeout": { + "type": "string", + "description": "Maximum duration for a background prefetch query, using Go duration syntax such as '10s' or '1m'.", + "default": "10s" + } + }, + "additionalProperties": false + }, "deduplicate": { "type": "boolean", "description": "Remove duplicate documents across strategies", diff --git a/docs/superpowers/plans/2026-06-18-adaptive-rag-prefetcher.md b/docs/superpowers/plans/2026-06-18-adaptive-rag-prefetcher.md new file mode 100644 index 000000000..507d18ee0 --- /dev/null +++ b/docs/superpowers/plans/2026-06-18-adaptive-rag-prefetcher.md @@ -0,0 +1,88 @@ +# Adaptive RAG Prefetcher Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add an opt-in adaptive RAG prefetcher that caches exact repeat queries and warms deterministic follow-up candidates when recent RAG topology is stable. + +**Architecture:** Implement a small `pkg/rag/prefetch` package with bounded cache, topology tracker, and background scheduler. Wire it into `pkg/rag.Manager` as an optional layer around the existing query pipeline, leaving strategy implementations unchanged. + +**Tech Stack:** Go, existing RAG strategy interfaces, existing config/latest schema, existing docs/examples. + +## Global Constraints + +- Do not add an Aether-Lang dependency. +- Only change latest config, not frozen config versions. +- Feature is opt-in with `results.prefetch.enabled`. +- Background prefetch must be bounded, cancellable, and non-blocking. +- Commits must use DCO sign-off via `git commit -s`. +- Validate with `task lint` and `task test` before PR. + +--- + +### Task 1: Prefetch Package + +**Files:** +- Create: `pkg/rag/prefetch/prefetch.go` +- Create: `pkg/rag/prefetch/prefetch_test.go` + +**Interfaces:** +- Produces: `Config`, `Prefetcher`, `New(Config) *Prefetcher` +- Produces: `Get(query string) ([]database.SearchResult, bool)`, `Store(query string, results []database.SearchResult)`, `Observe(query string, results []database.SearchResult)`, `Candidates(query string, results []database.SearchResult) []string`, `Prefetch(ctx context.Context, query string, fn FetchFunc)` + +- [ ] Write tests for disabled config, bounded cache eviction, stable candidate generation, and drift suppression. +- [ ] Implement config defaults and normalization. +- [ ] Implement topology tracker using query length, term count, result count, and average similarity. +- [ ] Implement candidate generation from query text and top source path basenames. +- [ ] Implement bounded background prefetch with timeout and in-flight de-duplication. +- [ ] Run `go test ./pkg/rag/prefetch`. +- [ ] Commit with `git commit -s -m "feat: add rag prefetch primitives"`. + +### Task 2: Manager Integration + +**Files:** +- Modify: `pkg/rag/manager.go` +- Modify: `pkg/rag/builder.go` +- Modify: `pkg/rag/manager_test.go` + +**Interfaces:** +- Consumes: `prefetch.Config`, `prefetch.Prefetcher` +- Produces: optional manager-level prefetch behavior for `Manager.Query` + +- [ ] Add prefetch config to `rag.Config` and create the prefetcher in `New`. +- [ ] Extract current query logic into an unexported `queryUncached(ctx, query string)` helper. +- [ ] Make `Query` check exact cache hits first when enabled. +- [ ] Store successful final results and schedule background candidate prefetches after cache misses. +- [ ] Add integration tests using a fake strategy to prove a second identical query is served from cache. +- [ ] Run `go test ./pkg/rag`. +- [ ] Commit with `git commit -s -m "feat: wire adaptive prefetch into rag manager"`. + +### Task 3: Config, Schema, Docs, Example + +**Files:** +- Modify: `pkg/config/latest/types.go` +- Modify: `agent-schema.json` +- Modify: `docs/tools/rag/index.md` +- Create: `examples/rag/adaptive_prefetch.yaml` + +**Interfaces:** +- Produces: `latest.RAGPrefetchConfig` +- Wires: `latest.RAGResultsConfig.Prefetch *RAGPrefetchConfig` + +- [ ] Add `RAGPrefetchConfig` with `enabled`, `max_entries`, `max_candidates`, `min_similarity`, `drift_threshold`, and `timeout`. +- [ ] Add schema definition/properties matching Go JSON tags. +- [ ] Document the feature and its conservative defaults in RAG docs. +- [ ] Add a runnable example using hybrid RAG plus `results.prefetch`. +- [ ] Run `go test ./pkg/config`. +- [ ] Commit with `git commit -s -m "docs: document adaptive rag prefetching"`. + +### Task 4: Final Validation and PR + +**Files:** +- Modify as needed from validation findings only. + +- [ ] Run `task lint`. +- [ ] Run `task test`. +- [ ] Run `task build`. +- [ ] Inspect `git diff --stat main...HEAD`. +- [ ] Push branch to fork remote. +- [ ] Open draft PR against `docker/docker-agent:main` linking issue `#3164`. diff --git a/docs/superpowers/specs/2026-06-18-adaptive-rag-prefetcher-design.md b/docs/superpowers/specs/2026-06-18-adaptive-rag-prefetcher-design.md new file mode 100644 index 000000000..beb4705d0 --- /dev/null +++ b/docs/superpowers/specs/2026-06-18-adaptive-rag-prefetcher-design.md @@ -0,0 +1,79 @@ +# Adaptive RAG Prefetcher Design + +## Context + +Issue: https://github.com/docker/docker-agent/issues/3164 + +Aether-Lang contains useful algorithms for sparse attention graphs, hierarchical block metadata, adaptive epsilon thresholds, and centroid drift detection. docker-agent should not import Aether-Lang or add Rust/runtime dependencies. The contribution should translate the useful ideas into small Go primitives that fit the existing RAG manager and strategy interfaces. + +## Goal + +Add an opt-in adaptive RAG prefetcher that reduces repeated retrieval latency and warms likely follow-up queries without blocking the active user turn. + +## Non-Goals + +- No Aether-Lang dependency. +- No new DSL, kernel, or model provider. +- No replacement of existing RAG strategies, fusion, or reranking. +- No hidden behavior when config does not enable the feature. + +## Design + +The RAG manager gets an optional prefetcher configured under `results.prefetch`. When disabled, `Manager.Query` keeps current behavior. + +When enabled: + +1. `Manager.Query` checks a small in-memory cache keyed by normalized query text. +2. Cache hit returns the final post-processed results from the earlier query. +3. Cache miss runs the normal strategy/fusion/reranking pipeline. +4. The prefetcher observes the query and result metadata. +5. If topology is stable enough, it schedules a bounded background prefetch for one or more deterministic follow-up candidates. + +The prefetcher computes a lightweight state vector from query/result metadata: + +- query length +- token count +- result count +- average similarity + +It tracks a centroid and drift score over recent observations. A simple adaptive threshold accepts stable sessions and suppresses prefetch during strong topic shifts. + +## Configuration + +Add to the latest config only: + +```yaml +results: + prefetch: + enabled: true + max_entries: 32 + max_candidates: 2 + min_similarity: 0.5 + drift_threshold: 0.8 + timeout: 10s +``` + +Defaults are conservative. `enabled` defaults false. + +## Components + +- `pkg/rag/prefetch`: owns cache, topology tracker, candidate generation, and background scheduling. +- `pkg/rag/manager.go`: wires cache lookup, observation, and background scheduling around existing query behavior. +- `pkg/config/latest/types.go`: latest-only config structs. +- `agent-schema.json`: schema sync for config fields. +- `docs/tools/rag/index.md`: user-facing docs. +- `examples/rag/adaptive_prefetch.yaml`: runnable example config. + +## Safety + +- Background prefetch uses a timeout and a max in-flight limit. +- Prefetch errors are logged at debug level and never fail the active query. +- Results cache is bounded and evicts oldest entries. +- Query work is skipped while drift exceeds threshold. +- Prefetch stores only normal RAG search results already available to the running process. + +## Testing + +- Unit tests for cache eviction, candidate generation, drift suppression, and scheduling. +- Manager integration test for cache hit avoiding a second strategy call. +- Config/schema test coverage through existing schema sync tests. diff --git a/docs/tools/rag/index.md b/docs/tools/rag/index.md index 35576f9d4..3c524fc7b 100644 --- a/docs/tools/rag/index.md +++ b/docs/tools/rag/index.md @@ -16,6 +16,7 @@ The `rag` toolset lets agents search through your documents to find relevant inf - **Multiple strategies** — Semantic embeddings, BM25 keyword search, and LLM-enhanced search - **Hybrid search** — Combine strategies with result fusion for best results - **Reranking** — Re-score results with specialized models for improved relevance +- **Adaptive prefetching** — Cache repeated queries and warm stable follow-up searches ## Quick Start @@ -156,6 +157,23 @@ results: Supported reranking providers: **DMR** (native `/rerank` endpoint), **OpenAI**, **Anthropic**, **Gemini**. +## Adaptive Prefetching + +Adaptive prefetching is opt-in. It caches repeated RAG queries and, when query topology is stable, warms a small number of deterministic follow-up candidates in the background. + +```yaml +results: + prefetch: + enabled: true + max_entries: 32 + max_candidates: 2 + min_similarity: 0.5 + drift_threshold: 0.8 + timeout: 10s +``` + +The prefetcher is bounded and non-blocking. Exact repeated queries can be served from cache. Background candidate prefetch is skipped when reranking is enabled so reranker errors and fallback behavior stay tied to the foreground query. + ## Code-Aware Chunking For source code, enable AST-based chunking to keep functions and methods intact: @@ -263,6 +281,12 @@ Look for log tags: `[RAG Manager]`, `[Chunked-Embeddings Strategy]`, `[BM25 Stra | `include_score` | bool | `false` | Include relevance scores in results | | `return_full_content` | bool | `false` | Return full document content instead of just matched chunks | | `reranking.model` | string | — | Reranking model reference | -| `reranking.top_k` | int | (`limit`) | Only rerank top K results. Defaults to the results `limit` when set. | +| `reranking.top_k` | int | (`limit`) | Only rerank top K results. Defaults to the results `limit` when set. | | `reranking.threshold` | float | `0.5` | Minimum relevance score after reranking | | `reranking.criteria` | string | — | Custom relevance guidance for the reranking model | +| `prefetch.enabled` | bool | `false` | Enable adaptive query prefetching | +| `prefetch.max_entries` | int | `32` | Maximum cached query result sets | +| `prefetch.max_candidates` | int | `2` | Maximum follow-up candidates warmed after a cache miss | +| `prefetch.min_similarity` | float | `0.5` | Minimum similarity score for source-derived candidates | +| `prefetch.drift_threshold` | float | `0.8` | Maximum topology drift that still allows background prefetch | +| `prefetch.timeout` | string | `10s` | Timeout for each background prefetch query | diff --git a/examples/rag/adaptive_prefetch.yaml b/examples/rag/adaptive_prefetch.yaml new file mode 100644 index 000000000..29602a07f --- /dev/null +++ b/examples/rag/adaptive_prefetch.yaml @@ -0,0 +1,55 @@ +# This example demonstrates adaptive RAG prefetching for repeated and +# topology-stable follow-up queries. + +agents: + root: + model: openai/gpt-5-mini + description: assistant with adaptive RAG prefetching + instruction: | + You are a helpful assistant with access to hybrid retrieval. + Use the knowledge base before answering questions about blorks. + toolsets: + - type: rag + ref: prefetched_knowledge + +rag: + prefetched_knowledge: + tool: + description: to be used to search for information about blorks + docs: + - ./blork_field_guide.txt + strategies: + - type: chunked-embeddings + embedding_model: openai/text-embedding-3-small + database: ./adaptive_prefetch_embeddings.db + vector_dimensions: 1536 + similarity_metric: cosine_similarity + threshold: 0.5 + limit: 20 + chunking: + size: 1000 + overlap: 100 + respect_word_boundaries: true + - type: bm25 + database: ./adaptive_prefetch_bm25.db + k1: 1.5 + b: 0.75 + threshold: 0.3 + limit: 15 + chunking: + size: 1000 + overlap: 100 + respect_word_boundaries: true + results: + fusion: + strategy: rrf + k: 60 + deduplicate: true + limit: 5 + prefetch: + enabled: true + max_entries: 32 + max_candidates: 2 + min_similarity: 0.5 + drift_threshold: 0.8 + timeout: 10s diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index b00fc4f13..c1ed336bc 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -1813,11 +1813,22 @@ type RAGResultsConfig struct { Limit int `json:"limit,omitempty"` // Maximum number of results to return (top K) Fusion *RAGFusionConfig `json:"fusion,omitempty"` // How to combine results from multiple strategies Reranking *RAGRerankingConfig `json:"reranking,omitempty"` // Optional reranking configuration + Prefetch *RAGPrefetchConfig `json:"prefetch,omitempty"` // Optional adaptive query prefetching Deduplicate bool `json:"deduplicate,omitempty"` // Remove duplicate documents across strategies IncludeScore bool `json:"include_score,omitempty"` // Include relevance scores in results ReturnFullContent bool `json:"return_full_content,omitempty"` // Return full document content instead of just matched chunks } +// RAGPrefetchConfig configures adaptive RAG query prefetching. +type RAGPrefetchConfig struct { + Enabled bool `json:"enabled,omitempty"` + MaxEntries int `json:"max_entries,omitempty"` + MaxCandidates int `json:"max_candidates,omitempty"` + MinSimilarity float64 `json:"min_similarity,omitempty"` + DriftThreshold float64 `json:"drift_threshold,omitempty"` + Timeout Duration `json:"timeout,omitempty"` +} + // RAGRerankingConfig represents reranking configuration type RAGRerankingConfig struct { Model string `json:"model"` // Model reference for reranking (e.g., "hf.co/ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF") @@ -1871,6 +1882,7 @@ func (r *RAGResultsConfig) UnmarshalYAML(unmarshal func(any) error) error { Limit int `json:"limit,omitempty"` Fusion *RAGFusionConfig `json:"fusion,omitempty"` Reranking *RAGRerankingConfig `json:"reranking,omitempty"` + Prefetch *RAGPrefetchConfig `json:"prefetch,omitempty"` Deduplicate *bool `json:"deduplicate,omitempty"` IncludeScore *bool `json:"include_score,omitempty"` ReturnFullContent *bool `json:"return_full_content,omitempty"` @@ -1889,6 +1901,7 @@ func (r *RAGResultsConfig) UnmarshalYAML(unmarshal func(any) error) error { } r.Fusion = raw.Fusion r.Reranking = raw.Reranking + r.Prefetch = raw.Prefetch if raw.Deduplicate != nil { r.Deduplicate = *raw.Deduplicate diff --git a/pkg/config/schema_test.go b/pkg/config/schema_test.go index 2c3d6c981..df9f062ec 100644 --- a/pkg/config/schema_test.go +++ b/pkg/config/schema_test.go @@ -126,6 +126,7 @@ func TestSchemaMatchesGoTypes(t *testing.T) { {reflect.TypeFor[latest.RAGResultsConfig](), []string{"RAGConfig", "results"}, "RAGResultsConfig (RAGConfig.results)"}, {reflect.TypeFor[latest.RAGFusionConfig](), []string{"RAGConfig", "results", "fusion"}, "RAGFusionConfig (RAGConfig.results.fusion)"}, {reflect.TypeFor[latest.RAGRerankingConfig](), []string{"RAGConfig", "results", "reranking"}, "RAGRerankingConfig (RAGConfig.results.reranking)"}, + {reflect.TypeFor[latest.RAGPrefetchConfig](), []string{"RAGConfig", "results", "prefetch"}, "RAGPrefetchConfig (RAGConfig.results.prefetch)"}, {reflect.TypeFor[latest.RAGChunkingConfig](), []string{"RAGConfig", "strategies", "*", "chunking"}, "RAGChunkingConfig (RAGConfig.strategies[].chunking)"}, } diff --git a/pkg/rag/builder.go b/pkg/rag/builder.go index 2553bf9a3..8853c00c7 100644 --- a/pkg/rag/builder.go +++ b/pkg/rag/builder.go @@ -11,6 +11,7 @@ import ( "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/rag/prefetch" "github.com/docker/docker-agent/pkg/rag/rerank" "github.com/docker/docker-agent/pkg/rag/strategy" "github.com/docker/docker-agent/pkg/rag/types" @@ -135,9 +136,24 @@ func buildManagerConfig( Results: results, FusionConfig: fusionCfg, StrategyConfigs: strategyConfigs, + PrefetchConfig: buildPrefetchConfig(ragCfg.Results.Prefetch), }, nil } +func buildPrefetchConfig(cfg *latest.RAGPrefetchConfig) prefetch.Config { + if cfg == nil { + return prefetch.Config{} + } + return prefetch.Config{ + Enabled: cfg.Enabled, + MaxEntries: cfg.MaxEntries, + MaxCandidates: cfg.MaxCandidates, + MinSimilarity: cfg.MinSimilarity, + DriftThreshold: cfg.DriftThreshold, + Timeout: cfg.Timeout.Duration, + } +} + // buildRerankingConfig constructs a RerankingConfig from the configuration. func buildRerankingConfig( ctx context.Context, diff --git a/pkg/rag/manager.go b/pkg/rag/manager.go index bbd918fb0..99d50e7d1 100644 --- a/pkg/rag/manager.go +++ b/pkg/rag/manager.go @@ -15,6 +15,7 @@ import ( "github.com/docker/docker-agent/pkg/modelerrors" "github.com/docker/docker-agent/pkg/rag/database" "github.com/docker/docker-agent/pkg/rag/fusion" + "github.com/docker/docker-agent/pkg/rag/prefetch" "github.com/docker/docker-agent/pkg/rag/rerank" "github.com/docker/docker-agent/pkg/rag/strategy" "github.com/docker/docker-agent/pkg/rag/types" @@ -35,6 +36,7 @@ type Config struct { Results ResultsConfig FusionConfig *FusionConfig StrategyConfigs []strategy.Config + PrefetchConfig prefetch.Config } // ResultsConfig captures result-postprocessing behavior for the manager. @@ -64,6 +66,7 @@ type Manager struct { reranker rerank.Reranker // Optional reranker for result re-scoring rerankDisabled atomic.Bool // Set after a non-retryable reranking error to stop doomed requests events <-chan types.Event // Shared event channel from strategies and other RAG operations + prefetcher *prefetch.Prefetcher } // FusionConfig holds configuration for result fusion @@ -138,6 +141,7 @@ func New(_ context.Context, name string, config Config, strategyEvents <-chan ty fusion: fusionStrategy, reranker: reranker, events: strategyEvents, + prefetcher: prefetch.New(config.PrefetchConfig), } return m, nil @@ -220,6 +224,41 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes "num_strategies", len(m.strategies), "query_length", len(query)) + if cached, ok := m.prefetcher.Get(query); ok { + slog.DebugContext(ctx, "[RAG Manager] Returning prefetched RAG results", + "rag_name", m.name, + "result_count", len(cached)) + return cached, nil + } + + results, err := m.queryUncached(ctx, query) + if err != nil { + return nil, err + } + + if ctx.Err() == nil { + m.prefetcher.Store(query, results) + m.prefetcher.Observe(query, results) + } + if ctx.Err() == nil && m.reranker == nil { + for _, candidate := range m.prefetcher.Candidates(query, results) { + m.prefetcher.Prefetch(ctx, candidate, m.queryUncached) + } + } + + return results, nil +} + +func (m *Manager) queryUncached(ctx context.Context, query string) ([]database.SearchResult, error) { + results, err := m.queryStrategies(ctx, query) + if err != nil { + return nil, err + } + + return m.postprocessQueryResults(ctx, query, results), nil +} + +func (m *Manager) queryStrategies(ctx context.Context, query string) ([]database.SearchResult, error) { // Single retrieval strategy if len(m.strategies) == 1 { for strategyName, strategyImpl := range m.strategies { @@ -245,31 +284,6 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes "strategy", strategyName, "num_results", len(results)) - // Apply reranking if configured - results = m.rerank(ctx, query, results) - - if limit := m.config.Results.Limit; limit > 0 && len(results) > limit { - slog.DebugContext(ctx, "[RAG Manager] Truncating to global result limit", - "rag_name", m.name, - "strategy", strategyName, - "before", len(results), - "after", limit) - results = results[:limit] - } - - // Reconstruct full documents if configured - if m.config.Results.ReturnFullContent { - results = m.reconstructFullDocuments(ctx, results) - } - - if m.config.Results.Deduplicate { - results = m.deduplicateResults(results) - slog.DebugContext(ctx, "[RAG Manager] Deduplicated single-strategy results", - "rag_name", m.name, - "strategy", strategyName, - "num_results", len(results)) - } - return results, nil } } @@ -352,37 +366,41 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes "fused_results", len(fusedResults), "result_limit", m.config.Results.Limit) + return fusedResults, nil +} + +func (m *Manager) postprocessQueryResults(ctx context.Context, query string, results []database.SearchResult) []database.SearchResult { // Apply reranking if configured (before limit and deduplication) - fusedResults = m.rerank(ctx, query, fusedResults) + results = m.rerank(ctx, query, results) // Apply result limit if configured - if limit := m.config.Results.Limit; limit > 0 && len(fusedResults) > limit { + if limit := m.config.Results.Limit; limit > 0 && len(results) > limit { slog.DebugContext(ctx, "[RAG Manager] Truncating to result limit", "rag_name", m.name, - "before", len(fusedResults), + "before", len(results), "after", limit) - fusedResults = fusedResults[:limit] + results = results[:limit] } // Reconstruct full documents if configured if m.config.Results.ReturnFullContent { - fusedResults = m.reconstructFullDocuments(ctx, fusedResults) + results = m.reconstructFullDocuments(ctx, results) } // Optionally deduplicate based on the final content that will be returned // (full documents or chunks). if m.config.Results.Deduplicate { - fusedResults = m.deduplicateResults(fusedResults) + results = m.deduplicateResults(results) slog.DebugContext(ctx, "[RAG Manager] Deduplicated fused results", "rag_name", m.name, - "num_results", len(fusedResults)) + "num_results", len(results)) } // TODO: Track and emit query embedding usage // For queries during agent execution, usage should be added to agent's session // This requires passing session context through the RAG tool - return fusedResults, nil + return results } // Helper to get strategy names for logging @@ -437,6 +455,7 @@ func (m *Manager) CheckAndReindexChangedFiles(ctx context.Context) error { return fmt.Errorf("strategy %s failed: %w", strategyName, err) } } + m.prefetcher.Clear() return nil } diff --git a/pkg/rag/manager_test.go b/pkg/rag/manager_test.go index 2ebadd1df..74ef01bcd 100644 --- a/pkg/rag/manager_test.go +++ b/pkg/rag/manager_test.go @@ -1,12 +1,18 @@ package rag import ( + "context" "os" "path/filepath" + "sync/atomic" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/rag/database" + "github.com/docker/docker-agent/pkg/rag/prefetch" + "github.com/docker/docker-agent/pkg/rag/strategy" ) func TestGetAbsolutePaths_WithBasePath(t *testing.T) { @@ -31,3 +37,57 @@ func TestGetAbsolutePaths_NilInput(t *testing.T) { result := GetAbsolutePaths("/base", nil) assert.Nil(t, result) } + +type countingStrategy struct { + calls atomic.Int64 + results []database.SearchResult +} + +func (s *countingStrategy) Initialize(context.Context, []string, strategy.ChunkingConfig) error { + return nil +} + +func (s *countingStrategy) Query(context.Context, string, int, float64) ([]database.SearchResult, error) { + s.calls.Add(1) + return append([]database.SearchResult(nil), s.results...), nil +} + +func (s *countingStrategy) CheckAndReindexChangedFiles(context.Context, []string, strategy.ChunkingConfig) error { + return nil +} + +func (s *countingStrategy) StartFileWatcher(context.Context, []string, strategy.ChunkingConfig) error { + return nil +} + +func (s *countingStrategy) Close() error { return nil } + +func TestQueryUsesPrefetchCacheForRepeatedQuery(t *testing.T) { + strat := &countingStrategy{results: []database.SearchResult{{ + Document: database.Document{ID: "1", SourcePath: "docs/rag.md", Content: "doc one"}, + Similarity: 0.9, + }}} + m, err := New(t.Context(), "test", Config{ + Results: ResultsConfig{Limit: 15}, + PrefetchConfig: prefetch.Config{Enabled: true, MaxEntries: 4}, + StrategyConfigs: []strategy.Config{{ + Name: "counting", + Strategy: strat, + Limit: 5, + Threshold: 0.5, + }}, + }, nil) + require.NoError(t, err) + + first, err := m.Query(t.Context(), "RAG cache") + require.NoError(t, err) + require.Len(t, first, 1) + first[0].Document.Content = "caller mutation" + + second, err := m.Query(t.Context(), " rag cache ") + require.NoError(t, err) + require.Len(t, second, 1) + + assert.Equal(t, int64(1), strat.calls.Load()) + assert.Equal(t, "doc one", second[0].Document.Content) +} diff --git a/pkg/rag/prefetch/prefetch.go b/pkg/rag/prefetch/prefetch.go new file mode 100644 index 000000000..e2250eca5 --- /dev/null +++ b/pkg/rag/prefetch/prefetch.go @@ -0,0 +1,300 @@ +package prefetch + +import ( + "context" + "path/filepath" + "slices" + "strings" + "sync" + "time" + + "github.com/docker/docker-agent/pkg/rag/database" +) + +const ( + defaultMaxEntries = 32 + defaultMaxCandidates = 2 + defaultMinSimilarity = 0.5 + defaultDriftThreshold = 0.8 + defaultTimeout = 10 * time.Second +) + +// Config controls adaptive RAG prefetching. The zero value disables it. +type Config struct { + Enabled bool + MaxEntries int + MaxCandidates int + MinSimilarity float64 + DriftThreshold float64 + Timeout time.Duration +} + +func (c Config) withDefaults() Config { + if c.MaxEntries <= 0 { + c.MaxEntries = defaultMaxEntries + } + if c.MaxCandidates <= 0 { + c.MaxCandidates = defaultMaxCandidates + } + if c.MinSimilarity <= 0 { + c.MinSimilarity = defaultMinSimilarity + } + if c.DriftThreshold <= 0 { + c.DriftThreshold = defaultDriftThreshold + } + if c.Timeout <= 0 { + c.Timeout = defaultTimeout + } + return c +} + +// FetchFunc runs an uncached query for a prefetch candidate. +type FetchFunc func(context.Context, string) ([]database.SearchResult, error) + +// Prefetcher owns bounded result cache, lightweight topology state, and +// background prefetch scheduling for one RAG manager. +type Prefetcher struct { + cfg Config + + mu sync.Mutex + cache map[string][]database.SearchResult + order []string + inflight map[string]struct{} + tracker tracker +} + +// New creates a prefetcher. It returns nil when disabled so callers can keep +// the hot path branch small. +func New(cfg Config) *Prefetcher { + if !cfg.Enabled { + return nil + } + cfg = cfg.withDefaults() + return &Prefetcher{ + cfg: cfg, + cache: make(map[string][]database.SearchResult, cfg.MaxEntries), + inflight: make(map[string]struct{}), + } +} + +// Get returns cached final results for query. +func (p *Prefetcher) Get(query string) ([]database.SearchResult, bool) { + if p == nil { + return nil, false + } + key := normalize(query) + if key == "" { + return nil, false + } + + p.mu.Lock() + defer p.mu.Unlock() + results, ok := p.cache[key] + if !ok { + return nil, false + } + return cloneResults(results), true +} + +// Store records final post-processed results for query. +func (p *Prefetcher) Store(query string, results []database.SearchResult) { + if p == nil || len(results) == 0 { + return + } + key := normalize(query) + if key == "" { + return + } + + p.mu.Lock() + defer p.mu.Unlock() + p.storeLocked(key, results) +} + +func (p *Prefetcher) storeLocked(key string, results []database.SearchResult) { + if _, exists := p.cache[key]; !exists { + p.order = append(p.order, key) + } + p.cache[key] = cloneResults(results) + + for len(p.order) > p.cfg.MaxEntries { + oldest := p.order[0] + p.order = p.order[1:] + delete(p.cache, oldest) + } +} + +// Clear drops cached and in-flight query state after index changes. +func (p *Prefetcher) Clear() { + if p == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + clear(p.cache) + clear(p.inflight) + p.order = nil +} + +// Observe updates the topology tracker with query/result metadata. +func (p *Prefetcher) Observe(query string, results []database.SearchResult) { + if p == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + p.tracker.observe(query, results) +} + +// Candidates returns deterministic follow-up queries worth warming. +func (p *Prefetcher) Candidates(query string, results []database.SearchResult) []string { + if p == nil || len(results) == 0 { + return nil + } + + p.mu.Lock() + stable := p.tracker.stable(p.cfg.DriftThreshold) + p.mu.Unlock() + if !stable { + return nil + } + + base := normalize(query) + seen := map[string]struct{}{base: {}} + candidates := make([]string, 0, p.cfg.MaxCandidates) + + for _, result := range results { + if len(candidates) >= p.cfg.MaxCandidates { + break + } + if result.Similarity < p.cfg.MinSimilarity { + continue + } + name := sourceToken(result.Document.SourcePath) + if name == "" { + continue + } + candidate := strings.TrimSpace(base + " " + name) + if candidate == "" { + continue + } + if _, ok := seen[candidate]; ok { + continue + } + seen[candidate] = struct{}{} + candidates = append(candidates, candidate) + } + + return candidates +} + +// Prefetch schedules one bounded background fetch for query. +func (p *Prefetcher) Prefetch(ctx context.Context, query string, fetch FetchFunc) { + if p == nil || fetch == nil { + return + } + key := normalize(query) + if key == "" { + return + } + + p.mu.Lock() + if _, ok := p.cache[key]; ok { + p.mu.Unlock() + return + } + if _, ok := p.inflight[key]; ok { + p.mu.Unlock() + return + } + p.inflight[key] = struct{}{} + p.mu.Unlock() + + go func() { + defer func() { + p.mu.Lock() + delete(p.inflight, key) + p.mu.Unlock() + }() + + prefetchCtx, cancel := context.WithTimeout(ctx, p.cfg.Timeout) + defer cancel() + + results, err := fetch(prefetchCtx, key) + if err != nil || len(results) == 0 { + return + } + + p.mu.Lock() + p.storeLocked(key, results) + p.mu.Unlock() + }() +} + +func normalize(query string) string { + return strings.Join(strings.Fields(strings.ToLower(query)), " ") +} + +func sourceToken(path string) string { + base := filepath.Base(path) + if base == "." || base == string(filepath.Separator) { + return "" + } + ext := filepath.Ext(base) + return strings.TrimSuffix(base, ext) +} + +func cloneResults(results []database.SearchResult) []database.SearchResult { + return slices.Clone(results) +} + +type tracker struct { + seen int + centroid [4]float64 + drift float64 +} + +func (t *tracker) observe(query string, results []database.SearchResult) { + point := pointFor(query, results) + t.seen++ + if t.seen == 1 { + t.centroid = point + t.drift = 0 + return + } + + var dist float64 + for i := range point { + diff := point[i] - t.centroid[i] + dist += diff * diff + } + t.drift = dist + + weight := 1 / float64(t.seen) + for i := range t.centroid { + t.centroid[i] += (point[i] - t.centroid[i]) * weight + } +} + +func (t *tracker) stable(threshold float64) bool { + if t.seen < 2 { + return true + } + return t.drift <= threshold +} + +func pointFor(query string, results []database.SearchResult) [4]float64 { + var avgSimilarity float64 + for _, result := range results { + avgSimilarity += result.Similarity + } + if len(results) > 0 { + avgSimilarity /= float64(len(results)) + } + return [4]float64{ + float64(len(query)) / 256, + float64(len(strings.Fields(query))) / 32, + float64(len(results)) / 32, + avgSimilarity, + } +} diff --git a/pkg/rag/prefetch/prefetch_test.go b/pkg/rag/prefetch/prefetch_test.go new file mode 100644 index 000000000..f25c4e76a --- /dev/null +++ b/pkg/rag/prefetch/prefetch_test.go @@ -0,0 +1,93 @@ +package prefetch + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/rag/database" +) + +func TestDisabledPrefetcher(t *testing.T) { + assert.Nil(t, New(Config{})) +} + +func TestStoreGetAndEvictOldest(t *testing.T) { + p := New(Config{Enabled: true, MaxEntries: 2}) + + p.Store("Alpha Query", result("a.go", 0.9)) + p.Store("Beta Query", result("b.go", 0.9)) + p.Store("Gamma Query", result("c.go", 0.9)) + + _, ok := p.Get("alpha query") + assert.False(t, ok) + _, ok = p.Get("beta query") + assert.True(t, ok) + got, ok := p.Get("GAMMA query") + require.True(t, ok) + assert.Equal(t, "c.go", got[0].Document.SourcePath) +} + +func TestCandidatesUseStableTopologyAndSourceNames(t *testing.T) { + p := New(Config{Enabled: true, MaxCandidates: 2, MinSimilarity: 0.5, DriftThreshold: 1}) + results := []database.SearchResult{ + result("pkg/rag/manager.go", 0.9)[0], + result("pkg/rag/vector_store.go", 0.7)[0], + result("pkg/rag/weak.go", 0.1)[0], + } + + p.Observe("RAG manager", results) + p.Observe("RAG manager cache", results) + + assert.Equal(t, []string{"rag manager manager", "rag manager vector_store"}, p.Candidates("RAG manager", results)) +} + +func TestCandidatesSuppressedWhenDrifting(t *testing.T) { + p := New(Config{Enabled: true, DriftThreshold: 0.0001}) + results := result("pkg/rag/manager.go", 0.9) + + p.Observe("short", results) + p.Observe("this is a completely different and much longer query with more tokens", results) + + assert.Empty(t, p.Candidates("short", results)) +} + +func TestPrefetchDeduplicatesInFlightAndStoresResult(t *testing.T) { + p := New(Config{Enabled: true, Timeout: time.Second}) + var calls atomic.Int64 + done := make(chan struct{}) + + fetch := func(context.Context, string) ([]database.SearchResult, error) { + calls.Add(1) + close(done) + return result("pkg/rag/manager.go", 0.9), nil + } + + p.Prefetch(t.Context(), "RAG manager", fetch) + p.Prefetch(t.Context(), "rag manager", fetch) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("prefetch did not run") + } + require.Eventually(t, func() bool { + _, ok := p.Get("rag manager") + return ok + }, time.Second, 10*time.Millisecond) + assert.Equal(t, int64(1), calls.Load()) +} + +func result(path string, similarity float64) []database.SearchResult { + return []database.SearchResult{{ + Document: database.Document{ + SourcePath: path, + Content: "content", + }, + Similarity: similarity, + }} +} From 53f09fff758826847d34fdb0f25821d4698bf23a Mon Sep 17 00:00:00 2001 From: Teerth Sharma Date: Fri, 19 Jun 2026 16:12:51 +0530 Subject: [PATCH 2/4] fix: resolve adaptive rag lint failures Signed-off-by: Teerth Sharma --- pkg/config/latest/types.go | 2 +- pkg/rag/manager_test.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index c1ed336bc..ec07b8036 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -1826,7 +1826,7 @@ type RAGPrefetchConfig struct { MaxCandidates int `json:"max_candidates,omitempty"` MinSimilarity float64 `json:"min_similarity,omitempty"` DriftThreshold float64 `json:"drift_threshold,omitempty"` - Timeout Duration `json:"timeout,omitempty"` + Timeout Duration `json:"timeout,omitzero"` } // RAGRerankingConfig represents reranking configuration diff --git a/pkg/rag/manager_test.go b/pkg/rag/manager_test.go index 74ef01bcd..1fd500d49 100644 --- a/pkg/rag/manager_test.go +++ b/pkg/rag/manager_test.go @@ -63,7 +63,7 @@ func (s *countingStrategy) StartFileWatcher(context.Context, []string, strategy. func (s *countingStrategy) Close() error { return nil } func TestQueryUsesPrefetchCacheForRepeatedQuery(t *testing.T) { - strat := &countingStrategy{results: []database.SearchResult{{ + searchStrategy := &countingStrategy{results: []database.SearchResult{{ Document: database.Document{ID: "1", SourcePath: "docs/rag.md", Content: "doc one"}, Similarity: 0.9, }}} @@ -72,7 +72,7 @@ func TestQueryUsesPrefetchCacheForRepeatedQuery(t *testing.T) { PrefetchConfig: prefetch.Config{Enabled: true, MaxEntries: 4}, StrategyConfigs: []strategy.Config{{ Name: "counting", - Strategy: strat, + Strategy: searchStrategy, Limit: 5, Threshold: 0.5, }}, @@ -88,6 +88,6 @@ func TestQueryUsesPrefetchCacheForRepeatedQuery(t *testing.T) { require.NoError(t, err) require.Len(t, second, 1) - assert.Equal(t, int64(1), strat.calls.Load()) + assert.Equal(t, int64(1), searchStrategy.calls.Load()) assert.Equal(t, "doc one", second[0].Document.Content) } From b5315bb8424d0a1d67ccb9b2130798040c763409 Mon Sep 17 00:00:00 2001 From: Teerth Sharma Date: Fri, 19 Jun 2026 18:41:16 +0530 Subject: [PATCH 3/4] feat: prove topology-assisted rag prefetch Signed-off-by: Teerth Sharma --- .../2026-06-18-adaptive-rag-prefetcher.md | 88 ------- ...26-06-18-adaptive-rag-prefetcher-design.md | 79 ------- docs/tools/rag/index.md | 4 +- pkg/rag/manager.go | 38 ++- pkg/rag/manager_test.go | 27 +++ pkg/rag/prefetch/prefetch.go | 120 +++++++++- pkg/rag/prefetch/prefetch_test.go | 216 +++++++++++++++++- pkg/rag/prefetch/proofs/TopologyHit.lean | 32 +++ 8 files changed, 419 insertions(+), 185 deletions(-) delete mode 100644 docs/superpowers/plans/2026-06-18-adaptive-rag-prefetcher.md delete mode 100644 docs/superpowers/specs/2026-06-18-adaptive-rag-prefetcher-design.md create mode 100644 pkg/rag/prefetch/proofs/TopologyHit.lean diff --git a/docs/superpowers/plans/2026-06-18-adaptive-rag-prefetcher.md b/docs/superpowers/plans/2026-06-18-adaptive-rag-prefetcher.md deleted file mode 100644 index 507d18ee0..000000000 --- a/docs/superpowers/plans/2026-06-18-adaptive-rag-prefetcher.md +++ /dev/null @@ -1,88 +0,0 @@ -# Adaptive RAG Prefetcher Implementation Plan - -> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Add an opt-in adaptive RAG prefetcher that caches exact repeat queries and warms deterministic follow-up candidates when recent RAG topology is stable. - -**Architecture:** Implement a small `pkg/rag/prefetch` package with bounded cache, topology tracker, and background scheduler. Wire it into `pkg/rag.Manager` as an optional layer around the existing query pipeline, leaving strategy implementations unchanged. - -**Tech Stack:** Go, existing RAG strategy interfaces, existing config/latest schema, existing docs/examples. - -## Global Constraints - -- Do not add an Aether-Lang dependency. -- Only change latest config, not frozen config versions. -- Feature is opt-in with `results.prefetch.enabled`. -- Background prefetch must be bounded, cancellable, and non-blocking. -- Commits must use DCO sign-off via `git commit -s`. -- Validate with `task lint` and `task test` before PR. - ---- - -### Task 1: Prefetch Package - -**Files:** -- Create: `pkg/rag/prefetch/prefetch.go` -- Create: `pkg/rag/prefetch/prefetch_test.go` - -**Interfaces:** -- Produces: `Config`, `Prefetcher`, `New(Config) *Prefetcher` -- Produces: `Get(query string) ([]database.SearchResult, bool)`, `Store(query string, results []database.SearchResult)`, `Observe(query string, results []database.SearchResult)`, `Candidates(query string, results []database.SearchResult) []string`, `Prefetch(ctx context.Context, query string, fn FetchFunc)` - -- [ ] Write tests for disabled config, bounded cache eviction, stable candidate generation, and drift suppression. -- [ ] Implement config defaults and normalization. -- [ ] Implement topology tracker using query length, term count, result count, and average similarity. -- [ ] Implement candidate generation from query text and top source path basenames. -- [ ] Implement bounded background prefetch with timeout and in-flight de-duplication. -- [ ] Run `go test ./pkg/rag/prefetch`. -- [ ] Commit with `git commit -s -m "feat: add rag prefetch primitives"`. - -### Task 2: Manager Integration - -**Files:** -- Modify: `pkg/rag/manager.go` -- Modify: `pkg/rag/builder.go` -- Modify: `pkg/rag/manager_test.go` - -**Interfaces:** -- Consumes: `prefetch.Config`, `prefetch.Prefetcher` -- Produces: optional manager-level prefetch behavior for `Manager.Query` - -- [ ] Add prefetch config to `rag.Config` and create the prefetcher in `New`. -- [ ] Extract current query logic into an unexported `queryUncached(ctx, query string)` helper. -- [ ] Make `Query` check exact cache hits first when enabled. -- [ ] Store successful final results and schedule background candidate prefetches after cache misses. -- [ ] Add integration tests using a fake strategy to prove a second identical query is served from cache. -- [ ] Run `go test ./pkg/rag`. -- [ ] Commit with `git commit -s -m "feat: wire adaptive prefetch into rag manager"`. - -### Task 3: Config, Schema, Docs, Example - -**Files:** -- Modify: `pkg/config/latest/types.go` -- Modify: `agent-schema.json` -- Modify: `docs/tools/rag/index.md` -- Create: `examples/rag/adaptive_prefetch.yaml` - -**Interfaces:** -- Produces: `latest.RAGPrefetchConfig` -- Wires: `latest.RAGResultsConfig.Prefetch *RAGPrefetchConfig` - -- [ ] Add `RAGPrefetchConfig` with `enabled`, `max_entries`, `max_candidates`, `min_similarity`, `drift_threshold`, and `timeout`. -- [ ] Add schema definition/properties matching Go JSON tags. -- [ ] Document the feature and its conservative defaults in RAG docs. -- [ ] Add a runnable example using hybrid RAG plus `results.prefetch`. -- [ ] Run `go test ./pkg/config`. -- [ ] Commit with `git commit -s -m "docs: document adaptive rag prefetching"`. - -### Task 4: Final Validation and PR - -**Files:** -- Modify as needed from validation findings only. - -- [ ] Run `task lint`. -- [ ] Run `task test`. -- [ ] Run `task build`. -- [ ] Inspect `git diff --stat main...HEAD`. -- [ ] Push branch to fork remote. -- [ ] Open draft PR against `docker/docker-agent:main` linking issue `#3164`. diff --git a/docs/superpowers/specs/2026-06-18-adaptive-rag-prefetcher-design.md b/docs/superpowers/specs/2026-06-18-adaptive-rag-prefetcher-design.md deleted file mode 100644 index beb4705d0..000000000 --- a/docs/superpowers/specs/2026-06-18-adaptive-rag-prefetcher-design.md +++ /dev/null @@ -1,79 +0,0 @@ -# Adaptive RAG Prefetcher Design - -## Context - -Issue: https://github.com/docker/docker-agent/issues/3164 - -Aether-Lang contains useful algorithms for sparse attention graphs, hierarchical block metadata, adaptive epsilon thresholds, and centroid drift detection. docker-agent should not import Aether-Lang or add Rust/runtime dependencies. The contribution should translate the useful ideas into small Go primitives that fit the existing RAG manager and strategy interfaces. - -## Goal - -Add an opt-in adaptive RAG prefetcher that reduces repeated retrieval latency and warms likely follow-up queries without blocking the active user turn. - -## Non-Goals - -- No Aether-Lang dependency. -- No new DSL, kernel, or model provider. -- No replacement of existing RAG strategies, fusion, or reranking. -- No hidden behavior when config does not enable the feature. - -## Design - -The RAG manager gets an optional prefetcher configured under `results.prefetch`. When disabled, `Manager.Query` keeps current behavior. - -When enabled: - -1. `Manager.Query` checks a small in-memory cache keyed by normalized query text. -2. Cache hit returns the final post-processed results from the earlier query. -3. Cache miss runs the normal strategy/fusion/reranking pipeline. -4. The prefetcher observes the query and result metadata. -5. If topology is stable enough, it schedules a bounded background prefetch for one or more deterministic follow-up candidates. - -The prefetcher computes a lightweight state vector from query/result metadata: - -- query length -- token count -- result count -- average similarity - -It tracks a centroid and drift score over recent observations. A simple adaptive threshold accepts stable sessions and suppresses prefetch during strong topic shifts. - -## Configuration - -Add to the latest config only: - -```yaml -results: - prefetch: - enabled: true - max_entries: 32 - max_candidates: 2 - min_similarity: 0.5 - drift_threshold: 0.8 - timeout: 10s -``` - -Defaults are conservative. `enabled` defaults false. - -## Components - -- `pkg/rag/prefetch`: owns cache, topology tracker, candidate generation, and background scheduling. -- `pkg/rag/manager.go`: wires cache lookup, observation, and background scheduling around existing query behavior. -- `pkg/config/latest/types.go`: latest-only config structs. -- `agent-schema.json`: schema sync for config fields. -- `docs/tools/rag/index.md`: user-facing docs. -- `examples/rag/adaptive_prefetch.yaml`: runnable example config. - -## Safety - -- Background prefetch uses a timeout and a max in-flight limit. -- Prefetch errors are logged at debug level and never fail the active query. -- Results cache is bounded and evicts oldest entries. -- Query work is skipped while drift exceeds threshold. -- Prefetch stores only normal RAG search results already available to the running process. - -## Testing - -- Unit tests for cache eviction, candidate generation, drift suppression, and scheduling. -- Manager integration test for cache hit avoiding a second strategy call. -- Config/schema test coverage through existing schema sync tests. diff --git a/docs/tools/rag/index.md b/docs/tools/rag/index.md index 3c524fc7b..d96c894bd 100644 --- a/docs/tools/rag/index.md +++ b/docs/tools/rag/index.md @@ -159,7 +159,7 @@ Supported reranking providers: **DMR** (native `/rerank` endpoint), **OpenAI**, ## Adaptive Prefetching -Adaptive prefetching is opt-in. It caches repeated RAG queries and, when query topology is stable, warms a small number of deterministic follow-up candidates in the background. +Adaptive prefetching is opt-in. It caches repeated RAG queries and, when query topology is stable, can reuse warmed results for closely related follow-up queries that share source-derived anchors. The topology path is gated by query-token overlap, source-path anchors, and drift detection so unrelated topic shifts fall back to normal retrieval. ```yaml results: @@ -172,7 +172,7 @@ results: timeout: 10s ``` -The prefetcher is bounded and non-blocking. Exact repeated queries can be served from cache. Background candidate prefetch is skipped when reranking is enabled so reranker errors and fallback behavior stay tied to the foreground query. +The prefetcher is bounded and non-blocking. Exact repeated queries are served from cache first; related-query topology hits are considered only after an exact miss. Background candidate prefetch is skipped when reranking is enabled so reranker errors and fallback behavior stay tied to the foreground query. The deterministic replay benchmark in `pkg/rag/prefetch` reports exact-repeat and topology-assisted hit rates separately. ## Code-Aware Chunking diff --git a/pkg/rag/manager.go b/pkg/rag/manager.go index 99d50e7d1..6fb364970 100644 --- a/pkg/rag/manager.go +++ b/pkg/rag/manager.go @@ -79,7 +79,7 @@ type FusionConfig struct { // New creates a new RAG manager with one or more strategies. // Pass multiple strategy configs to enable hybrid retrieval. // The strategyEvents channel should be shared across all strategies for this manager. -func New(_ context.Context, name string, config Config, strategyEvents <-chan types.Event) (*Manager, error) { +func New(ctx context.Context, name string, config Config, strategyEvents <-chan types.Event) (*Manager, error) { if len(config.StrategyConfigs) == 0 { return nil, errors.New("at least one strategy required") } @@ -127,12 +127,13 @@ func New(_ context.Context, name string, config Config, strategyEvents <-chan ty var reranker rerank.Reranker if config.Results.RerankingConfig != nil { reranker = config.Results.RerankingConfig.Reranker - slog.Debug("[RAG Manager] Reranking enabled", + slog.DebugContext(ctx, "[RAG Manager] Reranking enabled", "rag_name", name, "top_k", config.Results.RerankingConfig.TopK, "threshold", config.Results.RerankingConfig.Threshold) } + prefetcher := prefetch.New(config.PrefetchConfig) m := &Manager{ name: name, config: config, @@ -140,13 +141,42 @@ func New(_ context.Context, name string, config Config, strategyEvents <-chan ty strategyConfigs: strategyConfigMap, fusion: fusionStrategy, reranker: reranker, - events: strategyEvents, - prefetcher: prefetch.New(config.PrefetchConfig), + events: forwardEvents(ctx, strategyEvents, prefetcher), + prefetcher: prefetcher, } return m, nil } +func forwardEvents(ctx context.Context, in <-chan types.Event, prefetcher *prefetch.Prefetcher) <-chan types.Event { + if in == nil { + return nil + } + out := make(chan types.Event, 500) + go func() { + defer close(out) + for { + select { + case <-ctx.Done(): + return + case event, ok := <-in: + if !ok { + return + } + if event.Type == types.EventTypeIndexingComplete { + prefetcher.Clear() + } + select { + case out <- event: + default: + slog.WarnContext(ctx, "RAG manager event channel full, dropping event", "event_type", event.Type) + } + } + } + }() + return out +} + // Initialize indexes all documents using all configured strategies // Each strategy indexes its own document set (shared + strategy-specific) // Strategies are initialized in parallel for better performance diff --git a/pkg/rag/manager_test.go b/pkg/rag/manager_test.go index 1fd500d49..efa23b407 100644 --- a/pkg/rag/manager_test.go +++ b/pkg/rag/manager_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "sync/atomic" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -13,6 +14,7 @@ import ( "github.com/docker/docker-agent/pkg/rag/database" "github.com/docker/docker-agent/pkg/rag/prefetch" "github.com/docker/docker-agent/pkg/rag/strategy" + "github.com/docker/docker-agent/pkg/rag/types" ) func TestGetAbsolutePaths_WithBasePath(t *testing.T) { @@ -91,3 +93,28 @@ func TestQueryUsesPrefetchCacheForRepeatedQuery(t *testing.T) { assert.Equal(t, int64(1), searchStrategy.calls.Load()) assert.Equal(t, "doc one", second[0].Document.Content) } + +func TestManagerClearsPrefetchCacheOnIndexingCompleteEvent(t *testing.T) { + events := make(chan types.Event, 1) + m, err := New(t.Context(), "test", Config{ + Results: ResultsConfig{Limit: 15}, + PrefetchConfig: prefetch.Config{Enabled: true, MaxEntries: 4}, + StrategyConfigs: []strategy.Config{{ + Name: "counting", + Strategy: &countingStrategy{}, + }}, + }, events) + require.NoError(t, err) + + m.prefetcher.Store("RAG cache", []database.SearchResult{{ + Document: database.Document{ID: "1", SourcePath: "docs/rag.md", Content: "stale"}, + Similarity: 0.9, + }}) + + events <- types.Event{Type: types.EventTypeIndexingComplete} + + require.Eventually(t, func() bool { + _, ok := m.prefetcher.Get("RAG cache") + return !ok + }, time.Second, 10*time.Millisecond) +} diff --git a/pkg/rag/prefetch/prefetch.go b/pkg/rag/prefetch/prefetch.go index e2250eca5..33d470c1f 100644 --- a/pkg/rag/prefetch/prefetch.go +++ b/pkg/rag/prefetch/prefetch.go @@ -57,12 +57,17 @@ type Prefetcher struct { cfg Config mu sync.Mutex - cache map[string][]database.SearchResult + cache map[string]cacheEntry order []string inflight map[string]struct{} tracker tracker } +type cacheEntry struct { + results []database.SearchResult + anchors []string +} + // New creates a prefetcher. It returns nil when disabled so callers can keep // the hot path branch small. func New(cfg Config) *Prefetcher { @@ -72,7 +77,7 @@ func New(cfg Config) *Prefetcher { cfg = cfg.withDefaults() return &Prefetcher{ cfg: cfg, - cache: make(map[string][]database.SearchResult, cfg.MaxEntries), + cache: make(map[string]cacheEntry, cfg.MaxEntries), inflight: make(map[string]struct{}), } } @@ -91,9 +96,9 @@ func (p *Prefetcher) Get(query string) ([]database.SearchResult, bool) { defer p.mu.Unlock() results, ok := p.cache[key] if !ok { - return nil, false + return p.getTopologyLocked(key) } - return cloneResults(results), true + return cloneResults(results.results), true } // Store records final post-processed results for query. @@ -115,7 +120,10 @@ func (p *Prefetcher) storeLocked(key string, results []database.SearchResult) { if _, exists := p.cache[key]; !exists { p.order = append(p.order, key) } - p.cache[key] = cloneResults(results) + p.cache[key] = cacheEntry{ + results: cloneResults(results), + anchors: anchorsFor(results), + } for len(p.order) > p.cfg.MaxEntries { oldest := p.order[0] @@ -136,6 +144,21 @@ func (p *Prefetcher) Clear() { p.order = nil } +func (p *Prefetcher) getTopologyLocked(key string) ([]database.SearchResult, bool) { + if !p.tracker.stable(p.cfg.DriftThreshold) { + return nil, false + } + queryTokens := tokenSet(key) + for _, entryKey := range slices.Backward(p.order) { + entry := p.cache[entryKey] + if !topologyRelated(entryKey, queryTokens, entry.anchors) { + continue + } + return cloneResults(entry.results), true + } + return nil, false +} + // Observe updates the topology tracker with query/result metadata. func (p *Prefetcher) Observe(query string, results []database.SearchResult) { if p == nil { @@ -160,6 +183,7 @@ func (p *Prefetcher) Candidates(query string, results []database.SearchResult) [ } base := normalize(query) + baseTokens := tokenSet(base) seen := map[string]struct{}{base: {}} candidates := make([]string, 0, p.cfg.MaxCandidates) @@ -170,11 +194,11 @@ func (p *Prefetcher) Candidates(query string, results []database.SearchResult) [ if result.Similarity < p.cfg.MinSimilarity { continue } - name := sourceToken(result.Document.SourcePath) - if name == "" { + suffix := sourceSuffix(result.Document.SourcePath, baseTokens) + if suffix == "" { continue } - candidate := strings.TrimSpace(base + " " + name) + candidate := strings.TrimSpace(base + " " + suffix) if candidate == "" { continue } @@ -217,7 +241,7 @@ func (p *Prefetcher) Prefetch(ctx context.Context, query string, fetch FetchFunc p.mu.Unlock() }() - prefetchCtx, cancel := context.WithTimeout(ctx, p.cfg.Timeout) + prefetchCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), p.cfg.Timeout) defer cancel() results, err := fetch(prefetchCtx, key) @@ -244,6 +268,81 @@ func sourceToken(path string) string { return strings.TrimSuffix(base, ext) } +func sourceSuffix(path string, seen map[string]struct{}) string { + parts := strings.FieldsFunc(sourceToken(path), isTokenSeparator) + novel := make([]string, 0, len(parts)) + for _, part := range parts { + part = normalize(part) + if part == "" { + continue + } + if _, ok := seen[part]; ok { + continue + } + novel = append(novel, part) + } + return strings.Join(novel, " ") +} + +func anchorsFor(results []database.SearchResult) []string { + seen := map[string]struct{}{} + anchors := make([]string, 0, len(results)) + for _, result := range results { + for _, part := range strings.FieldsFunc(sourceToken(result.Document.SourcePath), isTokenSeparator) { + part = normalize(part) + if len(part) < 3 { + continue + } + if _, ok := seen[part]; ok { + continue + } + seen[part] = struct{}{} + anchors = append(anchors, part) + } + } + return anchors +} + +func isTokenSeparator(r rune) bool { + return r == '_' || r == '-' || r == '.' +} + +func topologyRelated(entryKey string, queryTokens map[string]struct{}, anchors []string) bool { + hasAnchor := false + for _, anchor := range anchors { + if _, ok := queryTokens[anchor]; ok { + hasAnchor = true + break + } + } + if !hasAnchor { + return false + } + return jaccard(tokenSet(entryKey), queryTokens) >= 0.25 +} + +func tokenSet(query string) map[string]struct{} { + tokens := map[string]struct{}{} + for token := range strings.FieldsSeq(query) { + tokens[token] = struct{}{} + } + return tokens +} + +func jaccard(a, b map[string]struct{}) float64 { + var intersection int + for token := range a { + if _, ok := b[token]; ok { + intersection++ + } + } + union := len(a) + len(b) - intersection + if union == 0 { + return 0 + } + return float64(intersection) / float64(union) +} + func cloneResults(results []database.SearchResult) []database.SearchResult { return slices.Clone(results) } @@ -270,9 +369,8 @@ func (t *tracker) observe(query string, results []database.SearchResult) { } t.drift = dist - weight := 1 / float64(t.seen) for i := range t.centroid { - t.centroid[i] += (point[i] - t.centroid[i]) * weight + t.centroid[i] += (point[i] - t.centroid[i]) * 0.35 } } diff --git a/pkg/rag/prefetch/prefetch_test.go b/pkg/rag/prefetch/prefetch_test.go index f25c4e76a..6f472f5c6 100644 --- a/pkg/rag/prefetch/prefetch_test.go +++ b/pkg/rag/prefetch/prefetch_test.go @@ -43,7 +43,33 @@ func TestCandidatesUseStableTopologyAndSourceNames(t *testing.T) { p.Observe("RAG manager", results) p.Observe("RAG manager cache", results) - assert.Equal(t, []string{"rag manager manager", "rag manager vector_store"}, p.Candidates("RAG manager", results)) + assert.Equal(t, []string{"rag manager vector store"}, p.Candidates("RAG manager", results)) +} + +func TestGetUsesTopologyForRelatedFollowupQuery(t *testing.T) { + p := New(Config{Enabled: true}) + p.Store("how does rag manager query work", []database.SearchResult{ + result("pkg/rag/manager.go", 0.92)[0], + result("pkg/rag/prefetch/prefetch.go", 0.76)[0], + }) + + got, ok := p.Get("rag manager cache behavior") + + require.True(t, ok) + require.Len(t, got, 2) + assert.Equal(t, "pkg/rag/manager.go", got[0].Document.SourcePath) +} + +func TestGetDoesNotUseTopologyAcrossUnrelatedQueries(t *testing.T) { + p := New(Config{Enabled: true}) + p.Store("docker model provider auth config validation", []database.SearchResult{ + result("pkg/model/provider/anthropic/federation/federation.go", 0.86)[0], + result("pkg/config/latest/auth.go", 0.78)[0], + }) + + _, ok := p.Get("tui message rendering scroll behavior") + + assert.False(t, ok) } func TestCandidatesSuppressedWhenDrifting(t *testing.T) { @@ -82,6 +108,67 @@ func TestPrefetchDeduplicatesInFlightAndStoresResult(t *testing.T) { assert.Equal(t, int64(1), calls.Load()) } +func TestPrefetchSurvivesForegroundContextCancellation(t *testing.T) { + p := New(Config{Enabled: true, Timeout: time.Second}) + ctx, cancel := context.WithCancel(t.Context()) + started := make(chan struct{}) + allowReturn := make(chan struct{}) + + fetch := func(ctx context.Context, _ string) ([]database.SearchResult, error) { + close(started) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-allowReturn: + return result("pkg/rag/manager.go", 0.9), nil + } + } + + p.Prefetch(ctx, "RAG manager", fetch) + <-started + cancel() + close(allowReturn) + + require.Eventually(t, func() bool { + _, ok := p.Get("rag manager") + return ok + }, time.Second, 10*time.Millisecond) +} + +func TestReplayHitRatesTopologyBeatsExactRepeat(t *testing.T) { + trace := replayTrace() + + exact := replayExact(trace) + topology := replayTopology(trace) + + assert.Equal(t, replayMetrics{exactHits: 2, topologyHits: 0, misses: 8}, exact) + assert.Equal(t, replayMetrics{exactHits: 2, topologyHits: 2, misses: 6}, topology) + assert.Greater(t, topology.hitRate(len(trace)), exact.hitRate(len(trace))) +} + +func BenchmarkReplayHitRates(b *testing.B) { + trace := replayTrace() + + b.Run("exact-repeat", func(b *testing.B) { + var metrics replayMetrics + for range b.N { + metrics = replayExact(trace) + } + b.ReportMetric(metrics.hitRate(len(trace))*100, "hit_percent") + b.ReportMetric(float64(metrics.misses), "retrievals/op") + }) + + b.Run("topology-assisted", func(b *testing.B) { + var metrics replayMetrics + for range b.N { + metrics = replayTopology(trace) + } + b.ReportMetric(metrics.hitRate(len(trace))*100, "hit_percent") + b.ReportMetric(float64(metrics.misses), "retrievals/op") + b.ReportMetric(float64(metrics.topologyHits), "topology_hits/op") + }) +} + func result(path string, similarity float64) []database.SearchResult { return []database.SearchResult{{ Document: database.Document{ @@ -91,3 +178,130 @@ func result(path string, similarity float64) []database.SearchResult { Similarity: similarity, }} } + +type replayTurn struct { + query string + results []database.SearchResult +} + +type replayMetrics struct { + exactHits int + topologyHits int + misses int +} + +func (m replayMetrics) hitRate(total int) float64 { + return float64(m.exactHits+m.topologyHits) / float64(total) +} + +func replayExact(trace []replayTurn) replayMetrics { + cache := map[string]struct{}{} + var metrics replayMetrics + for _, turn := range trace { + key := normalize(turn.query) + if _, ok := cache[key]; ok { + metrics.exactHits++ + continue + } + metrics.misses++ + cache[key] = struct{}{} + } + return metrics +} + +func replayTopology(trace []replayTurn) replayMetrics { + p := New(Config{Enabled: true}) + exactSeen := map[string]struct{}{} + var metrics replayMetrics + for _, turn := range trace { + key := normalize(turn.query) + if _, ok := exactSeen[key]; ok { + metrics.exactHits++ + continue + } + if _, ok := p.Get(turn.query); ok { + metrics.topologyHits++ + exactSeen[key] = struct{}{} + continue + } + metrics.misses++ + exactSeen[key] = struct{}{} + p.Store(turn.query, turn.results) + } + return metrics +} + +func replayTrace() []replayTurn { + return []replayTurn{ + { + query: "how does rag manager query work", + results: []database.SearchResult{ + result("pkg/rag/manager.go", 0.92)[0], + result("pkg/rag/prefetch/prefetch.go", 0.76)[0], + }, + }, + { + query: "rag manager cache behavior", + results: []database.SearchResult{ + result("pkg/rag/manager.go", 0.89)[0], + result("pkg/rag/prefetch/prefetch.go", 0.81)[0], + }, + }, + { + query: "how does rag manager query work", + results: []database.SearchResult{ + result("pkg/rag/manager.go", 0.92)[0], + result("pkg/rag/prefetch/prefetch.go", 0.76)[0], + }, + }, + { + query: "prefetch drift threshold behavior", + results: []database.SearchResult{ + result("pkg/rag/prefetch/prefetch.go", 0.91)[0], + result("pkg/rag/prefetch/prefetch_test.go", 0.84)[0], + }, + }, + { + query: "background prefetch should survive turn cancellation", + results: []database.SearchResult{ + result("pkg/rag/prefetch/prefetch.go", 0.88)[0], + result("pkg/tools/builtin/agent/agent.go", 0.67)[0], + }, + }, + { + query: "how are rag documents reindexed after file changes", + results: []database.SearchResult{ + result("pkg/rag/strategy/vector_store.go", 0.9)[0], + result("pkg/rag/strategy/bm25.go", 0.82)[0], + }, + }, + { + query: "docker model provider auth config validation", + results: []database.SearchResult{ + result("pkg/model/provider/anthropic/federation/federation.go", 0.86)[0], + result("pkg/config/latest/auth.go", 0.78)[0], + }, + }, + { + query: "anthropic auth config validation", + results: []database.SearchResult{ + result("pkg/config/latest/auth.go", 0.83)[0], + result("pkg/model/provider/anthropic/federation/federation.go", 0.79)[0], + }, + }, + { + query: "docker model provider auth config validation", + results: []database.SearchResult{ + result("pkg/model/provider/anthropic/federation/federation.go", 0.86)[0], + result("pkg/config/latest/auth.go", 0.78)[0], + }, + }, + { + query: "tui message rendering scroll behavior", + results: []database.SearchResult{ + result("pkg/tui/components/messages/messages.go", 0.88)[0], + result("pkg/tui/components/scrollview/scrollview.go", 0.74)[0], + }, + }, + } +} diff --git a/pkg/rag/prefetch/proofs/TopologyHit.lean b/pkg/rag/prefetch/proofs/TopologyHit.lean new file mode 100644 index 000000000..b0c5fd528 --- /dev/null +++ b/pkg/rag/prefetch/proofs/TopologyHit.lean @@ -0,0 +1,32 @@ +namespace RAGPrefetch + +def ExactHit {Query : Type} (cache : Query -> Prop) (query : Query) : Prop := + cache query + +def TopologyHit {Query : Type} (cache : Query -> Prop) (related : Query -> Query -> Prop) (query : Query) : Prop := + exists cached, cache cached /\ related query cached + +theorem exact_hit_is_topology_hit_when_related_self + {Query : Type} + {cache : Query -> Prop} + {related : Query -> Query -> Prop} + {query : Query} + (hit : ExactHit cache query) + (selfRelated : related query query) : + TopologyHit cache related query := by + exists query + +theorem topology_hit_can_strictly_extend_exact_hit + {Query : Type} + {cache : Query -> Prop} + {related : Query -> Query -> Prop} + {query cached : Query} + (cachedHit : cache cached) + (topologyRelated : related query cached) + (exactMiss : Not (cache query)) : + TopologyHit cache related query /\ Not (ExactHit cache query) := by + constructor + · exists cached + · exact exactMiss + +end RAGPrefetch From 0a048e1b5d10eb7282e26674522d7ab184f2c130 Mon Sep 17 00:00:00 2001 From: Teerth Sharma Date: Fri, 19 Jun 2026 21:58:12 +0530 Subject: [PATCH 4/4] fix: tighten rag cache topology behavior Signed-off-by: Teerth Sharma --- agent-schema.json | 45 +-- docs/tools/rag/index.md | 18 +- ...daptive_prefetch.yaml => query_cache.yaml} | 21 +- pkg/config/latest/types.go | 48 +-- pkg/config/schema_test.go | 1 + pkg/rag/builder.go | 30 +- pkg/rag/manager.go | 35 +- pkg/rag/manager_test.go | 82 ++++- pkg/rag/prefetch/prefetch.go | 312 +----------------- pkg/rag/prefetch/prefetch_test.go | 267 ++------------- pkg/rag/prefetch/proofs/TopologyHit.lean | 32 -- pkg/rag/strategy/bm25.go | 119 +++++-- pkg/rag/strategy/bm25_test.go | 61 ++++ pkg/rag/strategy/vector_store.go | 8 +- pkg/rag/topology/prior.go | 161 +++++++++ pkg/rag/topology/prior_test.go | 67 ++++ 16 files changed, 594 insertions(+), 713 deletions(-) rename examples/rag/{adaptive_prefetch.yaml => query_cache.yaml} (71%) delete mode 100644 pkg/rag/prefetch/proofs/TopologyHit.lean create mode 100644 pkg/rag/strategy/bm25_test.go create mode 100644 pkg/rag/topology/prior.go create mode 100644 pkg/rag/topology/prior_test.go diff --git a/agent-schema.json b/agent-schema.json index 03f85e5fd..6f85c9e8d 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -2692,11 +2692,11 @@ }, "prefetch": { "type": "object", - "description": "Optional adaptive query prefetching. When enabled, docker-agent caches repeated RAG queries and warms bounded follow-up candidates in the background.", + "description": "Optional exact-repeat RAG query caching. When enabled, docker-agent caches final results for repeated normalized queries.", "properties": { "enabled": { "type": "boolean", - "description": "Enable adaptive RAG query prefetching.", + "description": "Enable exact-repeat RAG query caching.", "default": false }, "max_entries": { @@ -2704,30 +2704,31 @@ "description": "Maximum number of cached query result sets.", "minimum": 1, "default": 32 + } + }, + "additionalProperties": false + }, + "topology_prior": { + "type": "object", + "description": "Optional topology-based score prior. When enabled, docker-agent runs normal retrieval first, then applies a small capped score bias to the current query's retrieved results based on query/source topology and recent result sources.", + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable topology-based score biasing.", + "default": false }, - "max_candidates": { - "type": "integer", - "description": "Maximum number of follow-up query candidates to prefetch after a cache miss.", - "minimum": 1, - "default": 2 - }, - "min_similarity": { - "type": "number", - "description": "Minimum result similarity required before a source path can seed a follow-up prefetch candidate.", - "minimum": 0, - "maximum": 1, - "default": 0.5 - }, - "drift_threshold": { + "weight": { "type": "number", - "description": "Maximum topology drift score that still allows background prefetching.", + "description": "Maximum topology contribution blended into each result score. Values above 0.2 are clamped in code.", "minimum": 0, - "default": 0.8 + "maximum": 0.2, + "default": 0.05 }, - "timeout": { - "type": "string", - "description": "Maximum duration for a background prefetch query, using Go duration syntax such as '10s' or '1m'.", - "default": "10s" + "max_source_history": { + "type": "integer", + "description": "Maximum number of recent result source paths kept for topology scoring.", + "minimum": 1, + "default": 32 } }, "additionalProperties": false diff --git a/docs/tools/rag/index.md b/docs/tools/rag/index.md index d96c894bd..70da0dff5 100644 --- a/docs/tools/rag/index.md +++ b/docs/tools/rag/index.md @@ -16,7 +16,7 @@ The `rag` toolset lets agents search through your documents to find relevant inf - **Multiple strategies** — Semantic embeddings, BM25 keyword search, and LLM-enhanced search - **Hybrid search** — Combine strategies with result fusion for best results - **Reranking** — Re-score results with specialized models for improved relevance -- **Adaptive prefetching** — Cache repeated queries and warm stable follow-up searches +- **Query caching** — Cache exact repeated queries after result post-processing ## Quick Start @@ -157,22 +157,18 @@ results: Supported reranking providers: **DMR** (native `/rerank` endpoint), **OpenAI**, **Anthropic**, **Gemini**. -## Adaptive Prefetching +## Query Caching -Adaptive prefetching is opt-in. It caches repeated RAG queries and, when query topology is stable, can reuse warmed results for closely related follow-up queries that share source-derived anchors. The topology path is gated by query-token overlap, source-path anchors, and drift detection so unrelated topic shifts fall back to normal retrieval. +Query caching is opt-in. It caches final RAG results for exact repeated queries after whitespace and case normalization. Related but different queries always run normal retrieval so results are scored for the user's current query. ```yaml results: prefetch: enabled: true max_entries: 32 - max_candidates: 2 - min_similarity: 0.5 - drift_threshold: 0.8 - timeout: 10s ``` -The prefetcher is bounded and non-blocking. Exact repeated queries are served from cache first; related-query topology hits are considered only after an exact miss. Background candidate prefetch is skipped when reranking is enabled so reranker errors and fallback behavior stay tied to the foreground query. The deterministic replay benchmark in `pkg/rag/prefetch` reports exact-repeat and topology-assisted hit rates separately. +The cache is bounded per RAG manager and stores cloned result slices so callers cannot mutate cached entries. It is cleared whenever the manager receives an indexing-complete event from initialization or live file-watcher reindexing, which prevents serving results from a previous index version. ## Code-Aware Chunking @@ -284,9 +280,5 @@ Look for log tags: `[RAG Manager]`, `[Chunked-Embeddings Strategy]`, `[BM25 Stra | `reranking.top_k` | int | (`limit`) | Only rerank top K results. Defaults to the results `limit` when set. | | `reranking.threshold` | float | `0.5` | Minimum relevance score after reranking | | `reranking.criteria` | string | — | Custom relevance guidance for the reranking model | -| `prefetch.enabled` | bool | `false` | Enable adaptive query prefetching | +| `prefetch.enabled` | bool | `false` | Enable exact-repeat query caching | | `prefetch.max_entries` | int | `32` | Maximum cached query result sets | -| `prefetch.max_candidates` | int | `2` | Maximum follow-up candidates warmed after a cache miss | -| `prefetch.min_similarity` | float | `0.5` | Minimum similarity score for source-derived candidates | -| `prefetch.drift_threshold` | float | `0.8` | Maximum topology drift that still allows background prefetch | -| `prefetch.timeout` | string | `10s` | Timeout for each background prefetch query | diff --git a/examples/rag/adaptive_prefetch.yaml b/examples/rag/query_cache.yaml similarity index 71% rename from examples/rag/adaptive_prefetch.yaml rename to examples/rag/query_cache.yaml index 29602a07f..ad668129d 100644 --- a/examples/rag/adaptive_prefetch.yaml +++ b/examples/rag/query_cache.yaml @@ -1,19 +1,18 @@ -# This example demonstrates adaptive RAG prefetching for repeated and -# topology-stable follow-up queries. +# This example demonstrates exact-repeat RAG query caching and a small topology prior. agents: root: model: openai/gpt-5-mini - description: assistant with adaptive RAG prefetching + description: assistant with RAG query caching and topology ranking instruction: | You are a helpful assistant with access to hybrid retrieval. Use the knowledge base before answering questions about blorks. toolsets: - type: rag - ref: prefetched_knowledge + ref: cached_knowledge rag: - prefetched_knowledge: + cached_knowledge: tool: description: to be used to search for information about blorks docs: @@ -21,7 +20,7 @@ rag: strategies: - type: chunked-embeddings embedding_model: openai/text-embedding-3-small - database: ./adaptive_prefetch_embeddings.db + database: ./query_cache_embeddings.db vector_dimensions: 1536 similarity_metric: cosine_similarity threshold: 0.5 @@ -31,7 +30,7 @@ rag: overlap: 100 respect_word_boundaries: true - type: bm25 - database: ./adaptive_prefetch_bm25.db + database: ./query_cache_bm25.db k1: 1.5 b: 0.75 threshold: 0.3 @@ -49,7 +48,7 @@ rag: prefetch: enabled: true max_entries: 32 - max_candidates: 2 - min_similarity: 0.5 - drift_threshold: 0.8 - timeout: 10s + topology_prior: + enabled: true + weight: 0.05 + max_source_history: 32 diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index ec07b8036..a3f193270 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -1810,23 +1810,27 @@ func (c *RAGChunkingConfig) UnmarshalYAML(unmarshal func(any) error) error { // RAGResultsConfig represents result post-processing configuration (common across strategies) type RAGResultsConfig struct { - Limit int `json:"limit,omitempty"` // Maximum number of results to return (top K) - Fusion *RAGFusionConfig `json:"fusion,omitempty"` // How to combine results from multiple strategies - Reranking *RAGRerankingConfig `json:"reranking,omitempty"` // Optional reranking configuration - Prefetch *RAGPrefetchConfig `json:"prefetch,omitempty"` // Optional adaptive query prefetching - Deduplicate bool `json:"deduplicate,omitempty"` // Remove duplicate documents across strategies - IncludeScore bool `json:"include_score,omitempty"` // Include relevance scores in results - ReturnFullContent bool `json:"return_full_content,omitempty"` // Return full document content instead of just matched chunks + Limit int `json:"limit,omitempty"` // Maximum number of results to return (top K) + Fusion *RAGFusionConfig `json:"fusion,omitempty"` // How to combine results from multiple strategies + Reranking *RAGRerankingConfig `json:"reranking,omitempty"` // Optional reranking configuration + Prefetch *RAGPrefetchConfig `json:"prefetch,omitempty"` // Optional exact-repeat query cache + TopologyPrior *RAGTopologyPriorConfig `json:"topology_prior,omitempty"` // Optional topology-based score prior + Deduplicate bool `json:"deduplicate,omitempty"` // Remove duplicate documents across strategies + IncludeScore bool `json:"include_score,omitempty"` // Include relevance scores in results + ReturnFullContent bool `json:"return_full_content,omitempty"` // Return full document content instead of just matched chunks } -// RAGPrefetchConfig configures adaptive RAG query prefetching. +// RAGPrefetchConfig configures the exact-repeat RAG query cache. type RAGPrefetchConfig struct { - Enabled bool `json:"enabled,omitempty"` - MaxEntries int `json:"max_entries,omitempty"` - MaxCandidates int `json:"max_candidates,omitempty"` - MinSimilarity float64 `json:"min_similarity,omitempty"` - DriftThreshold float64 `json:"drift_threshold,omitempty"` - Timeout Duration `json:"timeout,omitzero"` + Enabled bool `json:"enabled,omitempty"` + MaxEntries int `json:"max_entries,omitempty"` +} + +// RAGTopologyPriorConfig configures topology-based score biasing. +type RAGTopologyPriorConfig struct { + Enabled bool `json:"enabled,omitempty"` + Weight float64 `json:"weight,omitempty"` + MaxSourceHistory int `json:"max_source_history,omitempty"` } // RAGRerankingConfig represents reranking configuration @@ -1879,13 +1883,14 @@ func defaultRAGResultsConfig() RAGResultsConfig { // UnmarshalYAML implements custom unmarshaling so we can apply sensible defaults func (r *RAGResultsConfig) UnmarshalYAML(unmarshal func(any) error) error { var raw struct { - Limit int `json:"limit,omitempty"` - Fusion *RAGFusionConfig `json:"fusion,omitempty"` - Reranking *RAGRerankingConfig `json:"reranking,omitempty"` - Prefetch *RAGPrefetchConfig `json:"prefetch,omitempty"` - Deduplicate *bool `json:"deduplicate,omitempty"` - IncludeScore *bool `json:"include_score,omitempty"` - ReturnFullContent *bool `json:"return_full_content,omitempty"` + Limit int `json:"limit,omitempty"` + Fusion *RAGFusionConfig `json:"fusion,omitempty"` + Reranking *RAGRerankingConfig `json:"reranking,omitempty"` + Prefetch *RAGPrefetchConfig `json:"prefetch,omitempty"` + TopologyPrior *RAGTopologyPriorConfig `json:"topology_prior,omitempty"` + Deduplicate *bool `json:"deduplicate,omitempty"` + IncludeScore *bool `json:"include_score,omitempty"` + ReturnFullContent *bool `json:"return_full_content,omitempty"` } if err := unmarshal(&raw); err != nil { @@ -1902,6 +1907,7 @@ func (r *RAGResultsConfig) UnmarshalYAML(unmarshal func(any) error) error { r.Fusion = raw.Fusion r.Reranking = raw.Reranking r.Prefetch = raw.Prefetch + r.TopologyPrior = raw.TopologyPrior if raw.Deduplicate != nil { r.Deduplicate = *raw.Deduplicate diff --git a/pkg/config/schema_test.go b/pkg/config/schema_test.go index df9f062ec..ca2784f1c 100644 --- a/pkg/config/schema_test.go +++ b/pkg/config/schema_test.go @@ -127,6 +127,7 @@ func TestSchemaMatchesGoTypes(t *testing.T) { {reflect.TypeFor[latest.RAGFusionConfig](), []string{"RAGConfig", "results", "fusion"}, "RAGFusionConfig (RAGConfig.results.fusion)"}, {reflect.TypeFor[latest.RAGRerankingConfig](), []string{"RAGConfig", "results", "reranking"}, "RAGRerankingConfig (RAGConfig.results.reranking)"}, {reflect.TypeFor[latest.RAGPrefetchConfig](), []string{"RAGConfig", "results", "prefetch"}, "RAGPrefetchConfig (RAGConfig.results.prefetch)"}, + {reflect.TypeFor[latest.RAGTopologyPriorConfig](), []string{"RAGConfig", "results", "topology_prior"}, "RAGTopologyPriorConfig (RAGConfig.results.topology_prior)"}, {reflect.TypeFor[latest.RAGChunkingConfig](), []string{"RAGConfig", "strategies", "*", "chunking"}, "RAGChunkingConfig (RAGConfig.strategies[].chunking)"}, } diff --git a/pkg/rag/builder.go b/pkg/rag/builder.go index 8853c00c7..2fb5a1097 100644 --- a/pkg/rag/builder.go +++ b/pkg/rag/builder.go @@ -132,11 +132,12 @@ func buildManagerConfig( Description: ragCfg.Tool.Description, Instruction: ragCfg.Tool.Instruction, }, - Docs: GetAbsolutePaths(buildCfg.ParentDir, ragCfg.Docs), - Results: results, - FusionConfig: fusionCfg, - StrategyConfigs: strategyConfigs, - PrefetchConfig: buildPrefetchConfig(ragCfg.Results.Prefetch), + Docs: GetAbsolutePaths(buildCfg.ParentDir, ragCfg.Docs), + Results: results, + FusionConfig: fusionCfg, + StrategyConfigs: strategyConfigs, + PrefetchConfig: buildPrefetchConfig(ragCfg.Results.Prefetch), + TopologyPriorConfig: buildTopologyPriorConfig(ragCfg.Results.TopologyPrior), }, nil } @@ -145,12 +146,19 @@ func buildPrefetchConfig(cfg *latest.RAGPrefetchConfig) prefetch.Config { return prefetch.Config{} } return prefetch.Config{ - Enabled: cfg.Enabled, - MaxEntries: cfg.MaxEntries, - MaxCandidates: cfg.MaxCandidates, - MinSimilarity: cfg.MinSimilarity, - DriftThreshold: cfg.DriftThreshold, - Timeout: cfg.Timeout.Duration, + Enabled: cfg.Enabled, + MaxEntries: cfg.MaxEntries, + } +} + +func buildTopologyPriorConfig(cfg *latest.RAGTopologyPriorConfig) TopologyPriorConfig { + if cfg == nil { + return TopologyPriorConfig{} + } + return TopologyPriorConfig{ + Enabled: cfg.Enabled, + Weight: cfg.Weight, + MaxSourceHistory: cfg.MaxSourceHistory, } } diff --git a/pkg/rag/manager.go b/pkg/rag/manager.go index 6fb364970..cdd38c3eb 100644 --- a/pkg/rag/manager.go +++ b/pkg/rag/manager.go @@ -18,6 +18,7 @@ import ( "github.com/docker/docker-agent/pkg/rag/prefetch" "github.com/docker/docker-agent/pkg/rag/rerank" "github.com/docker/docker-agent/pkg/rag/strategy" + "github.com/docker/docker-agent/pkg/rag/topology" "github.com/docker/docker-agent/pkg/rag/types" ) @@ -31,14 +32,17 @@ type ToolConfig struct { // Config represents RAG manager configuration in domain terms, // independent of any particular config schema version. type Config struct { - Tool ToolConfig - Docs []string - Results ResultsConfig - FusionConfig *FusionConfig - StrategyConfigs []strategy.Config - PrefetchConfig prefetch.Config + Tool ToolConfig + Docs []string + Results ResultsConfig + FusionConfig *FusionConfig + StrategyConfigs []strategy.Config + PrefetchConfig prefetch.Config + TopologyPriorConfig TopologyPriorConfig } +type TopologyPriorConfig = topology.Config + // ResultsConfig captures result-postprocessing behavior for the manager. type ResultsConfig struct { Limit int // Maximum number of results to return (top K) @@ -67,6 +71,7 @@ type Manager struct { rerankDisabled atomic.Bool // Set after a non-retryable reranking error to stop doomed requests events <-chan types.Event // Shared event channel from strategies and other RAG operations prefetcher *prefetch.Prefetcher + topologyPrior *topology.Prior } // FusionConfig holds configuration for result fusion @@ -134,6 +139,7 @@ func New(ctx context.Context, name string, config Config, strategyEvents <-chan } prefetcher := prefetch.New(config.PrefetchConfig) + topologyPrior := topology.NewPrior(config.TopologyPriorConfig) m := &Manager{ name: name, config: config, @@ -141,14 +147,15 @@ func New(ctx context.Context, name string, config Config, strategyEvents <-chan strategyConfigs: strategyConfigMap, fusion: fusionStrategy, reranker: reranker, - events: forwardEvents(ctx, strategyEvents, prefetcher), + events: forwardEvents(ctx, strategyEvents, prefetcher, topologyPrior), prefetcher: prefetcher, + topologyPrior: topologyPrior, } return m, nil } -func forwardEvents(ctx context.Context, in <-chan types.Event, prefetcher *prefetch.Prefetcher) <-chan types.Event { +func forwardEvents(ctx context.Context, in <-chan types.Event, prefetcher *prefetch.Prefetcher, topologyPrior *topology.Prior) <-chan types.Event { if in == nil { return nil } @@ -165,6 +172,7 @@ func forwardEvents(ctx context.Context, in <-chan types.Event, prefetcher *prefe } if event.Type == types.EventTypeIndexingComplete { prefetcher.Clear() + topologyPrior.Clear() } select { case out <- event: @@ -255,7 +263,7 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes "query_length", len(query)) if cached, ok := m.prefetcher.Get(query); ok { - slog.DebugContext(ctx, "[RAG Manager] Returning prefetched RAG results", + slog.DebugContext(ctx, "[RAG Manager] Returning cached RAG results", "rag_name", m.name, "result_count", len(cached)) return cached, nil @@ -268,12 +276,7 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes if ctx.Err() == nil { m.prefetcher.Store(query, results) - m.prefetcher.Observe(query, results) - } - if ctx.Err() == nil && m.reranker == nil { - for _, candidate := range m.prefetcher.Candidates(query, results) { - m.prefetcher.Prefetch(ctx, candidate, m.queryUncached) - } + m.topologyPrior.Observe(query, results) } return results, nil @@ -402,6 +405,7 @@ func (m *Manager) queryStrategies(ctx context.Context, query string) ([]database func (m *Manager) postprocessQueryResults(ctx context.Context, query string, results []database.SearchResult) []database.SearchResult { // Apply reranking if configured (before limit and deduplication) results = m.rerank(ctx, query, results) + results = m.topologyPrior.Apply(query, results) // Apply result limit if configured if limit := m.config.Results.Limit; limit > 0 && len(results) > limit { @@ -486,6 +490,7 @@ func (m *Manager) CheckAndReindexChangedFiles(ctx context.Context) error { } } m.prefetcher.Clear() + m.topologyPrior.Clear() return nil } diff --git a/pkg/rag/manager_test.go b/pkg/rag/manager_test.go index efa23b407..9db7a6a4e 100644 --- a/pkg/rag/manager_test.go +++ b/pkg/rag/manager_test.go @@ -41,16 +41,22 @@ func TestGetAbsolutePaths_NilInput(t *testing.T) { } type countingStrategy struct { - calls atomic.Int64 - results []database.SearchResult + calls atomic.Int64 + results []database.SearchResult + resultsByQuery map[string][]database.SearchResult } func (s *countingStrategy) Initialize(context.Context, []string, strategy.ChunkingConfig) error { return nil } -func (s *countingStrategy) Query(context.Context, string, int, float64) ([]database.SearchResult, error) { +func (s *countingStrategy) Query(_ context.Context, query string, _ int, _ float64) ([]database.SearchResult, error) { s.calls.Add(1) + if s.resultsByQuery != nil { + if results, ok := s.resultsByQuery[query]; ok { + return append([]database.SearchResult(nil), results...), nil + } + } return append([]database.SearchResult(nil), s.results...), nil } @@ -118,3 +124,73 @@ func TestManagerClearsPrefetchCacheOnIndexingCompleteEvent(t *testing.T) { return !ok }, time.Second, 10*time.Millisecond) } + +func TestManagerClearsTopologyPriorOnIndexingCompleteEvent(t *testing.T) { + events := make(chan types.Event, 1) + m, err := New(t.Context(), "test", Config{ + TopologyPriorConfig: TopologyPriorConfig{Enabled: true, Weight: 0.05}, + StrategyConfigs: []strategy.Config{{ + Name: "counting", + Strategy: &countingStrategy{}, + }}, + }, events) + require.NoError(t, err) + + m.topologyPrior.Observe("how does rag manager query work", []database.SearchResult{{ + Document: database.Document{ID: "1", SourcePath: "pkg/rag/manager.go", Content: "old"}, + Similarity: 0.91, + }}) + + events <- types.Event{Type: types.EventTypeIndexingComplete} + + require.Eventually(t, func() bool { + got := m.topologyPrior.Apply("rag manager cache behavior", []database.SearchResult{ + {Document: database.Document{ID: "2", SourcePath: "pkg/model/provider/client.go", Content: "provider"}, Similarity: 0.72}, + {Document: database.Document{ID: "3", SourcePath: "pkg/rag/manager.go", Content: "manager"}, Similarity: 0.70}, + }) + return got[0].Document.SourcePath == "pkg/model/provider/client.go" + }, time.Second, 10*time.Millisecond) +} + +func TestTopologyPriorReranksOnlyFreshCurrentQueryResults(t *testing.T) { + searchStrategy := &countingStrategy{resultsByQuery: map[string][]database.SearchResult{ + "how does rag manager query work": { + { + Document: database.Document{ID: "1", SourcePath: "pkg/rag/manager.go", Content: "manager"}, + Similarity: 0.91, + }, + }, + "rag manager cache behavior": { + { + Document: database.Document{ID: "2", SourcePath: "pkg/model/provider/client.go", Content: "provider"}, + Similarity: 0.72, + }, + { + Document: database.Document{ID: "3", SourcePath: "pkg/rag/manager.go", Content: "manager current"}, + Similarity: 0.70, + }, + }, + }} + m, err := New(t.Context(), "test", Config{ + Results: ResultsConfig{Limit: 15}, + TopologyPriorConfig: TopologyPriorConfig{Enabled: true, Weight: 0.05, MaxSourceHistory: 8}, + StrategyConfigs: []strategy.Config{{ + Name: "counting", + Strategy: searchStrategy, + Limit: 5, + Threshold: 0.5, + }}, + }, nil) + require.NoError(t, err) + + _, err = m.Query(t.Context(), "how does rag manager query work") + require.NoError(t, err) + + got, err := m.Query(t.Context(), "rag manager cache behavior") + require.NoError(t, err) + + require.Len(t, got, 2) + assert.Equal(t, int64(2), searchStrategy.calls.Load()) + assert.Equal(t, "pkg/rag/manager.go", got[0].Document.SourcePath) + assert.Equal(t, "manager current", got[0].Document.Content) +} diff --git a/pkg/rag/prefetch/prefetch.go b/pkg/rag/prefetch/prefetch.go index 33d470c1f..7747d0899 100644 --- a/pkg/rag/prefetch/prefetch.go +++ b/pkg/rag/prefetch/prefetch.go @@ -1,71 +1,37 @@ package prefetch import ( - "context" - "path/filepath" "slices" "strings" "sync" - "time" "github.com/docker/docker-agent/pkg/rag/database" ) const ( - defaultMaxEntries = 32 - defaultMaxCandidates = 2 - defaultMinSimilarity = 0.5 - defaultDriftThreshold = 0.8 - defaultTimeout = 10 * time.Second + defaultMaxEntries = 32 ) -// Config controls adaptive RAG prefetching. The zero value disables it. +// Config controls the exact-repeat RAG query cache. The zero value disables it. type Config struct { - Enabled bool - MaxEntries int - MaxCandidates int - MinSimilarity float64 - DriftThreshold float64 - Timeout time.Duration + Enabled bool + MaxEntries int } func (c Config) withDefaults() Config { if c.MaxEntries <= 0 { c.MaxEntries = defaultMaxEntries } - if c.MaxCandidates <= 0 { - c.MaxCandidates = defaultMaxCandidates - } - if c.MinSimilarity <= 0 { - c.MinSimilarity = defaultMinSimilarity - } - if c.DriftThreshold <= 0 { - c.DriftThreshold = defaultDriftThreshold - } - if c.Timeout <= 0 { - c.Timeout = defaultTimeout - } return c } -// FetchFunc runs an uncached query for a prefetch candidate. -type FetchFunc func(context.Context, string) ([]database.SearchResult, error) - -// Prefetcher owns bounded result cache, lightweight topology state, and -// background prefetch scheduling for one RAG manager. +// Prefetcher owns a bounded result cache for one RAG manager. type Prefetcher struct { cfg Config - mu sync.Mutex - cache map[string]cacheEntry - order []string - inflight map[string]struct{} - tracker tracker -} - -type cacheEntry struct { - results []database.SearchResult - anchors []string + mu sync.Mutex + cache map[string][]database.SearchResult + order []string } // New creates a prefetcher. It returns nil when disabled so callers can keep @@ -76,13 +42,12 @@ func New(cfg Config) *Prefetcher { } cfg = cfg.withDefaults() return &Prefetcher{ - cfg: cfg, - cache: make(map[string]cacheEntry, cfg.MaxEntries), - inflight: make(map[string]struct{}), + cfg: cfg, + cache: make(map[string][]database.SearchResult, cfg.MaxEntries), } } -// Get returns cached final results for query. +// Get returns cached final results for the exact normalized query. func (p *Prefetcher) Get(query string) ([]database.SearchResult, bool) { if p == nil { return nil, false @@ -96,9 +61,9 @@ func (p *Prefetcher) Get(query string) ([]database.SearchResult, bool) { defer p.mu.Unlock() results, ok := p.cache[key] if !ok { - return p.getTopologyLocked(key) + return nil, false } - return cloneResults(results.results), true + return cloneResults(results), true } // Store records final post-processed results for query. @@ -120,10 +85,7 @@ func (p *Prefetcher) storeLocked(key string, results []database.SearchResult) { if _, exists := p.cache[key]; !exists { p.order = append(p.order, key) } - p.cache[key] = cacheEntry{ - results: cloneResults(results), - anchors: anchorsFor(results), - } + p.cache[key] = cloneResults(results) for len(p.order) > p.cfg.MaxEntries { oldest := p.order[0] @@ -140,259 +102,13 @@ func (p *Prefetcher) Clear() { p.mu.Lock() defer p.mu.Unlock() clear(p.cache) - clear(p.inflight) p.order = nil } -func (p *Prefetcher) getTopologyLocked(key string) ([]database.SearchResult, bool) { - if !p.tracker.stable(p.cfg.DriftThreshold) { - return nil, false - } - queryTokens := tokenSet(key) - for _, entryKey := range slices.Backward(p.order) { - entry := p.cache[entryKey] - if !topologyRelated(entryKey, queryTokens, entry.anchors) { - continue - } - return cloneResults(entry.results), true - } - return nil, false -} - -// Observe updates the topology tracker with query/result metadata. -func (p *Prefetcher) Observe(query string, results []database.SearchResult) { - if p == nil { - return - } - p.mu.Lock() - defer p.mu.Unlock() - p.tracker.observe(query, results) -} - -// Candidates returns deterministic follow-up queries worth warming. -func (p *Prefetcher) Candidates(query string, results []database.SearchResult) []string { - if p == nil || len(results) == 0 { - return nil - } - - p.mu.Lock() - stable := p.tracker.stable(p.cfg.DriftThreshold) - p.mu.Unlock() - if !stable { - return nil - } - - base := normalize(query) - baseTokens := tokenSet(base) - seen := map[string]struct{}{base: {}} - candidates := make([]string, 0, p.cfg.MaxCandidates) - - for _, result := range results { - if len(candidates) >= p.cfg.MaxCandidates { - break - } - if result.Similarity < p.cfg.MinSimilarity { - continue - } - suffix := sourceSuffix(result.Document.SourcePath, baseTokens) - if suffix == "" { - continue - } - candidate := strings.TrimSpace(base + " " + suffix) - if candidate == "" { - continue - } - if _, ok := seen[candidate]; ok { - continue - } - seen[candidate] = struct{}{} - candidates = append(candidates, candidate) - } - - return candidates -} - -// Prefetch schedules one bounded background fetch for query. -func (p *Prefetcher) Prefetch(ctx context.Context, query string, fetch FetchFunc) { - if p == nil || fetch == nil { - return - } - key := normalize(query) - if key == "" { - return - } - - p.mu.Lock() - if _, ok := p.cache[key]; ok { - p.mu.Unlock() - return - } - if _, ok := p.inflight[key]; ok { - p.mu.Unlock() - return - } - p.inflight[key] = struct{}{} - p.mu.Unlock() - - go func() { - defer func() { - p.mu.Lock() - delete(p.inflight, key) - p.mu.Unlock() - }() - - prefetchCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), p.cfg.Timeout) - defer cancel() - - results, err := fetch(prefetchCtx, key) - if err != nil || len(results) == 0 { - return - } - - p.mu.Lock() - p.storeLocked(key, results) - p.mu.Unlock() - }() -} - func normalize(query string) string { return strings.Join(strings.Fields(strings.ToLower(query)), " ") } -func sourceToken(path string) string { - base := filepath.Base(path) - if base == "." || base == string(filepath.Separator) { - return "" - } - ext := filepath.Ext(base) - return strings.TrimSuffix(base, ext) -} - -func sourceSuffix(path string, seen map[string]struct{}) string { - parts := strings.FieldsFunc(sourceToken(path), isTokenSeparator) - novel := make([]string, 0, len(parts)) - for _, part := range parts { - part = normalize(part) - if part == "" { - continue - } - if _, ok := seen[part]; ok { - continue - } - novel = append(novel, part) - } - return strings.Join(novel, " ") -} - -func anchorsFor(results []database.SearchResult) []string { - seen := map[string]struct{}{} - anchors := make([]string, 0, len(results)) - for _, result := range results { - for _, part := range strings.FieldsFunc(sourceToken(result.Document.SourcePath), isTokenSeparator) { - part = normalize(part) - if len(part) < 3 { - continue - } - if _, ok := seen[part]; ok { - continue - } - seen[part] = struct{}{} - anchors = append(anchors, part) - } - } - return anchors -} - -func isTokenSeparator(r rune) bool { - return r == '_' || r == '-' || r == '.' -} - -func topologyRelated(entryKey string, queryTokens map[string]struct{}, anchors []string) bool { - hasAnchor := false - for _, anchor := range anchors { - if _, ok := queryTokens[anchor]; ok { - hasAnchor = true - break - } - } - if !hasAnchor { - return false - } - return jaccard(tokenSet(entryKey), queryTokens) >= 0.25 -} - -func tokenSet(query string) map[string]struct{} { - tokens := map[string]struct{}{} - for token := range strings.FieldsSeq(query) { - tokens[token] = struct{}{} - } - return tokens -} - -func jaccard(a, b map[string]struct{}) float64 { - var intersection int - for token := range a { - if _, ok := b[token]; ok { - intersection++ - } - } - union := len(a) + len(b) - intersection - if union == 0 { - return 0 - } - return float64(intersection) / float64(union) -} - func cloneResults(results []database.SearchResult) []database.SearchResult { return slices.Clone(results) } - -type tracker struct { - seen int - centroid [4]float64 - drift float64 -} - -func (t *tracker) observe(query string, results []database.SearchResult) { - point := pointFor(query, results) - t.seen++ - if t.seen == 1 { - t.centroid = point - t.drift = 0 - return - } - - var dist float64 - for i := range point { - diff := point[i] - t.centroid[i] - dist += diff * diff - } - t.drift = dist - - for i := range t.centroid { - t.centroid[i] += (point[i] - t.centroid[i]) * 0.35 - } -} - -func (t *tracker) stable(threshold float64) bool { - if t.seen < 2 { - return true - } - return t.drift <= threshold -} - -func pointFor(query string, results []database.SearchResult) [4]float64 { - var avgSimilarity float64 - for _, result := range results { - avgSimilarity += result.Similarity - } - if len(results) > 0 { - avgSimilarity /= float64(len(results)) - } - return [4]float64{ - float64(len(query)) / 256, - float64(len(strings.Fields(query))) / 32, - float64(len(results)) / 32, - avgSimilarity, - } -} diff --git a/pkg/rag/prefetch/prefetch_test.go b/pkg/rag/prefetch/prefetch_test.go index 6f472f5c6..708c2ad0f 100644 --- a/pkg/rag/prefetch/prefetch_test.go +++ b/pkg/rag/prefetch/prefetch_test.go @@ -1,10 +1,7 @@ package prefetch import ( - "context" - "sync/atomic" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -32,141 +29,42 @@ func TestStoreGetAndEvictOldest(t *testing.T) { assert.Equal(t, "c.go", got[0].Document.SourcePath) } -func TestCandidatesUseStableTopologyAndSourceNames(t *testing.T) { - p := New(Config{Enabled: true, MaxCandidates: 2, MinSimilarity: 0.5, DriftThreshold: 1}) - results := []database.SearchResult{ - result("pkg/rag/manager.go", 0.9)[0], - result("pkg/rag/vector_store.go", 0.7)[0], - result("pkg/rag/weak.go", 0.1)[0], - } - - p.Observe("RAG manager", results) - p.Observe("RAG manager cache", results) - - assert.Equal(t, []string{"rag manager vector store"}, p.Candidates("RAG manager", results)) -} - -func TestGetUsesTopologyForRelatedFollowupQuery(t *testing.T) { +func TestGetOnlyMatchesExactNormalizedQuery(t *testing.T) { p := New(Config{Enabled: true}) p.Store("how does rag manager query work", []database.SearchResult{ result("pkg/rag/manager.go", 0.92)[0], result("pkg/rag/prefetch/prefetch.go", 0.76)[0], }) - got, ok := p.Get("rag manager cache behavior") - - require.True(t, ok) - require.Len(t, got, 2) - assert.Equal(t, "pkg/rag/manager.go", got[0].Document.SourcePath) -} - -func TestGetDoesNotUseTopologyAcrossUnrelatedQueries(t *testing.T) { - p := New(Config{Enabled: true}) - p.Store("docker model provider auth config validation", []database.SearchResult{ - result("pkg/model/provider/anthropic/federation/federation.go", 0.86)[0], - result("pkg/config/latest/auth.go", 0.78)[0], - }) - - _, ok := p.Get("tui message rendering scroll behavior") + _, ok := p.Get("rag manager cache behavior") assert.False(t, ok) } -func TestCandidatesSuppressedWhenDrifting(t *testing.T) { - p := New(Config{Enabled: true, DriftThreshold: 0.0001}) - results := result("pkg/rag/manager.go", 0.9) - - p.Observe("short", results) - p.Observe("this is a completely different and much longer query with more tokens", results) - - assert.Empty(t, p.Candidates("short", results)) -} - -func TestPrefetchDeduplicatesInFlightAndStoresResult(t *testing.T) { - p := New(Config{Enabled: true, Timeout: time.Second}) - var calls atomic.Int64 - done := make(chan struct{}) - - fetch := func(context.Context, string) ([]database.SearchResult, error) { - calls.Add(1) - close(done) - return result("pkg/rag/manager.go", 0.9), nil - } - - p.Prefetch(t.Context(), "RAG manager", fetch) - p.Prefetch(t.Context(), "rag manager", fetch) - - select { - case <-done: - case <-time.After(time.Second): - t.Fatal("prefetch did not run") - } - require.Eventually(t, func() bool { - _, ok := p.Get("rag manager") - return ok - }, time.Second, 10*time.Millisecond) - assert.Equal(t, int64(1), calls.Load()) -} - -func TestPrefetchSurvivesForegroundContextCancellation(t *testing.T) { - p := New(Config{Enabled: true, Timeout: time.Second}) - ctx, cancel := context.WithCancel(t.Context()) - started := make(chan struct{}) - allowReturn := make(chan struct{}) - - fetch := func(ctx context.Context, _ string) ([]database.SearchResult, error) { - close(started) - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-allowReturn: - return result("pkg/rag/manager.go", 0.9), nil - } - } - - p.Prefetch(ctx, "RAG manager", fetch) - <-started - cancel() - close(allowReturn) - - require.Eventually(t, func() bool { - _, ok := p.Get("rag manager") - return ok - }, time.Second, 10*time.Millisecond) -} +func TestStoreAndGetCloneResults(t *testing.T) { + p := New(Config{Enabled: true}) + results := result("docs/rag.md", 0.9) -func TestReplayHitRatesTopologyBeatsExactRepeat(t *testing.T) { - trace := replayTrace() + p.Store("RAG cache", results) + results[0].Document.Content = "mutated after store" - exact := replayExact(trace) - topology := replayTopology(trace) + got, ok := p.Get("rag cache") + require.True(t, ok) + got[0].Document.Content = "mutated after get" - assert.Equal(t, replayMetrics{exactHits: 2, topologyHits: 0, misses: 8}, exact) - assert.Equal(t, replayMetrics{exactHits: 2, topologyHits: 2, misses: 6}, topology) - assert.Greater(t, topology.hitRate(len(trace)), exact.hitRate(len(trace))) + again, ok := p.Get("rag cache") + require.True(t, ok) + assert.Equal(t, "content", again[0].Document.Content) } -func BenchmarkReplayHitRates(b *testing.B) { - trace := replayTrace() +func TestClearDropsCachedResults(t *testing.T) { + p := New(Config{Enabled: true}) + p.Store("RAG cache", result("docs/rag.md", 0.9)) - b.Run("exact-repeat", func(b *testing.B) { - var metrics replayMetrics - for range b.N { - metrics = replayExact(trace) - } - b.ReportMetric(metrics.hitRate(len(trace))*100, "hit_percent") - b.ReportMetric(float64(metrics.misses), "retrievals/op") - }) + p.Clear() - b.Run("topology-assisted", func(b *testing.B) { - var metrics replayMetrics - for range b.N { - metrics = replayTopology(trace) - } - b.ReportMetric(metrics.hitRate(len(trace))*100, "hit_percent") - b.ReportMetric(float64(metrics.misses), "retrievals/op") - b.ReportMetric(float64(metrics.topologyHits), "topology_hits/op") - }) + _, ok := p.Get("RAG cache") + assert.False(t, ok) } func result(path string, similarity float64) []database.SearchResult { @@ -178,130 +76,3 @@ func result(path string, similarity float64) []database.SearchResult { Similarity: similarity, }} } - -type replayTurn struct { - query string - results []database.SearchResult -} - -type replayMetrics struct { - exactHits int - topologyHits int - misses int -} - -func (m replayMetrics) hitRate(total int) float64 { - return float64(m.exactHits+m.topologyHits) / float64(total) -} - -func replayExact(trace []replayTurn) replayMetrics { - cache := map[string]struct{}{} - var metrics replayMetrics - for _, turn := range trace { - key := normalize(turn.query) - if _, ok := cache[key]; ok { - metrics.exactHits++ - continue - } - metrics.misses++ - cache[key] = struct{}{} - } - return metrics -} - -func replayTopology(trace []replayTurn) replayMetrics { - p := New(Config{Enabled: true}) - exactSeen := map[string]struct{}{} - var metrics replayMetrics - for _, turn := range trace { - key := normalize(turn.query) - if _, ok := exactSeen[key]; ok { - metrics.exactHits++ - continue - } - if _, ok := p.Get(turn.query); ok { - metrics.topologyHits++ - exactSeen[key] = struct{}{} - continue - } - metrics.misses++ - exactSeen[key] = struct{}{} - p.Store(turn.query, turn.results) - } - return metrics -} - -func replayTrace() []replayTurn { - return []replayTurn{ - { - query: "how does rag manager query work", - results: []database.SearchResult{ - result("pkg/rag/manager.go", 0.92)[0], - result("pkg/rag/prefetch/prefetch.go", 0.76)[0], - }, - }, - { - query: "rag manager cache behavior", - results: []database.SearchResult{ - result("pkg/rag/manager.go", 0.89)[0], - result("pkg/rag/prefetch/prefetch.go", 0.81)[0], - }, - }, - { - query: "how does rag manager query work", - results: []database.SearchResult{ - result("pkg/rag/manager.go", 0.92)[0], - result("pkg/rag/prefetch/prefetch.go", 0.76)[0], - }, - }, - { - query: "prefetch drift threshold behavior", - results: []database.SearchResult{ - result("pkg/rag/prefetch/prefetch.go", 0.91)[0], - result("pkg/rag/prefetch/prefetch_test.go", 0.84)[0], - }, - }, - { - query: "background prefetch should survive turn cancellation", - results: []database.SearchResult{ - result("pkg/rag/prefetch/prefetch.go", 0.88)[0], - result("pkg/tools/builtin/agent/agent.go", 0.67)[0], - }, - }, - { - query: "how are rag documents reindexed after file changes", - results: []database.SearchResult{ - result("pkg/rag/strategy/vector_store.go", 0.9)[0], - result("pkg/rag/strategy/bm25.go", 0.82)[0], - }, - }, - { - query: "docker model provider auth config validation", - results: []database.SearchResult{ - result("pkg/model/provider/anthropic/federation/federation.go", 0.86)[0], - result("pkg/config/latest/auth.go", 0.78)[0], - }, - }, - { - query: "anthropic auth config validation", - results: []database.SearchResult{ - result("pkg/config/latest/auth.go", 0.83)[0], - result("pkg/model/provider/anthropic/federation/federation.go", 0.79)[0], - }, - }, - { - query: "docker model provider auth config validation", - results: []database.SearchResult{ - result("pkg/model/provider/anthropic/federation/federation.go", 0.86)[0], - result("pkg/config/latest/auth.go", 0.78)[0], - }, - }, - { - query: "tui message rendering scroll behavior", - results: []database.SearchResult{ - result("pkg/tui/components/messages/messages.go", 0.88)[0], - result("pkg/tui/components/scrollview/scrollview.go", 0.74)[0], - }, - }, - } -} diff --git a/pkg/rag/prefetch/proofs/TopologyHit.lean b/pkg/rag/prefetch/proofs/TopologyHit.lean deleted file mode 100644 index b0c5fd528..000000000 --- a/pkg/rag/prefetch/proofs/TopologyHit.lean +++ /dev/null @@ -1,32 +0,0 @@ -namespace RAGPrefetch - -def ExactHit {Query : Type} (cache : Query -> Prop) (query : Query) : Prop := - cache query - -def TopologyHit {Query : Type} (cache : Query -> Prop) (related : Query -> Query -> Prop) (query : Query) : Prop := - exists cached, cache cached /\ related query cached - -theorem exact_hit_is_topology_hit_when_related_self - {Query : Type} - {cache : Query -> Prop} - {related : Query -> Query -> Prop} - {query : Query} - (hit : ExactHit cache query) - (selfRelated : related query query) : - TopologyHit cache related query := by - exists query - -theorem topology_hit_can_strictly_extend_exact_hit - {Query : Type} - {cache : Query -> Prop} - {related : Query -> Query -> Prop} - {query cached : Query} - (cachedHit : cache cached) - (topologyRelated : related query cached) - (exactMiss : Not (cache query)) : - TopologyHit cache related query /\ Not (ExactHit cache query) := by - constructor - · exists cached - · exact exactMiss - -end RAGPrefetch diff --git a/pkg/rag/strategy/bm25.go b/pkg/rag/strategy/bm25.go index 4524b0cb7..f346f760f 100644 --- a/pkg/rag/strategy/bm25.go +++ b/pkg/rag/strategy/bm25.go @@ -672,6 +672,89 @@ func (s *BM25Strategy) addPathToWatcher(ctx context.Context, path string) error return nil } +func (s *BM25Strategy) reindexChangedFiles(ctx context.Context, docPaths, changedFiles []string) int { + filesToReindex := make([]string, 0, len(changedFiles)) + for _, file := range changedFiles { + select { + case <-ctx.Done(): + return 0 + default: + } + + matches, matchErr := fsx.Matches(file, docPaths) + if matchErr != nil { + slog.ErrorContext(ctx, "Failed to match path", "file", file, "error", matchErr) + continue + } + if !matches { + continue + } + if s.shouldIgnore != nil && s.shouldIgnore(file) { + slog.DebugContext(ctx, "File changed but is ignored by filter, skipping", "path", file) + continue + } + + needsIndexing, err := s.needsIndexing(ctx, file) + if err != nil || !needsIndexing { + continue + } + filesToReindex = append(filesToReindex, file) + } + + if len(filesToReindex) == 0 { + return 0 + } + + s.emitEvent(types.Event{ + Type: types.EventTypeIndexingStarted, + Message: fmt.Sprintf("Re-indexing %d changed file(s)", len(filesToReindex)), + }) + + indexed := 0 + for _, file := range filesToReindex { + select { + case <-ctx.Done(): + return indexed + default: + } + + slog.DebugContext(ctx, "Indexing file", "path", file, "strategy", s.name) + if err := s.indexFile(ctx, file); err != nil { + slog.ErrorContext(ctx, "Failed to re-index file", "path", file, "error", err) + s.emitEvent(types.Event{ + Type: types.EventTypeError, + Message: "Failed to re-index: " + filepath.Base(file), + Error: err, + }) + continue + } + + indexed++ + s.emitEvent(types.Event{ + Type: types.EventTypeIndexingProgress, + Message: "Re-indexing: " + filepath.Base(file), + Progress: &types.Progress{ + Current: indexed, + Total: len(filesToReindex), + }, + }) + } + + if indexed == 0 { + return 0 + } + + if err := s.calculateAvgDocLength(ctx); err != nil { + slog.ErrorContext(ctx, "Failed to recalculate average document length", "error", err) + } + + s.emitEvent(types.Event{ + Type: types.EventTypeIndexingComplete, + Message: fmt.Sprintf("Re-indexed %d file(s)", indexed), + }) + return indexed +} + func (s *BM25Strategy) watchLoop(ctx context.Context, docPaths []string) { // Capture watcher reference at goroutine start to avoid racing with Close() // which sets s.watcher = nil under watcherMu. @@ -700,41 +783,7 @@ func (s *BM25Strategy) watchLoop(ctx context.Context, docPaths []string) { return } - for _, file := range changedFiles { - // Check for context cancellation - select { - case <-ctx.Done(): - return // Stop processing if context is cancelled - default: - } - - // Check if the file matches any of the configured document paths/patterns - matches, matchErr := fsx.Matches(file, docPaths) - if matchErr != nil { - slog.ErrorContext(ctx, "Failed to match path", "file", file, "error", matchErr) - continue - } - if !matches { - continue - } - // Check if the file should be ignored (e.g., gitignore) - if s.shouldIgnore != nil && s.shouldIgnore(file) { - slog.DebugContext(ctx, "File changed but is ignored by filter, skipping", "path", file) - continue - } - - needsIndexing, err := s.needsIndexing(ctx, file) - if err != nil || !needsIndexing { - continue - } - - slog.DebugContext(ctx, "Indexing file", "path", file, "strategy", s.name) - if err := s.indexFile(ctx, file); err != nil { - slog.ErrorContext(ctx, "Failed to re-index file", "path", file, "error", err) - } - } - - _ = s.calculateAvgDocLength(ctx) + s.reindexChangedFiles(ctx, docPaths, changedFiles) } for { diff --git a/pkg/rag/strategy/bm25_test.go b/pkg/rag/strategy/bm25_test.go new file mode 100644 index 000000000..997107f19 --- /dev/null +++ b/pkg/rag/strategy/bm25_test.go @@ -0,0 +1,61 @@ +package strategy + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/rag/types" +) + +func TestBM25LiveReindexEmitsIndexingComplete(t *testing.T) { + events := make(chan types.Event, 16) + dir := t.TempDir() + docPath := filepath.Join(dir, "doc.txt") + require.NoError(t, os.WriteFile(docPath, []byte("initial blork content"), 0o644)) + + db, err := newBM25DB(filepath.Join(dir, "bm25.db"), "bm25") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) + + strategy := newBM25Strategy("bm25", db, events, 1.5, 0.75, ChunkingConfig{Size: 1024}, nil) + require.NoError(t, strategy.Initialize(t.Context(), []string{docPath}, ChunkingConfig{Size: 1024})) + drainEvents(events) + + require.NoError(t, os.WriteFile(docPath, []byte("updated blork content"), 0o644)) + + indexed := strategy.reindexChangedFiles(t.Context(), []string{docPath}, []string{docPath}) + + assert.Equal(t, 1, indexed) + assertEventType(t, events, types.EventTypeIndexingComplete) +} + +func drainEvents(events <-chan types.Event) { + for { + select { + case <-events: + default: + return + } + } +} + +func assertEventType(t *testing.T, events <-chan types.Event, want types.EventTye) { + t.Helper() + for range cap(events) { + select { + case event := <-events: + if event.Type == want { + return + } + default: + t.Fatalf("event %q was not emitted", want) + } + } + t.Fatalf("event %q was not emitted", want) +} diff --git a/pkg/rag/strategy/vector_store.go b/pkg/rag/strategy/vector_store.go index 344c6a0cd..e800c0533 100644 --- a/pkg/rag/strategy/vector_store.go +++ b/pkg/rag/strategy/vector_store.go @@ -912,7 +912,7 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) { if len(filesToReindex) > 0 { s.emitEvent(types.Event{ - Type: "indexing_started", + Type: types.EventTypeIndexingStarted, Message: fmt.Sprintf("Re-indexing %d changed file(s)", len(filesToReindex)), }) @@ -926,7 +926,7 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) { } s.emitEvent(types.Event{ - Type: "indexing_progress", + Type: types.EventTypeIndexingProgress, Message: "Re-indexing: " + filepath.Base(file), Progress: &types.Progress{ Current: i + 1, @@ -937,7 +937,7 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) { if err := s.indexFile(ctx, file); err != nil { slog.ErrorContext(ctx, "Failed to re-index file", "path", file, "error", err) s.emitEvent(types.Event{ - Type: "error", + Type: types.EventTypeError, Message: "Failed to re-index: " + filepath.Base(file), Error: err, }) @@ -953,7 +953,7 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) { } s.emitEvent(types.Event{ - Type: "indexing_completed", + Type: types.EventTypeIndexingComplete, Message: fmt.Sprintf("Re-indexed %d file(s)", len(filesToReindex)), }) } diff --git a/pkg/rag/topology/prior.go b/pkg/rag/topology/prior.go new file mode 100644 index 000000000..b9e34f526 --- /dev/null +++ b/pkg/rag/topology/prior.go @@ -0,0 +1,161 @@ +package topology + +import ( + "cmp" + "math" + "slices" + "strings" + "sync" + + "github.com/docker/docker-agent/pkg/rag/database" +) + +const ( + defaultWeight = 0.05 + defaultMaxSourceHistory = 32 + maxWeight = 0.2 +) + +// Config controls the topology prior. The zero value disables it. +type Config struct { + Enabled bool + Weight float64 + MaxSourceHistory int +} + +// Prior applies a small source-topology score to already-retrieved results. +type Prior struct { + cfg Config + + mu sync.Mutex + sources []sourcePoint +} + +type sourcePoint struct { + path string + tokens map[string]struct{} +} + +// NewPrior creates a disabled-by-default topology prior. +func NewPrior(cfg Config) *Prior { + if !cfg.Enabled { + return nil + } + if cfg.Weight <= 0 { + cfg.Weight = defaultWeight + } + cfg.Weight = math.Min(cfg.Weight, maxWeight) + if cfg.MaxSourceHistory <= 0 { + cfg.MaxSourceHistory = defaultMaxSourceHistory + } + return &Prior{cfg: cfg} +} + +// Apply blends a capped topology score into the current query's retrieved results. +func (p *Prior) Apply(query string, results []database.SearchResult) []database.SearchResult { + if p == nil || len(results) == 0 { + return results + } + + p.mu.Lock() + history := slices.Clone(p.sources) + p.mu.Unlock() + + queryTokens := tokenSet(query) + scored := slices.Clone(results) + for i := range scored { + sourceTokens := sourceTokenSet(scored[i].Document.SourcePath) + score := 0.7*jaccard(queryTokens, sourceTokens) + 0.3*historyScore(sourceTokens, history) + scored[i].Similarity += p.cfg.Weight * score + } + slices.SortStableFunc(scored, func(a, b database.SearchResult) int { + return cmp.Compare(b.Similarity, a.Similarity) + }) + return scored +} + +// Observe records source topology from completed foreground retrievals. +func (p *Prior) Observe(_ string, results []database.SearchResult) { + if p == nil || len(results) == 0 { + return + } + + p.mu.Lock() + defer p.mu.Unlock() + for _, result := range results { + path := result.Document.SourcePath + if path == "" || containsSource(p.sources, path) { + continue + } + p.sources = append(p.sources, sourcePoint{ + path: path, + tokens: sourceTokenSet(path), + }) + } + for len(p.sources) > p.cfg.MaxSourceHistory { + p.sources = p.sources[1:] + } +} + +// Clear drops topology history after index changes. +func (p *Prior) Clear() { + if p == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + p.sources = nil +} + +func containsSource(sources []sourcePoint, path string) bool { + for _, source := range sources { + if source.path == path { + return true + } + } + return false +} + +func historyScore(tokens map[string]struct{}, history []sourcePoint) float64 { + var best float64 + for _, source := range history { + best = math.Max(best, jaccard(tokens, source.tokens)) + } + return best +} + +func tokenSet(text string) map[string]struct{} { + tokens := map[string]struct{}{} + for _, token := range strings.FieldsFunc(strings.ToLower(text), isTokenSeparator) { + if len(token) < 2 { + continue + } + tokens[token] = struct{}{} + } + return tokens +} + +func sourceTokenSet(path string) map[string]struct{} { + return tokenSet(path) +} + +func isTokenSeparator(r rune) bool { + return r == '/' || r == '\\' || r == '.' || r == '-' || r == '_' || r == ' ' +} + +func jaccard(a, b map[string]struct{}) float64 { + if len(a) == 0 || len(b) == 0 { + return 0 + } + var intersection int + for token := range a { + if _, ok := b[token]; ok { + intersection++ + } + } + union := len(a) + len(b) - intersection + if union == 0 { + return 0 + } + return float64(intersection) / float64(union) +} diff --git a/pkg/rag/topology/prior_test.go b/pkg/rag/topology/prior_test.go new file mode 100644 index 000000000..78cd81821 --- /dev/null +++ b/pkg/rag/topology/prior_test.go @@ -0,0 +1,67 @@ +package topology + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/rag/database" +) + +func TestPriorReranksCurrentResultsWithSmallTopologyBoost(t *testing.T) { + prior := NewPrior(Config{Enabled: true, Weight: 0.05, MaxSourceHistory: 8}) + prior.Observe("how does rag manager query work", []database.SearchResult{ + result("pkg/rag/manager.go", 0.91), + }) + results := []database.SearchResult{ + result("pkg/model/provider/client.go", 0.72), + result("pkg/rag/manager.go", 0.70), + } + + got := prior.Apply("rag manager cache behavior", results) + + require.Len(t, got, 2) + assert.Equal(t, "pkg/rag/manager.go", got[0].Document.SourcePath) + assert.Greater(t, got[0].Similarity, 0.70) + assert.LessOrEqual(t, got[0].Similarity, 0.75) + assert.Equal(t, "pkg/model/provider/client.go", got[1].Document.SourcePath) +} + +func TestDisabledPriorReturnsResultsUnchanged(t *testing.T) { + prior := NewPrior(Config{}) + results := []database.SearchResult{ + result("pkg/rag/manager.go", 0.70), + result("pkg/model/provider/client.go", 0.72), + } + + got := prior.Apply("rag manager cache behavior", results) + + assert.Equal(t, results, got) +} + +func TestClearDropsSourceHistory(t *testing.T) { + prior := NewPrior(Config{Enabled: true, Weight: 0.05}) + prior.Observe("how does rag manager query work", []database.SearchResult{ + result("pkg/rag/manager.go", 0.91), + }) + + prior.Clear() + got := prior.Apply("rag manager cache behavior", []database.SearchResult{ + result("pkg/model/provider/client.go", 0.72), + result("pkg/rag/manager.go", 0.70), + }) + + require.Len(t, got, 2) + assert.Equal(t, "pkg/model/provider/client.go", got[0].Document.SourcePath) +} + +func result(path string, similarity float64) database.SearchResult { + return database.SearchResult{ + Document: database.Document{ + SourcePath: path, + Content: "content", + }, + Similarity: similarity, + } +}