import torch
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import types
import warnings
from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES, MODEL_ALIASES, make_model_alias_map
import random

warnings.filterwarnings("ignore", message="torch_dtype is deprecated")

# ========== 【配置区】只需修改这里 ==========
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 测试参数
NUM_SENTENCES = 100
RANDOM_SEED = 42

# ===== 【核心配置】定义所有要测试的配置组合 =====
ALL_CONFIGS = [
    {
        'model': 14,
        'layer': [2, 2],
        'head': [2, 1],
        'max_step': 1300,
        'topk': 100000,
        'insert': 500
    }
    
]

# ===== 【类别配置】定义所有可能的类别模板 =====
def get_categories_for_config(config):
    """根据config生成对应的categories"""
    model_size = config['model']
    topk = config['topk']
    insert = config['insert']
    max_step = config['max_step']
    
    # 生成训练步数范围
    steps = list(range(0, max_step + 1, 100))
    
    # 定义两种类别组合
    categories_mask = {
        f"randommask{model_size}": {
            "root":"",
            "pattern": f"delranprev\\" + "delranprev{step}"
        },
    }
    
    
    return {
        'mask': categories_mask
    }, steps

# 输出目录（程序会自动创建子目录）
OUTPUT_BASE = ""

# ========== 以下代码无需修改 ==========

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)
    return path


def generate_experiment_name(config, category_type):
    """根据配置自动生成实验名称"""
    model_size = config['model']
    heads_str = "_".join([f"L{layer}H{head}" for layer, head in zip(config['layer'], config['head'])])
    return f"model{model_size}_{category_type}_{heads_str}_topk{config['topk']}_insert{config['insert']}"


def generate_random_sentences(num_sentences=100, seed=42):
    """生成随机测试句子"""
    random.seed(seed)
    np.random.seed(seed)
    
    all_words = [
        # 名词
        "apple", "book", "car", "dog", "elephant", "flower", "guitar", "hat",
        "ice", "juice", "kite", "lamp", "moon", "notebook", "ocean", "pencil",
        "queen", "robot", "star", "table", "umbrella", "violin", "window", "box",
        "cloud", "door", "egg", "fire", "grass", "hill", "island", "jar",
        # 动词
        "jumps", "runs", "sleeps", "eats", "drinks", "flies", "swims", "walks",
        # 形容词
        "big", "small", "happy", "sad", "red", "blue", "green", "yellow",
        # 副词
        "quickly", "slowly", "happily", "sadly", "loudly", "quietly", "carefully",
        # 介词/连词
        "on", "in", "under", "over", "beside", "behind", "near", "and", "but",
        # 代词
        "he", "she", "it", "they", "we", "I", "you", "him", "her", "them"
    ]
    
    sentences = []
    for _ in range(num_sentences):
        length = random.randint(8, 20)
        words = [random.choice(all_words) for _ in range(length)]
        words[0] = words[0].capitalize()
        sentence = " ".join(words) + "."
        sentences.append(sentence)
    
    return sentences


def compute_previous_token_scores(cache, model, seq_len: int):
    """计算 previous token scores"""
    scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device=model.cfg.device)
    
    for layer in range(model.cfg.n_layers):
        pat = cache[f"blocks.{layer}.attn.hook_pattern"][0]
        
        if seq_len > 1:
            prev_token_attention = pat.diagonal(dim1=-2, dim2=-1, offset=-1)
            if prev_token_attention.numel() > 0:
                scores[layer] = prev_token_attention.mean(dim=-1)
    
    return scores


def bind_tokenizer_methods(model, tokenizer):
    """绑定tokenizer方法"""
    model.tokenizer = tokenizer
    
    def to_tokens_tinystories(self, text: str, prepend_bos: bool = False, move_to_device: bool = True):
        if isinstance(text, str):
            tokens = self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=prepend_bos)
        else:
            tokens = torch.cat([self.tokenizer.encode(t, return_tensors="pt", add_special_tokens=prepend_bos) for t in text], dim=-1)
        
        if move_to_device:
            tokens = tokens.to(self.cfg.device)
        return tokens

    def to_str_tokens_tinystories(self, tokens: torch.Tensor, prepend_bos: bool = False):
        if tokens.dim() == 2:
            tokens = tokens.squeeze(0)
        return [self.tokenizer.decode([token_id], skip_special_tokens=False) for token_id in tokens.tolist()]

    model.to_tokens = types.MethodType(to_tokens_tinystories, model)
    model.to_str_tokens = types.MethodType(to_str_tokens_tinystories, model)


def register_aliases(alias_map: dict):
    """注册模型别名"""
    for path, alias in alias_map.items():
        if path not in OFFICIAL_MODEL_NAMES:
            OFFICIAL_MODEL_NAMES.append(path)
        MODEL_ALIASES[path] = [alias]
    make_model_alias_map()


def build_model_paths(categories, steps):
    """根据 CATEGORIES 配置生成模型路径"""
    cat_models = {cat: [] for cat in categories.keys()}
    alias_map = {}
    
    for cat_name, config in categories.items():
        root = config["root"]
        pattern = config["pattern"]
        
        for step in steps:
            subdir = pattern.format(step=step)
            path = os.path.join(root, subdir)
            alias = f"{cat_name}_{step}"
            cat_models[cat_name].append((step, alias, path))
            alias_map[path] = alias
    
    return cat_models, alias_map


def evaluate_model_on_heads(model_alias: str, model_path: str, sentences: list[str], 
                            target_heads: list[dict]) -> dict:
    """评估单个模型"""
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = HookedTransformer.from_pretrained_no_processing(
        model_alias,
        device=DEVICE,
        dtype=torch.float32,
        n_devices=1,
        tokenizer=tokenizer
    )
    model.eval()
    
    # 初始化结果
    results = {f"L{h['layer']}H{h['head']}": [] for h in target_heads}
    
    # 处理所有测试句子
    all_scores = []
    for sentence in sentences:
        try:
            tokens = model.to_tokens(sentence, prepend_bos=False)
            seq_len = tokens.shape[1]
            
            if seq_len < 2:
                continue
            
            with torch.no_grad():
                _, cache = model.run_with_cache(
                    tokens,
                    return_type="logits",
                    names_filter=lambda name: "pattern" in name
                )
            
            scores = compute_previous_token_scores(cache, model, seq_len)
            all_scores.append(scores.cpu())
            
            del cache
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
        except Exception as e:
            continue
    
    # 计算平均分数
    if all_scores:
        averaged_scores = torch.stack(all_scores).mean(dim=0)
        for head_info in target_heads:
            layer = head_info['layer']
            head = head_info['head']
            if layer < averaged_scores.shape[0] and head < averaged_scores.shape[1]:
                score = averaged_scores[layer, head].item()
                key = f"L{layer}H{head}"
                results[key] = score
            else:
                results[f"L{layer}H{head}"] = float('nan')
    else:
        for head_info in target_heads:
            results[f"L{head_info['layer']}H{head_info['head']}"] = float('nan')
    
    # 清理内存
    del model, tokenizer
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return results


def plot_results_for_head(results_by_cat: dict, output_dir: str, exp_name: str, 
                          layer: int, head: int):
    """绘制单个head的追踪图"""
    ensure_dir(output_dir)
    
    colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#6A994E']
    
    fig, ax = plt.subplots(figsize=(14, 8))
    
    for idx, (cat, series) in enumerate(results_by_cat.items()):
        series = [(s, score) for (s, score) in series if score == score]  # 过滤NaN
        if not series:
            continue
        steps = [s for (s, _) in series]
        scores = [score for (_, score) in series]
        
        ax.plot(steps, scores, 
                marker='o', 
                color=colors[idx % len(colors)],
                linewidth=2.5, 
                markersize=8,
                alpha=0.8,
                label=cat)
    
    ax.set_xlabel('Training Step', fontsize=14, fontweight='bold')
    ax.set_ylabel('Previous Token Score', fontsize=14, fontweight='bold')
    ax.set_title(f'Previous Token Head Score - {exp_name} - Layer {layer}, Head {head}', 
                fontsize=16, fontweight='bold', pad=20)
    
    # 动态设置x轴刻度
    all_steps = sorted(set(s for series in results_by_cat.values() for (s, _) in series))
    if all_steps:
        ax.set_xticks(all_steps[::max(1, len(all_steps)//10)])
        ax.set_xticklabels(all_steps[::max(1, len(all_steps)//10)], rotation=45, ha='right')
    
    ax.legend(title="Category")
    ax.grid(True, alpha=0.3, linestyle='--')
    plt.tight_layout()
    
    filename = f"{exp_name}_L{layer}H{head}_tracking.png"
    save_path = os.path.join(output_dir, filename)
    plt.savefig(save_path, dpi=200, bbox_inches='tight')
    plt.close()
    print(f"✓ 已保存图像: {save_path}")


def save_results_txt(results_by_cat: dict, output_dir: str, exp_name: str, 
                     target_heads: list[dict], categories: dict):
    """保存数值结果"""
    ensure_dir(output_dir)
    
    filename = f"{exp_name}_results.txt"
    results_path = os.path.join(output_dir, filename)
    
    with open(results_path, "w", encoding="utf-8") as f:
        f.write("Previous Token Head 追踪结果\n")
        f.write("=" * 80 + "\n")
        f.write(f"测试句子数: {NUM_SENTENCES}\n")
        f.write(f"追踪的heads: {target_heads}\n")
        f.write(f"类别: {list(categories.keys())}\n\n")
        
        for head_info in target_heads:
            layer = head_info['layer']
            head = head_info['head']
            key = f"L{layer}H{head}"
            
            f.write(f"\n{key}\n")
            f.write("-" * 80 + "\n")
            f.write(f"{'Step':<10}")
            for cat in categories.keys():
                f.write(f"{cat:<15}")
            f.write("\n")
            f.write("-" * 80 + "\n")
            
            if key not in results_by_cat:
                f.write(f"[错误] 未找到 {key} 的数据\n")
                continue
            
            head_results = results_by_cat[key]
            
            # 收集所有步数
            all_steps = sorted(set(
                step for cat_data in head_results.values() 
                for (step, _) in cat_data
            ))
            
            # 输出每一步的结果
            for step in all_steps:
                f.write(f"{step:<10}")
                for cat in categories.keys():
                    cat_data = head_results.get(cat, [])
                    score_dict = {s: sc for (s, sc) in cat_data}
                    score = score_dict.get(step, float('nan'))
                    
                    if score == score:  # 检查不是NaN
                        f.write(f"{score:<15.4f}")
                    else:
                        f.write(f"{'NaN':<15}")
                f.write("\n")
    
    print(f"✓ 已保存数值结果: {results_path}")


def run_single_experiment(config, category_type, categories, steps, sentences):
    """运行单个实验配置"""
    # 生成target_heads
    target_heads = [{'layer': layer, 'head': head} for layer, head in zip(config['layer'], config['head'])]
    
    # 生成实验名称
    exp_name = generate_experiment_name(config, category_type)
    output_dir = os.path.join(OUTPUT_BASE, exp_name)
    
    print("\n" + "=" * 80)
    print(f"开始实验: {exp_name}")
    print("=" * 80)
    print(f"模型大小: {config['model']}M")
    print(f"类别类型: {category_type}")
    print(f"追踪的heads: {target_heads}")
    print(f"训练步数范围: 0-{config['max_step']}")
    print(f"输出目录: {output_dir}")
    print("=" * 80)
    
    ensure_dir(output_dir)
    
    # 组织模型路径
    print("\n组织模型路径...")
    cat_models, alias_map = build_model_paths(categories, steps)
    num_models = sum(len(v) for v in cat_models.values())
    print(f"✓ 共计 {num_models} 个模型")
    
    # 注册alias
    register_aliases(alias_map)
    print("✓ 已注册模型别名")
    
    # 存储结果
    all_results = {
        head_key: {cat: [] for cat in categories.keys()} 
        for head_key in [f"L{h['layer']}H{h['head']}" for h in target_heads]
    }
    
    # 评估
    for cat in categories.keys():
        print("\n" + "-" * 80)
        print(f"评估类别: {cat}")
        print("-" * 80)
        for step, alias, path in tqdm(cat_models[cat], desc=f"{cat}", ncols=80):
            try:
                if not os.path.isdir(path):
                    print(f"  [跳过] 路径不存在: {path}")
                    for head_info in target_heads:
                        key = f"L{head_info['layer']}H{head_info['head']}"
                        all_results[key][cat].append((step, float('nan')))
                    continue
                
                head_stats = evaluate_model_on_heads(alias, path, sentences, target_heads)
                
                for head_info in target_heads:
                    key = f"L{head_info['layer']}H{head_info['head']}"
                    score = head_stats.get(key, float('nan'))
                    all_results[key][cat].append((step, score))
                
                # 打印简要信息
                info = " | ".join([
                    f"{k}: {('NaN' if v!=v else f'{v:.4f}')}"
                    for k, v in head_stats.items()
                ])
                print(f"  步 {step}: {info}")
                
            except Exception as e:
                print(f"  [错误] step={step}: {e}")
                for head_info in target_heads:
                    key = f"L{head_info['layer']}H{head_info['head']}"
                    all_results[key][cat].append((step, float('nan')))
                continue
    
    # 排序并保存
    print("\n保存结果...")
    for head_info in target_heads:
        layer = head_info['layer']
        head = head_info['head']
        key = f"L{layer}H{head}"
        
        # 排序
        for cat in all_results[key]:
            all_results[key][cat].sort(key=lambda x: x[0])
        
        # 绘图
        plot_results_for_head(all_results[key], output_dir, exp_name, layer, head)
    
    # 保存文本结果
    save_results_txt(all_results, output_dir, exp_name, target_heads, categories)
    
    print("\n" + "=" * 80)
    print(f"实验完成: {exp_name}")
    print(f"结果已保存至: {output_dir}")
    print("=" * 80)


def main():
    print("=" * 80)
    print("Previous Token Head 追踪程序 - 批量配置版本")
    print("=" * 80)
    print(f"总配置数: {len(ALL_CONFIGS)}")
    print(f"每个配置测试2种类别组合 (mask, order)")
    print(f"总实验数: {len(ALL_CONFIGS) * 2}")
    print(f"测试句子数: {NUM_SENTENCES}")
    print("=" * 80)
    
    # 生成测试句子（所有实验共用）
    print("\n生成测试句子...")
    sentences = generate_random_sentences(NUM_SENTENCES, seed=RANDOM_SEED)
    print(f"✓ 生成了 {len(sentences)} 个句子")
    print(f"  示例: {sentences[0]}")
    
    # 遍历所有配置
    total_experiments = len(ALL_CONFIGS) * 2
    current_exp = 0
    
    for config_idx, config in enumerate(ALL_CONFIGS, 1):
        print(f"\n{'#' * 80}")
        print(f"配置 {config_idx}/{len(ALL_CONFIGS)}: 模型大小 {config['model']}M")
        print(f"{'#' * 80}")
        
        # 获取该配置的categories和steps
        all_categories, steps = get_categories_for_config(config)
        
        # 测试mask类别
        current_exp += 1
        print(f"\n[实验 {current_exp}/{total_experiments}]")
        run_single_experiment(config, 'mask', all_categories['mask'], steps, sentences)
        
        # 测试order类别
        current_exp += 1
        # print(f"\n[实验 {current_exp}/{total_experiments}]")
        # run_single_experiment(config, 'order', all_categories['order'], steps, sentences)
    
    print("\n" + "=" * 80)
    print("所有实验完成！")
    print(f"总共完成 {total_experiments} 个实验")
    print(f"结果保存在: {OUTPUT_BASE}")
    print("=" * 80)


if __name__ == "__main__":
    main()