import os
import types
import warnings
import random
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

from transformers import AutoTokenizer, GPTNeoXTokenizerFast
from transformer_lens import HookedTransformer
from transformer_lens.loading_from_pretrained import (
    OFFICIAL_MODEL_NAMES, MODEL_ALIASES, make_model_alias_map
)

warnings.filterwarnings("ignore", message="torch_dtype is deprecated")

# ========== 【配置区】只需修改这里 ==========
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 层与头（零基）
TARGET_LAYER = 3
TARGET_HEADS = [3]

# Induction 任务参数
NUM_SAMPLES = 100  # 每个模型测试的样本数
RANDOM_SEED = 42
PREFIX_LEN = 5     # 随机前缀长度
PATTERN_LEN = 3    # A B C
GAP_LEN = 3        # 中间间隔

# ===== 【核心配置】定义所有要测试的配置组合 =====
ALL_CONFIGS = [
    {
        'model': 14,
        'max_step': 2000,
    }
]

# ===== 【类别配置】定义所有可能的类别模板 =====
def get_categories_for_config(config):
    """根据config生成对应的categories"""
    model_size = config['model']
    max_step = config['max_step']
    
    # 生成训练步数范围
    steps = list(range(1100, max_step + 1, 100))
    
    # 定义类别组合
    categories_augment = {
        f"randomorder{model_size}": {
            "root": "",
            "pattern": f"randomorder\\randomorder{{step}}"
        },
        
    }
    
    return {'augment': categories_augment}, steps

# 输出目录
OUTPUT_BASE = r""

# ========== 以下代码无需修改 ==========

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)
    return path


def generate_experiment_name(config, category_type):
    """根据配置自动生成实验名称"""
    model_size = config['model']
    heads_str = "_".join(map(str, TARGET_HEADS))
    return f"model{model_size}_{category_type}_ablation_L{TARGET_LAYER}_H{heads_str}"


def bind_tokenizer_methods(model, tokenizer):
    """绑定 to_tokens / to_str_tokens 到 HookedTransformer 实例。"""
    model.tokenizer = tokenizer

    def to_tokens_tinystories(self, text: str, prepend_bos: bool = False, 
                              move_to_device: bool = True) -> torch.Tensor:
        if isinstance(text, str):
            tokens = self.tokenizer.encode(text, return_tensors="pt", 
                                          add_special_tokens=prepend_bos)
        else:
            tokens = torch.cat([
                self.tokenizer.encode(t, return_tensors="pt", 
                                     add_special_tokens=prepend_bos) 
                for t in text
            ], dim=-1)
        if move_to_device:
            tokens = tokens.to(self.cfg.device)
        return tokens

    def to_str_tokens_tinystories(self, tokens: torch.Tensor, 
                                  prepend_bos: bool = False) -> list[str]:
        if tokens.dim() == 2:
            tokens = tokens.squeeze(0)
        return [
            self.tokenizer.decode([token_id], skip_special_tokens=False) 
            for token_id in tokens.tolist()
        ]

    model.to_tokens = types.MethodType(to_tokens_tinystories, model)
    model.to_str_tokens = types.MethodType(to_str_tokens_tinystories, model)


def register_aliases(alias_to_path: dict):
    """把本地路径注册成 transformer_lens 可识别的 alias。"""
    for path, alias in alias_to_path.items():
        if path not in OFFICIAL_MODEL_NAMES:
            OFFICIAL_MODEL_NAMES.append(path)
        MODEL_ALIASES[path] = [alias]
    make_model_alias_map()


def build_model_paths(categories, steps):
    """构建模型路径"""
    cat_models = {cat: [] for cat in categories.keys()}
    alias_map = {}
    
    for cat_name, config in categories.items():
        root = config["root"]
        pattern = config["pattern"]
        
        for step in steps:
            subdir = pattern.format(step=step)
            path = os.path.join(root, subdir)
            alias = f"{cat_name}_{step}"
            cat_models[cat_name].append((step, alias, path))
            alias_map[path] = alias
    
    return cat_models, alias_map


def is_candidate_word(tokenizer, tid: int) -> bool:
    """过滤 token：保留空格开头的小写字母单词"""
    if tid in getattr(tokenizer, "all_special_ids", []):
        return False
    s = tokenizer.decode([tid], skip_special_tokens=False)
    if not s.startswith(" "):
        return False
    word = s[1:]
    if len(word) < 2:
        return False
    if not all(("a" <= c <= "z") for c in word):
        return False
    return True


def build_token_pool(tokenizer):
    """构建可用的 token 池"""
    pool = []
    for tid in range(tokenizer.vocab_size):
        try:
            if is_candidate_word(tokenizer, tid):
                pool.append(tid)
        except Exception:
            continue
    return pool


def build_induction_sequence(pool, rng, prefix_len, pattern_len, gap_len):
    """
    构造 induction 序列：[prefix] + [pattern] + [gap] + [pattern[:2]]
    返回：完整序列、目标位置、目标 token
    """
    prefix = [rng.choice(pool) for _ in range(prefix_len)]
    pattern = [rng.choice(pool) for _ in range(pattern_len)]
    # 确保 pattern 中的 token 互不相同
    while len(set(pattern)) != len(pattern):
        pattern = [rng.choice(pool) for _ in range(pattern_len)]
    
    gap = [rng.choice(pool) for _ in range(gap_len)]
    query = pattern[:2]  # A B
    
    seq = prefix + pattern + gap + query
    target_pos = len(seq) - 1  # 预测最后一个位置
    target_token = pattern[2]   # 期望预测 C
    
    return seq, target_pos, target_token


def evaluate_model_induction(
    model_alias: str,
    model_path: str,
    num_samples: int,
    target_layer: int,
    heads: list[int],
    seed: int
) -> tuple[float, dict]:
    """
    对单个模型：
      - 构造 induction 任务：[prefix] A B C [gap] A B ? 预测 C
      - 对每个 head：用 ablation 差值定义贡献
          Δ = logit_C(clean) - logit_C(ablated_head)
    返回：
      prob_mean: 平均 P(C) 在 clean 模型中
      contrib_means_by_head: {h: mean_Δ}
    """
    # 加载 tokenizer & 模型
    tokenizer = GPTNeoXTokenizerFast.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = HookedTransformer.from_pretrained_no_processing(
        model_alias,
        device=DEVICE,
        dtype=torch.float32,
        n_devices=1,
        tokenizer=tokenizer
    )
    model.eval()
    bind_tokenizer_methods(model, tokenizer)

    if target_layer >= model.cfg.n_layers:
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return float("nan"), {h: float("nan") for h in heads}

    valid_heads = [h for h in heads if h < model.cfg.n_heads]
    
    # 构建 token 池
    pool = build_token_pool(tokenizer)
    if len(pool) < PREFIX_LEN + PATTERN_LEN + GAP_LEN:
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return float("nan"), {h: float("nan") for h in heads}

    rng = random.Random(seed)
    hook_key = f"blocks.{target_layer}.attn.hook_z"

    probs = []
    contrib_accum = {h: [] for h in valid_heads}

    for _ in range(num_samples):
        try:
            # 构造 induction 序列
            seq, target_pos, target_token = build_induction_sequence(
                pool, rng, PREFIX_LEN, PATTERN_LEN, GAP_LEN
            )
            x = torch.tensor([seq], dtype=torch.long, device=model.cfg.device)

            # -------- 1) 干净前向：baseline --------
            with torch.no_grad():
                logits_clean = model(x, return_type="logits")  # [1, seq_len, vocab]
            logit_C_clean = logits_clean[0, target_pos, target_token]
            prob_C = torch.softmax(logits_clean[0, target_pos], dim=-1)[target_token].item()
            probs.append(prob_C)

            # -------- 2) 对每个 head 做 ablation --------
            for h in valid_heads:
                def zero_head_z(z, hook):
                    z = z.clone()
                    z[:, :, h, :] = 0.0
                    return z

                with torch.no_grad():
                    logits_abl = model.run_with_hooks(
                        x,
                        return_type="logits",
                        fwd_hooks=[(hook_key, zero_head_z)],
                    )
                logit_C_abl = logits_abl[0, target_pos, target_token]
                delta = (logit_C_clean - logit_C_abl).item()
                contrib_accum[h].append(delta)

            del logits_clean, logits_abl
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        except Exception:
            continue

    prob_mean = float(np.mean(probs)) if len(probs) > 0 else float("nan")
    contrib_means = {
        h: (float(np.mean(contrib_accum[h])) if len(contrib_accum[h]) > 0 else float("nan"))
        for h in heads
    }

    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return prob_mean, contrib_means


def plot_induction_prob(results_by_cat: dict, output_dir: str, exp_name: str):
    ensure_dir(output_dir)
    plt.figure(figsize=(12, 7))
    sns.set_style("whitegrid")

    cats = list(results_by_cat.keys())
    palette = sns.color_palette("tab10", n_colors=len(cats))
    color_map = {cat: palette[i] for i, cat in enumerate(cats)}

    for cat in cats:
        series = results_by_cat.get(cat, [])
        series = [(s, a) for (s, a) in series if a == a]
        if not series:
            continue
        steps = [s for (s, _) in series]
        probs = np.array([a for (_, a) in series], dtype=np.float64)
        plt.plot(steps, probs, marker="o", color=color_map[cat], label=cat)

    plt.title(f'P(C) in Induction Task - {exp_name}\n' + 
             '[prefix] A B C [gap] A B ? → C')
    plt.xlabel("Training Step")
    plt.ylabel("Mean Probability P(C)")
    
    all_steps = sorted(set(s for series in results_by_cat.values() for (s, _) in series))
    if all_steps:
        plt.xticks(all_steps[::max(1, len(all_steps)//10)])
    
    plt.ylim(0.0, 1.0)
    plt.legend(title="Category")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    save_path = os.path.join(output_dir, f"{exp_name}_induction_prob.png")
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"✓ 已保存图像: {save_path}")


def plot_logit_contrib_for_head(results_by_cat: dict, output_dir: str, 
                                exp_name: str, layer: int, head: int):
    ensure_dir(output_dir)
    plt.figure(figsize=(12, 7))
    sns.set_style("whitegrid")

    cats = list(results_by_cat.keys())
    palette = sns.color_palette("tab10", n_colors=len(cats))
    color_map = {cat: palette[i] for i, cat in enumerate(cats)}

    for cat in cats:
        series = results_by_cat.get(cat, [])
        series = [(s, m) for (s, m) in series if m == m]
        if not series:
            continue
        steps = [s for (s, _) in series]
        means = np.array([m for (_, m) in series], dtype=np.float64)
        plt.plot(steps, means, marker="o", color=color_map[cat], label=cat)

    plt.title(f'Head Ablation Δ logit(C) in Induction Task\n' + 
             f'{exp_name} - L{layer}H{head}')
    plt.xlabel("Training Step")
    plt.ylabel("Mean Δ logit(C) (clean - ablated)")
    
    all_steps = sorted(set(s for series in results_by_cat.values() for (s, _) in series))
    if all_steps:
        plt.xticks(all_steps[::max(1, len(all_steps)//10)])
    
    plt.legend(title="Category")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    save_path = os.path.join(output_dir, f"{exp_name}_L{layer}_H{head}_contrib.png")
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"✓ 已保存图像: {save_path}")


def save_csv(prob_results: dict, contrib_results: dict, output_dir: str, exp_name: str):
    """保存CSV"""
    ensure_dir(output_dir)
    
    # 保存 P(C) 结果
    rows_prob = ["category,step,prob_C"]
    for cat, series in prob_results.items():
        for (step, prob) in series:
            prob_s = "" if prob != prob else f"{prob:.6f}"
            rows_prob.append(f"{cat},{step},{prob_s}")
    
    prob_csv_path = os.path.join(output_dir, f"{exp_name}_prob.csv")
    with open(prob_csv_path, "w", encoding="utf-8") as f:
        f.write("\n".join(rows_prob))
    print(f"✓ 已保存CSV: {prob_csv_path}")
    
    # 保存每个头的贡献结果
    for h in TARGET_HEADS:
        rows_contrib = ["category,step,delta_logit"]
        for cat, series in contrib_results[h].items():
            for (step, delta) in series:
                delta_s = "" if delta != delta else f"{delta:.6f}"
                rows_contrib.append(f"{cat},{step},{delta_s}")
        
        contrib_csv_path = os.path.join(output_dir, f"{exp_name}_L{TARGET_LAYER}_H{h}_contrib.csv")
        with open(contrib_csv_path, "w", encoding="utf-8") as f:
            f.write("\n".join(rows_contrib))
        print(f"✓ 已保存CSV: {contrib_csv_path}")


def run_single_experiment(config, category_type, categories, steps):
    """运行单个实验配置"""
    exp_name = generate_experiment_name(config, category_type)
    output_dir = os.path.join(OUTPUT_BASE, exp_name)
    
    print("\n" + "=" * 80)
    print(f"开始实验: {exp_name}")
    print("=" * 80)
    print(f"模型大小: {config['model']}M")
    print(f"类别类型: {category_type}")
    print(f"目标层: {TARGET_LAYER}, 目标头: {TARGET_HEADS}")
    print(f"任务: [prefix({PREFIX_LEN})] + [A B C] + [gap({GAP_LEN})] + [A B ?] → 预测 C")
    print(f"输出目录: {output_dir}")
    print("=" * 80)
    
    ensure_dir(output_dir)

    # 组织模型路径
    print("\n组织模型路径...")
    cat_models, alias_map = build_model_paths(categories, steps)
    num_models = sum(len(v) for v in cat_models.values())
    print(f"✓ 共计 {num_models} 个模型")

    # 注册alias
    register_aliases(alias_map)
    print("✓ 已注册模型别名")

    # 存储结果
    induction_prob_results = {cat: [] for cat in categories.keys()}
    head_contrib_results = {h: {cat: [] for cat in categories.keys()} for h in TARGET_HEADS}

    # 评估
    for cat in categories.keys():
        print("\n" + "-" * 80)
        print(f"评估类别: {cat}")
        print("-" * 80)
        for step, alias, path in tqdm(cat_models[cat], desc=f"{cat}", ncols=80):
            try:
                if not os.path.isdir(path):
                    print(f"  [跳过] 路径不存在: {path}")
                    induction_prob_results[cat].append((step, float("nan")))
                    for h in TARGET_HEADS:
                        head_contrib_results[h][cat].append((step, float("nan")))
                    continue

                prob_mean, contrib_means = evaluate_model_induction(
                    alias, path, NUM_SAMPLES, TARGET_LAYER, TARGET_HEADS, RANDOM_SEED
                )
                induction_prob_results[cat].append((step, prob_mean))
                for h in TARGET_HEADS:
                    head_contrib_results[h][cat].append(
                        (step, contrib_means.get(h, float("nan")))
                    )
                
                # 打印简要信息
                contrib_info = " | ".join([
                    f"H{h}: {('NaN' if contrib_means.get(h)!=contrib_means.get(h) else f'{contrib_means.get(h):.4f}')}"
                    for h in TARGET_HEADS
                ])
                prob_info = f"P(C): {('NaN' if prob_mean!=prob_mean else f'{prob_mean:.4f}')}"
                print(f"  步 {step}: {prob_info} | {contrib_info}")
            except Exception as e:
                print(f"  [错误] step={step}: {e}")
                induction_prob_results[cat].append((step, float("nan")))
                for h in TARGET_HEADS:
                    head_contrib_results[h][cat].append((step, float("nan")))
                continue

    # 排序并保存
    print("\n保存结果...")
    for cat in induction_prob_results:
        induction_prob_results[cat].sort(key=lambda x: x[0])
    for h in TARGET_HEADS:
        for cat in head_contrib_results[h]:
            head_contrib_results[h][cat].sort(key=lambda x: x[0])
    
    plot_induction_prob(induction_prob_results, output_dir, exp_name)
    for h in TARGET_HEADS:
        plot_logit_contrib_for_head(head_contrib_results[h], output_dir, exp_name, TARGET_LAYER, h)
    save_csv(induction_prob_results, head_contrib_results, output_dir, exp_name)

    print("\n" + "=" * 80)
    print(f"实验完成: {exp_name}")
    print(f"结果已保存至: {output_dir}")
    print("=" * 80)


def main():
    print("=" * 80)
    print("Induction 任务 + Head Ablation 评估程序")
    print("=" * 80)
    print(f"任务: [prefix] A B C [gap] A B ? → 预测 C")
    print(f"目标层: {TARGET_LAYER}, 目标头: {TARGET_HEADS}")
    print(f"总配置数: {len(ALL_CONFIGS)}")
    print(f"测试样本数: {NUM_SAMPLES}")
    print("=" * 80)

    # 遍历所有配置
    total_experiments = len(ALL_CONFIGS)
    
    for config_idx, config in enumerate(ALL_CONFIGS, 1):
        print(f"\n{'#' * 80}")
        print(f"配置 {config_idx}/{len(ALL_CONFIGS)}: 模型大小 {config['model']}M")
        print(f"{'#' * 80}")
        
        # 获取该配置的categories和steps
        all_categories, steps = get_categories_for_config(config)
        
        # 测试augment类别
        run_single_experiment(config, 'augment', all_categories['augment'], steps)
    
    print("\n" + "=" * 80)
    print("所有实验完成！")
    print(f"总共完成 {total_experiments} 个实验")
    print(f"结果保存在: {OUTPUT_BASE}")
    print("=" * 80)


if __name__ == "__main__":
    main()
