import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

# =================================================================================
# 분석을 위한 설정 변수
# =================================================================================
TARGET_BLOCKS_ATTN = range(6) # 분석할 블록 범위
OUTPUT_DIR = '/home/user/regcache/outlier_prediction_result/output_indices/'
MLP_OUTPUT_FILE = os.path.join(OUTPUT_DIR, 'mlp_max_token_indices.csv')
NUM_HEADS = 12 # 모델의 어텐션 헤드 개수
TOP_K = 5 # 분석할 상위 순위 개수
# [MODIFIED] 분석할 파일 이름 패턴 변경
ATTN_FILENAME_PATTERN = 'head_{head_num}_attn_div_val_max_index.csv'
# =================================================================================


def analyze_top_k_matching():
    """
    MLP의 max norm 인덱스와 Attention의 Top-K min norm 인덱스 간의
    순위별 일치율을 분석하고 시각화합니다.
    """
    print(f"📊 Top-{TOP_K} 순위별 인덱스 일치율 분석 시작...")

    # --- 1. MLP 인덱스 데이터 로드 ---
    try:
        mlp_indices = pd.read_csv(MLP_OUTPUT_FILE, header=None)[0]
        print(f"✅ MLP 인덱스 로드 완료. 총 {len(mlp_indices)}개 샘플.")
    except (FileNotFoundError, pd.errors.EmptyDataError) as e:
        print(f"❌ 에러: MLP 인덱스 파일 '{MLP_OUTPUT_FILE}'을 로드할 수 없습니다. ({e})")
        return

    # --- 2. Attention Top-K 인덱스 로드 ---
    attn_indices = {}
    min_samples = len(mlp_indices)
    
    for block_num in TARGET_BLOCKS_ATTN:
        attn_indices[block_num] = {}
        block_dir = os.path.join(OUTPUT_DIR, f'block_{block_num}')
        if not os.path.exists(block_dir):
            print(f"⚠️ 경고: 블록 {block_num}의 디렉토리가 없습니다. 건너뜁니다.")
            continue
            
        for head_num in range(NUM_HEADS):
            head_file = os.path.join(block_dir, ATTN_FILENAME_PATTERN.format(head_num=head_num))
            try:
                # [MODIFIED] Top-K 인덱스가 담긴 CSV를 DataFrame으로 로드
                head_data = pd.read_csv(head_file, header=None)
                if head_data.shape[1] < TOP_K:
                    print(f"⚠️ 경고: {head_file} 파일의 열 개수가 {TOP_K}보다 작습니다. 건너뜁니다.")
                    continue
                attn_indices[block_num][head_num] = head_data
                min_samples = min(min_samples, len(head_data))
            except (FileNotFoundError, pd.errors.EmptyDataError):
                # 파일이 없거나 비어있는 경우는 건너뛰기
                continue

    if not any(attn_indices.values()):
        print("❌ 에러: 분석할 Attention 인덱스 데이터가 없습니다.")
        return
        
    # 모든 데이터 길이를 최소 샘플 수에 맞춤
    mlp_indices = mlp_indices[:min_samples]
    for block_num in attn_indices:
        for head_num in attn_indices[block_num]:
            attn_indices[block_num][head_num] = attn_indices[block_num][head_num].iloc[:min_samples, :]
            
    print(f"➡️ 모든 데이터 길이를 {min_samples}개 샘플로 통일하여 분석합니다.")

    # --- 3. 순위별 일치율 계산 ---
    # 각 순위별로 일치율을 저장할 딕셔너리. Key: 순위, Value: DataFrame
    match_rates_by_rank = {f'Top-{i+1}': pd.DataFrame(index=list(TARGET_BLOCKS_ATTN), columns=range(NUM_HEADS), dtype=float) for i in range(TOP_K)}

    for block_num, heads_data in attn_indices.items():
        for head_num, top_k_df in heads_data.items():
            for rank in range(TOP_K):
                # [MODIFIED] 현재 순위(rank)에 해당하는 Attention 인덱스 컬럼을 가져옴
                attn_indices_at_rank = top_k_df[rank]
                
                # MLP 인덱스와 현재 순위의 인덱스 비교
                matches = (mlp_indices == attn_indices_at_rank)
                match_rate = matches.mean()
                
                # 결과 저장
                match_rates_by_rank[f'Top-{rank+1}'].loc[block_num, head_num] = match_rate

    # --- 4. 분석 결과 출력 ---
    print("\n--- 분석 결과 ---")
    block_avg_by_rank = pd.DataFrame(index=list(TARGET_BLOCKS_ATTN))
    
    for rank_str, rates_df in match_rates_by_rank.items():
        print(f"\n각 블록/헤드별 [{rank_str}] 일치율 (%):")
        print((rates_df.dropna(how='all', axis=1) * 100).round(2)) # 데이터가 있는 헤드만 출력
        block_avg_by_rank[rank_str] = rates_df.mean(axis=1)
        
    print("\n--- 요약: 각 블록의 순위별 평균 일치율 (%) ---")
    print((block_avg_by_rank * 100).round(2))

    # --- 5. 시각화 ---
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(15, 8))
    
    # [MODIFIED] 그룹 막대그래프 생성
    block_avg_by_rank.plot(kind='bar', ax=ax, width=0.8)
    
    ax.set_title(f'MLP Max-Norm Index vs. Attention Top-{TOP_K} Min-Norm Indices Matching Rate', fontsize=16)
    ax.set_xlabel('Transformer Block', fontsize=12)
    ax.set_ylabel('Average Match Rate', fontsize=12)
    ax.tick_params(axis='x', rotation=0)
    ax.legend(title='Rank')
    ax.grid(axis='y', linestyle='--')
    
    # Y축 범위를 보기 좋게 조절
    ax.set_ylim(0, max(0.1, block_avg_by_rank.max().max() * 1.2))

    plt.tight_layout()
    
    # 시각화 결과 저장 및 출력
    save_path = os.path.join(OUTPUT_DIR, 'top_k_match_rate_visualization.png')
    plt.savefig(save_path)
    print(f"\n📈 시각화 결과가 '{save_path}'에 저장되었습니다.")
    plt.show()


if __name__ == '__main__':
    analyze_top_k_matching()