#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Per-layer margin evaluation (Logit Lens) for multiple-choice / binary-choice tasks.

Outputs per-sample metrics.jsonl and summary.csv.

Added metrics for:
(B) gain curve / amplifier evidence:
- delta_m_curve_*: per-layer margin gain Δm_l = m_l - m_{l-1}
- late_gain_sum_* / late_gain_mean_*: cumulative gain in late layers
- amp_topk_*: top-k layers with largest Δm_l (potential amplifier layers)

(E) more sensitive phase-transition indicators:
- crossed_*: whether any layer has margin > 0
- persist_all_after_cross_*: after first crossing, stays >0 until last layer
- pos_ratio_after_cross_*: fraction of layers after crossing with margin > 0
- last_pos_layer_*: last layer index with margin > 0
- min_margin_after_cross_*: minimum margin after crossing (stability depth)
- max_margin_* / mean_margin_*
"""

import os
import json
import argparse
from typing import Dict, List, Tuple, Any, Optional

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM


# -------------------------
# Prompt builders
# -------------------------

def build_prompt_arc(question: str, choice_texts: List[str], choice_labels: List[str]) -> str:
    options = [f"{lab}. {txt}" for lab, txt in zip(choice_labels, choice_texts)]
    options_str = "\n".join(options)
    prompt = (
        "### Task:\n"
        "Choose the best answer to the following question.\n\n"
        f"### Question:\n{question}\n\n"
        f"### Options:\n{options_str}\n\n"
        "### Answer:"
    )
    return prompt

def build_prompt_boolq(passage: str, question: str) -> str:
    prompt = (
        "### Task:\nRead the following passage and only answer the Yes/No question based on it.\n\n"
        f"### Passage:\n{passage}\n\n"
        f"### Question:\n{question}\n\n"
        "### Options:\nA. Yes\nB. No\n\n"
        "### Answer:"
    )
    return prompt

def build_prompt_piqa(goal: str, choices: List[str]) -> str:
    assert len(choices) == 2
    prompt = (
        "### Task:\n"
        "Choose the most physically plausible solution to achieve the goal.\n\n"
        f"### Goal:\n{goal}\n\n"
        "### Options:\n"
        f"A. {choices[0]}\n"
        f"B. {choices[1]}\n\n"
        "### Answer:"
    )
    return prompt

def build_prompt_winogrande(sentence: str, option1: str, option2: str) -> str:
    prompt = (
        "### Task:\n"
        "Choose the correct option to fill in the blank (\"_\") in the sentence.\n\n"
        f"### Sentence:\n{sentence}\n\n"
        "### Options:\n"
        f"A. {option1}\n"
        f"B. {option2}\n\n"
        "### Answer:"
    )
    return prompt

def build_prompt_hellaswag(ctx: str, endings: List[str]) -> str:
    labels = ["A", "B", "C", "D"]
    option_lines = [f"{lab}. {txt}" for lab, txt in zip(labels, endings)]
    options_str = "\n".join(option_lines)
    prompt = (
        "### Task:\n"
        "Choose the most plausible continuation of the following context.\n\n"
        f"### Context:\n{ctx}\n\n"
        f"### Options:\n{options_str}\n\n"
        "### Answer:"
    )
    return prompt


# -------------------------
# Utilities
# -------------------------

def normalize_gold(gold_raw: Any, letters: List[str]) -> str:
    g = str(gold_raw).strip()
    g_up = g.upper()

    if g_up in letters:
        return g_up

    if g.isdigit():
        n = int(g)
        if len(letters) == 2:
            if n == 0:
                return letters[0]
            if n == 1:
                return letters[1]
            if n == 2:
                return letters[1]
        else:
            if 0 <= n <= 3:
                return ["A", "B", "C", "D"][n]
            if 1 <= n <= 4:
                return ["A", "B", "C", "D"][n - 1]

    raise ValueError(f"Cannot normalize gold '{gold_raw}' to letters={letters}")

def get_num_layers_llama(model: AutoModelForCausalLM) -> int:
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        return len(model.model.layers)
    if hasattr(model, "config") and hasattr(model.config, "num_hidden_layers"):
        return int(model.config.num_hidden_layers)
    raise RuntimeError("Cannot determine number of layers from model.")

@torch.no_grad()
def assert_pruned_vs_dense(dense: AutoModelForCausalLM, pruned: AutoModelForCausalLM) -> None:
    ld = get_num_layers_llama(dense)
    lp = get_num_layers_llama(pruned)
    if lp > ld:
        raise AssertionError(f"Pruned layers ({lp}) > Dense layers ({ld}) - unexpected.")
    print(f"[LayerCheck] dense_layers={ld}, pruned_layers={lp}")

def _single_token_id(tokenizer: AutoTokenizer, text: str) -> Optional[int]:
    ids = tokenizer.encode(text, add_special_tokens=False)
    if len(ids) == 1:
        return ids[0]
    return None

def choose_letter_token_ids(tokenizer: AutoTokenizer, letters: List[str]) -> Dict[str, int]:
    out = {}
    for L in letters:
        tid = _single_token_id(tokenizer, " " + L)
        if tid is None:
            tid = _single_token_id(tokenizer, L)
        if tid is None:
            raise ValueError(
                f"Cannot find single-token id for letter '{L}' or ' {L}'. "
                f"Try a different template or use option-text scoring."
            )
        out[L] = tid
    return out

@torch.no_grad()
def per_layer_letter_logits(
    model: AutoModelForCausalLM,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    letter_token_ids: Dict[str, int],
) -> Tuple[List[Dict[str, float]], int]:
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        use_cache=False,
    )
    hidden_states = outputs.hidden_states
    L = get_num_layers_llama(model)

    if len(hidden_states) != L + 1:
        hidden_states = hidden_states[-(L + 1):]

    per_layer = []
    for l in range(1, L + 1):
        hs = hidden_states[l]
        if hasattr(model, "model") and hasattr(model.model, "norm"):
            hs = model.model.norm(hs)
        logits = model.lm_head(hs)
        last = logits[0, -1, :]
        d = {k: float(last[v].item()) for k, v in letter_token_ids.items()}
        per_layer.append(d)
    return per_layer, L

def compute_margin_curve(
    per_layer_logits: List[Dict[str, float]],
    gold: str,
    letters: List[str],
) -> Tuple[List[float], int, float, str]:
    m_curve = []
    for d in per_layer_logits:
        gold_logit = d[gold]
        other = [d[x] for x in letters if x != gold]
        m_curve.append(gold_logit - max(other))
    L = len(m_curve)
    l_star = next((i + 1 for i, m in enumerate(m_curve) if m > 0.0), L + 1)
    m_last = m_curve[-1] if L > 0 else float("nan")
    last_layer = per_layer_logits[-1]
    pred_last = max(last_layer.items(), key=lambda kv: kv[1])[0]
    return m_curve, l_star, m_last, pred_last

def compute_gain_and_stability(
    m_curve: List[float],
    l_star: int,
    late_frac: float = 0.5,
    topk: int = 5,
) -> Dict[str, Any]:
    """
    From m_curve (length L), compute:
    - delta_m_curve: length L-1
    - late_gain_sum/mean over last ceil(L*late_frac) layers' deltas
    - crossing / persistence metrics
    - top-k amplifier deltas (indices are 1-indexed layer positions for delta: between l-1->l, we report 'to_layer'=l)
    """
    L = len(m_curve)
    if L == 0:
        return {
            "delta_m_curve": [],
            "late_gain_sum": float("nan"),
            "late_gain_mean": float("nan"),
            "crossed": False,
            "persist_all_after_cross": False,
            "pos_ratio_after_cross": float("nan"),
            "last_pos_layer": 0,
            "min_margin_after_cross": float("nan"),
            "max_margin": float("nan"),
            "mean_margin": float("nan"),
            "amp_topk": [],
        }

    # Δm_l for l=2..L (to layer l)
    delta = [m_curve[i] - m_curve[i - 1] for i in range(1, L)]

    # late region on deltas
    # choose start layer index for margin (1-indexed) => late start in [1..L]
    late_start_layer = max(1, int((1.0 - late_frac) * L) + 1)  # e.g., L=32, late_frac=0.5 => start at 17
    # deltas correspond to transitions to layer l (2..L), so select those with to_layer >= late_start_layer
    late_delta = []
    for i, dv in enumerate(delta, start=2):  # i = to_layer
        if i >= late_start_layer:
            late_delta.append(dv)
    late_gain_sum = float(sum(late_delta)) if len(late_delta) > 0 else 0.0
    late_gain_mean = float(late_gain_sum / len(late_delta)) if len(late_delta) > 0 else 0.0

    crossed = (l_star <= L)

    # stability after first crossing
    if crossed:
        after = m_curve[l_star - 1:]  # from crossing layer to end
        persist_all = all(m > 0.0 for m in after)
        pos_ratio = float(sum(1 for m in after if m > 0.0) / len(after))
        min_after = float(min(after)) if len(after) > 0 else float("nan")
    else:
        persist_all = False
        pos_ratio = float("nan")
        min_after = float("nan")

    # last positive layer
    last_pos = 0
    for i, m in enumerate(m_curve, start=1):
        if m > 0.0:
            last_pos = i

    # amplifier top-k deltas (largest positive Δm)
    amp = []
    if len(delta) > 0:
        # pair each delta with to_layer index
        pairs = [(to_layer, dv) for to_layer, dv in enumerate(delta, start=2)]
        pairs.sort(key=lambda x: x[1], reverse=True)
        for to_layer, dv in pairs[:max(1, topk)]:
            amp.append({"to_layer": int(to_layer), "delta_m": float(dv)})

    return {
        "delta_m_curve": [float(x) for x in delta],
        "late_start_layer": int(late_start_layer),
        "late_gain_sum": late_gain_sum,
        "late_gain_mean": late_gain_mean,
        "crossed": bool(crossed),
        "persist_all_after_cross": bool(persist_all),
        "pos_ratio_after_cross": pos_ratio,
        "last_pos_layer": int(last_pos),
        "min_margin_after_cross": min_after,
        "max_margin": float(max(m_curve)),
        "mean_margin": float(sum(m_curve) / L),
        "amp_topk": amp,
    }


# -------------------------
# Dataset loaders (parquet)
# -------------------------

def iter_samples(task: str, parquet_path: str, split: str, limit: Optional[int], seed: int):
    ds = load_dataset("parquet", data_files=parquet_path, split=split)

    if limit is not None and limit < len(ds):
        g = torch.Generator().manual_seed(seed)
        perm = torch.randperm(len(ds), generator=g).tolist()
        ds = ds.select(perm[:limit])

    task = task.lower()

    for item in ds:
        if task in ("arc_easy", "arcchallenge", "arc_challenge"):
            letters = ["A", "B", "C", "D"]
            prompt = build_prompt_arc(
                question=item["question"],
                choice_texts=item["choices"]["text"],
                choice_labels=item["choices"]["label"],
            )
            gold = normalize_gold(item["answerKey"], letters)
            sid = item.get("id", None)
            yield sid, prompt, gold, letters

        elif task == "boolq":
            letters = ["A", "B"]
            prompt = build_prompt_boolq(item["passage"], item["question"])
            gold_raw = item["answer"]
            gold = "A" if bool(gold_raw) else "B"
            gold = normalize_gold(gold, letters)
            sid = item.get("id", None)
            yield sid, prompt, gold, letters

        elif task == "piqa":
            letters = ["A", "B"]
            prompt = build_prompt_piqa(item["question"], item["choices"])
            if "answer_index" in item and item["answer_index"] is not None:
                gold = "A" if int(item["answer_index"]) == 0 else "B"
            else:
                gold = str(item["answer"]).strip().upper()
            gold = normalize_gold(gold, letters)
            sid = item.get("id", None)
            yield sid, prompt, gold, letters

        elif task in ("winogrande", "wino"):
            letters = ["A", "B"]
            prompt = build_prompt_winogrande(item["sentence"], item["option1"], item["option2"])
            ans = str(item["answer"]).strip()
            if ans in ("1", "2"):
                gold = "A" if ans == "1" else "B"
            else:
                gold = normalize_gold(ans, letters)
            gold = normalize_gold(gold, letters)
            sid = item.get("id", None)
            yield sid, prompt, gold, letters

        elif task in ("hellaswag", "hellas"):
            letters = ["A", "B", "C", "D"]
            prompt = build_prompt_hellaswag(item["ctx"], item["endings"])
            lab = str(item["label"]).strip()
            gold = normalize_gold(lab, letters)
            sid = item.get("ind", item.get("id", None))
            yield sid, prompt, gold, letters

        else:
            raise ValueError(f"Unsupported task: {task}")


# -------------------------
# Main
# -------------------------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dense_model", type=str, required=True)
    ap.add_argument("--pruned_model", type=str, required=True)
    ap.add_argument("--task", type=str, required=True)
    ap.add_argument("--parquet", type=str, required=True)
    ap.add_argument("--split", type=str, default="train")
    ap.add_argument("--limit", type=int, default=500)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--max_length", type=int, default=1024)
    ap.add_argument("--device", type=str, default="cuda:0")
    ap.add_argument("--out_dir", type=str, required=True)

    # new knobs for B/E
    ap.add_argument("--late_frac", type=float, default=0.5, help="late region fraction for late gain stats (e.g., 0.5 = last half)")
    ap.add_argument("--amp_topk", type=int, default=5, help="top-k amplifier deltas to record per sample")
    args = ap.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    out_jsonl = os.path.join(args.out_dir, "metrics.jsonl")
    out_summary = os.path.join(args.out_dir, "summary.csv")

    device = torch.device(args.device)

    # Load tokenizer once
    tok = AutoTokenizer.from_pretrained(args.dense_model, use_fast=False)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    dense = AutoModelForCausalLM.from_pretrained(
        args.dense_model,
        torch_dtype=torch.float16 if "cuda" in args.device else torch.float32,
        device_map=None,
    ).to(device).eval()

    pruned = AutoModelForCausalLM.from_pretrained(
        args.pruned_model,
        torch_dtype=torch.float16 if "cuda" in args.device else torch.float32,
        device_map=None,
    ).to(device).eval()

    assert_pruned_vs_dense(dense, pruned)

    # Stats accumulators (keep your originals + add a few)
    total = 0
    correct_dense = 0
    correct_pruned = 0
    pos_margin_dense = 0
    pos_margin_pruned = 0
    lstar_sum_dense = 0.0
    lstar_sum_pruned = 0.0

    # new summary stats (optional but helpful)
    late_gain_sum_dense = 0.0
    late_gain_sum_pruned = 0.0
    persist_dense = 0
    persist_pruned = 0

    # cache token ids per letters
    cached_letters = None
    cached_letter_token_ids = None

    with open(out_jsonl, "w", encoding="utf-8") as f:
        for sid, prompt, gold, letters in iter_samples(args.task, args.parquet, args.split, args.limit, args.seed):
            enc = tok(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=args.max_length,
                padding=False,
            )
            input_ids = enc["input_ids"].to(device)
            attention_mask = enc["attention_mask"].to(device)

            # compute token ids only when letters change
            if cached_letters != tuple(letters):
                cached_letters = tuple(letters)
                cached_letter_token_ids = choose_letter_token_ids(tok, letters)
                print("[TokenID] letters:", letters, "ids:", cached_letter_token_ids)

            letter_token_ids = cached_letter_token_ids

            # Dense
            per_layer_dense, Ld = per_layer_letter_logits(dense, input_ids, attention_mask, letter_token_ids)
            m_curve_dense, lstar_dense, mlast_dense, pred_dense = compute_margin_curve(per_layer_dense, gold, letters)
            extra_dense = compute_gain_and_stability(
                m_curve_dense, lstar_dense, late_frac=args.late_frac, topk=args.amp_topk
            )

            # Pruned
            per_layer_pruned, Lp = per_layer_letter_logits(pruned, input_ids, attention_mask, letter_token_ids)
            m_curve_pruned, lstar_pruned, mlast_pruned, pred_pruned = compute_margin_curve(per_layer_pruned, gold, letters)
            extra_pruned = compute_gain_and_stability(
                m_curve_pruned, lstar_pruned, late_frac=args.late_frac, topk=args.amp_topk
            )

            # Update stats
            total += 1
            cd = int(pred_dense == gold)
            cp = int(pred_pruned == gold)
            correct_dense += cd
            correct_pruned += cp
            pos_margin_dense += int(mlast_dense > 0.0)
            pos_margin_pruned += int(mlast_pruned > 0.0)
            lstar_sum_dense += float(lstar_dense)
            lstar_sum_pruned += float(lstar_pruned)

            late_gain_sum_dense += float(extra_dense["late_gain_sum"])
            late_gain_sum_pruned += float(extra_pruned["late_gain_sum"])
            persist_dense += int(extra_dense["persist_all_after_cross"])
            persist_pruned += int(extra_pruned["persist_all_after_cross"])

            rec = {
                "sample_id": sid,
                "task": args.task,
                "gold": gold,
                "letters": letters,

                "dense_layers": Ld,
                "pruned_layers": Lp,

                "pred_dense": pred_dense,
                "pred_pruned": pred_pruned,
                "correct_dense": cd,
                "correct_pruned": cp,

                # your original key metrics
                "m_last_dense": float(mlast_dense),
                "m_last_pruned": float(mlast_pruned),
                "l_star_dense": int(lstar_dense),
                "l_star_pruned": int(lstar_pruned),
                "m_curve_dense": [float(x) for x in m_curve_dense],
                "m_curve_pruned": [float(x) for x in m_curve_pruned],

                # (B) gain / amplifier signals
                "delta_m_curve_dense": extra_dense["delta_m_curve"],
                "delta_m_curve_pruned": extra_pruned["delta_m_curve"],
                "late_start_layer_dense": extra_dense["late_start_layer"],
                "late_start_layer_pruned": extra_pruned["late_start_layer"],
                "late_gain_sum_dense": extra_dense["late_gain_sum"],
                "late_gain_sum_pruned": extra_pruned["late_gain_sum"],
                "late_gain_mean_dense": extra_dense["late_gain_mean"],
                "late_gain_mean_pruned": extra_pruned["late_gain_mean"],
                "amp_topk_dense": extra_dense["amp_topk"],
                "amp_topk_pruned": extra_pruned["amp_topk"],

                # (E) stability / persistence signals
                "crossed_dense": extra_dense["crossed"],
                "crossed_pruned": extra_pruned["crossed"],
                "persist_all_after_cross_dense": extra_dense["persist_all_after_cross"],
                "persist_all_after_cross_pruned": extra_pruned["persist_all_after_cross"],
                "pos_ratio_after_cross_dense": extra_dense["pos_ratio_after_cross"],
                "pos_ratio_after_cross_pruned": extra_pruned["pos_ratio_after_cross"],
                "last_pos_layer_dense": extra_dense["last_pos_layer"],
                "last_pos_layer_pruned": extra_pruned["last_pos_layer"],
                "min_margin_after_cross_dense": extra_dense["min_margin_after_cross"],
                "min_margin_after_cross_pruned": extra_pruned["min_margin_after_cross"],
                "max_margin_dense": extra_dense["max_margin"],
                "max_margin_pruned": extra_pruned["max_margin"],
                "mean_margin_dense": extra_dense["mean_margin"],
                "mean_margin_pruned": extra_pruned["mean_margin"],
            }

            f.write(json.dumps(rec, ensure_ascii=False) + "\n")

            if total % 20 == 0:
                print(f"[Progress] {total} samples")

    # Write summary (keep your original columns + add a few extra)
    acc_d = correct_dense / max(total, 1)
    acc_p = correct_pruned / max(total, 1)
    pm_d = pos_margin_dense / max(total, 1)
    pm_p = pos_margin_pruned / max(total, 1)
    ls_d = lstar_sum_dense / max(total, 1)
    ls_p = lstar_sum_pruned / max(total, 1)

    lg_d = late_gain_sum_dense / max(total, 1)
    lg_p = late_gain_sum_pruned / max(total, 1)
    per_d = persist_dense / max(total, 1)
    per_p = persist_pruned / max(total, 1)

    with open(out_summary, "w", encoding="utf-8") as f:
        f.write(
            "total,acc_dense,acc_pruned,"
            "P(m_last>0)_dense,P(m_last>0)_pruned,"
            "mean_l_star_dense,mean_l_star_pruned,"
            "mean_late_gain_sum_dense,mean_late_gain_sum_pruned,"
            "P(persist_all_after_cross)_dense,P(persist_all_after_cross)_pruned\n"
        )
        f.write(
            f"{total},{acc_d:.6f},{acc_p:.6f},"
            f"{pm_d:.6f},{pm_p:.6f},"
            f"{ls_d:.6f},{ls_p:.6f},"
            f"{lg_d:.6f},{lg_p:.6f},"
            f"{per_d:.6f},{per_p:.6f}\n"
        )

    print("[Done]")
    print("summary:", out_summary)
    print("metrics:", out_jsonl)


if __name__ == "__main__":
    main()