#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations
import os, re, json, argparse, random
from typing import List, Optional, Dict, Any, Tuple

import numpy as np
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# ======================= Defaults =======================
DATASETS = [
    "google-research-datasets/mbpp",
    "stanfordnlp/coqa",
    "openai/gsm8k",
    "EleutherAI/hendrycks_math",
    "newfacade/LeetCodeDataset",
    "allenai/openbookqa",
    "allenai/ai2_arc",
    "allenai/math_qa",
]
MODELS = [
    "Qwen/Qwen2.5-1.5B-Instruct",
    "Qwen/Qwen2.5-Math-1.5B-Instruct",
    "Qwen/Qwen2.5-Coder-1.5B-Instruct",
    "meta-llama/Llama-3.2-1B-Instruct",
    "tiiuae/Falcon3-1B-Instruct",
    "Qwen/Qwen2.5-7B-Instruct",
    "Qwen/Qwen2.5-Math-7B-Instruct",
    "Qwen/Qwen2.5-Coder-7B-Instruct",
    "meta-llama/Llama-3.1-8B-Instruct",
    "tiiuae/Falcon3-7B-Instruct",
    "amd/AMD-OLMo-1B-SFT",
    "allenai/OLMo-2-0425-1B-Instruct",
    "google/gemma-3-1b-it",
    "deepseek-ai/deepseek-coder-1.3b-instruct",
    "deepseek-ai/deepseek-llm-7b-chat",
    "deepseek-ai/deepseek-math-7b-instruct",
    "deepseek-ai/deepseek-coder-6.7b-instruct",
    "allenai/OLMo-2-1124-7B-Instruct",
]
DEFAULT_CHAT_TEMPLATE = (
    "{% for m in messages %}"
    "{% if m['role'] == 'system' %}System: {{ m['content'] }}\n"
    "{% elif m['role'] == 'user' %}User: {{ m['content'] }}\n"
    "{% elif m['role'] == 'assistant' %}Assistant: {{ m['content'] }}\n"
    "{% endif %}"
    "{% endfor %}"
)

# ======================= Utils =======================
def ensure_chat_template(tokenizer):
    if not getattr(tokenizer, "chat_template", None):
        tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
    return tokenizer

def ds_sanitize(name: str) -> str:
    import re as _re
    return _re.sub(r"[^A-Za-z0-9_.-]", "_", name)

def short_name(model_id: str) -> str:
    return model_id.split("/")[-1]

# ======================= Data =======================
def load_qa_dataset(dataset_name: str, split: str = "train") -> List[str]:
    """Load the question texts from a HF dataset; returns a list of prompts/questions."""
    from datasets import load_dataset
    if dataset_name == "openai/gsm8k":
        ds = load_dataset(dataset_name, "main", split=split)
        return [ex["question"] for ex in ds]
    elif dataset_name == "google-research-datasets/mbpp":
        ds = load_dataset(dataset_name, split=split)
        return [ex.get("text", "").strip() for ex in ds]
    elif dataset_name == "EleutherAI/hendrycks_math":
        subsets = ["algebra","counting_and_probability","geometry","intermediate_algebra","number_theory","prealgebra","precalculus"]
        data_all: List[str] = []
        for sub in subsets:
            ds = load_dataset(dataset_name, sub, split=split)
            data_all.extend([ex["problem"] for ex in ds])
        return data_all
    elif dataset_name == "stanfordnlp/coqa":
        ds = load_dataset(dataset_name, split=split)
        return [ex["story"] for ex in ds]
    elif dataset_name == "newfacade/LeetCodeDataset":
        ds = load_dataset(dataset_name, split=split)
        return [ex["query"] for ex in ds]
    elif dataset_name == "allenai/openbookqa":
        ds = load_dataset(dataset_name, split=split)
        return [ex["question_stem"] for ex in ds]
    elif dataset_name == "allenai/ai2_arc":
        ds = load_dataset(dataset_name, "ARC-Challenge", split=split)
        return [ex["question"] for ex in ds]
    elif dataset_name == "allenai/math_qa":
        ds = load_dataset(dataset_name, split=split)
        return [ex["Problem"] for ex in ds]
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

def sample_questions_n(all_q: List[str], n: int, seed: int) -> List[str]:
    """Sample up to n unique examples uniformly at random (no replacement)."""
    if not all_q:
        return []
    n = max(1, min(n, len(all_q)))
    rng = random.Random(seed)
    idx = list(range(len(all_q)))
    rng.shuffle(idx)
    return [all_q[i] for i in idx[:n]]

# ======================= Dispersion & R² =======================
EPS = 1e-12

def layer_dispersion_from_gram(H, per_q_cap: Optional[int] = None) -> float:
    """Compute dispersion from the Gram matrix of token embeddings for a single layer."""
    import numpy as np, torch
    if isinstance(H, torch.Tensor):
        H = H.detach().to(torch.float32).cpu().numpy()
    if not isinstance(H, np.ndarray) or H.ndim != 2 or H.shape[0] <= 1:
        return 0.0
    if per_q_cap is not None and H.shape[0] > per_q_cap:
        sel = np.random.default_rng(42).choice(H.shape[0], size=per_q_cap, replace=False)
        H = H[sel]
    H = H.astype(np.float32, copy=False)
    H -= H.mean(axis=0, keepdims=True)
    G = H @ H.T
    G = 0.5 * (G + G.T)
    evals = np.linalg.eigvalsh(G)
    evals = np.clip(evals, 0.0, None)
    s = float(evals.sum()) + EPS
    p = evals / s
    hhi = float(np.sum(p * p))
    Teff = len(evals)
    return float((1.0 - hhi) / (1.0 - 1.0 / Teff + EPS))

def forward_dispersion_streaming(model, tokenizer, questions, device, max_length, per_q_cap):
    """Run the model to get hidden states per prompt and compute per-layer dispersion."""
    import numpy as np, torch
    rows = []

    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token

    try:
        model.config.use_cache = False
    except Exception:
        pass
    model.eval()

    from contextlib import nullcontext
    amp_dtype = None
    try:
        p_dtype = next(model.parameters()).dtype
        if p_dtype in (torch.bfloat16, torch.float16):
            amp_dtype = p_dtype
    except StopIteration:
        pass
    amp_ctx = torch.autocast("cuda", dtype=amp_dtype) if (amp_dtype and device.type == "cuda") else nullcontext()

    with torch.inference_mode():
        for q in questions:
            msgs = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": q},
            ]
            try:
                text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
            except Exception:
                text = f"System: You are a helpful assistant.\nUser: {q}\nAssistant:"

            enc = tokenizer(text, return_tensors="pt", padding=False, truncation=True, max_length=max_length)
            enc = {k: v.to(device) for k, v in enc.items()}

            with amp_ctx:
                out = model.forward(**enc, output_hidden_states=True, return_dict=True, use_cache=False)

            hidden_list = getattr(out, "hidden_states", None)
            if hidden_list is None:
                try:
                    model.config.output_hidden_states = True
                    with amp_ctx:
                        out = model.forward(**enc, output_hidden_states=True, return_dict=True, use_cache=False)
                    hidden_list = getattr(out, "hidden_states", None)
                except Exception:
                    hidden_list = None

            if hidden_list is None:
                last = getattr(out, "last_hidden_state", None)
                if last is None:
                    continue
                hidden_list = (last,)

            vals = []
            for h in hidden_list:
                H = h.squeeze(0).detach().cpu().numpy()
                vals.append(layer_dispersion_from_gram(H, per_q_cap=per_q_cap))
            rows.append(np.array(vals, dtype=np.float32))

            del out, enc, h, H
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    Lmin = min((len(r) for r in rows), default=0)
    return np.stack([r[:Lmin] for r in rows], axis=0) if Lmin > 0 else np.zeros((0, 0), np.float32)

def slice_expansion_only(U_raw: np.ndarray) -> np.ndarray:
    """Keep the segment from the global minimum of mean dispersion onward (expansion phase)."""
    if U_raw.size == 0 or U_raw.shape[1] < 2:
        return U_raw
    idx_min = int(np.argmin(U_raw.mean(axis=0)))
    if idx_min >= U_raw.shape[1] - 1:
        idx_min = max(0, U_raw.shape[1] - 2)
    return U_raw[:, idx_min:]

def _interp_traj(u: np.ndarray, K: int) -> np.ndarray:
    """Interpolate a length-L trajectory to K bins and back to length L."""
    L = len(u)
    if L == 0 or K < 2:
        return u.copy()
    tL = np.linspace(0.0, 1.0, num=L, dtype=np.float32)
    tK = np.linspace(0.0, 1.0, num=K, dtype=np.float32)
    return np.interp(tL, tK, np.interp(tK, tL, u))

def _r2(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Compute R^2; fall back to cosine-like similarity if variance is ~0."""
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    if y_true.size == 0 or y_pred.size == 0:
        return 1.0
    y_mean = float(y_true.mean())
    sse = float(np.sum((y_true - y_pred)**2))
    sst = float(np.sum((y_true - y_mean)**2))
    if sst <= 1e-12:
        num = float(np.dot(y_true, y_pred))
        den = float(np.linalg.norm(y_true) * np.linalg.norm(y_pred) + 1e-12)
        return num/den if den > 0 else 1.0
    return 1.0 - (sse/(sst + 1e-12))

def _quantile_r2_over_samples(U: np.ndarray, K: int, q: float) -> float:
    """Compute the q-quantile of R^2 across samples after K-bin interpolation."""
    if U.size == 0:
        return 1.0
    r = []
    for u in U:
        r.append(_r2(u, _interp_traj(u, K)))
    return float(np.quantile(np.asarray(r, dtype=np.float64), q)) if r else 1.0

def select_min_K_r2(U, r2_thresh, k_min, k_max, k_step, q):
    """Find the smallest K in [k_min, k_max] such that quantile R^2 >= threshold."""
    import numpy as np
    if not isinstance(k_min, int) or not isinstance(k_max, int) or not isinstance(k_step, int):
        return None
    if k_max < k_min or k_step <= 0:
        return None
    if U.size == 0 or U.shape[1] < 2:
        return None
    if not np.isfinite(U).any():
        return None

    for K in range(k_min, k_max + 1, k_step):
        r2_q = _quantile_r2_over_samples(U, K, q=q)
        if not np.isfinite(r2_q):
            continue
        if r2_q >= r2_thresh:
            return K
    return None

# ======================= Per (dataset, model) =======================
def compute_k(dataset: str, model_id: str, seed: int, args, sample_n: int) -> Optional[int]:
    """Compute K for a (dataset, model) pair using exactly N random samples from the dataset."""
    # Always load from HF datasets; do not use any local representation files.
    all_q = load_qa_dataset(dataset, split="train")
    if len(all_q) == 0:
        return None
    questions = sample_questions_n(all_q, n=sample_n, seed=seed)
    if len(questions) == 0:
        return None

    # tokenizer / model
    try:
        tok = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
    except Exception:
        tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    tok = ensure_chat_template(tok)
    if "gemma" in model_id:
        tok.padding_side = "left"

    torch_dtype = torch.bfloat16 if args.dtype == "bfloat16" else (torch.float16 if args.dtype == "float16" else torch.float32)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        output_hidden_states=True,
        low_cpu_mem_usage=True,
        device_map="auto",
        trust_remote_code=True,
        torch_dtype=torch_dtype,
    )
    if getattr(model.config, "use_cache", None) is not False:
        model.config.use_cache = False
    device = next(model.parameters()).device

    U_raw = forward_dispersion_streaming(model, tok, questions, device=device, max_length=args.max_length, per_q_cap=args.per_q_cap)
    del model, tok
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    if U_raw.size == 0:
        return None
    U_seg = slice_expansion_only(U_raw) if args.segment == "expansion" else U_raw
    if U_seg.size == 0:
        return None
    return select_min_K_r2(U_seg, args.r2_thresh, args.kmin, args.kmax, args.kstep, args.q)

# ======================= CLI =======================
def main():
    ap = argparse.ArgumentParser(description="Find K (bins) via R^2 threshold and print results")
    ap.add_argument("--datasets", nargs="+", default=DATASETS)
    ap.add_argument("--models", nargs="+", default=MODELS)
    ap.add_argument("--seed", type=int, default=42)

    # Use a fixed sample size instead of a fraction of the dataset.
    ap.add_argument("--sample-n", type=int, default=100, help="Number of random samples to draw per dataset (default: 100)")

    ap.add_argument("--segment", choices=["full","expansion"], default="full")
    ap.add_argument("--dtype", choices=["bfloat16","float16","float32"], default="float16")
    ap.add_argument("--max-length", type=int, default=1024)
    ap.add_argument("--per-q-cap", type=int, default=None)
    # R² search
    ap.add_argument("--r2-thresh", type=float, default=0.95)
    ap.add_argument("--kmin", type=int, default=4)
    ap.add_argument("--kmax", type=int, default=32)
    ap.add_argument("--kstep", type=int, default=1)
    ap.add_argument("--q", type=float, default=0.9)
    ap.add_argument("--save-json", type=str, default=None, help="(optional) Save results to a JSON file")
    args = ap.parse_args()

    sample_n = max(1, args.sample_n)
    results = []
    print(f"=== K selection: n={sample_n} samples, segment={args.segment}, r2>={args.r2_thresh}, k∈[{args.kmin},{args.kmax}] ===")

    for ds in args.datasets:
        rec = {"dataset": ds, "n": sample_n, "per_model_k": {}, "final_k": None}
        ks = []
        for m in args.models:
            try:
                k = compute_k(ds, m, args.seed, args, sample_n)
            except Exception as e:
                print(f"[ERROR] {ds} | {m}: {e}")
                k = None
            rec["per_model_k"][m] = k
            if isinstance(k, int):
                ks.append(k)
            print(f"[K] dataset={ds} model={m} -> K={k}")
        rec["final_k"] = max(ks) if ks else None
        print(f"[K] dataset={ds} -> final_k (max across models) = {rec['final_k']}")
        results.append(rec)

    if args.save_json:
        with open(args.save_json, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        print(f"[SAVED] {args.save_json}")

if __name__ == "__main__":
    main()
