import os
import types
import warnings
import random
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

from transformers import AutoTokenizer, GPTNeoXTokenizerFast
from transformer_lens import HookedTransformer
from transformer_lens.loading_from_pretrained import (
    OFFICIAL_MODEL_NAMES, MODEL_ALIASES, make_model_alias_map
)

warnings.filterwarnings("ignore", message="torch_dtype is deprecated")

# ========== 【配置区】只需修改这里 ==========
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 句子与重复参数
NUM_SENTENCES = 100
NUM_REPEATS = 2
RANDOM_SEED = 42

# ===== 【核心配置】定义所有要测试的配置组合 =====
ALL_CONFIGS = [
     {
         'model': 14,
         'layer': 3,
         'head': [1,3],
         'max_step': 3000,
         'topk': 100000
     }
  
]

# ===== 【类别配置】定义所有可能的类别模板 =====
def get_categories_for_config(config):
    """根据config生成对应的categories"""
    model_size = config['model']
    topk = config['topk']
    max_step = config['max_step']
    
    # 生成训练步数范围
    steps = list(range(100, max_step + 1, 100))
    
    # 定义两种类别组合
    categories_augment = {
        
        f"original{model_size}": {
            "root":"",
            "pattern": f"nomask\\" + "nomask{step}"
        },          
    }
    
    
    return {
        'augment': categories_augment
    }, steps

# 输出目录（程序会自动创建子目录）
OUTPUT_BASE = r""

# ========== 以下代码无需修改 ==========


def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)
    return path


def generate_experiment_name(config, category_type):
    """根据配置自动生成实验名称"""
    model_size = config['model']
    layer = config['layer']
    heads_str = "_".join(map(str, config['head']))
    return f"model{model_size}_{category_type}_L{layer}_H{heads_str}_topk{config['topk']}"


def generate_random_sentences(num_sentences=100, seed=42):
    """生成随机句子"""
    random.seed(seed)
    np.random.seed(seed)

    all_words = [
        "apple", "book", "car", "dog", "elephant", "flower", "guitar", "hat",
        "ice", "juice", "kite", "lamp", "moon", "notebook", "ocean", "pencil",
        "queen", "robot", "star", "table", "umbrella", "violin", "window", "box",
        "cloud", "door", "egg", "fire", "grass", "hill", "island", "jar",
        "jumps", "runs", "sleeps", "eats", "drinks", "flies", "swims", "walks",
        "big", "small", "happy", "sad", "red", "blue", "green", "yellow",
        "quickly", "slowly", "happily", "sadly", "loudly", "quietly", "carefully",
        "on", "in", "under", "over", "beside", "behind", "near", "and", "but",
        "he", "she", "it", "they", "we", "I", "you", "him", "her", "them"
    ]

    sentences = []
    for _ in range(num_sentences):
        length = random.randint(8, 20)
        words = [random.choice(all_words) for _ in range(length)]
        words[0] = words[0].capitalize()
        sentence = " ".join(words) + "."
        sentences.append(sentence)
    return sentences


def bind_tokenizer_methods(model, tokenizer):
    """绑定tokenizer方法"""
    model.tokenizer = tokenizer

    def to_tokens_tinystories(self, text: str, prepend_bos: bool = False, move_to_device: bool = True) -> torch.Tensor:
        if isinstance(text, str):
            tokens = self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=prepend_bos)
        else:
            tokens = torch.cat([self.tokenizer.encode(t, return_tensors="pt", add_special_tokens=prepend_bos) for t in text], dim=-1)
        if move_to_device:
            tokens = tokens.to(self.cfg.device)
        return tokens

    def to_str_tokens_tinystories(self, tokens: torch.Tensor, prepend_bos: bool = False) -> list[str]:
        if tokens.dim() == 2:
            tokens = tokens.squeeze(0)
        return [self.tokenizer.decode([token_id], skip_special_tokens=False) for token_id in tokens.tolist()]

    model.to_tokens = types.MethodType(to_tokens_tinystories, model)
    model.to_str_tokens = types.MethodType(to_str_tokens_tinystories, model)


def compute_head_score_offset1_from_cache(cache, layer: int, head: int, block_len: int, num_repeats: int) -> float:
    """计算单个头的 diagonal(offset=1) 平均分数"""
    key = f"blocks.{layer}.attn.hook_pattern"
    if key not in cache:
        return float("nan")

    pat_all_heads = cache[key][0]
    if head >= pat_all_heads.shape[0]:
        return float("nan")
    pat = pat_all_heads[head]

    scores = []
    for repeat_idx in range(1, num_repeats):
        query_start = repeat_idx * block_len
        query_end = min((repeat_idx + 1) * block_len, pat.shape[0])
        key_start = (repeat_idx - 1) * block_len
        key_end = min(repeat_idx * block_len, pat.shape[1])
        if (query_end - query_start) > 0 and (key_end - key_start) > 0:
            att_block = pat[query_start:query_end, key_start:key_end]
            if att_block.shape[0] > 1 and att_block.shape[1] > 0:
                diag = att_block.diagonal(offset=1, dim1=-2, dim2=-1)
                if diag.numel() > 0:
                    scores.append(diag.mean())
    if len(scores) == 0:
        return float("nan")
    return torch.stack(scores).mean().item()


def register_aliases(alias_to_path: dict):
    """注册模型别名"""
    for path, alias in alias_to_path.items():
        if path not in OFFICIAL_MODEL_NAMES:
            OFFICIAL_MODEL_NAMES.append(path)
        MODEL_ALIASES[path] = [alias]
    make_model_alias_map()


def build_model_paths(categories, steps):
    """
    根据 CATEGORIES 配置生成模型路径
    返回：
      cat_models = {cat_name: [(step, alias, path), ...]}
      alias_map = {path: alias}
    """
    cat_models = {cat: [] for cat in categories.keys()}
    alias_map = {}
    
    for cat_name, config in categories.items():
        root = config["root"]
        pattern = config["pattern"]
        
        for step in steps:
            subdir = pattern.format(step=step)
            path = os.path.join(root, subdir)
            alias = f"{cat_name}_{step}"
            cat_models[cat_name].append((step, alias, path))
            alias_map[path] = alias
    
    return cat_models, alias_map


def evaluate_model_on_heads(model_alias: str, model_path: str, sentences: list[str],
                            target_layer: int, heads: list[int]) -> dict:
    """评估单个模型在指定heads上的分数"""
    tokenizer = GPTNeoXTokenizerFast.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = HookedTransformer.from_pretrained_no_processing(
        model_alias,
        device=DEVICE,
        dtype=torch.float32,
        n_devices=1,
        tokenizer=tokenizer
    )
    model.eval()

    results = {h: (float("nan"), float("nan"), 0) for h in heads}

    if target_layer >= model.cfg.n_layers:
        print(f"  [跳过] 模型 {model_alias}: 层超界")
        del model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        return results

    per_head_scores = {h: [] for h in heads if h < model.cfg.n_heads}
    key = f"blocks.{target_layer}.attn.hook_pattern"

    for sent in sentences:
        try:
            single_tokens = model.to_tokens(sent, prepend_bos=False)
            block_len = single_tokens.shape[1]
            if block_len < 2:
                continue
            tokens = single_tokens.repeat(1, NUM_REPEATS)
            with torch.no_grad():
                _, cache = model.run_with_cache(
                    tokens,
                    return_type="logits",
                    names_filter=lambda name: name == key
                )
            for h in per_head_scores.keys():
                score = compute_head_score_offset1_from_cache(cache, target_layer, h, block_len, NUM_REPEATS)
                if score == score:
                    per_head_scores[h].append(score)
            del cache
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
        except Exception as e:
            print(f"    [警告] 句子处理失败: {e}")
            continue

    for h in per_head_scores.keys():
        arr = np.array(per_head_scores[h], dtype=np.float64)
        if arr.size > 0:
            results[h] = (float(np.mean(arr)), float(np.std(arr, ddof=0)), int(arr.size))
        else:
            results[h] = (float("nan"), float("nan"), 0)

    del model
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    return results


def plot_results_for_head(results_by_cat: dict, output_dir: str, exp_name: str, layer: int, head: int):
    """绘制单个head的多条曲线"""
    ensure_dir(output_dir)
    plt.figure(figsize=(12, 7))
    sns.set_style("whitegrid")

    cats = list(results_by_cat.keys())
    palette = sns.color_palette("tab10", n_colors=len(cats))
    color_map = {cat: palette[i] for i, cat in enumerate(cats)}

    for cat in cats:
        series = results_by_cat.get(cat, [])
        series = [(s, m) for (s, m, sd, n) in series if m == m]
        if not series:
            continue
        steps = [s for (s, _) in series]
        means = np.array([m for (_, m) in series], dtype=np.float64)
        plt.plot(steps, means, marker="o", color=color_map[cat], label=cat)

    plt.title(f"Induction Score (diagonal offset=1) - {exp_name} - L{layer}H{head}")
    plt.xlabel("Training Step")
    plt.ylabel("Mean Score")
    
    # 动态设置x轴刻度
    all_steps = sorted(set(s for series in results_by_cat.values() for (s, _, _, _) in series))
    if all_steps:
        plt.xticks(all_steps[::max(1, len(all_steps)//10)])
    
    plt.legend(title="Category")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    filename = f"{exp_name}_L{layer}H{head}.png"
    save_path = os.path.join(output_dir, filename)
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"✓ 已保存图像: {save_path}")


def save_csv_for_head(results_by_cat: dict, output_dir: str, exp_name: str, layer: int, head: int):
    """保存CSV"""
    ensure_dir(output_dir)
    rows = ["category,step,mean_score,std_score,num_sentences_used"]
    for cat, series in results_by_cat.items():
        for (step, mean, std, n) in series:
            mean_s = "" if mean != mean else f"{mean:.6f}"
            std_s = "" if std != std else f"{std:.6f}"
            rows.append(f"{cat},{step},{mean_s},{std_s},{n}")
    
    filename = f"{exp_name}_L{layer}H{head}.csv"
    csv_path = os.path.join(output_dir, filename)
    with open(csv_path, "w", encoding="utf-8") as f:
        f.write("\n".join(rows))
    print(f"✓ 已保存CSV: {csv_path}")


def run_single_experiment(config, category_type, categories, steps, sentences):
    """运行单个实验配置"""
    target_layer = config['layer']
    target_heads = config['head']
    
    # 生成实验名称
    exp_name = generate_experiment_name(config, category_type)
    output_dir = os.path.join(OUTPUT_BASE, exp_name)
    
    print("\n" + "=" * 80)
    print(f"开始实验: {exp_name}")
    print("=" * 80)
    print(f"模型大小: {config['model']}M")
    print(f"类别类型: {category_type}")
    print(f"目标层: {target_layer}, 目标头: {target_heads}")
    print(f"训练步数范围: 0-{config['max_step']}")
    print(f"输出目录: {output_dir}")
    print("=" * 80)
    
    ensure_dir(output_dir)

    # 组织模型路径
    print("\n组织模型路径...")
    cat_models, alias_map = build_model_paths(categories, steps)
    num_models = sum(len(v) for v in cat_models.values())
    print(f"✓ 共计 {num_models} 个模型")

    # 注册alias
    register_aliases(alias_map)
    print("✓ 已注册模型别名")

    # 存储结果
    all_results = {h: {cat: [] for cat in categories.keys()} for h in target_heads}

    # 评估
    for cat in categories.keys():
        print("\n" + "-" * 80)
        print(f"评估类别: {cat}")
        print("-" * 80)
        for step, alias, path in tqdm(cat_models[cat], desc=f"{cat}", ncols=80):
            try:
                if not os.path.isdir(path):
                    print(f"  [跳过] 路径不存在: {path}")
                    for h in target_heads:
                        all_results[h][cat].append((step, float("nan"), float("nan"), 0))
                    continue

                head_stats = evaluate_model_on_heads(alias, path, sentences, target_layer, target_heads)
                for h in target_heads:
                    mean, std, n_used = head_stats.get(h, (float("nan"), float("nan"), 0))
                    all_results[h][cat].append((step, mean, std, n_used))
                
                # 打印简要信息
                info = " | ".join([
                    f"H{h}: {('NaN' if m!=m else f'{m:.4f}')}"
                    for h, (m, s, n) in head_stats.items()
                ])
                print(f"  步 {step}: {info}")
            except Exception as e:
                print(f"  [错误] step={step}: {e}")
                for h in target_heads:
                    all_results[h][cat].append((step, float("nan"), float("nan"), 0))
                continue

    # 排序并保存
    print("\n保存结果...")
    for h in target_heads:
        for cat in all_results[h]:
            all_results[h][cat].sort(key=lambda x: x[0])
        plot_results_for_head(all_results[h], output_dir, exp_name, target_layer, h)
        save_csv_for_head(all_results[h], output_dir, exp_name, target_layer, h)

    print("\n" + "=" * 80)
    print(f"实验完成: {exp_name}")
    print(f"结果已保存至: {output_dir}")
    print("=" * 80)


def main():
    print("=" * 80)
    print("Induction Score 追踪程序 - 批量配置版本")
    print("=" * 80)
    print(f"总配置数: {len(ALL_CONFIGS)}")
    print(f"每个配置测试2种类别组合 (augment, distaugment)")
    print(f"总实验数: {len(ALL_CONFIGS) * 2}")
    print(f"测试句子数: {NUM_SENTENCES}")
    print(f"重复次数: {NUM_REPEATS}")
    print("=" * 80)

    # 生成测试句子（所有实验共用）
    print("\n生成测试句子...")
    sentences = generate_random_sentences(NUM_SENTENCES, seed=RANDOM_SEED)
    print(f"✓ 生成了 {len(sentences)} 个句子")
    print(f"  示例: {sentences[0]}")

    # 遍历所有配置
    total_experiments = len(ALL_CONFIGS) * 2
    current_exp = 0
    
    for config_idx, config in enumerate(ALL_CONFIGS, 1):
        print(f"\n{'#' * 80}")
        print(f"配置 {config_idx}/{len(ALL_CONFIGS)}: 模型大小 {config['model']}M")
        print(f"{'#' * 80}")
        
        # 获取该配置的categories和steps
        all_categories, steps = get_categories_for_config(config)
        
        # 测试augment类别
        current_exp += 1
        print(f"\n[实验 {current_exp}/{total_experiments}]")
        run_single_experiment(config, 'augment', all_categories['augment'], steps, sentences)
        
        # 测试distaugment类别
        current_exp += 1
        # print(f"\n[实验 {current_exp}/{total_experiments}]")
        # run_single_experiment(config, 'distaugment', all_categories['distaugment'], steps, sentences)
    
    print("\n" + "=" * 80)
    print("所有实验完成！")
    print(f"总共完成 {total_experiments} 个实验")
    print(f"结果保存在: {OUTPUT_BASE}")
    print("=" * 80)


if __name__ == "__main__":
    main()