#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Single-sample correct-argmax vs layer (logit-lens)

For a given sample i:
  For each layer ℓ:
    - Take hidden state at decision position
    - final norm -> lm_head
    - Compare logits among options
    - Record: is correct option the argmax? (0/1)

Outputs:
  - correct_argmax_per_layer.csv
  - correct_argmax_vs_layer.png
"""

import os
import re
import argparse
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from data_utils import get_sft_dataset, collate_sft

torch.set_num_threads(4)

# -----------------------------
# option parsing
# -----------------------------
_OPT_LETTER = re.compile(r"^\s*([A-D])\.\s", re.IGNORECASE)
_OPT_NUM = re.compile(r"^\s*([1-4])\.\s")
_ANSWER_RE = re.compile(r"\b([A-D]|[1-4])\b")

NUM2ABC = {"1": "A", "2": "B", "3": "C", "4": "D"}


def normalize_key(k: str) -> str:
    k = k.upper()
    return NUM2ABC.get(k, k)


def extract_option_keys(prompt: str) -> List[str]:
    keys = []
    for line in prompt.splitlines():
        m = _OPT_LETTER.match(line)
        if m:
            keys.append(m.group(1).upper())
        m2 = _OPT_NUM.match(line)
        if m2:
            keys.append(NUM2ABC[m2.group(1)])
    uniq = []
    for k in keys:
        if k not in uniq:
            uniq.append(k)
    if set(uniq) >= {"A", "B", "C", "D"}:
        return ["A", "B", "C", "D"]
    if set(uniq) >= {"A", "B"}:
        return ["A", "B"]
    return uniq


def extract_gold_key(tokenizer, input_ids, labels, attn) -> str:
    mask = (labels != -100) & (attn == 1)
    if not mask.any():
        raise RuntimeError("No gold answer tokens")
    text = tokenizer.decode(labels[mask].tolist(), skip_special_tokens=False)
    m = _ANSWER_RE.search(text)
    if not m:
        raise RuntimeError(f"Cannot parse gold answer from: {text!r}")
    return normalize_key(m.group(1))


# -----------------------------
# model helpers
# -----------------------------
def get_final_norm(model):
    if hasattr(model, "model") and hasattr(model.model, "norm"):
        return model.model.norm
    if hasattr(model, "transformer") and hasattr(model.transformer, "ln_f"):
        return model.transformer.ln_f
    raise AttributeError("final norm not found")


def get_lm_head(model):
    if hasattr(model, "lm_head"):
        return model.lm_head
    if hasattr(model, "embed_out"):
        return model.embed_out
    raise AttributeError("lm_head not found")


def candidate_token_ids(tokenizer, key: str) -> List[int]:
    cands = [key, " " + key, "\n" + key, "\n " + key]
    out = []
    for s in cands:
        ids = tokenizer.encode(s, add_special_tokens=False)
        if len(ids) == 1:
            out.append(ids[0])
    return list(dict.fromkeys(out))


def choose_token_id(tokenizer, model, device, prompt_ids, key: str) -> int:
    cand_ids = candidate_token_ids(tokenizer, key)
    if not cand_ids:
        raise RuntimeError(f"No token id for option {key}")
    with torch.no_grad():
        out = model(prompt_ids.unsqueeze(0).to(device))
    logits = out.logits[0, -1]
    return max(cand_ids, key=lambda i: float(logits[i]))


# -----------------------------
# main
# -----------------------------
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model_name_or_path", required=True)
    p.add_argument("--output_dir", required=True)
    p.add_argument("--sft_dataset", default="hellaswag")
    p.add_argument("--eval_split", default="validation")
    p.add_argument("--sample_index", type=int, required=True)
    p.add_argument("--max_length", type=int, default=512)
    p.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
    p.add_argument("--dpi", type=int, default=200)
    return p.parse_args()


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype={"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.dtype],
    ).to(device).eval()

    dataloader = DataLoader(
        get_sft_dataset(
            name=args.sft_dataset,
            tokenizer=tokenizer,
            max_length=args.max_length,
            split=args.eval_split,
            seed=42,
        ),
        batch_size=1,
        shuffle=False,
        collate_fn=collate_sft,
    )

    batch = next(iter(dataloader))
    for _ in range(args.sample_index):
        batch = next(iter(dataloader))

    batch = {k: v.to(device) for k, v in batch.items()}
    input_ids = batch["input_ids"][0]
    attn = batch["attention_mask"][0]
    labels = batch["labels"][0]

    prompt_mask = (labels == -100) & (attn == 1)
    prompt_ids = input_ids[prompt_mask]

    prompt_text = tokenizer.decode(prompt_ids.tolist(), skip_special_tokens=False)
    option_keys = extract_option_keys(prompt_text)
    gold = extract_gold_key(tokenizer, input_ids, labels, attn)

    final_norm = get_final_norm(model)
    lm_head = get_lm_head(model)

    option_token_ids = {
        k: choose_token_id(tokenizer, model, device, prompt_ids, k)
        for k in option_keys
    }

    with torch.no_grad():
        out = model(
            input_ids=prompt_ids.unsqueeze(0),
            attention_mask=torch.ones_like(prompt_ids.unsqueeze(0)),
            output_hidden_states=True,
        )

    hs = out.hidden_states
    L = len(hs) - 1
    correct_curve = []

    for l in range(1, L + 1):
        h = final_norm(hs[l][0, -1])
        logits = lm_head(h)
        opt_logits = {k: float(logits[option_token_ids[k]]) for k in option_keys}
        pred = max(opt_logits, key=opt_logits.get)
        correct_curve.append(int(pred == gold))

    df = pd.DataFrame({
        "layer": np.arange(1, L + 1),
        "correct_argmax": correct_curve,
    })
    csv_path = os.path.join(args.output_dir, "correct_argmax_per_layer.csv")
    df.to_csv(csv_path, index=False)

    plt.figure(figsize=(10, 3))
    plt.plot(df["layer"], df["correct_argmax"], marker="o")
    plt.ylim(-0.05, 1.05)
    plt.xlabel("Layer")
    plt.ylabel("Correct is argmax (0/1)")
    plt.title(f"Sample {args.sample_index}: correct-argmax vs layer (gold={gold})")
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.savefig(os.path.join(args.output_dir, "correct_argmax_vs_layer.png"), dpi=args.dpi)

    print(f"[Saved] {csv_path}")
    print("[Done]")


if __name__ == "__main__":
    main()
