"""Shared DistilBERT inference classes (importable by train_distilbert + evaluate).

Defining these in a stable module path means joblib bundles produced by
train_distilbert can be unpickled by evaluate without the
`Can't get attribute 'DistilBertPipeline' on <module ...>` error.
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import numpy as np
import torch
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
)


@dataclass
class DistilBertPipeline:
    """No-op feature pipeline; classifier handles tokenization end-to-end."""

    model_dir: str
    label_order: list[str]

    def transform(self, rows):
        return rows


@dataclass
class DistilBertClassifier:
    """sklearn-style wrapper around a fine-tuned HF model on disk."""

    model_dir: str
    label_order: list[str]
    max_len: int = 512
    batch_size: int = 64
    device_id: Optional[int] = None   # None = auto-detect first CUDA device
    use_fp16: bool = False             # cast model to half precision for 2x throughput
    _model: Optional[object] = None
    _tok: Optional[object] = None
    _device: Optional[str] = None

    def _load(self):
        if self._model is None:
            self._tok = AutoTokenizer.from_pretrained(self.model_dir)
            if self.device_id is not None:
                self._device = f"cuda:{self.device_id}"
            else:
                self._device = "cuda" if torch.cuda.is_available() else "cpu"
            # Load on CPU first (low_cpu_mem_usage=False = eager loading, no meta tensors),
            # then optionally cast to fp16 and move to target device.
            self._model = AutoModelForSequenceClassification.from_pretrained(
                self.model_dir,
                low_cpu_mem_usage=False,
            )
            if self.use_fp16 and self._device != "cpu":
                self._model = self._model.half()
            self._model.to(self._device).eval()

    @torch.no_grad()
    def _logits(self, rows, verbose: bool = True):
        self._load()
        import time as _time
        all_logits = []
        n_batches = (len(rows) + self.batch_size - 1) // self.batch_size
        t_start = _time.time()
        for batch_idx, start in enumerate(range(0, len(rows), self.batch_size)):
            batch = rows[start:start + self.batch_size]
            enc = self._tok(
                [(r.get("preceding_context") or "").strip() or "<empty>" for r in batch],
                [(r.get("span_text") or "").strip() for r in batch],
                truncation=True, max_length=self.max_len,
                padding=True, return_tensors="pt",
            ).to(self._device)
            out = self._model(**enc).logits.float().cpu().numpy()
            all_logits.append(out)
            if verbose and n_batches >= 100 and (batch_idx + 1) % max(1, n_batches // 20) == 0:
                done = batch_idx + 1
                elapsed = _time.time() - t_start
                eta = elapsed / done * (n_batches - done)
                pct = 100 * done / n_batches
                print(f"    [{self._device}] {pct:.0f}% ({done}/{n_batches} batches, "
                      f"eta {eta:.0f}s)", flush=True)
        return (
            np.concatenate(all_logits, axis=0)
            if all_logits else np.zeros((0, len(self.label_order)))
        )

    def predict(self, rows):
        return self._logits(rows).argmax(axis=1)

    def predict_proba(self, rows):
        logits = self._logits(rows)
        return torch.softmax(torch.tensor(logits), dim=-1).numpy()
