# ==========================================
# 1. 环境安装与依赖导入
# ==========================================
!pip install -q torch transformers datasets accelerate evaluate wandb matplotlib pandas

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim import AdamW
from datasets import load_dataset
import wandb
import random
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import gc # 用于垃圾回收

# ==========================================
# 2. 全局配置与工具函数
# ==========================================


MANUAL_CONFIG = {
    # True: 进行采样 (do_sample=True); False: 贪婪解码 (do_sample=False)
    "USE_SAMPLING": True,

    # 手动 Tag (可选)。
    # 你可以填 "experiment1"，代码会自动变成 "experiment1_greedy" 或 "experiment1_sample"
    # 这样能绝对保证标签和实际运行模式一致。
    "CUSTOM_LABEL": "my_test",
}

# --- 自动处理 Tag 逻辑 (修正版) ---
# 1. 获取真实的模式字符串
real_mode_suffix = "sample" if MANUAL_CONFIG["USE_SAMPLING"] else "greedy"

# 2. 组合最终 Tag
if MANUAL_CONFIG["CUSTOM_LABEL"]:
    # 如果你有自定义标签，强制拼接真实模式，防止标错
    MANUAL_CONFIG["RUN_TAG"] = f"{MANUAL_CONFIG['CUSTOM_LABEL']}_{real_mode_suffix}"
else:
    # 如果没填，直接用模式名
    MANUAL_CONFIG["RUN_TAG"] = real_mode_suffix

print(f"当前实验配置: 最终Tag=[{MANUAL_CONFIG['RUN_TAG']}], 采样模式=[{MANUAL_CONFIG['USE_SAMPLING']}]")
# ============================================================


# 定义我们要对比的 Lambda (真实数据比例)
lambda_list = [0.0, 0.1]

# 定义随机种子列表
seeds = [2025, 2026, 2027]

base_config = {
    "model_id": "distilgpt2",
    "iterations": 20,
    "samples_per_iter": 64,
    "batch_size": 8,
    "lr": 5e-5,
    "seq_len": 128,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    # 将手动配置合并进 base_config
    "use_sampling": MANUAL_CONFIG["USE_SAMPLING"],
    "run_tag": MANUAL_CONFIG["RUN_TAG"]
}

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# --- 准备数据 (只加载一次) ---
print("Loading dataset...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
val_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
tokenizer = AutoTokenizer.from_pretrained(base_config["model_id"])
tokenizer.pad_token = tokenizer.eos_token

# 预处理验证集
val_encodings = tokenizer("\n\n".join(val_data['text'][:50]), return_tensors="pt", max_length=base_config["seq_len"], truncation=True).input_ids.to(base_config["device"])

def get_real_data(n):
    """从数据集中随机抽取 n 条真实数据"""
    texts = []
    while len(texts) < n:
        idx = random.randint(0, len(dataset)-1)
        t = dataset[idx]['text']
        if len(t) > 100: texts.append(t[:500])
    return texts

def evaluate_model(model):
    """计算验证集 PPL"""
    model.eval()
    with torch.no_grad():
        outputs = model(val_encodings, labels=val_encodings)
        loss = outputs.loss
    return torch.exp(loss).item()

def generate_synthetic_data(model, n_samples, use_sampling):
    """
    使用当前模型生成合成数据
    :param use_sampling: Boolean, 控制是否使用采样
    """
    model.eval()
    synthetic_texts = []
    prompts = ["The", "In", "It", "A", "Once", "However", "Despite"]

    # 批量生成以提高速度
    batch_size = 8
    num_batches = (n_samples + batch_size - 1) // batch_size

    for _ in range(num_batches):
        batch_prompts = [random.choice(prompts) for _ in range(batch_size)]
        inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True).to(base_config["device"])

        with torch.no_grad():
            # 修改点：根据传入的 use_sampling 参数动态决定
            outputs = model.generate(
                **inputs,
                max_new_tokens=base_config["seq_len"],
                do_sample=use_sampling,  # 这里直接使用变量，不再需要注释
                top_k=20 if use_sampling else None, # 如果不sample，top_k通常不需要
                temperature=0.7 if use_sampling else None, # 如果不sample，temp通常不需要
                pad_token_id=tokenizer.eos_token_id
            )
        synthetic_texts.extend(tokenizer.batch_decode(outputs, skip_special_tokens=True))

    return synthetic_texts[:n_samples]

class TextDataset(Dataset):
    def __init__(self, texts):
        self.encodings = tokenizer(texts, truncation=True, padding="max_length",
                                   max_length=base_config["seq_len"], return_tensors="pt")
    def __len__(self): return len(self.encodings.input_ids)
    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = item['input_ids'].clone()
        return item

# 用于本地存储所有实验结果，方便最后画图
all_results = []

# ==========================================
# 3. 实验主循环
# ==========================================

# 登录 WandB
wandb.login()

print(f"开始实验... 总共需要运行 {len(lambda_list) * len(seeds)} 次实验。")
print(f"当前 Tag: {base_config['run_tag']}")

for current_lambda in lambda_list:
    for seed in seeds:

        # --- A. 初始化设置 ---
        seed_everything(seed)

        # 清理显存
        torch.cuda.empty_cache()
        gc.collect()

        config = base_config.copy()
        config["lambda_val"] = current_lambda
        config["seed"] = seed

        # 修改点：Run Name 加上 Tag，方便在 WandB 列表中一眼看到
        run_name = f"{config['run_tag']}_L{current_lambda}_S{seed}"
        group_name = f"{config['run_tag']}_lambda_{current_lambda}"

        print(f"\n>>> Running: {run_name}")

        wandb.init(
            project="erbp-rebuttal-llm-v4",
            config=config,
            name=run_name,
            group=group_name,
            tags=[config['run_tag']], # 修改点：添加 Tag 到 WandB 元数据
            reinit=True
        )

        model = AutoModelForCausalLM.from_pretrained(config["model_id"]).to(config["device"])
        optimizer = AdamW(model.parameters(), lr=config["lr"])

        ppl_history = []
        div_history = []

        # --- B. 迭代训练 ---
        progress_bar = tqdm(range(config["iterations"]), desc=f"Tag={config['run_tag']}, L={current_lambda}, S={seed}")

        for t in progress_bar:

            # 1. 生成 (修改点：传入 use_sampling 参数)
            gen_texts = generate_synthetic_data(model, config["samples_per_iter"], use_sampling=config["use_sampling"])

            # 计算多样性 (Diversity)
            all_tokens = " ".join(gen_texts).split()
            if len(all_tokens) > 0:
                bigrams = set(zip(all_tokens, all_tokens[1:]))
                diversity_score = len(bigrams) / len(all_tokens)
            else:
                diversity_score = 0

            # 2. 混合 (Mixing)
            n_real = int(config["samples_per_iter"] * config["lambda_val"])
            n_gen = config["samples_per_iter"] - n_real

            train_texts = gen_texts[:n_gen]
            if n_real > 0:
                train_texts += get_real_data(n_real)

            random.shuffle(train_texts)

            # 3. 训练
            if len(train_texts) > 0:
                train_dataset = TextDataset(train_texts)
                train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)

                model.train()
                for batch in train_loader:
                    optimizer.zero_grad()
                    input_ids = batch['input_ids'].to(config["device"])
                    attention_mask = batch['attention_mask'].to(config["device"])
                    labels = batch['labels'].to(config["device"])

                    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                    loss = outputs.loss
                    loss.backward()
                    optimizer.step()

            # 4. 评估
            ppl = evaluate_model(model)
            ppl_history.append(ppl)
            div_history.append(diversity_score)

            # 记录到 WandB
            wandb.log({
                "perplexity": ppl,
                "diversity": diversity_score,
                "iteration": t,
                "tag": config['run_tag'] # 每次log都带上tag，确保万无一失
            })

            # 记录到本地数据结构 (修改点：增加 tag 字段)
            all_results.append({
                "tag": config['run_tag'], # <--- 关键修改：这里增加了 tag
                "lambda": current_lambda,
                "seed": seed,
                "iteration": t,
                "perplexity": ppl,
                "diversity": diversity_score,
                "use_sampling": config["use_sampling"]
            })

            progress_bar.set_postfix({"PPL": f"{ppl:.2f}", "Div": f"{diversity_score:.2f}"})

        wandb.finish()

# ==========================================
# 4. 结果可视化
# ==========================================
print("\n实验结束，正在生成汇总图表...")

# 将结果转换为 DataFrame
df = pd.DataFrame(all_results)

# 打印前几行检查 Tag 是否存在
print("数据预览 (前5行):")
print(df.head())

# 如果你想保存数据到 CSV，现在就可以根据 tag 区分了
csv_filename = f"experiment_results_{base_config['run_tag']}.csv"
df.to_csv(csv_filename, index=False)
print(f"数据已保存至: {csv_filename}")

plt.figure(figsize=(12, 5))

# --- 子图 1: Perplexity ---
plt.subplot(1, 2, 1)
for lam in lambda_list:
    subset = df[df["lambda"] == lam]
    stats = subset.groupby("iteration")["perplexity"].agg(['mean', 'std']).reset_index()

    plt.plot(stats["iteration"], stats["mean"], label=f"Lambda={lam}")
    plt.fill_between(stats["iteration"],
                     stats["mean"] - stats["std"],
                     stats["mean"] + stats["std"],
                     alpha=0.2)

plt.title(f"Perplexity ({base_config['run_tag']})")
plt.xlabel("Iteration")
plt.ylabel("Perplexity")
plt.legend()
plt.grid(True, alpha=0.3)

# --- 子图 2: Diversity ---
plt.subplot(1, 2, 2)
for lam in lambda_list:
    subset = df[df["lambda"] == lam]
    stats = subset.groupby("iteration")["diversity"].agg(['mean', 'std']).reset_index()

    plt.plot(stats["iteration"], stats["mean"], label=f"Lambda={lam}")
    plt.fill_between(stats["iteration"],
                     stats["mean"] - stats["std"],
                     stats["mean"] + stats["std"],
                     alpha=0.2)

plt.title(f"Diversity ({base_config['run_tag']})")
plt.xlabel("Iteration")
plt.ylabel("Diversity")
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()