"""Hybrid search — FTS5 + sqlite-vec with Reciprocal Rank Fusion.""" import json import logging import struct import sqlite3 logger = logging.getLogger("kb.search") def hybrid_search( conn: sqlite3.Connection, query: str, cfg, top: int = 10, tags: list[str] | None = None, doc_type: str | None = None, fts_only: bool = False, vec_only: bool = False, threshold: float | None = None, ) -> dict: """Run hybrid search and return merged, enriched results. Args: conn: SQLite connection (with row_factory = sqlite3.Row). query: User search query string. cfg: Config object with ``model`` and ``device`` attributes. top: Maximum number of results to return. tags: Optional tag filter — documents must have *all* listed tags. doc_type: Optional document-type filter. fts_only: Only use FTS5 (skip vector search). vec_only: Only use vector search (skip FTS5). threshold: Optional minimum score; results below are dropped. Returns: Dict with keys: query, results, total_matches, returned. """ candidate_count = top * 3 fts_results: dict[int, float] = {} vec_results: dict[int, float] = {} if not vec_only: fts_results = _fts_search(conn, query, candidate_count, tags, doc_type) if not fts_only: vec_results = _vector_search(conn, query, candidate_count, tags, doc_type) # --- merge --------------------------------------------------------------- if fts_only: merged = sorted(fts_results.items(), key=lambda x: x[1], reverse=True) elif vec_only: merged = sorted(vec_results.items(), key=lambda x: x[1], reverse=True) else: merged = _rrf_merge(fts_results, vec_results) # Apply threshold filter — use config default if not specified per-query effective_threshold = threshold if threshold is not None else cfg.search_threshold if effective_threshold > 0: merged = [(cid, score) for cid, score in merged if score >= effective_threshold] total_matches = len(merged) merged = merged[:top] # --- enrich -------------------------------------------------------------- results = _enrich(conn, merged) return { "query": query, "results": results, "total_matches": total_matches, "returned": len(results), } # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _sanitize_fts_query(query: str) -> str: """Escape a raw user query for safe use with FTS5 MATCH. Splits on whitespace, strips double quotes from each token, wraps each token in double quotes (making FTS5 treat all content as literals), and joins with spaces. Returns empty string if no valid tokens remain. """ tokens = [] for token in query.split(): token = token.replace('"', '') if token: tokens.append(f'"{token}"') return " ".join(tokens) def _fts_search( conn: sqlite3.Connection, query: str, limit: int, tags: list[str] | None, doc_type: str | None, ) -> dict[int, float]: """FTS5 search on ``chunks_fts``. Returns: {chunk_id: bm25_score} where scores are positive (higher = better). """ safe_query = _sanitize_fts_query(query) if not safe_query: return {} sql = "SELECT f.rowid AS chunk_id, bm25(chunks_fts) AS rank FROM chunks_fts f" joins: list[str] = [] where: list[str] = ["chunks_fts MATCH ?"] params: list = [safe_query] if tags or doc_type: joins.append("JOIN chunks c ON f.rowid = c.id") joins.append("JOIN documents d ON c.document_id = d.id") if doc_type: where.append("d.doc_type = ?") params.append(doc_type) if tags: for i, tag in enumerate(tags): joins.append(f"JOIN document_tags dt{i} ON d.id = dt{i}.document_id") joins.append(f"JOIN tags t{i} ON dt{i}.tag_id = t{i}.id") where.append(f"t{i}.name = ?") params.append(tag.strip().lower()) sql += " " + " ".join(joins) sql += " WHERE " + " AND ".join(where) sql += " ORDER BY rank LIMIT ?" params.append(limit) try: rows = conn.execute(sql, params).fetchall() except sqlite3.OperationalError: logger.warning("FTS5 query failed for input: %r", query) return {} # BM25 returns negative values (lower = better match); negate so # higher = better. return {row[0]: -row[1] for row in rows} def _vector_search( conn: sqlite3.Connection, query: str, limit: int, tags: list[str] | None, doc_type: str | None, ) -> dict[int, float]: """Embed *query* and search ``chunks_vec`` via sqlite-vec. Returns: {chunk_id: similarity} where similarity = 1 / (1 + distance). """ from kb.embeddings import embed_texts query_embedding = embed_texts([query])[0] blob = struct.pack(f"{len(query_embedding)}f", *query_embedding) rows = conn.execute( """ SELECT chunk_id, distance FROM chunks_vec WHERE embedding MATCH ? ORDER BY distance LIMIT ? """, (blob, limit), ).fetchall() results: dict[int, float] = {} for row in rows: chunk_id = row[0] distance = row[1] similarity = 1.0 / (1.0 + distance) # Post-hoc tag / doc_type filtering for vector results if tags or doc_type: if not _passes_filters(conn, chunk_id, tags, doc_type): continue results[chunk_id] = similarity return results def _passes_filters( conn: sqlite3.Connection, chunk_id: int, tags: list[str] | None, doc_type: str | None, ) -> bool: """Return True if chunk passes tag and doc_type filters.""" sql = """ SELECT d.id FROM chunks c JOIN documents d ON c.document_id = d.id WHERE c.id = ? """ params: list = [chunk_id] if doc_type: sql += " AND d.doc_type = ?" params.append(doc_type) doc_row = conn.execute(sql, params).fetchone() if not doc_row: return False if tags: doc_id = doc_row[0] placeholders = ",".join("?" * len(tags)) normalised = [t.strip().lower() for t in tags] count = conn.execute( f""" SELECT COUNT(DISTINCT t.name) FROM document_tags dt JOIN tags t ON dt.tag_id = t.id WHERE dt.document_id = ? AND t.name IN ({placeholders}) """, [doc_id, *normalised], ).fetchone()[0] if count < len(tags): return False return True def _rrf_merge( fts_results: dict[int, float], vec_results: dict[int, float], k: int = 60, ) -> list[tuple[int, float]]: """Reciprocal Rank Fusion over two scored result sets. Each set is ranked independently (highest score first, rank starts at 1). RRF score for a document = sum of 1/(k + rank) across sets it appears in. Returns: Sorted list of (chunk_id, rrf_score), highest first. """ fts_ranked = _rank_by_score(fts_results) vec_ranked = _rank_by_score(vec_results) all_ids = set(fts_ranked) | set(vec_ranked) scores: list[tuple[int, float]] = [] for chunk_id in all_ids: rrf = 0.0 if chunk_id in fts_ranked: rrf += 1.0 / (k + fts_ranked[chunk_id]) if chunk_id in vec_ranked: rrf += 1.0 / (k + vec_ranked[chunk_id]) scores.append((chunk_id, rrf)) scores.sort(key=lambda x: x[1], reverse=True) return scores def _rank_by_score(results: dict[int, float]) -> dict[int, int]: """Return {id: 1-based rank} sorted by score descending.""" ordered = sorted(results, key=results.get, reverse=True) return {cid: rank for rank, cid in enumerate(ordered, start=1)} def _enrich( conn: sqlite3.Connection, merged: list[tuple[int, float]], ) -> list[dict]: """Fetch chunk text, document metadata, chunk metadata, and tags.""" results: list[dict] = [] for chunk_id, score in merged: row = conn.execute( """ SELECT c.id, c.text, c.chunk_index, c.metadata AS chunk_meta, d.id AS doc_id, d.title, d.doc_type, d.source_path, d.created_at FROM chunks c JOIN documents d ON c.document_id = d.id WHERE c.id = ? """, (chunk_id,), ).fetchone() if row is None: continue chunk_meta = json.loads(row[3]) if row[3] else {} tag_rows = conn.execute( """ SELECT t.name FROM tags t JOIN document_tags dt ON t.id = dt.tag_id WHERE dt.document_id = ? ORDER BY t.name """, (row[4],), # doc_id ).fetchall() results.append({ "chunk_id": row[0], "score": round(score, 6), "text": row[1], "chunk_index": row[2], "chunk_metadata": chunk_meta, "title": row[5], "doc_type": row[6], "source_path": row[7], "created_at": row[8], "tags": [t[0] for t in tag_rows], }) return results