Source code for paper_firehose.commands.query

"""Query paper databases for entries with flexible filtering and output."""

import json
import logging
from typing import Any, Dict, List, Optional

from ..core.command_context import CommandContext
from ..core.model_manager import ensure_local_model
from ..processors.st_ranker import STRanker

logger = logging.getLogger(__name__)

# Default columns shown in human-readable table output per database.
_DEFAULT_TABLE_FIELDS = {
    'current': ['rank_score', 'published_date', 'title', 'topic', 'authors'],
    'history': ['rank_score', 'published_date', 'matched_date', 'title', 'topics'],
    'all_feeds': ['published_date', 'first_seen', 'title', 'feed_name', 'authors'],
}

_SORT_MAP = {
    'rank': 'rank_score DESC',
    'date': 'published_date DESC',
    'title': 'title ASC',
}


def _resolve_sort(sort_arg: str, db_key: str) -> str:
    """Map a short sort name to an ORDER BY clause."""
    order = _SORT_MAP.get(sort_arg)
    if order is None:
        raise ValueError(f"Unknown sort key '{sort_arg}'. Choose from: {', '.join(_SORT_MAP)}")
    if sort_arg == 'rank' and db_key == 'all_feeds':
        return 'published_date DESC'
    return order


def _truncate(text: Optional[str], width: int) -> str:
    if not text:
        return ''
    text = ' '.join(text.split())  # collapse whitespace
    if len(text) <= width:
        return text
    return text[:width - 1] + '\u2026'


def _format_table(rows: List[Dict[str, Any]], total: int,
                  fields: List[str], offset: int, limit: int) -> str:
    """Format rows as a human-readable table."""
    if not rows:
        return 'No entries found.'

    lines: list[str] = []
    end = offset + len(rows)
    lines.append(f'Found {total} entries (showing {offset + 1}-{end})')
    lines.append('')

    # Column widths
    col_widths: Dict[str, int] = {}
    display_rows: list[dict] = []
    for row in rows:
        display: dict = {}
        for f in fields:
            val = row.get(f)
            if f == 'rank_score' and val is not None:
                display[f] = f'{val:.3f}'
            elif f in ('title', 'authors'):
                display[f] = _truncate(str(val) if val else '', 55)
            else:
                display[f] = str(val) if val is not None else ''
        display_rows.append(display)
        for f in fields:
            col_widths[f] = max(col_widths.get(f, len(f)), len(display[f]))

    # Header
    hdr = ' #  ' + '  '.join(f.ljust(col_widths[f]) for f in fields)
    lines.append(hdr)

    # Rows
    for i, display in enumerate(display_rows, start=offset + 1):
        num = str(i).rjust(2)
        cells = '  '.join(display[f].ljust(col_widths[f]) for f in fields)
        lines.append(f'{num}  {cells}')

    if total > end:
        lines.append('')
        lines.append(f'Showing {len(rows)} of {total}. Use --offset {end} for next page.')

    return '\n'.join(lines)


def _format_json(rows: List[Dict[str, Any]], total: int,
                 fields: Optional[List[str]],
                 offset: int, limit: int) -> str:
    """Format rows as a JSON object."""
    if fields:
        rows = [{k: v for k, v in r.items() if k in fields} for r in rows]
    obj = {
        'total': total,
        'offset': offset,
        'limit': limit,
        'entries': rows,
    }
    return json.dumps(obj, indent=2, default=str)


# Map db_key -> (id column, topic/group column, abstract/text column)
_DB_FIELD_MAP = {
    'current': ('id', 'topic', 'abstract'),
    'history': ('entry_id', 'topics', 'abstract'),
    'all_feeds': ('entry_id', 'feed_name', 'summary'),
}


def _build_rerank_text(row: Dict[str, Any], text_col: str) -> str:
    """Build the text to embed for reranking: title + abstract/summary."""
    title = (row.get('title') or '').strip()
    body = (row.get(text_col) or '').strip()
    if body:
        return f"{title} {body}"
    return title


[docs] def run( config_path: Optional[str], *, db_key: str = 'current', topic: Optional[str] = None, min_rank: Optional[float] = None, status: Optional[str] = None, has_doi: bool = False, has_abstract: bool = False, since: Optional[str] = None, until: Optional[str] = None, search: Optional[str] = None, fuzzy: Optional[str] = None, rerank: Optional[str] = None, sort: str = 'rank', limit: int = 20, offset: int = 0, output_json: bool = False, count_only: bool = False, fields: Optional[str] = None, ) -> None: """Execute a query against one of the paper databases.""" # Validate incompatible options if db_key == 'all_feeds': if min_rank is not None: raise ValueError("--min-rank is not available for --all-feeds (no rank_score column)") if status: raise ValueError("--status is not available for --all-feeds") if has_abstract: raise ValueError("--has-abstract is not available for --all-feeds (no abstract column)") if db_key == 'history' and status: raise ValueError("--status is not available for --history") ctx = CommandContext(config_path) order_by = _resolve_sort(sort, db_key) # When reranking or BM25-sorting search results, fetch all candidates # (no SQL-level pagination) so we can re-sort before applying limit/offset. needs_client_sort = bool(rerank) or bool(search) fetch_limit = 0 if needs_client_sort else limit fetch_offset = 0 if needs_client_sort else offset rows, total = ctx.db.query_entries( db_key=db_key, topic=topic, min_rank=min_rank, status=status, has_doi=has_doi or None, has_abstract=has_abstract or None, since=since, until=until, search=search, fuzzy=fuzzy, order_by=order_by, limit=fetch_limit, offset=fetch_offset, ) # BM25 relevance sort for keyword search (when not reranking) if search and not rerank and rows: rows.sort(key=lambda r: r.get('bm25_score') or 0.0) # FTS5 rank is negative; lower = more relevant total = len(rows) if limit: rows = rows[offset:offset + limit] elif offset: rows = rows[offset:] # Semantic reranking if rerank and rows: model_name = ensure_local_model("all-MiniLM-L6-v2") ranker = STRanker(model_name=model_name) if not ranker.available(): raise RuntimeError( "Sentence-transformer model unavailable. " "Install sentence-transformers or check model path." ) id_col, group_col, text_col = _DB_FIELD_MAP[db_key] batch = [ (row[id_col], row.get(group_col, ''), _build_rerank_text(row, text_col)) for row in rows ] scores = ranker.score_entries(rerank, batch) # Build score lookup: (id, group) -> score score_map = {(eid, grp): score for eid, grp, score in scores} for row in rows: key = (row[id_col], row.get(group_col, '')) row['rerank_score'] = round(score_map.get(key, 0.0), 4) rows.sort(key=lambda r: r['rerank_score'], reverse=True) # Apply limit/offset to reranked results total = len(rows) if limit: rows = rows[offset:offset + limit] elif offset: rows = rows[offset:] if count_only: print(total) return field_list: Optional[List[str]] = None if fields: field_list = [f.strip() for f in fields.split(',')] if output_json: print(_format_json(rows, total, field_list, offset, limit)) else: table_fields = field_list or _DEFAULT_TABLE_FIELDS.get(db_key, ['title']) if rerank and not fields: table_fields = ['rerank_score'] + table_fields print(_format_table(rows, total, table_fields, offset, limit))