Source code for paper_firehose.core.model_manager
"""
Model management utilities for Sentence-Transformers models.
Handles vendoring (downloading and caching) of transformer models locally
to avoid runtime network dependencies in production environments.
"""
from __future__ import annotations
import logging
import re
import shutil
from pathlib import Path
from .paths import get_system_path, resolve_data_dir
logger = logging.getLogger(__name__)
[docs]
def has_model_files(path: str) -> bool:
"""Heuristic check that a local Sentence-Transformers model folder is valid.
Args:
path: Path to model directory
Returns:
True if the directory appears to contain a valid model
"""
p = Path(path)
if not p.exists() or not p.is_dir():
return False
# Common files for ST models
candidates = [p / "config.json", p / "modules.json"]
return any(c.exists() for c in candidates)
[docs]
def ensure_local_model(model_spec: str) -> str:
"""Ensure a local model directory exists for the given spec and return the path or original spec.
Behavior:
- If spec is the default alias 'all-MiniLM-L6-v2':
Use 'models/all-MiniLM-L6-v2'. If missing or empty, download
'sentence-transformers/all-MiniLM-L6-v2' into that folder.
- If spec looks like a repo id (e.g., 'sentence-transformers/x' or 'intfloat/e5-small'):
Vendor to 'models/<last-segment>' when not present or empty.
- If spec is a local path and valid, return it. If it exists but appears empty,
try to infer repo id from the folder name and download into it.
- On any failure (e.g., no network), return the original spec and let STRanker handle it.
Args:
model_spec: Model specification (repo ID, alias, or local path)
Returns:
Local path to model directory, or original spec if vendoring failed
"""
# Try local path directly if it's already valid
if Path(model_spec).exists() and has_model_files(model_spec):
return model_spec
models_root = resolve_data_dir('models', ensure_exists=True)
system_models_root = get_system_path('models')
repo_id: str | None = None
target_dir: Path | None = None
# Case 1: default alias
if model_spec == "all-MiniLM-L6-v2":
repo_id = "sentence-transformers/all-MiniLM-L6-v2"
target_dir = models_root / "all-MiniLM-L6-v2"
# Case 2: looks like HF repo id "org/name"
elif "/" in model_spec and not Path(model_spec).exists():
repo_id = model_spec
last = model_spec.rsplit("/", 1)[-1]
# sanitize last segment for filesystem safety just in case
last = re.sub(r"[^A-Za-z0-9._\-]", "_", last)
target_dir = models_root / last
# Case 3: non-default spec that may be a local folder name or alias
else:
# If spec is a path but empty, try infer repo as sentence-transformers/<name>
p = Path(model_spec)
name = p.name if p.name else str(model_spec)
repo_id = f"sentence-transformers/{name}"
target_dir = p if p.is_absolute() else models_root / name
assert target_dir is not None and repo_id is not None
# If the target already looks valid, use it
if has_model_files(str(target_dir)):
return str(target_dir)
# If the system bundle ships the model, copy it into the runtime directory
if system_models_root.exists():
system_candidate = system_models_root / target_dir.name
try:
if system_candidate.exists() and system_candidate.resolve() != target_dir.resolve():
shutil.copytree(system_candidate, target_dir)
if has_model_files(str(target_dir)):
return str(target_dir)
except FileExistsError:
pass
except OSError as e:
logger.debug("Model seed copy failed for %s -> %s: %s", system_candidate, target_dir, e)
# Attempt download (best-effort)
try:
from huggingface_hub import snapshot_download # type: ignore
target_dir.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id=repo_id,
local_dir=str(target_dir),
local_dir_use_symlinks=False,
)
return str(target_dir)
except Exception as e: # pragma: no cover - network optional
logger.warning("Model vendor failed for '%s' -> %s: %s", repo_id, target_dir, e)
# Fall back to original spec; STRanker will try to resolve
return model_spec