Files
kb/engine/kb/search.py
T
steve 6fec627503 Upload-time duplicate detection, FTS5 query sanitization, release guard
- Reject duplicate uploads at the API boundary (HTTP 409) instead of
  silently skipping in the background worker. Checks both ingested
  documents and in-flight jobs via content_hash on the jobs table.
- Go client handles 409 with distinct messages for already-imported
  documents vs already-queued jobs.
- Sanitize FTS5 search queries by quoting each token to prevent syntax
  errors from special characters like ?, *, ", (), AND, OR, NOT.
- Add try/except safety net around FTS5 execute for edge cases.
- Add main branch guard to release.sh to prevent releasing from
  feature branches.
- Update specs and README to reflect new behaviour.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 23:05:07 +00:00

317 lines
9.2 KiB
Python

"""Hybrid search — FTS5 + sqlite-vec with Reciprocal Rank Fusion."""
import json
import logging
import struct
import sqlite3
logger = logging.getLogger("kb.search")
def hybrid_search(
conn: sqlite3.Connection,
query: str,
cfg,
top: int = 10,
tags: list[str] | None = None,
doc_type: str | None = None,
fts_only: bool = False,
vec_only: bool = False,
threshold: float | None = None,
) -> dict:
"""Run hybrid search and return merged, enriched results.
Args:
conn: SQLite connection (with row_factory = sqlite3.Row).
query: User search query string.
cfg: Config object with ``model`` and ``device`` attributes.
top: Maximum number of results to return.
tags: Optional tag filter — documents must have *all* listed tags.
doc_type: Optional document-type filter.
fts_only: Only use FTS5 (skip vector search).
vec_only: Only use vector search (skip FTS5).
threshold: Optional minimum score; results below are dropped.
Returns:
Dict with keys: query, results, total_matches, returned.
"""
candidate_count = top * 3
fts_results: dict[int, float] = {}
vec_results: dict[int, float] = {}
if not vec_only:
fts_results = _fts_search(conn, query, candidate_count, tags, doc_type)
if not fts_only:
vec_results = _vector_search(conn, query, candidate_count, tags, doc_type)
# --- merge ---------------------------------------------------------------
if fts_only:
merged = sorted(fts_results.items(), key=lambda x: x[1], reverse=True)
elif vec_only:
merged = sorted(vec_results.items(), key=lambda x: x[1], reverse=True)
else:
merged = _rrf_merge(fts_results, vec_results)
# Apply threshold filter — use config default if not specified per-query
effective_threshold = threshold if threshold is not None else cfg.search_threshold
if effective_threshold > 0:
merged = [(cid, score) for cid, score in merged if score >= effective_threshold]
total_matches = len(merged)
merged = merged[:top]
# --- enrich --------------------------------------------------------------
results = _enrich(conn, merged)
return {
"query": query,
"results": results,
"total_matches": total_matches,
"returned": len(results),
}
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _sanitize_fts_query(query: str) -> str:
"""Escape a raw user query for safe use with FTS5 MATCH.
Splits on whitespace, strips double quotes from each token, wraps each
token in double quotes (making FTS5 treat all content as literals), and
joins with spaces. Returns empty string if no valid tokens remain.
"""
tokens = []
for token in query.split():
token = token.replace('"', '')
if token:
tokens.append(f'"{token}"')
return " ".join(tokens)
def _fts_search(
conn: sqlite3.Connection,
query: str,
limit: int,
tags: list[str] | None,
doc_type: str | None,
) -> dict[int, float]:
"""FTS5 search on ``chunks_fts``.
Returns:
{chunk_id: bm25_score} where scores are positive (higher = better).
"""
safe_query = _sanitize_fts_query(query)
if not safe_query:
return {}
sql = "SELECT f.rowid AS chunk_id, bm25(chunks_fts) AS rank FROM chunks_fts f"
joins: list[str] = []
where: list[str] = ["chunks_fts MATCH ?"]
params: list = [safe_query]
if tags or doc_type:
joins.append("JOIN chunks c ON f.rowid = c.id")
joins.append("JOIN documents d ON c.document_id = d.id")
if doc_type:
where.append("d.doc_type = ?")
params.append(doc_type)
if tags:
for i, tag in enumerate(tags):
joins.append(f"JOIN document_tags dt{i} ON d.id = dt{i}.document_id")
joins.append(f"JOIN tags t{i} ON dt{i}.tag_id = t{i}.id")
where.append(f"t{i}.name = ?")
params.append(tag.strip().lower())
sql += " " + " ".join(joins)
sql += " WHERE " + " AND ".join(where)
sql += " ORDER BY rank LIMIT ?"
params.append(limit)
try:
rows = conn.execute(sql, params).fetchall()
except sqlite3.OperationalError:
logger.warning("FTS5 query failed for input: %r", query)
return {}
# BM25 returns negative values (lower = better match); negate so
# higher = better.
return {row[0]: -row[1] for row in rows}
def _vector_search(
conn: sqlite3.Connection,
query: str,
limit: int,
tags: list[str] | None,
doc_type: str | None,
) -> dict[int, float]:
"""Embed *query* and search ``chunks_vec`` via sqlite-vec.
Returns:
{chunk_id: similarity} where similarity = 1 / (1 + distance).
"""
from kb.embeddings import embed_texts
query_embedding = embed_texts([query])[0]
blob = struct.pack(f"{len(query_embedding)}f", *query_embedding)
rows = conn.execute(
"""
SELECT chunk_id, distance
FROM chunks_vec
WHERE embedding MATCH ?
ORDER BY distance
LIMIT ?
""",
(blob, limit),
).fetchall()
results: dict[int, float] = {}
for row in rows:
chunk_id = row[0]
distance = row[1]
similarity = 1.0 / (1.0 + distance)
# Post-hoc tag / doc_type filtering for vector results
if tags or doc_type:
if not _passes_filters(conn, chunk_id, tags, doc_type):
continue
results[chunk_id] = similarity
return results
def _passes_filters(
conn: sqlite3.Connection,
chunk_id: int,
tags: list[str] | None,
doc_type: str | None,
) -> bool:
"""Return True if chunk passes tag and doc_type filters."""
sql = """
SELECT d.id FROM chunks c
JOIN documents d ON c.document_id = d.id
WHERE c.id = ?
"""
params: list = [chunk_id]
if doc_type:
sql += " AND d.doc_type = ?"
params.append(doc_type)
doc_row = conn.execute(sql, params).fetchone()
if not doc_row:
return False
if tags:
doc_id = doc_row[0]
placeholders = ",".join("?" * len(tags))
normalised = [t.strip().lower() for t in tags]
count = conn.execute(
f"""
SELECT COUNT(DISTINCT t.name) FROM document_tags dt
JOIN tags t ON dt.tag_id = t.id
WHERE dt.document_id = ? AND t.name IN ({placeholders})
""",
[doc_id, *normalised],
).fetchone()[0]
if count < len(tags):
return False
return True
def _rrf_merge(
fts_results: dict[int, float],
vec_results: dict[int, float],
k: int = 60,
) -> list[tuple[int, float]]:
"""Reciprocal Rank Fusion over two scored result sets.
Each set is ranked independently (highest score first, rank starts at 1).
RRF score for a document = sum of 1/(k + rank) across sets it appears in.
Returns:
Sorted list of (chunk_id, rrf_score), highest first.
"""
fts_ranked = _rank_by_score(fts_results)
vec_ranked = _rank_by_score(vec_results)
all_ids = set(fts_ranked) | set(vec_ranked)
scores: list[tuple[int, float]] = []
for chunk_id in all_ids:
rrf = 0.0
if chunk_id in fts_ranked:
rrf += 1.0 / (k + fts_ranked[chunk_id])
if chunk_id in vec_ranked:
rrf += 1.0 / (k + vec_ranked[chunk_id])
scores.append((chunk_id, rrf))
scores.sort(key=lambda x: x[1], reverse=True)
return scores
def _rank_by_score(results: dict[int, float]) -> dict[int, int]:
"""Return {id: 1-based rank} sorted by score descending."""
ordered = sorted(results, key=results.get, reverse=True)
return {cid: rank for rank, cid in enumerate(ordered, start=1)}
def _enrich(
conn: sqlite3.Connection,
merged: list[tuple[int, float]],
) -> list[dict]:
"""Fetch chunk text, document metadata, chunk metadata, and tags."""
results: list[dict] = []
for chunk_id, score in merged:
row = conn.execute(
"""
SELECT c.id, c.text, c.chunk_index, c.metadata AS chunk_meta,
d.id AS doc_id, d.title, d.doc_type, d.source_path,
d.created_at
FROM chunks c
JOIN documents d ON c.document_id = d.id
WHERE c.id = ?
""",
(chunk_id,),
).fetchone()
if row is None:
continue
chunk_meta = json.loads(row[3]) if row[3] else {}
tag_rows = conn.execute(
"""
SELECT t.name FROM tags t
JOIN document_tags dt ON t.id = dt.tag_id
WHERE dt.document_id = ?
ORDER BY t.name
""",
(row[4],), # doc_id
).fetchall()
results.append({
"chunk_id": row[0],
"score": round(score, 6),
"text": row[1],
"chunk_index": row[2],
"chunk_metadata": chunk_meta,
"title": row[5],
"doc_type": row[6],
"source_path": row[7],
"created_at": row[8],
"tags": [t[0] for t in tag_rows],
})
return results