#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Trace token-to-option contribution across layers (single GPU, non-distributed)

- No DeepSpeed
- No torch.distributed
- No MPI dependency
"""

import os
import re
import json
import csv
import argparse
from typing import List, Optional, Dict, Any, Tuple

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

from data_utils import get_sft_dataset, collate_sft  # keep your interface

torch.set_num_threads(4)

# -----------------------------
# prompt/label parsing (reuse your robust logic)
# -----------------------------
_OPTION_LINE_RE = re.compile(r"^\s*([A-D])\.\s", re.IGNORECASE)
_OPTION_LINE_NUM_RE = re.compile(r"^\s*([1-9])\.\s")
_ANSWER_KEY_RE = re.compile(r"^\s*([A-D]|[1-9])\s*$", re.IGNORECASE)


def set_seed(seed: int):
    import random
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def extract_option_keys_from_prompt(prompt_text: str):
    keys = []
    idx = prompt_text.find("### Options:")
    if idx >= 0:
        opt_block = prompt_text[idx:].split("### Answer", 1)[0]
    else:
        opt_block = prompt_text

    for line in opt_block.splitlines():
        m = _OPTION_LINE_RE.match(line)
        if m:
            keys.append(m.group(1).upper())
            continue
        m2 = _OPTION_LINE_NUM_RE.match(line)
        if m2:
            keys.append(m2.group(1))
            continue

    seen = set()
    uniq = []
    for k in keys:
        if k not in seen:
            seen.add(k)
            uniq.append(k)
    return uniq


def extract_gold_key_from_labels(tokenizer, input_ids_1d, labels_1d, attn_1d):
    valid_pos = (labels_1d != -100) & (attn_1d == 1)
    if not valid_pos.any():
        return None, ""

    ans_ids = labels_1d[valid_pos].tolist()
    ans_text = tokenizer.decode(ans_ids, skip_special_tokens=False)
    cleaned = ans_text.replace("\n", " ").strip()

    m = _ANSWER_KEY_RE.match(cleaned)
    if m:
        key = m.group(1).upper() if m.group(1).isalpha() else m.group(1)
        return key, cleaned

    m2 = re.search(r"\b([A-D])\b", cleaned, re.IGNORECASE)
    if m2:
        return m2.group(1).upper(), cleaned

    m3 = re.search(r"\b([1-9])\b", cleaned)
    if m3:
        return m3.group(1), cleaned

    return None, cleaned


def decode_token_clean(tokenizer, token_id: int) -> str:
    txt = tokenizer.decode([token_id], skip_special_tokens=False)
    return txt.replace("\n", "\\n")


def is_abcd_token(token_str: str) -> Optional[str]:
    s = token_str
    s = s.replace("\\n", "\n")
    s = s.strip()
    s = s.strip(" .:;,\t\r\n")
    if s in ["A", "B", "C", "D"]:
        return s
    return None


# -----------------------------
# args
# -----------------------------
def parse_args():
    p = argparse.ArgumentParser("Trace option contribution across layers (single GPU)")

    p.add_argument("--model_name_or_path", type=str, required=True)
    p.add_argument("--output_dir", type=str, required=True)

    p.add_argument("--sft_dataset", type=str, default="hellaswag")
    p.add_argument("--max_length", type=int, default=512)
    p.add_argument("--num_eval_samples", type=int, default=None)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--num_workers", type=int, default=2)
    p.add_argument("--eval_split", type=str, default="validation",
                   choices=["train", "validation", "valid", "test"])

    p.add_argument("--trace_sample_index", type=int, default=0)
    p.add_argument("--max_new_tokens_trace", type=int, default=64)

    p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
    p.add_argument("--force_eager_attn", action="store_true", default=True)

    return p.parse_args()


# -----------------------------
# model + data loading
# -----------------------------
def load_model_and_tokenizer(model_path: str, dtype: str = "bf16", force_eager_attn: bool = True):
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    torch_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[dtype]

    config_kwargs = {}
    if force_eager_attn:
        try:
            cfg = AutoConfig.from_pretrained(model_path)
            cfg._attn_implementation = "eager"
            config_kwargs["config"] = cfg
        except Exception as e:
            print(f"[Warn] AutoConfig eager attn setup failed: {e}")

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch_dtype,
        device_map=None,
        attn_implementation="eager" if force_eager_attn else None,
        **config_kwargs
    )
    return model, tokenizer


def build_eval_dataloader(args, tokenizer):
    dataset = get_sft_dataset(
        name=args.sft_dataset,
        tokenizer=tokenizer,
        max_length=args.max_length,
        seed=args.seed,
        num_samples=args.num_eval_samples,
        split=args.eval_split,
    )

    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,             # 关键：顺序可复现，index 就是全局 index
        collate_fn=collate_sft,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    return dataloader, dataset


# -----------------------------
# attentions -> csv
# -----------------------------
def attentions_to_rows(
    tokenizer,
    attentions: Tuple[torch.Tensor, ...],
    history_token_ids: List[int],
    query_pos_in_layer: int,
    step_tag: str
) -> List[Dict[str, Any]]:
    rows = []
    num_layers = len(attentions)
    token_strs = [decode_token_clean(tokenizer, tid) for tid in history_token_ids]

    for layer_idx in range(num_layers):
        w = attentions[layer_idx]  # [B,H,Q,K]
        w0 = w[0]                  # [H,Q,K]
        wq = w0[:, query_pos_in_layer, :]  # [H,K]
        w_mean = wq.mean(dim=0)            # [K]
        w_max = wq.max(dim=0).values       # [K]

        K = w_mean.numel()
        K2 = min(K, len(history_token_ids))

        for pos in range(K2):
            rows.append({
                "step_tag": step_tag,
                "layer": layer_idx,
                "token_pos": pos,
                "token_id": int(history_token_ids[pos]),
                "token_str": token_strs[pos],
                "attn_mean": float(w_mean[pos].item()),
                "attn_max": float(w_max[pos].item()),
            })
    return rows


def write_csv(path: str, rows: List[Dict[str, Any]]):
    if not rows:
        return
    os.makedirs(os.path.dirname(path), exist_ok=True)
    fieldnames = list(rows[0].keys())
    with open(path, "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        for r in rows:
            w.writerow(r)


def get_nth_batch(dataloader, n: int):
    """Get the n-th sample (0-based) from dataloader."""
    it = iter(dataloader)
    batch = None
    for _ in range(n + 1):
        batch = next(it)
    return batch


# -----------------------------
# core trace
# -----------------------------
def run_trace_once(args, model, tokenizer, dataset, dataloader, device):
    model.eval()

    if args.trace_sample_index < 0 or args.trace_sample_index >= len(dataset):
        raise ValueError(f"trace_sample_index out of range: {args.trace_sample_index} (dataset size={len(dataset)})")

    batch = get_nth_batch(dataloader, args.trace_sample_index)
    batch = {k: v.to(device) for k, v in batch.items()}

    input_ids = batch["input_ids"][0]           # [T]
    attention_mask = batch["attention_mask"][0] # [T]
    labels = batch["labels"][0]                 # [T]

    prompt_mask = (labels == -100) & (attention_mask == 1)
    prompt_ids = input_ids[prompt_mask]         # [Lp]
    if prompt_ids.numel() == 0:
        print("[Trace] Empty prompt_ids; skip.")
        return

    prompt_text = tokenizer.decode(prompt_ids.tolist(), skip_special_tokens=False)
    option_keys = extract_option_keys_from_prompt(prompt_text)
    if len(option_keys) == 0:
        option_keys = ["A", "B", "C", "D"]

    gold_key, gold_raw = extract_gold_key_from_labels(tokenizer, input_ids, labels, attention_mask)

    print("=" * 100)
    print(f"[Trace] Dataset={args.sft_dataset} | split={args.eval_split} | index={args.trace_sample_index}")
    print(f"[Trace] Detected option keys: {option_keys}")
    print(f"[Trace] Gold key: {gold_key!r} (raw={gold_raw!r})")
    print("-" * 100)
    print(prompt_text)
    print("=" * 100)

    rows_all = []
    found = False
    found_pred_key = None
    found_step_tag = None

    # Prefill
    prefill_ids = prompt_ids.unsqueeze(0)  # [1,Lp]
    prefill_attn = torch.ones_like(prefill_ids, dtype=torch.long, device=device)

    with torch.no_grad():
        out_prefill = model(
            input_ids=prefill_ids,
            attention_mask=prefill_attn,
            use_cache=True,
            output_attentions=True
        )

    logits_last = out_prefill.logits[:, -1, :]  # [1,V]
    next_id = int(torch.argmax(logits_last, dim=-1).item())
    next_tok_str = decode_token_clean(tokenizer, next_id)
    next_abcd = is_abcd_token(next_tok_str)

    if next_abcd is not None:
        found = True
        found_pred_key = next_abcd
        found_step_tag = "prefill_predict_option"

        attentions = out_prefill.attentions
        history_ids = prompt_ids.tolist()
        qpos = len(history_ids) - 1

        rows_all.extend(attentions_to_rows(
            tokenizer=tokenizer,
            attentions=attentions,
            history_token_ids=history_ids,
            query_pos_in_layer=qpos,
            step_tag=found_step_tag
        ))

        print(f"[Trace] Predicted option immediately after prompt: {found_pred_key!r}")
        print(f"[Trace] Gold: {gold_key!r}")

    # Decode loop
    if not found:
        past = out_prefill.past_key_values
        generated_ids: List[int] = []
        curr_input_id = next_id

        print(f"[Trace] First generated token after prompt (not option): {next_tok_str!r}")

        for step in range(args.max_new_tokens_trace):
            inp = torch.tensor([[curr_input_id]], device=device, dtype=torch.long)
            attn = torch.ones_like(inp, dtype=torch.long, device=device)

            past_before = past
            with torch.no_grad():
                out = model(
                    input_ids=inp,
                    attention_mask=attn,
                    past_key_values=past_before,
                    use_cache=True,
                    output_attentions=False
                )

            logits = out.logits[:, -1, :]
            pred_id = int(torch.argmax(logits, dim=-1).item())
            pred_tok_str = decode_token_clean(tokenizer, pred_id)
            pred_abcd = is_abcd_token(pred_tok_str)

            if pred_abcd is not None:
                found = True
                found_pred_key = pred_abcd
                found_step_tag = f"decode_step_{step}_predict_option"

                with torch.no_grad():
                    out_trace = model(
                        input_ids=inp,
                        attention_mask=attn,
                        past_key_values=past_before,
                        use_cache=True,
                        output_attentions=True
                    )

                attentions = out_trace.attentions
                history_ids = prompt_ids.tolist() + generated_ids + [curr_input_id]
                qpos = 0

                rows_all.extend(attentions_to_rows(
                    tokenizer=tokenizer,
                    attentions=attentions,
                    history_token_ids=history_ids,
                    query_pos_in_layer=qpos,
                    step_tag=found_step_tag
                ))

                print(f"[Trace] Predicted option at decode step {step}: {found_pred_key!r}")
                print(f"[Trace] Gold: {gold_key!r}")
                break

            generated_ids.append(curr_input_id)
            past = out.past_key_values
            curr_input_id = pred_id

            if step < 5:
                print(f"[Trace] step={step} | fed={decode_token_clean(tokenizer, generated_ids[-1])!r} "
                      f"| next_pred={pred_tok_str!r}")

    os.makedirs(args.output_dir, exist_ok=True)

    summary = {
        "dataset": args.sft_dataset,
        "split": args.eval_split,
        "index": args.trace_sample_index,
        "gold_key": gold_key,
        "gold_raw_decoded": gold_raw,
        "pred_key": found_pred_key,
        "found_step_tag": found_step_tag,
        "found": bool(found),
        "num_rows": len(rows_all),
    }
    with open(os.path.join(args.output_dir, "trace_summary.json"), "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)

    # -------------------------
    # Fallback: if no A/B/C/D found, use the FIRST generated token after prompt (prefill prediction)
    # -------------------------
    if not found:
        found = True
        found_pred_key = next_tok_str  # 直接记录第一个生成 token 的文本（不再要求是 A/B/C/D）
        found_step_tag = "fallback_first_generated_token"

        # 用 prefill 的 attentions：last prompt token (qpos=Lp-1) attends over prompt history keys
        attentions = out_prefill.attentions
        history_ids = prompt_ids.tolist()
        qpos = len(history_ids) - 1

        rows_all.extend(attentions_to_rows(
            tokenizer=tokenizer,
            attentions=attentions,
            history_token_ids=history_ids,
            query_pos_in_layer=qpos,
            step_tag=found_step_tag
        ))

        print(f"[Trace][Fallback] No A/B/C/D found within max_new_tokens_trace={args.max_new_tokens_trace}.")
        print(f"[Trace][Fallback] Use FIRST generated token after prompt: {found_pred_key!r}")
        print(f"[Trace][Fallback] Gold: {gold_key!r}")

    csv_path = os.path.join(args.output_dir, "trace_token_contrib.csv")
    write_csv(csv_path, rows_all)
    print(f"[Trace] Saved CSV: {csv_path}")
    print(f"[Trace] Saved summary: {os.path.join(args.output_dir, 'trace_summary.json')}")


def main():
    torch.backends.cudnn.enabled = False

    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    print("====== Trace args (single GPU) ======")
    for k, v in vars(args).items():
        print(f"{k}: {v}")
    print("====================================")

    set_seed(args.seed)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    print(f"[Load] model from {args.model_name_or_path}")
    model, tokenizer = load_model_and_tokenizer(
        args.model_name_or_path,
        dtype=args.dtype,
        force_eager_attn=args.force_eager_attn
    )
    model.to(device)

    dataloader, dataset = build_eval_dataloader(args, tokenizer)
    print(f"[Data] dataset size: {len(dataset)}")

    run_trace_once(args, model, tokenizer, dataset, dataloader, device)


if __name__ == "__main__":
    main()
