6fec627503
- Reject duplicate uploads at the API boundary (HTTP 409) instead of silently skipping in the background worker. Checks both ingested documents and in-flight jobs via content_hash on the jobs table. - Go client handles 409 with distinct messages for already-imported documents vs already-queued jobs. - Sanitize FTS5 search queries by quoting each token to prevent syntax errors from special characters like ?, *, ", (), AND, OR, NOT. - Add try/except safety net around FTS5 execute for edge cases. - Add main branch guard to release.sh to prevent releasing from feature branches. - Update specs and README to reflect new behaviour. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
317 lines
9.2 KiB
Python
317 lines
9.2 KiB
Python
"""Hybrid search — FTS5 + sqlite-vec with Reciprocal Rank Fusion."""
|
|
|
|
import json
|
|
import logging
|
|
import struct
|
|
import sqlite3
|
|
|
|
logger = logging.getLogger("kb.search")
|
|
|
|
|
|
def hybrid_search(
|
|
conn: sqlite3.Connection,
|
|
query: str,
|
|
cfg,
|
|
top: int = 10,
|
|
tags: list[str] | None = None,
|
|
doc_type: str | None = None,
|
|
fts_only: bool = False,
|
|
vec_only: bool = False,
|
|
threshold: float | None = None,
|
|
) -> dict:
|
|
"""Run hybrid search and return merged, enriched results.
|
|
|
|
Args:
|
|
conn: SQLite connection (with row_factory = sqlite3.Row).
|
|
query: User search query string.
|
|
cfg: Config object with ``model`` and ``device`` attributes.
|
|
top: Maximum number of results to return.
|
|
tags: Optional tag filter — documents must have *all* listed tags.
|
|
doc_type: Optional document-type filter.
|
|
fts_only: Only use FTS5 (skip vector search).
|
|
vec_only: Only use vector search (skip FTS5).
|
|
threshold: Optional minimum score; results below are dropped.
|
|
|
|
Returns:
|
|
Dict with keys: query, results, total_matches, returned.
|
|
"""
|
|
candidate_count = top * 3
|
|
|
|
fts_results: dict[int, float] = {}
|
|
vec_results: dict[int, float] = {}
|
|
|
|
if not vec_only:
|
|
fts_results = _fts_search(conn, query, candidate_count, tags, doc_type)
|
|
|
|
if not fts_only:
|
|
vec_results = _vector_search(conn, query, candidate_count, tags, doc_type)
|
|
|
|
# --- merge ---------------------------------------------------------------
|
|
if fts_only:
|
|
merged = sorted(fts_results.items(), key=lambda x: x[1], reverse=True)
|
|
elif vec_only:
|
|
merged = sorted(vec_results.items(), key=lambda x: x[1], reverse=True)
|
|
else:
|
|
merged = _rrf_merge(fts_results, vec_results)
|
|
|
|
# Apply threshold filter — use config default if not specified per-query
|
|
effective_threshold = threshold if threshold is not None else cfg.search_threshold
|
|
if effective_threshold > 0:
|
|
merged = [(cid, score) for cid, score in merged if score >= effective_threshold]
|
|
|
|
total_matches = len(merged)
|
|
merged = merged[:top]
|
|
|
|
# --- enrich --------------------------------------------------------------
|
|
results = _enrich(conn, merged)
|
|
|
|
return {
|
|
"query": query,
|
|
"results": results,
|
|
"total_matches": total_matches,
|
|
"returned": len(results),
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _sanitize_fts_query(query: str) -> str:
|
|
"""Escape a raw user query for safe use with FTS5 MATCH.
|
|
|
|
Splits on whitespace, strips double quotes from each token, wraps each
|
|
token in double quotes (making FTS5 treat all content as literals), and
|
|
joins with spaces. Returns empty string if no valid tokens remain.
|
|
"""
|
|
tokens = []
|
|
for token in query.split():
|
|
token = token.replace('"', '')
|
|
if token:
|
|
tokens.append(f'"{token}"')
|
|
return " ".join(tokens)
|
|
|
|
|
|
def _fts_search(
|
|
conn: sqlite3.Connection,
|
|
query: str,
|
|
limit: int,
|
|
tags: list[str] | None,
|
|
doc_type: str | None,
|
|
) -> dict[int, float]:
|
|
"""FTS5 search on ``chunks_fts``.
|
|
|
|
Returns:
|
|
{chunk_id: bm25_score} where scores are positive (higher = better).
|
|
"""
|
|
safe_query = _sanitize_fts_query(query)
|
|
if not safe_query:
|
|
return {}
|
|
|
|
sql = "SELECT f.rowid AS chunk_id, bm25(chunks_fts) AS rank FROM chunks_fts f"
|
|
joins: list[str] = []
|
|
where: list[str] = ["chunks_fts MATCH ?"]
|
|
params: list = [safe_query]
|
|
|
|
if tags or doc_type:
|
|
joins.append("JOIN chunks c ON f.rowid = c.id")
|
|
joins.append("JOIN documents d ON c.document_id = d.id")
|
|
|
|
if doc_type:
|
|
where.append("d.doc_type = ?")
|
|
params.append(doc_type)
|
|
|
|
if tags:
|
|
for i, tag in enumerate(tags):
|
|
joins.append(f"JOIN document_tags dt{i} ON d.id = dt{i}.document_id")
|
|
joins.append(f"JOIN tags t{i} ON dt{i}.tag_id = t{i}.id")
|
|
where.append(f"t{i}.name = ?")
|
|
params.append(tag.strip().lower())
|
|
|
|
sql += " " + " ".join(joins)
|
|
sql += " WHERE " + " AND ".join(where)
|
|
sql += " ORDER BY rank LIMIT ?"
|
|
params.append(limit)
|
|
|
|
try:
|
|
rows = conn.execute(sql, params).fetchall()
|
|
except sqlite3.OperationalError:
|
|
logger.warning("FTS5 query failed for input: %r", query)
|
|
return {}
|
|
|
|
# BM25 returns negative values (lower = better match); negate so
|
|
# higher = better.
|
|
return {row[0]: -row[1] for row in rows}
|
|
|
|
|
|
def _vector_search(
|
|
conn: sqlite3.Connection,
|
|
query: str,
|
|
limit: int,
|
|
tags: list[str] | None,
|
|
doc_type: str | None,
|
|
) -> dict[int, float]:
|
|
"""Embed *query* and search ``chunks_vec`` via sqlite-vec.
|
|
|
|
Returns:
|
|
{chunk_id: similarity} where similarity = 1 / (1 + distance).
|
|
"""
|
|
from kb.embeddings import embed_texts
|
|
|
|
query_embedding = embed_texts([query])[0]
|
|
blob = struct.pack(f"{len(query_embedding)}f", *query_embedding)
|
|
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT chunk_id, distance
|
|
FROM chunks_vec
|
|
WHERE embedding MATCH ?
|
|
ORDER BY distance
|
|
LIMIT ?
|
|
""",
|
|
(blob, limit),
|
|
).fetchall()
|
|
|
|
results: dict[int, float] = {}
|
|
for row in rows:
|
|
chunk_id = row[0]
|
|
distance = row[1]
|
|
similarity = 1.0 / (1.0 + distance)
|
|
|
|
# Post-hoc tag / doc_type filtering for vector results
|
|
if tags or doc_type:
|
|
if not _passes_filters(conn, chunk_id, tags, doc_type):
|
|
continue
|
|
|
|
results[chunk_id] = similarity
|
|
|
|
return results
|
|
|
|
|
|
def _passes_filters(
|
|
conn: sqlite3.Connection,
|
|
chunk_id: int,
|
|
tags: list[str] | None,
|
|
doc_type: str | None,
|
|
) -> bool:
|
|
"""Return True if chunk passes tag and doc_type filters."""
|
|
sql = """
|
|
SELECT d.id FROM chunks c
|
|
JOIN documents d ON c.document_id = d.id
|
|
WHERE c.id = ?
|
|
"""
|
|
params: list = [chunk_id]
|
|
|
|
if doc_type:
|
|
sql += " AND d.doc_type = ?"
|
|
params.append(doc_type)
|
|
|
|
doc_row = conn.execute(sql, params).fetchone()
|
|
if not doc_row:
|
|
return False
|
|
|
|
if tags:
|
|
doc_id = doc_row[0]
|
|
placeholders = ",".join("?" * len(tags))
|
|
normalised = [t.strip().lower() for t in tags]
|
|
count = conn.execute(
|
|
f"""
|
|
SELECT COUNT(DISTINCT t.name) FROM document_tags dt
|
|
JOIN tags t ON dt.tag_id = t.id
|
|
WHERE dt.document_id = ? AND t.name IN ({placeholders})
|
|
""",
|
|
[doc_id, *normalised],
|
|
).fetchone()[0]
|
|
if count < len(tags):
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def _rrf_merge(
|
|
fts_results: dict[int, float],
|
|
vec_results: dict[int, float],
|
|
k: int = 60,
|
|
) -> list[tuple[int, float]]:
|
|
"""Reciprocal Rank Fusion over two scored result sets.
|
|
|
|
Each set is ranked independently (highest score first, rank starts at 1).
|
|
RRF score for a document = sum of 1/(k + rank) across sets it appears in.
|
|
|
|
Returns:
|
|
Sorted list of (chunk_id, rrf_score), highest first.
|
|
"""
|
|
fts_ranked = _rank_by_score(fts_results)
|
|
vec_ranked = _rank_by_score(vec_results)
|
|
|
|
all_ids = set(fts_ranked) | set(vec_ranked)
|
|
scores: list[tuple[int, float]] = []
|
|
|
|
for chunk_id in all_ids:
|
|
rrf = 0.0
|
|
if chunk_id in fts_ranked:
|
|
rrf += 1.0 / (k + fts_ranked[chunk_id])
|
|
if chunk_id in vec_ranked:
|
|
rrf += 1.0 / (k + vec_ranked[chunk_id])
|
|
scores.append((chunk_id, rrf))
|
|
|
|
scores.sort(key=lambda x: x[1], reverse=True)
|
|
return scores
|
|
|
|
|
|
def _rank_by_score(results: dict[int, float]) -> dict[int, int]:
|
|
"""Return {id: 1-based rank} sorted by score descending."""
|
|
ordered = sorted(results, key=results.get, reverse=True)
|
|
return {cid: rank for rank, cid in enumerate(ordered, start=1)}
|
|
|
|
|
|
def _enrich(
|
|
conn: sqlite3.Connection,
|
|
merged: list[tuple[int, float]],
|
|
) -> list[dict]:
|
|
"""Fetch chunk text, document metadata, chunk metadata, and tags."""
|
|
results: list[dict] = []
|
|
|
|
for chunk_id, score in merged:
|
|
row = conn.execute(
|
|
"""
|
|
SELECT c.id, c.text, c.chunk_index, c.metadata AS chunk_meta,
|
|
d.id AS doc_id, d.title, d.doc_type, d.source_path,
|
|
d.created_at
|
|
FROM chunks c
|
|
JOIN documents d ON c.document_id = d.id
|
|
WHERE c.id = ?
|
|
""",
|
|
(chunk_id,),
|
|
).fetchone()
|
|
|
|
if row is None:
|
|
continue
|
|
|
|
chunk_meta = json.loads(row[3]) if row[3] else {}
|
|
|
|
tag_rows = conn.execute(
|
|
"""
|
|
SELECT t.name FROM tags t
|
|
JOIN document_tags dt ON t.id = dt.tag_id
|
|
WHERE dt.document_id = ?
|
|
ORDER BY t.name
|
|
""",
|
|
(row[4],), # doc_id
|
|
).fetchall()
|
|
|
|
results.append({
|
|
"chunk_id": row[0],
|
|
"score": round(score, 6),
|
|
"text": row[1],
|
|
"chunk_index": row[2],
|
|
"chunk_metadata": chunk_meta,
|
|
"title": row[5],
|
|
"doc_type": row[6],
|
|
"source_path": row[7],
|
|
"created_at": row[8],
|
|
"tags": [t[0] for t in tag_rows],
|
|
})
|
|
|
|
return results
|