import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Set
from collections import defaultdict
import os

# ================ 配置：PKL文件路径 ================
PKL_PATHS = {
    # Induction组（3个head）
    'induction_head1': r"influence_copytarget_qkvo124_L3H3.pkl",
    'induction_head2': r"influence_copytarget_qkvo124_L3H1.pkl",
    'induction_head3': r"influence_copytarget_qkvo124_L3H0.pkl",
    
    # 非Induction组（3个head）
    'non_induction_head1': r"influence_copytarget_qkvo124_L3H2.pkl",
    'non_induction_head2': r"influence_copytarget_qkvo124_L4H3.pkl",
    'non_induction_head3': r"influence_copytarget_qkvo124_L2H3.pkl",
}

# 分组定义
INDUCTION_HEADS = ['induction_head1', 'induction_head2', 'induction_head3']
NON_INDUCTION_HEADS = ['non_induction_head1', 'non_induction_head2', 'non_induction_head3']
ALL_HEADS = INDUCTION_HEADS + NON_INDUCTION_HEADS

# 输出路径
OUTPUT_DIR = r"influence function\allpkls\6x14m\overlap_analysis"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ================ 数据加载 ================
def load_positive_samples(pkl_path: str) -> Set[int]:
    """从pkl文件中提取正向样本的sample_index集合"""
    with open(pkl_path, 'rb') as f:
        data = pickle.load(f)
    
    positive_influencers = data.get('positive_influencers', [])
    sample_indices = set([item['sample_index'] for item in positive_influencers])
    
    return sample_indices

def load_all_heads_data() -> Dict[str, Set[int]]:
    """加载所有head的正向样本"""
    all_data = {}
    
    print("="*70)
    print("加载数据...")
    print("="*70)
    
    for head_name, pkl_path in PKL_PATHS.items():
        try:
            samples = load_positive_samples(pkl_path)
            all_data[head_name] = samples
            print(f"✓ {head_name}: {len(samples)} 个正向样本")
        except Exception as e:
            print(f"✗ {head_name}: 加载失败 - {e}")
            all_data[head_name] = set()
    
    return all_data

# ================ 计算重合度矩阵 ================
def compute_overlap_matrix(data: Dict[str, Set[int]], head_order: List[str]) -> np.ndarray:
    """计算所有head之间的重合度矩阵（交集大小）"""
    n = len(head_order)
    overlap_matrix = np.zeros((n, n), dtype=int)
    
    for i, head1 in enumerate(head_order):
        for j, head2 in enumerate(head_order):
            set1 = data[head1]
            set2 = data[head2]
            
            if i == j:
                overlap_matrix[i][j] = len(set1)
            else:
                intersection = len(set1 & set2)
                overlap_matrix[i][j] = intersection
    
    return overlap_matrix

def compute_overlap_ratio_matrix(data: Dict[str, Set[int]], head_order: List[str]) -> np.ndarray:
    """计算重合比例矩阵（交集/行head的大小）"""
    n = len(head_order)
    ratio_matrix = np.zeros((n, n))
    
    for i, head1 in enumerate(head_order):
        for j, head2 in enumerate(head_order):
            set1 = data[head1]
            set2 = data[head2]
            
            if i == j:
                ratio_matrix[i][j] = 1.0
            else:
                intersection = len(set1 & set2)
                ratio = intersection / len(set1) if len(set1) > 0 else 0.0
                ratio_matrix[i][j] = ratio
    
    return ratio_matrix

# ================ 生成热力图 ================
def plot_overlap_heatmap(overlap_matrix: np.ndarray, head_labels: List[str], output_path: str):
    """绘制重合度热力图（交集大小）"""
    plt.figure(figsize=(12, 10))
    
    # 创建标签（简化显示）
    display_labels = []
    for label in head_labels:
        if 'induction' in label:
            display_labels.append(label.replace('induction_head', 'IND-'))
        else:
            display_labels.append(label.replace('non_induction_head', 'NON-'))
    
    # 绘制热力图
    sns.heatmap(overlap_matrix, 
                annot=True,  # 显示数值
                fmt='d',     # 整数格式
                cmap='Blues',  # 颜色映射
                xticklabels=display_labels,
                yticklabels=display_labels,
                cbar_kws={'label': '重合样本数量'},
                linewidths=0.5,
                linecolor='gray')
    
    plt.title('正向样本重合度热力图 (交集大小)', fontsize=16, pad=20, fontproperties='SimHei')
    plt.xlabel('Head', fontsize=12, fontproperties='SimHei')
    plt.ylabel('Head', fontsize=12, fontproperties='SimHei')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"✓ 重合度热力图（绝对数量）已保存到: {output_path}")
    plt.close()

def plot_ratio_heatmap(ratio_matrix: np.ndarray, head_labels: List[str], output_path: str):
    """绘制重合比例热力图（交集/行head大小）"""
    plt.figure(figsize=(12, 10))
    
    # 创建标签（简化显示）
    display_labels = []
    for label in head_labels:
        if 'induction' in label:
            display_labels.append(label.replace('induction_head', 'IND-'))
        else:
            display_labels.append(label.replace('non_induction_head', 'NON-'))
    
    # 绘制热力图
    sns.heatmap(ratio_matrix, 
                annot=True,  # 显示数值
                fmt='.3f',   # 保留3位小数
                cmap='Greens',  # 颜色映射
                xticklabels=display_labels,
                yticklabels=display_labels,
                vmin=0, vmax=1,
                cbar_kws={'label': '重合比例'},
                linewidths=0.5,
                linecolor='gray')
    
    plt.title('正向样本重合比例热力图 (交集/行head大小)', fontsize=16, pad=20, fontproperties='SimHei')
    plt.xlabel('Head', fontsize=12, fontproperties='SimHei')
    plt.ylabel('Head', fontsize=12, fontproperties='SimHei')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"✓ 重合比例热力图已保存到: {output_path}")
    plt.close()

# ================ 生成文字报告 ================
def generate_text_report(data: Dict[str, Set[int]], overlap_matrix: np.ndarray, 
                        ratio_matrix: np.ndarray, head_order: List[str], output_path: str):
    """生成详细的文字分析报告"""
    
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write("="*70 + "\n")
        f.write("正向样本重合分析报告（基于重合度）\n")
        f.write("="*70 + "\n\n")
        
        # 1. 基本统计
        f.write("【1】基本统计\n")
        f.write("-"*70 + "\n")
        for head_name in head_order:
            f.write(f"{head_name}: {len(data[head_name])} 个正向样本\n")
        f.write("\n")
        
        # 2. Induction组内分析
        f.write("【2】Induction 组内重合分析\n")
        f.write("-"*70 + "\n")
        ind_sets = [data[h] for h in INDUCTION_HEADS]
        ind_intersection = set.intersection(*ind_sets)
        f.write(f"三个head的交集: {len(ind_intersection)} 个样本\n\n")
        
        for i, h1 in enumerate(INDUCTION_HEADS):
            for h2 in INDUCTION_HEADS[i+1:]:
                idx1 = head_order.index(h1)
                idx2 = head_order.index(h2)
                intersection = overlap_matrix[idx1][idx2]
                ratio1 = ratio_matrix[idx1][idx2]
                ratio2 = ratio_matrix[idx2][idx1]
                f.write(f"{h1} ↔ {h2}:\n")
                f.write(f"  交集大小: {intersection}\n")
                f.write(f"  占{h1}比例: {ratio1:.4f} ({intersection}/{len(data[h1])})\n")
                f.write(f"  占{h2}比例: {ratio2:.4f} ({intersection}/{len(data[h2])})\n\n")
        
        # 3. 非Induction组内分析
        f.write("【3】非Induction 组内重合分析\n")
        f.write("-"*70 + "\n")
        non_ind_sets = [data[h] for h in NON_INDUCTION_HEADS]
        non_ind_intersection = set.intersection(*non_ind_sets)
        f.write(f"三个head的交集: {len(non_ind_intersection)} 个样本\n\n")
        
        for i, h1 in enumerate(NON_INDUCTION_HEADS):
            for h2 in NON_INDUCTION_HEADS[i+1:]:
                idx1 = head_order.index(h1)
                idx2 = head_order.index(h2)
                intersection = overlap_matrix[idx1][idx2]
                ratio1 = ratio_matrix[idx1][idx2]
                ratio2 = ratio_matrix[idx2][idx1]
                f.write(f"{h1} ↔ {h2}:\n")
                f.write(f"  交集大小: {intersection}\n")
                f.write(f"  占{h1}比例: {ratio1:.4f} ({intersection}/{len(data[h1])})\n")
                f.write(f"  占{h2}比例: {ratio2:.4f} ({intersection}/{len(data[h2])})\n\n")
        
        # 4. 跨组分析
        f.write("【4】Induction 与 非Induction 跨组重合分析\n")
        f.write("-"*70 + "\n")
        
        ind_union = set.union(*ind_sets)
        non_ind_union = set.union(*non_ind_sets)
        cross_intersection = ind_union & non_ind_union
        
        f.write(f"Induction组并集: {len(ind_union)} 个样本\n")
        f.write(f"非Induction组并集: {len(non_ind_union)} 个样本\n")
        f.write(f"两组交集: {len(cross_intersection)} 个样本\n\n")
        
        # 各个head对之间的跨组比较
        f.write("各Head对的跨组重合情况:\n")
        for ind_h in INDUCTION_HEADS:
            for non_ind_h in NON_INDUCTION_HEADS:
                idx1 = head_order.index(ind_h)
                idx2 = head_order.index(non_ind_h)
                intersection = overlap_matrix[idx1][idx2]
                ratio_ind = ratio_matrix[idx1][idx2]
                ratio_non = ratio_matrix[idx2][idx1]
                f.write(f"  {ind_h} ↔ {non_ind_h}: {intersection} 个样本 "
                       f"(占IND {ratio_ind:.3f}, 占NON {ratio_non:.3f})\n")
        f.write("\n")
        
        # 5. 关键发现
        f.write("【5】关键发现\n")
        f.write("-"*70 + "\n")
        
        # 找出重合度最高的pair（组内）
        max_overlap_ind = 0
        max_pair_ind = None
        for i, h1 in enumerate(INDUCTION_HEADS):
            for h2 in INDUCTION_HEADS[i+1:]:
                idx1 = head_order.index(h1)
                idx2 = head_order.index(h2)
                if overlap_matrix[idx1][idx2] > max_overlap_ind:
                    max_overlap_ind = overlap_matrix[idx1][idx2]
                    max_pair_ind = (h1, h2)
        
        f.write(f"Induction组内重合最多的pair: {max_pair_ind[0]} ↔ {max_pair_ind[1]}, "
                f"重合{max_overlap_ind}个样本\n")
        
        max_overlap_non = 0
        max_pair_non = None
        for i, h1 in enumerate(NON_INDUCTION_HEADS):
            for h2 in NON_INDUCTION_HEADS[i+1:]:
                idx1 = head_order.index(h1)
                idx2 = head_order.index(h2)
                if overlap_matrix[idx1][idx2] > max_overlap_non:
                    max_overlap_non = overlap_matrix[idx1][idx2]
                    max_pair_non = (h1, h2)
        
        f.write(f"非Induction组内重合最多的pair: {max_pair_non[0]} ↔ {max_pair_non[1]}, "
                f"重合{max_overlap_non}个样本\n")
        
        # 找出跨组重合最多的pair
        max_overlap_cross = 0
        max_pair_cross = None
        for ind_h in INDUCTION_HEADS:
            for non_ind_h in NON_INDUCTION_HEADS:
                idx1 = head_order.index(ind_h)
                idx2 = head_order.index(non_ind_h)
                if overlap_matrix[idx1][idx2] > max_overlap_cross:
                    max_overlap_cross = overlap_matrix[idx1][idx2]
                    max_pair_cross = (ind_h, non_ind_h)
        
        f.write(f"跨组重合最多的pair: {max_pair_cross[0]} ↔ {max_pair_cross[1]}, "
                f"重合{max_overlap_cross}个样本\n")
        
        f.write("\n")
        f.write("="*70 + "\n")
        f.write("报告生成完毕\n")
        f.write("="*70 + "\n")
    
    print(f"✓ 文字报告已保存到: {output_path}")

# ================ 主函数 ================
if __name__ == "__main__":
    # 1. 加载数据
    all_data = load_all_heads_data()
    
    # 2. 计算重合度矩阵
    print("\n计算重合度矩阵...")
    overlap_matrix = compute_overlap_matrix(all_data, ALL_HEADS)
    ratio_matrix = compute_overlap_ratio_matrix(all_data, ALL_HEADS)
    
    # 3. 生成热力图（绝对数量）
    heatmap_overlap_path = os.path.join(OUTPUT_DIR, "overlap_heatmap_absolute.png")
    plot_overlap_heatmap(overlap_matrix, ALL_HEADS, heatmap_overlap_path)
    
    # 4. 生成热力图（比例）
    heatmap_ratio_path = os.path.join(OUTPUT_DIR, "overlap_heatmap_ratio.png")
    plot_ratio_heatmap(ratio_matrix, ALL_HEADS, heatmap_ratio_path)
    
    # 5. 生成文字报告
    report_path = os.path.join(OUTPUT_DIR, "overlap_analysis_report.txt")
    generate_text_report(all_data, overlap_matrix, ratio_matrix, ALL_HEADS, report_path)
    
    # 6. 保存矩阵
    overlap_matrix_path = os.path.join(OUTPUT_DIR, "overlap_matrix_absolute.npy")
    ratio_matrix_path = os.path.join(OUTPUT_DIR, "overlap_matrix_ratio.npy")
    np.save(overlap_matrix_path, overlap_matrix)
    np.save(ratio_matrix_path, ratio_matrix)
    print(f"✓ 重合度矩阵（绝对值）已保存到: {overlap_matrix_path}")
    print(f"✓ 重合度矩阵（比例）已保存到: {ratio_matrix_path}")
    
    print("\n" + "="*70)
    print("所有分析完成！")
    print(f"输出目录: {OUTPUT_DIR}")
    print("="*70)
