Source code for paper_firehose.processors.st_ranker
"""
Sentence-Transformers based ranking processor.
Minimal implementation: computes cosine similarity between a topic query
and entry texts, and returns scores suitable for writing into papers.db
(`rank_score`).
This module is intentionally lean and resilient: if sentence-transformers
is not available or the model cannot be loaded, it logs and returns an
empty result so callers can decide how to proceed.
"""
from __future__ import annotations
# Set before any heavy imports to silence HF tokenizers warning.
import os as _os
_os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import logging
from typing import Iterable, List, Tuple, Optional
logger = logging.getLogger(__name__)
[docs]
class STRanker:
def __init__(self, model_name: str = "all-MiniLM-L6-v2") -> None:
"""Lazy-load a SentenceTransformer model, logging a warning on failure."""
self.model_name = model_name
self._model = None
self._util = None
try:
from sentence_transformers import SentenceTransformer, util # type: ignore
self._model = SentenceTransformer(model_name)
self._util = util
except Exception as e: # pragma: no cover - optional dependency
logger.warning(
"sentence-transformers unavailable or model load failed (%s). Ranking will be skipped.",
e,
)
[docs]
def available(self) -> bool:
"""Return True when the embedding model loaded successfully."""
return self._model is not None and self._util is not None
[docs]
def score_entries(
self,
query: str,
entries: Iterable[Tuple[str, str, str]],
*,
use_summary: bool = False,
) -> List[Tuple[str, str, float]]:
"""Compute similarity scores for entries.
Args:
query: Natural-language ranking query
entries: Iterable of (entry_id, topic, text) where text is typically the title
use_summary: If True, the provided text should include summary; default False
Returns:
List of (entry_id, topic, score) tuples
"""
if not self.available(): # graceful no-op
return []
model = self._model
util = self._util
assert model is not None and util is not None
# Prepare batch
ids: List[str] = []
topics: List[str] = []
docs: List[str] = []
for eid, topic, text in entries:
ids.append(eid)
topics.append(topic)
# Be conservative: strip/normalize; title is usually enough
docs.append((text or "").strip())
if not docs:
return []
q_emb = model.encode([query.strip()], normalize_embeddings=True)
d_emb = model.encode(docs, normalize_embeddings=True)
sims = util.cos_sim(q_emb, d_emb).tolist()[0]
return list(zip(ids, topics, sims))