from typing import Dict, Tuple
import torch
from transformers import pipeline

def _pick_device_index(dev: str) -> int:
    """Return HF pipeline device index: 0 for CUDA if available, -1 for CPU."""
    if dev == "auto":
        return 0 if torch.cuda.is_available() else -1
    return 0 if dev.startswith("cuda") else -1

LABEL_MAP = {
    "entailment": "entailment", "ENTAILMENT": "entailment",
    "neutral": "neutral", "NEUTRAL": "neutral",
    "contradiction": "contradiction", "CONTRADICTION": "contradiction",
}

class DualTE:
    """
    Dual textual entailment wrapper around two NLI models.
    Provides per-pair scores for entailment/neutral/contradiction.
    """
    def __init__(self, model1: str, model2: str, device: str = "auto"):
        dev = _pick_device_index(device)
        self.pipe1 = pipeline("text-classification", model=model1, tokenizer=model1,
                              device=dev, top_k=None, truncation=True)
        self.pipe2 = pipeline("text-classification", model=model2, tokenizer=model2,
                              device=dev, top_k=None, truncation=True)

    @staticmethod
    def _scores(out) -> Dict[str, float]:
        """Normalize raw pipeline output into a dict with keys entailment/neutral/contradiction."""
        if isinstance(out, list):  # HF pipeline returns a list
            out = out[0]
        label2p = {}
        for d in out:
            lab = LABEL_MAP.get(d["label"], d["label"].lower())
            label2p[lab] = float(d["score"])
        # Ensure all three keys exist
        for k in ("entailment", "neutral", "contradiction"):
            label2p.setdefault(k, 0.0)
        return label2p

    def infer(self, premise: str, hypothesis: str) -> Tuple[Dict[str, float], Dict[str, float]]:
        """Return (scores_model1, scores_model2) for a single (premise, hypothesis) pair."""
        s1 = self._scores(self.pipe1(premise, text_pair=hypothesis))
        s2 = self._scores(self.pipe2(premise, text_pair=hypothesis))
        return s1, s2
