import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

# =================================================================================
# 분석을 위한 설정 변수
# ViT 코드의 설정과 동일하게 맞춰주세요.
# =================================================================================
TARGET_BLOCKS_ATTN = range(6) # 0, 1, 2, 3, 4번 블록
OUTPUT_DIR = '/home/user/regcache/outlier_prediction_result/output_indices/'
MLP_OUTPUT_FILE = os.path.join(OUTPUT_DIR, 'mlp_max_token_indices.csv')
# Attention 헤드 개수 (모델 아키텍처에 따라 수정 필요)
# 예시: DeiT-Small의 경우 6개, DeiT-Base의 경우 12개
NUM_HEADS = 12 
# =================================================================================


def analyze_index_matching():
    """
    MLP의 max norm 인덱스와 Attention의 min norm 인덱스 간의 일치율을 분석하고 시각화합니다.
    """
    print("📊 인덱스 일치율 분석 시작...")

    # --- 1. 데이터 로드 ---
    try:
        # MLP 인덱스 로드 (header=None으로 설정하여 첫 줄부터 데이터로 읽기)
        mlp_indices = pd.read_csv(MLP_OUTPUT_FILE, header=None)
        # MLP 인덱스는 모든 row에 대해 동일하므로 첫 번째 열만 사용
        mlp_indices = mlp_indices[0] 
        print(f"✅ MLP 인덱스 로드 완료. 총 {len(mlp_indices)}개 샘플.")
    except FileNotFoundError:
        print(f"❌ 에러: MLP 인덱스 파일 '{MLP_OUTPUT_FILE}'을 찾을 수 없습니다.")
        return

    # Attention 인덱스 로드
    attn_indices = {}
    min_samples = len(mlp_indices)
    
    for block_num in TARGET_BLOCKS_ATTN:
        attn_indices[block_num] = {}
        block_dir = os.path.join(OUTPUT_DIR, f'block_{block_num}')
        if not os.path.exists(block_dir):
            print(f"⚠️ 경고: 블록 {block_num}의 디렉토리가 없습니다. 건너뜁니다.")
            continue
            
        for head_num in range(NUM_HEADS):
            head_file = os.path.join(block_dir, f'head_{head_num}_attn_div_val_max_index.csv')
            try:
                # 각 헤드의 인덱스 로드
                head_data = pd.read_csv(head_file, header=None)[0]
                attn_indices[block_num][head_num] = head_data
                min_samples = min(min_samples, len(head_data))
            except FileNotFoundError:
                print(f"⚠️ 경고: 블록 {block_num}, 헤드 {head_num}의 파일이 없습니다. 건너뜁니다.")
                continue

    if not attn_indices:
        print("❌ 에러: 분석할 Attention 인덱스 데이터가 없습니다.")
        return

    # 모든 데이터의 길이를 최소 샘플 수에 맞춤 (데이터 로깅 중단 등으로 길이가 다를 경우 대비)
    mlp_indices = mlp_indices[:min_samples]
    for block_num in attn_indices:
        for head_num in attn_indices[block_num]:
            attn_indices[block_num][head_num] = attn_indices[block_num][head_num][:min_samples]
            
    print(f"➡️ 모든 데이터 길이를 {min_samples}개 샘플로 통일하여 분석합니다.")

    # --- 2. 일치율 계산 ---
    
    # 블록별/헤드별 일치율 저장을 위한 데이터프레임 생성
    head_match_rates = pd.DataFrame(index=list(TARGET_BLOCKS_ATTN), columns=range(NUM_HEADS))

    for block_num, heads_data in attn_indices.items():
        for head_num, head_indices in heads_data.items():
            # MLP 인덱스와 현재 헤드의 인덱스가 같은 경우를 True(1), 다른 경우를 False(0)로 변환
            matches = (mlp_indices == head_indices)
            # 일치율 계산 (일치하는 샘플 수 / 전체 샘플 수)
            match_rate = matches.mean()
            head_match_rates.loc[block_num, head_num] = match_rate

    # 블록별 평균 일치율 계산
    block_avg_match_rates = head_match_rates.mean(axis=1)

    print("\n--- 분석 결과 ---")
    print("각 블록의 헤드별 일치율 (%):")
    print((head_match_rates * 100).round(2))
    print("\n각 블록의 평균 일치율 (%):")
    print((block_avg_match_rates * 100).round(2))
    
    # --- 3. 시각화 ---
    plt.style.use('seaborn-v0_8-whitegrid')
    
    # Figure와 Axes 생성
    fig, axes = plt.subplots(1, 2, figsize=(18, 6), gridspec_kw={'width_ratios': [1, 2]})
    fig.suptitle('MLP Max-Norm Index vs Attention Min-Norm Index Matching Rate', fontsize=16)

    # 시각화 1: 블록별 평균 일치율 (Bar Chart)
    axes[0].set_title('Average Match Rate per Block', fontsize=12)
    block_avg_match_rates.plot(kind='bar', ax=axes[0], color=sns.color_palette("viridis", len(block_avg_match_rates)))
    axes[0].set_xlabel('Transformer Block')
    axes[0].set_ylabel('Average Match Rate')
    axes[0].tick_params(axis='x', rotation=0)
    axes[0].set_ylim(0, max(0.1, block_avg_match_rates.max() * 1.2)) # Y축 범위 자동 조절

    # 시각화 2: 블록/헤드별 일치율 (Heatmap)
    axes[1].set_title('Match Rate per Head in Each Block', fontsize=12)
    sns.heatmap(
        head_match_rates.astype(float), 
        annot=True, 
        fmt=".2f", 
        cmap="viridis", 
        linewidths=.5,
        ax=axes[1]
    )
    axes[1].set_xlabel('Attention Head')
    axes[1].set_ylabel('Transformer Block')
    
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    
    # 시각화 결과 저장 및 출력
    plt.savefig(os.path.join(OUTPUT_DIR, 'match_rate_visualization.png'))
    print(f"\n📈 시각화 결과가 '{os.path.join(OUTPUT_DIR, 'match_rate_visualization.png')}'에 저장되었습니다.")
    plt.show()


if __name__ == '__main__':
    analyze_index_matching()