#!/usr/bin/env python3
"""Evaluation helpers for LeanCheck."""

from __future__ import annotations

import csv
import random
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import torch

from leancheck_model import ID_TO_LABEL, LABEL_TO_ID, LeanCheckModel


@torch.no_grad()
def evaluate_loader(model, loader, tokenizer, device: str) -> Dict[str, float]:
    model.eval()
    total = 0
    lm_loss = 0.0
    cons_loss = 0.0
    cls_ok = 0
    gen_ok = 0
    verifies_id = tokenizer.convert_tokens_to_ids("[VERIFIES]")
    fails_id = tokenizer.convert_tokens_to_ids("[FAILS]")
    claim_id = tokenizer.convert_tokens_to_ids("[CLAIM]")
    for batch in loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        cls_labels = batch["cls_labels"].to(device)
        lm_labels = batch["lm_labels"].to(device)
        out = model(input_ids, attention_mask, lm_labels, cls_labels, tokenizer)
        n = input_ids.size(0)
        total += n
        lm_loss += float(out["lm_loss"].detach().cpu()) * n
        cons_loss += float(out["cons_loss"].detach().cpu()) * n
        cls_ok += int((out["cls_logits"].argmax(dim=-1) == cls_labels).sum().detach().cpu())
        logits = out["logits"]
        for i in range(n):
            ids = input_ids[i].tolist()
            try:
                claim_pos = ids.index(claim_id)
            except ValueError:
                claim_pos = max(0, int(attention_mask[i].sum().item()) - 2)
            next_logits = logits[i, min(claim_pos, logits.size(1) - 1), [fails_id, verifies_id]]
            pred_label = 1 if int(next_logits.argmax().item()) == 1 else 0
            gen_ok += int(pred_label == int(cls_labels[i].item()))
    return {
        "final_lm_loss": lm_loss / max(total, 1),
        "final_cons_loss": cons_loss / max(total, 1),
        "cls_claim_acc": cls_ok / max(total, 1),
        "gen_claim_acc": gen_ok / max(total, 1),
    }


@torch.no_grad()
def predict_labels(model, loader, tokenizer, device: str) -> List[str]:
    model.eval()
    preds: List[str] = []
    for batch in loader:
        out = model(
            batch["input_ids"].to(device),
            batch["attention_mask"].to(device),
            tokenizer=tokenizer,
        )
        ids = out["cls_logits"].argmax(dim=-1).detach().cpu().tolist()
        preds.extend(ID_TO_LABEL[i] for i in ids)
    return preds


def counterfactual_metrics(preds: List[str], rows: List[Dict[str, object]]) -> Dict[str, float]:
    follow_rat = 0
    follow_proof = 0
    total = 0
    for pred, row in zip(preds, rows):
        rat = str(row.get("rationale_label", row.get("label")))
        proof = str(row.get("proof_label", row.get("label")))
        follow_rat += int(pred == rat)
        follow_proof += int(pred == proof)
        total += 1
    return {
        "cfact_cls_follows_swap": follow_rat / max(total, 1),
        "cfact_cls_follows_orig": follow_proof / max(total, 1),
    }


def minimal_pair_flip_rate(preds: List[str], rows: List[Dict[str, object]]) -> Dict[str, float]:
    by_pair: Dict[int, Dict[str, str]] = {}
    for pred, row in zip(preds, rows):
        pid = int(row.get("pair_id", -1))
        role = str(row.get("pair_role", ""))
        by_pair.setdefault(pid, {})[role] = pred
    ok = 0
    total = 0
    for pair in by_pair.values():
        if "accepted" in pair and "rejected" in pair:
            ok += int(pair["accepted"] == "VERIFIES" and pair["rejected"] == "FAILS")
            total += 1
    return {"minimal_pair_flip_acc": ok / max(total, 1)}


def _encode_text(tokenizer, text: str, max_len: int, device: str) -> Dict[str, torch.Tensor]:
    if hasattr(tokenizer, "split"):
        enc = tokenizer.encode(text, max_length=max_len)
        return {k: v.unsqueeze(0).to(device).long() for k, v in enc.items()}
    enc = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=max_len,
        return_tensors="pt",
    )
    return {k: v.to(device).long() for k, v in enc.items()}


def _claim_pos(input_ids: torch.Tensor, attention_mask: torch.Tensor, tokenizer) -> int:
    claim_id = tokenizer.convert_tokens_to_ids("[CLAIM]")
    ids = input_ids[0].detach().cpu().tolist()
    try:
        return ids.index(claim_id)
    except ValueError:
        return max(0, int(attention_mask[0].sum().item()) - 2)


def _label_id(tokenizer, label: str) -> int:
    tok = "[VERIFIES]" if label == "VERIFIES" else "[FAILS]"
    return tokenizer.convert_tokens_to_ids(tok)


def _span_mask(model: LeanCheckModel, tokenizer, enc: Dict[str, torch.Tensor], span: str) -> torch.Tensor:
    if span == "theorem":
        variant = "wrong_span"
    elif span == "proof":
        variant = "proof_only"
    elif span == "rationale":
        variant = "rationale_only"
    elif span == "full":
        variant = "full_sequence"
    elif span == "claim":
        variant = "claim_only"
    else:
        variant = span
    marker = model.marker_ids(tokenizer)
    return model.span_mask(enc["input_ids"], enc["attention_mask"], marker, variant)[0]


def _random_mask_like(mask: torch.Tensor, attention_mask: torch.Tensor, rng: random.Random) -> torch.Tensor:
    length = int(mask.sum().item())
    valid = torch.nonzero(attention_mask[0].bool(), as_tuple=False).flatten().detach().cpu().tolist()
    if not valid:
        return mask.clone()
    length = max(1, min(length, len(valid)))
    start_choices = valid[: max(1, len(valid) - length + 1)]
    start = rng.choice(start_choices)
    out = torch.zeros_like(mask)
    out[start : min(start + length, out.numel())] = True
    return out & attention_mask[0].bool()


def _pair_minimal_rows(rows: List[Dict[str, object]]) -> List[Tuple[Dict[str, object], Dict[str, object]]]:
    by_pair: Dict[int, Dict[str, Dict[str, object]]] = {}
    for row in rows:
        if "pair_id" not in row:
            continue
        by_pair.setdefault(int(row["pair_id"]), {})[str(row.get("pair_role", ""))] = row
    pairs = []
    for pair in by_pair.values():
        if "accepted" in pair and "rejected" in pair:
            pairs.append((pair["accepted"], pair["rejected"]))
    return pairs


@torch.no_grad()
def activation_patching_diagnostic(
    model: LeanCheckModel,
    tokenizer,
    rows: List[Dict[str, object]],
    device: str,
    max_len: int,
    layers: Sequence[int],
    samples: int = 100,
    output_csv: Optional[Path] = None,
    seed: int = 0,
) -> Dict[str, float]:
    """Patch hidden states and measure LM-logit plus consistency-head shifts.

    For each accepted/rejected minimal pair, this runs both directions:
    accepted <- rejected and rejected <- accepted. The LM target is the source
    example's claim-label token logit at the base claim position. The head
    target is the source verifier class logit after pooling the patched final
    hidden states with the model variant's normal consistency-head span.
    Positive values mean the patch moved the base example toward the source
    verifier outcome.
    """
    blank = {
        "patch_rationale_effect": "",
        "patch_theorem_effect": "",
        "patch_random_effect": "",
        "patch_rationale_minus_random": "",
        "head_patch_rationale_effect": "",
        "head_patch_theorem_effect": "",
        "head_patch_random_effect": "",
        "head_patch_rationale_minus_random": "",
    }
    if model.is_tiny or not hasattr(model.backbone, "transformer"):
        return blank
    blocks = getattr(model.backbone.transformer, "h", None)
    if blocks is None:
        return blank

    def cls_target_id(label: str) -> int:
        return LABEL_TO_ID["VERIFIES" if label == "VERIFIES" else "FAILS"]

    def cls_logit_from_hidden(hidden: torch.Tensor, enc: Dict[str, torch.Tensor], target_cls: int) -> float:
        marker = model.marker_ids(tokenizer)
        smask = model.span_mask(enc["input_ids"], enc["attention_mask"], marker, model.variant)
        denom = smask.sum(dim=1).clamp_min(1).unsqueeze(-1)
        pooled = (hidden * smask.unsqueeze(-1)).sum(dim=1) / denom
        cls_logits = model.consistency_head(pooled)
        return float(cls_logits[0, target_cls].detach().cpu())

    rng = random.Random(seed)
    pairs = _pair_minimal_rows(rows)
    rng.shuffle(pairs)
    directions: List[Tuple[Dict[str, object], Dict[str, object]]] = []
    for accepted, rejected in pairs:
        directions.append((accepted, rejected))
        directions.append((rejected, accepted))
    directions = directions[:samples]

    patch_rows: List[Dict[str, object]] = []
    lm_effects: Dict[str, List[float]] = {"rationale": [], "theorem": [], "random": []}
    head_effects: Dict[str, List[float]] = {"rationale": [], "theorem": [], "random": []}
    model.eval()
    old_cache = getattr(model.backbone.config, "use_cache", False)
    model.backbone.config.use_cache = False

    try:
        for layer in layers:
            if layer < 0 or layer >= len(blocks):
                continue
            block = blocks[layer]
            for base_row, source_row in directions:
                base = _encode_text(tokenizer, str(base_row["text"]), max_len, device)
                source = _encode_text(tokenizer, str(source_row["text"]), max_len, device)
                target_token = _label_id(tokenizer, str(source_row["label"]))
                target_cls = cls_target_id(str(source_row["label"]))
                claim_pos = _claim_pos(base["input_ids"], base["attention_mask"], tokenizer)

                clean = model.backbone(
                    input_ids=base["input_ids"],
                    attention_mask=base["attention_mask"],
                    output_hidden_states=True,
                    return_dict=True,
                    use_cache=False,
                )
                lm_before = float(clean.logits[0, claim_pos, target_token].detach().cpu())
                head_before = cls_logit_from_hidden(clean.hidden_states[-1], base, target_cls)

                source_hidden: Dict[str, torch.Tensor] = {}

                def capture_hook(_module, _inputs, output):
                    hidden = output[0] if isinstance(output, tuple) else output
                    source_hidden["value"] = hidden.detach()

                handle = block.register_forward_hook(capture_hook)
                model.backbone(
                    input_ids=source["input_ids"],
                    attention_mask=source["attention_mask"],
                    return_dict=True,
                    use_cache=False,
                )
                handle.remove()
                if "value" not in source_hidden:
                    continue

                src_rat = _span_mask(model, tokenizer, source, "rationale")
                base_masks = {
                    "rationale": _span_mask(model, tokenizer, base, "rationale"),
                    "theorem": _span_mask(model, tokenizer, base, "theorem"),
                }
                base_masks["random"] = _random_mask_like(base_masks["rationale"], base["attention_mask"], rng)

                for span_name, base_mask in base_masks.items():
                    src_mask = src_rat if span_name in {"rationale", "random"} else _span_mask(model, tokenizer, source, "theorem")
                    src_positions = torch.nonzero(src_mask, as_tuple=False).flatten()
                    base_positions = torch.nonzero(base_mask, as_tuple=False).flatten()
                    if src_positions.numel() == 0 or base_positions.numel() == 0:
                        continue
                    src_values = source_hidden["value"][0, src_positions, :]
                    src_mean = src_values.mean(dim=0, keepdim=True)

                    def patch_hook(_module, _inputs, output):
                        if isinstance(output, tuple):
                            hidden = output[0].clone()
                            rest = output[1:]
                        else:
                            hidden = output.clone()
                            rest = None
                        take = min(base_positions.numel(), src_values.size(0))
                        hidden[0, base_positions[:take], :] = src_values[:take]
                        if take < base_positions.numel():
                            hidden[0, base_positions[take:], :] = src_mean
                        return (hidden,) + rest if rest is not None else hidden

                    handle = block.register_forward_hook(patch_hook)
                    patched = model.backbone(
                        input_ids=base["input_ids"],
                        attention_mask=base["attention_mask"],
                        output_hidden_states=True,
                        return_dict=True,
                        use_cache=False,
                    )
                    handle.remove()
                    lm_after = float(patched.logits[0, claim_pos, target_token].detach().cpu())
                    head_after = cls_logit_from_hidden(patched.hidden_states[-1], base, target_cls)
                    lm_effect = lm_after - lm_before
                    head_effect = head_after - head_before
                    lm_effects[span_name].append(lm_effect)
                    head_effects[span_name].append(head_effect)
                    patch_rows.append({
                        "layer": layer,
                        "span": span_name,
                        "base_label": base_row["label"],
                        "source_label": source_row["label"],
                        "target_logit_before": lm_before,
                        "target_logit_after": lm_after,
                        "patch_effect": lm_effect,
                        "lm_patch_effect": lm_effect,
                        "head_target_logit_before": head_before,
                        "head_target_logit_after": head_after,
                        "head_patch_effect": head_effect,
                    })
    finally:
        model.backbone.config.use_cache = old_cache

    if output_csv is not None:
        output_csv.parent.mkdir(parents=True, exist_ok=True)
        fields = [
            "layer",
            "span",
            "base_label",
            "source_label",
            "target_logit_before",
            "target_logit_after",
            "patch_effect",
            "lm_patch_effect",
            "head_target_logit_before",
            "head_target_logit_after",
            "head_patch_effect",
        ]
        with output_csv.open("w", newline="", encoding="utf-8") as f:
            writer = csv.DictWriter(f, fieldnames=fields)
            writer.writeheader()
            writer.writerows(patch_rows)

    def avg(bucket: Dict[str, List[float]], name: str):
        vals = bucket[name]
        return sum(vals) / len(vals) if vals else ""

    lm_rat = avg(lm_effects, "rationale")
    lm_rnd = avg(lm_effects, "random")
    head_rat = avg(head_effects, "rationale")
    head_rnd = avg(head_effects, "random")
    return {
        "patch_rationale_effect": lm_rat,
        "patch_theorem_effect": avg(lm_effects, "theorem"),
        "patch_random_effect": lm_rnd,
        "patch_rationale_minus_random": (lm_rat - lm_rnd) if isinstance(lm_rat, float) and isinstance(lm_rnd, float) else "",
        "head_patch_rationale_effect": head_rat,
        "head_patch_theorem_effect": avg(head_effects, "theorem"),
        "head_patch_random_effect": head_rnd,
        "head_patch_rationale_minus_random": (head_rat - head_rnd) if isinstance(head_rat, float) and isinstance(head_rnd, float) else "",
    }

