import math
import random
import json
from typing import List, Tuple, Dict
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from generate import blockdiffusion_ppl, wavefront_ppl


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def make_prompt_continuation_pairs(dataset_split,
                                   tokenizer,
                                   n_samples: int,
                                   prompt_len: int,
                                   cont_len: int,
                                   seed: int = 42,
                                   max_tries: int = 20000) -> List[Tuple[List[int], List[int]]]:
    """
    从 dataset_split（text 字段）中抽样 n_samples 个 (prompt_ids, continuation_ids)
    token-level 切分：每个 sample 取一段足够长的文本并按连续 tokens 切成 prompt+continuation
    """
    rng = random.Random(seed)
    pairs = []
    tries = 0
    idxs = list(range(len(dataset_split)))
    rng.shuffle(idxs)

    for doc_idx in idxs:
        if len(pairs) >= n_samples or tries > max_tries:
            break
        tries += 1
        text = dataset_split[doc_idx]["text"]
        if not text or len(text.strip()) == 0:
            continue
        enc = tokenizer(text, add_special_tokens=False)["input_ids"]
        if len(enc) < prompt_len + cont_len:
            continue
        start_max = len(enc) - (prompt_len + cont_len)
        start = rng.randint(0, start_max)
        seg = enc[start:start + prompt_len + cont_len]
        prompt_ids = seg[:prompt_len]
        cont_ids = seg[prompt_len:]
        pairs.append((prompt_ids, cont_ids))

    if len(pairs) < n_samples:
        raise RuntimeError(f"只能抽到 {len(pairs)} 个样本（请求 {n_samples}），请调整 prompt/continuation 长度或扩充数据集来源。")
    return pairs[:n_samples]


def evaluate_pairs(model,
                   tokenizer,
                   pairs: List[Tuple[List[int], List[int]]],
                   device: torch.device,
                   method_fn,
                   method_name: str,
                   batch: int = 1,
                   verbose: bool = False,
                   **method_kwargs):

    results = []
    model.eval()
    for (prompt_ids, cont_ids) in tqdm(pairs, desc=f"Eval {method_name}", disable=not verbose):
        prompt = torch.tensor([prompt_ids], dtype=torch.long, device=device)
        continuation = torch.tensor([cont_ids], dtype=torch.long, device=device)
        with torch.no_grad():
            ppl, avg_nll = method_fn(model=model,
                                     prompt=prompt,
                                     continuation=continuation,
                                     tokenizer=tokenizer,
                                     **method_kwargs)
        results.append((ppl, avg_nll))
    return results


def paired_bootstrap_ci(diffs: List[float], B: int = 2000, alpha: float = 0.05):
    """
    对差值列表 diffs（例如 per-sample PPL_block - PPL_wavefront）做 paired bootstrap CI
    返回 (mean_diff, lower, upper)
    """
    arr = np.array(diffs)
    n = len(arr)
    rng = np.random.RandomState(0)
    boot_means = []
    for _ in range(B):
        idx = rng.randint(0, n, size=n)
        boot_means.append(arr[idx].mean())
    boot_means = np.sort(boot_means)
    lower = boot_means[int((alpha/2) * B)]
    upper = boot_means[int((1 - alpha/2) * B)]
    return float(arr.mean()), float(lower), float(upper)


def evaluate_on_wikitext103(model_name: str ='GSAI-ML/LLaDA-8B-Instruct',
                            n_samples: int = 200,
                            prompt_len: int = 128,
                            cont_len: int = 256,
                            seed: int = 1234,
                            device_name: str = None,
                            blockdiff_kwargs: Dict = None,
                            wavefront_kwargs: Dict = None,
                            dataset_split: str = "test"):
    if blockdiff_kwargs is None:
        blockdiff_kwargs = dict(steps=1024, block_length=8, temperature=0.0, cfg_scale=0.0, remasking="low_confidence")
    if wavefront_kwargs is None:
        wavefront_kwargs = dict(steps=1024, r=2, K=1, F_max=8, temperature=0.0, cfg_scale=0.0)

    set_seed(seed)
    device = torch.device(device_name or ("cuda" if torch.cuda.is_available() else "cpu"))

    print("加载 tokenizer 与模型：", model_name)
    model = AutoModel.from_pretrained(
        model_name, 
        trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model.eval()

    ds = load_dataset("Salesforce/wikitext", "wikitext-103-v1", split=dataset_split)
    print(f"数据集加载完成，样本数：{len(ds)}")

    print("抽样 prompt+continuation pairs ...")
    pairs = make_prompt_continuation_pairs(ds, tokenizer, n_samples=n_samples,
                                           prompt_len=prompt_len, cont_len=cont_len, seed=seed)

    print("评估 BlockDiffusion-style PPL ...")
    bd_results = evaluate_pairs(model, tokenizer, pairs, device, blockdiffusion_ppl,
                                method_name="BlockDiffusion", verbose=True, **blockdiff_kwargs)

    # 评估 Wavefront-style ppl
    print("评估 Wavefront-style PPL ...")
    wf_results = evaluate_pairs(model, tokenizer, pairs, device, wavefront_ppl,
                                method_name="Wavefront", verbose=True, **wavefront_kwargs)

    # 聚合与统计
    bd_ppls = np.array([r[0] for r in bd_results])
    wf_ppls = np.array([r[0] for r in wf_results])
    bd_nlls = np.array([r[1] for r in bd_results])
    wf_nlls = np.array([r[1] for r in wf_results])

    diff = bd_ppls - wf_ppls  # positive -> blockdiff higher perplexity

    mean_bd = float(bd_ppls.mean())
    mean_wf = float(wf_ppls.mean())
    std_bd = float(bd_ppls.std(ddof=1))
    std_wf = float(wf_ppls.std(ddof=1))

    mean_diff, ci_lower, ci_upper = paired_bootstrap_ci(diff.tolist(), B=2000, alpha=0.05)

    # 计算简单的 paired t-test p-value（可选）
    try:
        from scipy import stats
        t_stat, p_value = stats.ttest_rel(bd_ppls, wf_ppls)
    except Exception:
        p_value = None

    result = {
        "model": model_name,
        "n_samples": n_samples,
        "prompt_len": prompt_len,
        "cont_len": cont_len,
        "seed": seed,
        "blockdiffusion": {"mean_ppl": mean_bd, "std_ppl": std_bd},
        "wavefront": {"mean_ppl": mean_wf, "std_ppl": std_wf},
        "paired_diff_mean": mean_diff,
        "paired_diff_95ci": [ci_lower, ci_upper],
        "paired_ttest_pvalue": p_value,
        "per_sample": [
            {"index": i, "bd_ppl": float(bd_ppls[i]), "wf_ppl": float(wf_ppls[i]),
             "bd_nll": float(bd_nlls[i]), "wf_nll": float(wf_nlls[i])}
            for i in range(len(bd_ppls))
        ]
    }

    # 保存结果
    out_file = f"ppl_comparison_{model_name.replace('/', '_')}_{n_samples}samples_seed{seed}.json"
    with open(out_file, "w", encoding="utf-8") as f:
        json.dump(result, f, indent=2, ensure_ascii=False)

    print("结果保存到", out_file)
    print("Summary:")
    print(f" BlockDiffusion mean PPL: {mean_bd:.4f} ± {std_bd:.4f}")
    print(f" Wavefront     mean PPL: {mean_wf:.4f} ± {std_wf:.4f}")
    print(f" Paired mean diff (BD - WF): {mean_diff:.6f}, 95% CI = [{ci_lower:.6f}, {ci_upper:.6f}]")
    if p_value is not None:
        print(f" Paired t-test p-value: {p_value:.4e}")

    return result


if __name__ == "__main__":
    res = evaluate_on_wikitext103(model_name="GSAI-ML/LLaDA-8B-Instruct",
                                  n_samples=1,
                                  prompt_len=128,
                                  cont_len=256,
                                  seed=2025)
