# -*- coding: utf-8 -*-
from typing import Dict, Tuple, List
import numpy as np, torch
from sentence_transformers import CrossEncoder
from config import NLI_MODEL_NAME, NLI_BATCH_SIZE

_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
_CE = None
_ID2LABEL: Dict[int, str] = {}

def init_nli(model_name: str = NLI_MODEL_NAME, device: str = _DEVICE):
    global _CE, _ID2LABEL
    _CE = CrossEncoder(model_name, device=device)
    try:
        id2label = _CE.model.config.id2label
        _ID2LABEL = {int(i): id2label[i].lower() for i in id2label}
    except Exception:
        pairs = [
            ("A dog is an animal.", "A dog is an animal."),
            ("A dog is an animal.", "A dog is not an animal."),
            ("A dog is an animal.", "The sky is blue."),
        ]
        scores = np.array(_CE.predict(pairs))
        ent_idx = int(scores[0].argmax())
        con_idx = int(scores[1].argmax())
        neu_idx = int(scores[2].argmax())
        _ID2LABEL = {ent_idx: "entailment", con_idx: "contradiction", neu_idx: "neutral"}

def _softmax(v: np.ndarray) -> np.ndarray:
    v = v - v.max()
    e = np.exp(v)
    return e / e.sum()

class NLICache:
    def __init__(self, batch_size: int = NLI_BATCH_SIZE):
        self._logit_cache: Dict[Tuple[str, str], np.ndarray] = {}
        self._batch_size = batch_size

    def predict_logits(self, pairs: List[Tuple[str, str]]) -> List[np.ndarray]:
        assert _CE is not None
        out: List[np.ndarray] = []
        to_run, idx = [], []
        for i, key in enumerate(pairs):
            if key in self._logit_cache:
                out.append(self._logit_cache[key])
            else:
                out.append(None)  # type: ignore
                to_run.append(key); idx.append(i)
        for s in range(0, len(to_run), self._batch_size):
            chunk = to_run[s:s+self._batch_size]
            logits_list = _CE.predict(chunk)
            for key, logits in zip(chunk, logits_list):
                arr = np.array(logits, dtype=float)
                self._logit_cache[key] = arr
        for i, key in zip(idx, to_run):
            out[i] = self._logit_cache[key]
        return out

    def label_prob(self, prem: str, hyp: str) -> Tuple[str, float, List[float]]:
        logits = self.predict_logits([(prem, hyp)])[0]
        probs = _softmax(logits)
        idx = int(logits.argmax())
        label = _ID2LABEL.get(idx, "neutral")
        return label, float(probs[idx]), probs.tolist()

_NLI = NLICache()

def nli_many(pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], Tuple[str, float, List[float]]]:
    logits_list = _NLI.predict_logits(pairs)
    out: Dict[Tuple[str, str], Tuple[str, float, List[float]]] = {}
    for key, logits in zip(pairs, logits_list):
        probs = _softmax(logits)
        idx = int(logits.argmax())
        out[key] = (_ID2LABEL.get(idx, "neutral"), float(probs[idx]), probs.tolist())
    return out

def nli_one(p: str, h: str) -> Tuple[str, float, List[float]]:
    return _NLI.label_prob(p, h)
