#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import re
import json
import argparse
import random
from typing import List, Dict, Optional

import numpy as np
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# ======================= Configuration =======================
RESAMPLE_BINS = 25

DATASETS = [
    "openai/gsm8k",
    "google-research-datasets/mbpp",
    "allenai/openbookqa",
    "allenai/ai2_arc",
    "allenai/math_qa",

    "stanfordnlp/coqa",
    "EleutherAI/hendrycks_math",

    "newfacade/LeetCodeDataset",
]

MODELS = [
   "Qwen/Qwen2.5-1.5B-Instruct",
   "Qwen/Qwen2.5-Math-1.5B-Instruct",
   "Qwen/Qwen2.5-Coder-1.5B-Instruct",
   "meta-llama/Llama-3.2-1B-Instruct",
   "tiiuae/Falcon3-1B-Instruct",

   "Qwen/Qwen2.5-7B-Instruct",
   "Qwen/Qwen2.5-Math-7B-Instruct",
   "Qwen/Qwen2.5-Coder-7B-Instruct",
   "meta-llama/Llama-3.1-8B-Instruct",
   "tiiuae/Falcon3-7B-Instruct",

   "amd/AMD-OLMo-1B-SFT",
   "allenai/OLMo-2-0425-1B-Instruct",
   "google/gemma-3-1b-it",
   "deepseek-ai/deepseek-coder-1.3b-instruct",

   "deepseek-ai/deepseek-llm-7b-chat",
   "deepseek-ai/deepseek-math-7b-instruct",
   "deepseek-ai/deepseek-coder-6.7b-instruct",
   "allenai/OLMo-2-1124-7B-Instruct",
]

DEFAULT_CHAT_TEMPLATE = (
    "{% for m in messages %}"
    "{% if m['role'] == 'system' %}System: {{ m['content'] }}\n"
    "{% elif m['role'] == 'user' %}User: {{ m['content'] }}\n"
    "{% elif m['role'] == 'assistant' %}Assistant: {{ m['content'] }}\n"
    "{% endif %}"
    "{% endfor %}"
)

# ======================= Utils =======================

def ensure_dir(p: str) -> None:
    os.makedirs(p, exist_ok=True)


def short_name(model_id: str) -> str:
    return model_id.split("/")[-1]


def ensure_chat_template(tokenizer):
    if not getattr(tokenizer, "chat_template", None):
        tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
    return tokenizer

def ds_sanitize(name: str) -> str:
    return re.sub(r"[^A-Za-z0-9_.-]", "_", name)

# ======================= Data Loading =======================

def load_local_representation_questions(dataset: str) -> Optional[List[str]]:
    """Try local JSON: data/cluster/representation_{dataset}.json with field 'question'."""
    cand = os.path.join("data", "cluster", f"representation_{dataset}.json")
    if os.path.exists(cand):
        with open(cand, "r", encoding="utf-8") as f:
            raw = json.load(f)
        return [e.get("question", "").strip() for e in raw if e.get("question")]
    return None


def load_qa_dataset(dataset_name: str, split: str = "train", subset: Optional[str] = None) -> List[str]:
    """Return a flat list of prompts/questions from the HF TRAIN split.

    Supported:
    - openai/gsm8k (config: main) → field: question
    - google-research-datasets/mbpp → field: text
    - EleutherAI/hendrycks_math (merged 7 subsets) → field: problem
    - stanfordnlp/coqa → field: story (used as prompt)
    - newfacade/LeetCodeDataset → field: query
    - allenai/openbookqa → field: question_stem
    - allenai/ai2_arc (ARC-Challenge) → field: question
    - allenai/math_qa → field: Problem
    """
    from datasets import load_dataset
    if dataset_name == "openai/gsm8k":
        ds = load_dataset(dataset_name, "main", split=split)
        return [ex["question"] for ex in ds]
    elif dataset_name == "google-research-datasets/mbpp":
        ds = load_dataset(dataset_name, split=split)
        return [ex.get("text", "").strip() for ex in ds]
    elif dataset_name == "EleutherAI/hendrycks_math":
        subsets = [
            "algebra",
            "counting_and_probability",
            "geometry",
            "intermediate_algebra",
            "number_theory",
            "prealgebra",
            "precalculus",
        ]
        data_all: List[str] = []
        for sub in subsets:
            print(f"🔹 Loading subset: {sub}")
            ds = load_dataset(dataset_name, sub, split=split)
            data_all.extend([ex["problem"] for ex in ds])
        print(f"✅ Loaded {len(data_all)} total samples from {len(subsets)} subsets")
        return data_all

    # elif dataset_name == "stanfordnlp/coqa":
    #     # Original variant: build story+single question prompt examples.
    #     ds = load_dataset(dataset_name, split=split)
    #     examples = []
    #     for ex in ds:
    #         story = ex["story"]
    #         for q in ex["questions"]:
    #             prompt = f"{story}\n\nQ: {q}\nA:"
    #             examples.append({"prompt": prompt})
    #     return examples

    elif dataset_name == "stanfordnlp/coqa":
        ds = load_dataset(dataset_name, split=split)  # canonical short name
        return [ex["story"] for ex in ds]

    elif dataset_name == "newfacade/LeetCodeDataset":
        ds = load_dataset(dataset_name, split=split)
        return [ex["query"] for ex in ds]
    elif dataset_name == "allenai/openbookqa":
        ds = load_dataset(dataset_name, split=split)
        return [ex["question_stem"] for ex in ds]
    elif dataset_name == "allenai/ai2_arc":
        # Default to ARC-Challenge; change here if you want ARC-Easy
        ds = load_dataset(dataset_name, "ARC-Challenge", split=split)
        return [ex["question"] for ex in ds]
    elif dataset_name == "allenai/math_qa":
        ds = load_dataset(dataset_name, split=split)
        return [ex["Problem"] for ex in ds]
    else:
        raise ValueError(f"Dataset {dataset_name} not supported")


def sample_questions_fraction(all_q: List[str], frac: float, seed: int) -> List[str]:
    """Randomly sample a fraction of questions with a fixed seed."""
    assert 0.0 < frac <= 1.0, "frac must be in (0,1]"
    n = max(1, int(len(all_q) * frac))
    rng = random.Random(seed)
    idx = list(range(len(all_q)))
    rng.shuffle(idx)
    sel = idx[:n]
    return [all_q[i] for i in sel]

# ======================= Dispersion =======================
EPS = 1e-12

def layer_dispersion_from_gram(H: np.ndarray, per_q_cap: Optional[int] = None) -> float:
    """Compute normalized dispersion from token hidden states H of shape (T, D)."""
    if isinstance(H, torch.Tensor):
        H = H.detach().to(torch.float32).cpu().numpy()
    if not isinstance(H, np.ndarray) or H.ndim != 2 or H.shape[0] <= 1:
        return 0.0
    if per_q_cap is not None and H.shape[0] > per_q_cap:
        sel = np.random.default_rng(42).choice(H.shape[0], size=per_q_cap, replace=False)
        H = H[sel]
    H = H.astype(np.float32, copy=False)
    H -= H.mean(axis=0, keepdims=True)
    G = H @ H.T
    G = 0.5 * (G + G.T)
    evals = np.linalg.eigvalsh(G)
    evals = np.clip(evals, 0.0, None)
    s = float(evals.sum()) + EPS
    p = evals / s
    hhi = float(np.sum(p * p))
    Teff = len(evals)
    return float((1.0 - hhi) / (1.0 - 1.0 / Teff + EPS))


def build_dispersion_from_forward(hidden_states_list: List[List[torch.Tensor]], per_q_cap: Optional[int] = None) -> np.ndarray:
    """Given hidden states per sample and per layer (each (1, T, D)), compute U_raw of shape (Q, Lmin)."""
    Q = len(hidden_states_list)
    if Q == 0:
        return np.zeros((0, 0), dtype=np.float32)
    Lmin = min(len(hs) for hs in hidden_states_list if hs)
    if Lmin < 1:
        return np.zeros((Q, 0), dtype=np.float32)
    U = np.zeros((Q, Lmin), dtype=np.float32)
    for q, hs in enumerate(hidden_states_list):
        for ell in range(Lmin):
            H = hs[ell].squeeze(0)  # (1,T,D) -> (T,D)
            U[q, ell] = layer_dispersion_from_gram(H, per_q_cap=per_q_cap)
    return U


def slice_expansion_only(U_raw: np.ndarray) -> np.ndarray:
    """Keep only the 'expansion' segment after the layer index with the minimum mean dispersion."""
    if U_raw.size == 0 or U_raw.shape[1] < 2:
        return U_raw
    idx_min = int(np.argmin(U_raw.mean(axis=0)))
    if idx_min >= U_raw.shape[1] - 1:
        idx_min = max(0, U_raw.shape[1] - 2)
    return U_raw[:, idx_min:]


def resample_trajs_by_depth(U: np.ndarray, L_target: int) -> np.ndarray:
    """Resample each depth trajectory to a fixed number of bins L_target via linear interpolation."""
    if U.size == 0:
        return U
    Q, L = U.shape
    if L_target == L:
        return U.copy()
    t0 = np.linspace(0.0, 1.0, num=L, dtype=np.float32)
    t1 = np.linspace(0.0, 1.0, num=L_target, dtype=np.float32)
    U1 = np.empty((Q, L_target), dtype=np.float32)
    for i in range(Q):
        U1[i] = np.interp(t1, t0, U[i])
    np.clip(U1, 0.0, 1.0, out=U1)
    return U1

# ======================= Model forward (no generation) =======================
# def forward_hidden_states(model, tokenizer, questions: List[str], device: torch.device, max_length: int) -> List[List[torch.Tensor]]:
#     """Alternative path that stores hidden states explicitly (kept here for reference)."""
#     out_list: List[List[torch.Tensor]] = []
#     for q in questions:
#         messages = [
#             {"role": "system", "content": "You are a helpful assistant."},
#             {"role": "user",   "content": q},
#         ]
#         try:
#             text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
#         except Exception:
#             text = f"System: You are a helpful assistant.\nUser: {q}\nAssistant:"
#         enc = tokenizer(
#             text,
#             return_tensors="pt",
#             padding=False,
#             truncation=True,
#             max_length=max_length,
#         )
#         enc = {k: v.to(device) for k, v in enc.items()}
#         with torch.no_grad():
#             out = model(**enc, output_hidden_states=True, return_dict=True, use_cache=False)
#             hs = out.hidden_states  # tuple length = n_layers+1, each (1,T,D)
#         out_list.append(list(hs))
#         del out, hs, enc
#         if torch.cuda.is_available():
#             torch.cuda.empty_cache()
#     return out_list

def forward_dispersion_streaming(
    model, tokenizer, questions, device, max_length, per_q_cap=None
):
    """Streamed forward pass that directly computes dispersion per layer for each question."""

    rows = []
    # Patch pad_token if needed
    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token
    try:
        model.config.use_cache = False
    except Exception:
        pass
    model.eval()

    from contextlib import nullcontext
    amp_dtype = torch.bfloat16 if next(model.parameters()).dtype == torch.bfloat16 else (
        torch.float16 if next(model.parameters()).dtype == torch.float16 else None
    )
    amp_ctx = torch.autocast("cuda", dtype=amp_dtype) if (amp_dtype and device.type=="cuda") else nullcontext()

    with torch.inference_mode():
        for q in questions:
            messages = [
                {"role":"system","content":"You are a helpful assistant."},
                {"role":"user","content":q},
            ]
            try:
                text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
            except Exception:
                text = f"System: You are a helpful assistant.\nUser: {q}\nAssistant:"

            enc = tokenizer(text, return_tensors="pt", padding=False, truncation=True, max_length=max_length)
            enc = {k: v.to(device) for k, v in enc.items()}

            with amp_ctx:
                out = model(**enc, output_hidden_states=True, return_dict=True, use_cache=False)

            row = []
            for h in out.hidden_states:          # each is (1, T, D)
                H = h.squeeze(0).detach().cpu()
                row.append(layer_dispersion_from_gram(H, per_q_cap=per_q_cap))

            rows.append(np.array(row, dtype=np.float32))

            del out, enc, h, H, row
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    # Align to (Q, Lmin)
    Lmin = min(len(r) for r in rows) if rows else 0
    U_raw = np.stack([r[:Lmin] for r in rows], axis=0) if Lmin>0 else np.zeros((0,0), np.float32)
    return U_raw


# ======================= Driver =======================
def run_once_for_dataset_model(dataset: str, model_id: str, seed: int, args, frac: float) -> Optional[str]:
    """Run the pipeline for a single (dataset, model, seed), saving an N×RESAMPLE_BINS matrix."""
    pct = int(round(frac * 100))
    ds_dir = f"{ds_sanitize(dataset)}_{pct}_{RESAMPLE_BINS}"
    save_dir = os.path.join(args.outdir, ds_dir, short_name(model_id))
    save_path = os.path.join(save_dir, f"{short_name(model_id)}_seed{seed}.npy")

    if os.path.exists(save_path):
        return save_path, False

    ensure_dir(save_dir)
    
    # 1) Load all questions from HF train split
    all_q = load_qa_dataset(dataset, split="train")
    if len(all_q) == 0:
        print(f"[WARN] No questions for dataset={dataset}")
        return None

    # 2) Sample a fraction of questions
    questions = sample_questions_fraction(all_q, frac=frac, seed=seed)
    
    # 3) Tokenizer/model
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
    except Exception:
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    tokenizer = ensure_chat_template(tokenizer)
    if "gemma" in model_id:
        tokenizer.padding_side = "left"

    torch_dtype = (
        torch.bfloat16 if args.dtype == "bfloat16"
        else (torch.float16 if args.dtype == "float16" else torch.float32)
    )

    # Shared kwargs: ensure hidden states are returned
    common_kwargs = dict(
        output_hidden_states=True,
        low_cpu_mem_usage=True,
        device_map="auto",
        trust_remote_code=True,
    )

    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch_dtype,
            # load_in_8bit=True,
            **common_kwargs,
        )
        # Prefer flash attention if available
        try:
            model.config.attn_implementation = "flash_attention_2"
        except Exception:
            pass
    except Exception:
        # Fallback: non-quantized path
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch_dtype,
            **common_kwargs,
        )
    # if torch.cuda.is_available():
    #     model.to("cuda")
    # After model load
    if getattr(model.config, "pad_token_id", None) is None and tokenizer.pad_token_id is not None:
        model.config.pad_token_id = tokenizer.pad_token_id

    if getattr(model.config, "use_cache", None) is not False:
        model.config.use_cache = False
    model.eval()

    device = next(model.parameters()).device

    # 4) Forward → dispersion (streaming)
    U_raw = forward_dispersion_streaming(
        model, tokenizer, questions, device=device, max_length=args.max_length, per_q_cap=args.per_q_cap
    )

    # Free memory aggressively
    del model, tokenizer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # 5) Segment and resample to fixed depth bins
    U_seg = slice_expansion_only(U_raw) if args.segment == "expansion" else U_raw
    U_bin = resample_trajs_by_depth(U_seg, L_target=RESAMPLE_BINS)

    if U_bin.size == 0:
        print(f"[WARN] Empty U after processing for {dataset} | {model_id} | seed={seed}")
        return None

    # 6) Save under: {dataset}_{PCT}_{RESAMPLE_BINS}
    np.save(save_path, U_bin)
    return save_path, True


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--outdir", type=str, default="result", help="save root for NxB matrices")
    parser.add_argument("--segment", type=str, choices=["full","expansion"], default="full")
    parser.add_argument("--seed", type=int, default=42, help="random seed for sampling")
    parser.add_argument("--max-length", type=int, default=1024)
    parser.add_argument("--dtype", type=str, choices=["bfloat16","float16","float32"], default="bfloat16")
    parser.add_argument("--per-q-cap", type=int, default=None, help="optional token cap per sample for dispersion")
    parser.add_argument("--datasets", nargs="*", default=DATASETS)
    parser.add_argument("--models", nargs="*", default=MODELS)
    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        "--sample-frac", type=float, nargs="+", default=[0.30],
        help="fractions of the train set to sample (0<frac<=1); can pass multiple like 0.1 0.3"
    )
    group.add_argument(
        "--sample-pct", type=int, nargs="+", default=None,
        help="percentages of the train set to sample (e.g., 10 30 50)"
    )
    args = parser.parse_args()

    if args.sample_pct is not None:
        args.sample_frac = [max(0.0, min(1.0, p / 100.0)) for p in args.sample_pct]

    if isinstance(args.sample_frac, float):
        args.sample_frac = [args.sample_frac]
    ensure_dir(args.outdir)

    for dataset in args.datasets:
        for frac in args.sample_frac:
            pct = int(round(frac * 100))
            print(f"\n===== DATASET={dataset} | FRACTION={frac:.2f} ({pct}%) | SEED={args.seed} =====")
            pbar = tqdm(args.models, desc=f"[{dataset}] models (seed {args.seed})")
            for m in pbar:
                pbar.set_postfix_str(short_name(m))
                try:
                    result = run_once_for_dataset_model(dataset, m, args.seed, args, frac)
                    if result is None:
                        continue
                    save_path, did_save = result
                    if did_save:
                        print(f"[SAVED] {save_path}")
                    else:
                        print(f"[SKIP] already exists: {save_path}")
                except Exception as e:
                    print(f"[ERROR] {dataset} | {m} | seed={args.seed}: {e}")

if __name__ == "__main__":
    main()
