from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import re
import inspect
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

nli_model_name = r"./nli_model"
nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(device)
nli_model.eval()

# Update the _SECTION_PATTERN to account for both section structures
_SECTION_PATTERN = re.compile(
    r"Premise/Evidence:\s*(?P<premises>.*?)"
    r"(Reasoning|Explanation):\s*(?P<reasoning_or_explanation>.*?)"
    r"Conclusion:\s*(?P<conclusion>.*)\Z",
    re.S | re.I,
)

def _parse_four_sections(text: str):
    m = _SECTION_PATTERN.search(text.strip())
    if not m:
        return None
    # Determine which structure is used and parse accordingly
    if m.group(2).lower() == "reasoning":
        secs = {
            "premises": m.group("premises").strip(),
            "reasoning": m.group("reasoning_or_explanation").strip(),
            "conclusion": m.group("conclusion").strip(),
        }
    elif m.group(2).lower() == "explanation":
        secs = {
            "premises": m.group("premises").strip(),
            "explanation": m.group("reasoning_or_explanation").strip(),
            "conclusion": m.group("conclusion").strip(),
        }
    return secs


def _pooled_logits(
        a: str,
        b: str,
        max_length: int = 512,
        stride: int = 128,
        agg: str = "mean",
        truncation_side: str = True,
) -> torch.Tensor:
    """
    Windowed inference over long sequences and pool logits across windows.
    """
    enc = nli_tokenizer(
        a, b,
        return_tensors="pt",
        truncation=truncation_side,
        max_length=max_length,
        stride=stride,
        return_overflowing_tokens=True,
        padding="max_length",
    )
    allowed = set(inspect.signature(nli_model.forward).parameters.keys())
    model_inputs = {k: v.to(device) for k, v in enc.items() if k in allowed}

    with torch.no_grad():
        logits = nli_model(**model_inputs).logits  # [num_windows, 3]

    # Apply aggregation method
    if logits.ndim == 1:
        pooled = logits
    else:
        if agg == "max":
            pooled = logits.max(dim=0).values
        elif agg == "mean":
            pooled = logits.mean(dim=0)
        elif agg == "logsumexp":
            pooled = torch.logsumexp(logits, dim=0)
        else:
            raise ValueError(f"Unknown agg: {agg}")

    return pooled  # [3]

def nli_label_probs(
    p1: str,
    p2: str,
    *,
    structured: bool = True,               # Enable section-wise comparison (default True; set False to disable)
    section_weights: dict | None = None,   # Section weights; defaults to equal weighting
    section_agg: str = "mean",             # Section-level aggregation: "logsumexp" | "mean" | "max"
):
    # ------------------ Unstructured path: original behavior ------------------
    if not structured:
        pooled = _pooled_logits(p1, p2)
        probs = torch.softmax(pooled, dim=-1).detach().cpu().tolist()
        return probs

    # ------------------ Structured path: section-wise equal-weight pooling ------------------
    secs1 = _parse_four_sections(p1)
    secs2 = _parse_four_sections(p2)
    if secs1 is None or secs2 is None:
        # Fallback to unstructured if parsing fails
        pooled = _pooled_logits(p1, p2)
        probs = torch.softmax(pooled, dim=-1).detach().cpu().tolist()
        return probs

    # Default: equal weights for sections
    if section_weights is None:
        section_weights = {"premises": 1.0, "reasoning": 1.0, "conclusion": 1.0}

    # Light sanitization: remove <final> tags to avoid noise
    def _sanitize(s: str) -> str:
        return re.sub(r"</?final>", "", s, flags=re.I).strip()

    logits_list, weights_list = [], []
    for name in ("premises", "reasoning", "conclusion"):
        # Handle reasoning/explanation based on the detected section header
        a = _sanitize(secs1.get(name, "") if name != "reasoning" else secs1.get("reasoning", ""))
        b = _sanitize(secs2.get(name, "") if name != "reasoning" else secs2.get("reasoning", ""))
        if a and b:
            lg = _pooled_logits(a, b)  # [3]
            logits_list.append(lg)
            weights_list.append(float(section_weights.get(name, 1.0)))

    # Edge case: all sections empty -> fallback to whole-text comparison
    if not logits_list:
        pooled = _pooled_logits(p1, p2)
        probs = torch.softmax(pooled, dim=-1).detach().cpu().tolist()
        return probs

    L = torch.stack(logits_list, dim=0)  # [num_sections, 3]
    w = torch.tensor(weights_list, device=L.device).clamp(min=0)

    # Section-level aggregation
    if section_agg == "mean":
        denom = w.sum() if w.sum() > 0 else torch.tensor(1.0, device=L.device)
        pooled_sections = (L * w.unsqueeze(-1)).sum(dim=0) / denom
    elif section_agg == "max":
        mask = w > 0
        pooled_sections = L[mask].max(dim=0).values if mask.any() else L.mean(dim=0)
    elif section_agg == "logsumexp":
        # Weighted logsumexp: log sum_i w_i * exp(l_i) = logsumexp(l_i + log w_i)
        logw = torch.where(w > 0, torch.log(w), torch.full_like(w, float("-inf")))
        pooled_sections = torch.logsumexp(L + logw.unsqueeze(-1), dim=0)
    else:
        raise ValueError(f"Unknown section_agg: {section_agg}")

    probs = torch.softmax(pooled_sections, dim=-1).detach().cpu().tolist()  # [c, n, e]
    return probs


def nli_entailment_score(p1: str, p2: str) -> float:
    """
    One-way entailment probability (entailment class probability).
    """
    return float(nli_label_probs(p1, p2)[2])

def mutual_entailment_score(
    p1: str,
    p2: str,
    mode: str = "min",
    lambda_penalty: float = 1,
    tie_break_eps: float = 1e-6,   # When ties occur, use the mean entailment to slightly break ties; 0 disables
) -> float:
    """
    mode ∈ {
        "min","mean","max","f1","sym","hybrid",
        "a2b","b2a","cont_max","neut_max",
        "cont_a2b","cont_mean","neut_a2b","cont_b2a","neut_b2a","neut_mean",
        "vote", "b2a_penalized", "mean_penalized"
    }

    - vote : Voting scheme. Each direction (p1→p2 and p2→p1) votes by argmax over (contradiction, neutral, entailment),
             mapped to {-1, 0, +1}. The final score is the sum of the two votes.
             If you want ties to prefer stronger entailment, set tie_break_eps>0 to add a small multiple of the mean entailment.
    """
    # Compute both directional probabilities once to avoid redundant inference
    probs12 = nli_label_probs(p1, p2)  # [c, n, e]
    probs21 = nli_label_probs(p2, p1)  # [c, n, e]

    c12, n12, e12 = map(float, probs12)
    c21, n21, e21 = map(float, probs21)

    if mode == "min":
        return min(e12, e21)
    elif mode == "mean":
        return (e12 + e21) / 2
    elif mode == "max":
        return max(e12, e21)
    elif mode == "f1":
        if e12 + e21 == 0:
            return 0.0
        return 2 * e12 * e21 / (e12 + e21)
    elif mode == "sym":
        return 1.0 - abs(e12 - e21)
    elif mode == "hybrid":
        avg = (e12 + e21) / 2
        diff = abs(e12 - e21)
        score = avg - lambda_penalty * diff
        return max(0.0, score)
    elif mode == "a2b":
        return e12
    elif mode == "b2a":
        return e21
    elif mode == "cont_max":
        return max(1 / c12, 1 / c21)
    elif mode == "neut_max":
        return max(1 / n12, 1 / n21)
    elif mode == "cont_a2b":
        return 1 / c12
    elif mode == "cont_mean":
        return (1 / c12 + 1 / c21) / 2
    elif mode == "neut_a2b":
        return 1 / n12
    elif mode == "cont_b2a":
        return 1 / c21
    elif mode == "neut_b2a":
        return 1 / n21
    elif mode == "neut_mean":
        return (1 / n12 + 1 / n21) / 2

    # Voting scheme
    elif mode == "vote":
        # Directional vote: argmax([contradiction, neutral, entailment]) → {-1, 0, +1}
        def vote(c, n, e):
            # 0: contradiction, 1: neutral, 2: entailment
            idx = 0
            if n >= c and n >= e:
                idx = 1
            elif e >= c and e >= n:
                idx = 2
            # Map to score
            return (-1, 0, +1)[idx]

        v12 = vote(c12, n12, e12)  # p1→p2
        v21 = vote(c21, n21, e21)  # p2→p1
        score = v12 + v21          # ∈ {-2,-1,0,1,2}
        entail_mean = (e12 + e21) / 2.0

        # Optional tie-break: prefer higher mean entailment
        if tie_break_eps > 0.0:
            return float(score + tie_break_eps * entail_mean)
        return float(score)

    elif mode == "b2a_penalized":
        # b2a entailment minus λ * contradiction
        return e21 - lambda_penalty * c21

    elif mode == "mean_penalized":
        # Mean entailment minus λ * mean contradiction
        avg_entail = (e12 + e21) / 2
        avg_contra = (c12 + c21) / 2
        return avg_entail - lambda_penalty * avg_contra

    else:
        raise ValueError(f"Unknown mode: {mode}")


def nli_tournament_judge_pairs(pairs: list[list[str]], mode: str = "min", debug: bool = True):
    """
    Tournament-style comparison using an NLI model.

    Args:
      - pairs: [[premise, hypothesis], ...]
      - mode : NLI scoring mode (passed to mutual_entailment_score)
      - debug: If True, print per-round comparison logs

    Returns:
      - best_idx   : 1-based index of the best pair
      - best_score : score of the best pair
      - score_list : list of scores for all pairs (aligned with the input order)
    """
    if not pairs:
        raise ValueError("pairs must not be empty")

    # Compute all scores once (useful for debugging/analysis and avoids recomputation)
    score_list: list[float] = []
    for p1, p2 in pairs:
        try:
            s = float(mutual_entailment_score(p1, p2, mode=mode))
        except Exception:
            s = 0.0
        score_list.append(s)

    # Knockout tournament based on the precomputed scores
    current_winner_idx = 0
    best_score = score_list[0]

    for challenger_idx in range(1, len(pairs)):
        score_current = score_list[current_winner_idx]
        score_challenger = score_list[challenger_idx]

        prev_winner_idx = current_winner_idx
        if score_challenger > score_current:
            current_winner_idx = challenger_idx
            best_score = score_challenger
        else:
            best_score = score_current

        if debug:
            print(
                f"[{mode}] [compare {prev_winner_idx + 1} vs {challenger_idx + 1}] "
                f"score_current={score_current:.4f}, score_challenger={score_challenger:.4f}, "
                f"winner={current_winner_idx + 1}"
            )

    # Return: best index (1-based), best score, and per-pair scores
    return current_winner_idx + 1, best_score, score_list
