#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Plot per-layer mean probability of the correct option (gold) for multiple tasks
in ONE figure (Logit Lens).

For each task:
- Build prompt (ARC/BoolQ/PIQA/Wino/Hella)
- For each layer l: hidden_states[l] -> (optional norm) -> lm_head -> last token logits
- Softmax over candidate option tokens only (A/B or A/B/C/D)
- Record p_gold(l)
Aggregate:
- mean_p_gold[l] = average over samples

Outputs:
- mean_p_gold_multi_tasks.png
- mean_p_gold_multi_tasks.json (optional)
"""

import os
import json
import argparse
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

from matplotlib.ticker import MaxNLocator



# -------------------------
# Style (match your ref style)
# -------------------------
def set_global_style(font_size: int):
    plt.rcParams.update({
        "font.size": font_size,
        "axes.titlesize": font_size + 2,
        "axes.labelsize": font_size,
        "xtick.labelsize": font_size - 1,
        "ytick.labelsize": font_size - 1,
        "legend.fontsize": font_size - 1,
    })

tab20 = plt.cm.tab20


# -------------------------
# Prompt builders (same as your margin code)
# -------------------------
def build_prompt_arc(question: str, choice_texts: List[str], choice_labels: List[str]) -> str:
    options = [f"{lab}. {txt}" for lab, txt in zip(choice_labels, choice_texts)]
    options_str = "\n".join(options)
    prompt = (
        "### Task:\n"
        "Choose the best answer to the following question.\n\n"
        f"### Question:\n{question}\n\n"
        f"### Options:\n{options_str}\n\n"
        "### Answer:"
    )
    return prompt

def build_prompt_boolq(passage: str, question: str) -> str:
    prompt = (
        "### Task:\nRead the following passage and only answer the Yes/No question based on it.\n\n"
        f"### Passage:\n{passage}\n\n"
        f"### Question:\n{question}\n\n"
        "### Options:\nA. Yes\nB. No\n\n"
        "### Answer:"
    )
    return prompt

def build_prompt_piqa(goal: str, choices: List[str]) -> str:
    assert len(choices) == 2
    prompt = (
        "### Task:\n"
        "Choose the most physically plausible solution to achieve the goal.\n\n"
        f"### Goal:\n{goal}\n\n"
        "### Options:\n"
        f"A. {choices[0]}\n"
        f"B. {choices[1]}\n\n"
        "### Answer:"
    )
    return prompt

def build_prompt_winogrande(sentence: str, option1: str, option2: str) -> str:
    prompt = (
        "### Task:\n"
        "Choose the correct option to fill in the blank (\"_\") in the sentence.\n\n"
        f"### Sentence:\n{sentence}\n\n"
        "### Options:\n"
        f"A. {option1}\n"
        f"B. {option2}\n\n"
        "### Answer:"
    )
    return prompt

def build_prompt_hellaswag(ctx: str, endings: List[str]) -> str:
    labels = ["A", "B", "C", "D"]
    option_lines = [f"{lab}. {txt}" for lab, txt in zip(labels, endings)]
    options_str = "\n".join(option_lines)
    prompt = (
        "### Task:\n"
        "Choose the most plausible continuation of the following context.\n\n"
        f"### Context:\n{ctx}\n\n"
        f"### Options:\n{options_str}\n\n"
        "### Answer:"
    )
    return prompt


# -------------------------
# Utilities
# -------------------------
def normalize_gold(gold_raw: Any, letters: List[str]) -> str:
    g = str(gold_raw).strip()
    g_up = g.upper()
    if g_up in letters:
        return g_up

    if g.isdigit():
        n = int(g)
        if len(letters) == 2:
            if n == 0:
                return letters[0]
            if n in (1, 2):
                return letters[1]
        else:
            if 0 <= n <= 3:
                return ["A", "B", "C", "D"][n]
            if 1 <= n <= 4:
                return ["A", "B", "C", "D"][n - 1]

    raise ValueError(f"Cannot normalize gold '{gold_raw}' to letters={letters}")

def get_num_layers_llama(model: AutoModelForCausalLM) -> int:
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        return len(model.model.layers)
    if hasattr(model, "config") and hasattr(model.config, "num_hidden_layers"):
        return int(model.config.num_hidden_layers)
    raise RuntimeError("Cannot determine number of layers from model.")

def _single_token_id(tokenizer: AutoTokenizer, text: str) -> Optional[int]:
    ids = tokenizer.encode(text, add_special_tokens=False)
    if len(ids) == 1:
        return ids[0]
    return None

def choose_letter_token_ids(tokenizer: AutoTokenizer, letters: List[str]) -> Dict[str, int]:
    out = {}
    for L in letters:
        tid = _single_token_id(tokenizer, " " + L)
        if tid is None:
            tid = _single_token_id(tokenizer, L)
        if tid is None:
            raise ValueError(
                f"Cannot find single-token id for letter '{L}' or ' {L}'. "
                f"Try a different template or use option-text scoring."
            )
        out[L] = tid
    return out


def iter_samples(task: str, parquet_path: str, split: str, limit: Optional[int], seed: int):
    ds = load_dataset("parquet", data_files=parquet_path, split=split)

    if limit is not None and limit < len(ds):
        g = torch.Generator().manual_seed(seed)
        perm = torch.randperm(len(ds), generator=g).tolist()
        ds = ds.select(perm[:limit])

    task = task.lower()

    for item in ds:
        if task in ("arc_easy", "arcchallenge", "arc_challenge", "arc_easy"):
            letters = ["A", "B", "C", "D"]
            prompt = build_prompt_arc(
                question=item["question"],
                choice_texts=item["choices"]["text"],
                choice_labels=item["choices"]["label"],
            )
            gold = normalize_gold(item["answerKey"], letters)
            sid = item.get("id", None)
            yield sid, prompt, gold, letters

        elif task == "boolq":
            letters = ["A", "B"]
            prompt = build_prompt_boolq(item["passage"], item["question"])
            gold_raw = item["answer"]
            gold = "A" if bool(gold_raw) else "B"
            gold = normalize_gold(gold, letters)
            sid = item.get("id", None)
            yield sid, prompt, gold, letters

        elif task == "piqa":
            letters = ["A", "B"]
            prompt = build_prompt_piqa(item["question"], item["choices"])
            if "answer_index" in item and item["answer_index"] is not None:
                gold = "A" if int(item["answer_index"]) == 0 else "B"
            else:
                gold = str(item["answer"]).strip().upper()
            gold = normalize_gold(gold, letters)
            sid = item.get("id", None)
            yield sid, prompt, gold, letters

        elif task in ("winogrande", "wino"):
            letters = ["A", "B"]
            prompt = build_prompt_winogrande(item["sentence"], item["option1"], item["option2"])
            ans = str(item["answer"]).strip()
            if ans in ("1", "2"):
                gold = "A" if ans == "1" else "B"
            else:
                gold = normalize_gold(ans, letters)
            gold = normalize_gold(gold, letters)
            sid = item.get("id", None)
            yield sid, prompt, gold, letters

        elif task in ("hellaswag", "hellas"):
            letters = ["A", "B", "C", "D"]
            prompt = build_prompt_hellaswag(item["ctx"], item["endings"])
            lab = str(item["label"]).strip()
            gold = normalize_gold(lab, letters)
            sid = item.get("ind", item.get("id", None))
            yield sid, prompt, gold, letters

        else:
            raise ValueError(f"Unsupported task: {task}")


@torch.no_grad()
def per_layer_gold_prob_curve(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    prompt: str,
    gold: str,
    letters: List[str],
    letter_token_ids: Dict[str, int],
    device: torch.device,
    max_length: int,
) -> Tuple[np.ndarray, int]:
    """
    Returns:
      p_gold_curve: [L] for layers 1..L
      L
    Probability is softmax over ONLY candidate letters at the last token.
    """
    enc = tok(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=max_length,
        padding=False,
    )
    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)

    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        use_cache=False,
    )
    hidden_states = outputs.hidden_states
    L = get_num_layers_llama(model)

    if len(hidden_states) != L + 1:
        hidden_states = hidden_states[-(L + 1):]

    cand_letters = list(letters)
    cand_ids = torch.tensor([letter_token_ids[x] for x in cand_letters], device=device, dtype=torch.long)
    gold_pos = cand_letters.index(gold)

    p_list = []
    for l in range(1, L + 1):
        hs = hidden_states[l]
        if hasattr(model, "model") and hasattr(model.model, "norm"):
            hs = model.model.norm(hs)

        logits = model.lm_head(hs)                 # [B,S,V]
        last = logits[0, -1, :]                    # [V]
        cand_logits = last.index_select(0, cand_ids)  # [K]
        probs = torch.softmax(cand_logits.float(), dim=-1)
        p_list.append(float(probs[gold_pos].item()))

    return np.array(p_list, dtype=np.float32), L


def compute_mean_curve_for_task(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    task: str,
    parquet: str,
    split: str,
    limit: int,
    seed: int,
    device: torch.device,
    max_length: int,
    verbose_every: int = 50,
) -> Tuple[np.ndarray, int, int]:
    """
    Returns:
      mean_curve: [L]
      L
      N (num samples)
    """
    cached_letters = None
    cached_letter_token_ids = None

    curves = []
    L_ref = None
    total = 0

    for sid, prompt, gold, letters in iter_samples(task, parquet, split, limit, seed):
        if cached_letters != tuple(letters):
            cached_letters = tuple(letters)
            cached_letter_token_ids = choose_letter_token_ids(tok, letters)
            print(f"[TokenID][{task}] letters={letters} ids={cached_letter_token_ids}")

        p_curve, L = per_layer_gold_prob_curve(
            model=model,
            tok=tok,
            prompt=prompt,
            gold=gold,
            letters=letters,
            letter_token_ids=cached_letter_token_ids,
            device=device,
            max_length=max_length,
        )

        if L_ref is None:
            L_ref = L
        elif L != L_ref:
            m = min(L_ref, L)
            p_curve = p_curve[:m]
            L_ref = m

        curves.append(p_curve)
        total += 1
        if verbose_every > 0 and total % verbose_every == 0:
            print(f"[Progress][{task}] {total} samples")

    if total == 0:
        raise RuntimeError(f"[{task}] No samples processed. Check parquet/split/task.")

    mat = np.stack(curves, axis=0)  # [N,L]
    mean_curve = mat.mean(axis=0)
    return mean_curve, int(mean_curve.shape[0]), total


def plot_multi_task_curves(out_dir: str, curves: Dict[str, np.ndarray], title: str, dpi: int):
    os.makedirs(out_dir, exist_ok=True)

    # determine a common L (min over tasks) to align x-axis
    Ls = [len(v) for v in curves.values()]
    L_common = int(min(Ls))
    x = np.arange(1, L_common + 1)

    plt.figure(figsize=(10, 6))
    ax = plt.gca()
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))

    task_names = list(curves.keys())
    for i, task in enumerate(task_names):
        y = curves[task][:L_common]
        col = tab20(i % 20)
        ax.plot(x, y, linewidth=2.8, color=col, label=task)
        ax.scatter(x, y, s=38, color=col, edgecolor="white", linewidth=0.6, zorder=5)

    ax.set_xlabel("Layer ID")
    ax.set_ylabel("Mean Probability of Gold")
    ax.set_title(title)
    ax.set_ylim(0.0, 1.0)
    ax.legend()

    plt.tight_layout()
    out_png = os.path.join(out_dir, "mean_p_gold_multi_tasks.png")
    plt.savefig(out_png, dpi=dpi)
    plt.close()
    return out_png, L_common


def parse_task_parquet_items(items: List[str]) -> List[Tuple[str, str]]:
    """
    items: ["arc_challenge=/path/a.parquet", "boolq=/path/b.parquet", ...]
    """
    out = []
    for it in items:
        if "=" not in it:
            raise ValueError(f"--task_parquet item must be 'task=parquet_path', got: {it}")
        task, path = it.split("=", 1)
        task = task.strip()
        path = path.strip()
        if task == "" or path == "":
            raise ValueError(f"Bad --task_parquet item: {it}")
        out.append((task, path))
    return out


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=str, required=True, help="HF model path (dense or pruned).")

    # multi tasks
    ap.add_argument(
        "--task_parquet",
        type=str,
        action="append",
        required=True,
        help="Repeatable. Format: task_name=/path/to/data.parquet"
    )

    ap.add_argument("--split", type=str, default="train")
    ap.add_argument("--limit", type=int, default=500)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--max_length", type=int, default=1024)
    ap.add_argument("--device", type=str, default="cuda:0")
    ap.add_argument("--out_dir", type=str, required=True)

    # style (match your ref)
    ap.add_argument("--font_size", type=int, default=14)
    ap.add_argument("--dpi", type=int, default=250)

    ap.add_argument("--save_json", action="store_true", help="Also save curves to json.")
    ap.add_argument("--title", type=str, default=None, help="Optional custom plot title.")
    args = ap.parse_args()

    set_global_style(args.font_size)
    device = torch.device(args.device)

    task_items = parse_task_parquet_items(args.task_parquet)
    print("[Tasks]", task_items)

    tok = AutoTokenizer.from_pretrained(args.model, use_fast=False)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        torch_dtype=torch.float16 if "cuda" in args.device else torch.float32,
        device_map=None,
    ).to(device).eval()

    curves = {}
    meta = {}

    for task, parquet in task_items:
        print("=" * 80)
        print(f"[Run] task={task} parquet={parquet}")
        mean_curve, L, N = compute_mean_curve_for_task(
            model=model,
            tok=tok,
            task=task,
            parquet=parquet,
            split=args.split,
            limit=args.limit,
            seed=args.seed,
            device=device,
            max_length=args.max_length,
            verbose_every=50,
        )
        curves[task] = mean_curve
        meta[task] = {"L": int(L), "N": int(N), "parquet": parquet}
        print(f"[Done] task={task} L={L} N={N}")

    title = args.title
    if title is None:
        title = f"Mean P(gold) vs Layer"

    out_png, L_common = plot_multi_task_curves(
        out_dir=args.out_dir,
        curves=curves,
        title=title,
        dpi=args.dpi,
    )

    if args.save_json:
        out_json = os.path.join(args.out_dir, "mean_p_gold_multi_tasks.json")
        payload = {
            "model": args.model,
            "split": args.split,
            "limit": args.limit,
            "seed": args.seed,
            "max_length": args.max_length,
            "device": args.device,
            "L_common": int(L_common),
            "tasks": meta,
            "curves": {k: [float(x) for x in v.tolist()] for k, v in curves.items()},
        }
        with open(out_json, "w", encoding="utf-8") as f:
            json.dump(payload, f, indent=2, ensure_ascii=False)
        print("json:", out_json)

    print("png:", out_png)
    print("[Saved]", args.out_dir)


if __name__ == "__main__":
    main()
