import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# =================================================================================
# 분석을 위한 설정 변수
# ViT 코드의 설정과 동일하게 맞춰주세요.
# =================================================================================
# Attention에서 분석한 블록 번호 범위
TARGET_BLOCKS_ATTN = range(6) # 0, 1, 2, 3, 4번 블록

# 인덱스 파일이 저장된 최상위 디렉토리
OUTPUT_DIR = '/home/user/regcache/outlier_prediction_result/output_indices/'

# MLP 최대 norm 토큰 인덱스가 저장된 파일 경로
MLP_OUTPUT_FILE = os.path.join(OUTPUT_DIR, 'mlp_max_token_indices.csv')

# Attention 블록별 합산 인덱스 파일 이름
ATTN_AGGREGATED_FILENAME = 'aggregated_min_token_indices.csv'
# =================================================================================


def analyze_aggregated_matching():
    """
    MLP의 max norm 인덱스와 Attention의 aggregated min norm 인덱스 간의
    일치율을 블록별로 분석하고 시각화합니다.
    """
    print("📊 통합 인덱스 일치율 분석 시작...")

    # --- 1. MLP 인덱스 데이터 로드 ---
    try:
        # header=None으로 설정하여 첫 줄부터 데이터로 읽기
        mlp_indices = pd.read_csv(MLP_OUTPUT_FILE, header=None)
        # 여러 열이 있을 경우 첫 번째 열만 사용
        mlp_indices = mlp_indices[0]
        print(f"✅ MLP 인덱스 로드 완료. 총 {len(mlp_indices)}개 샘플.")
    except FileNotFoundError:
        print(f"❌ 에러: MLP 인덱스 파일 '{MLP_OUTPUT_FILE}'을 찾을 수 없습니다.")
        return
    except pd.errors.EmptyDataError:
        print(f"❌ 에러: MLP 인덱스 파일 '{MLP_OUTPUT_FILE}'이 비어 있습니다.")
        return

    # --- 2. Attention 블록별 통합 인덱스 로드 ---
    block_indices_data = {}
    min_samples = len(mlp_indices)

    for block_num in TARGET_BLOCKS_ATTN:
        file_path = os.path.join(OUTPUT_DIR, f'block_{block_num}', ATTN_AGGREGATED_FILENAME)
        try:
            block_data = pd.read_csv(file_path, header=None)[0]
            block_indices_data[block_num] = block_data
            min_samples = min(min_samples, len(block_data))
        except FileNotFoundError:
            print(f"⚠️ 경고: 블록 {block_num}의 통합 인덱스 파일 '{file_path}'가 없습니다. 건너뜁니다.")
        except pd.errors.EmptyDataError:
            print(f"⚠️ 경고: 블록 {block_num}의 통합 인덱스 파일 '{file_path}'가 비어 있습니다. 건너뜁니다.")


    if not block_indices_data:
        print("❌ 에러: 분석할 Attention 인덱스 데이터가 없습니다.")
        return

    # 모든 데이터 길이를 최소 샘플 수에 맞춰 통일 (데이터 로깅 중단 등 길이 불일치 대비)
    mlp_indices = mlp_indices[:min_samples]
    for block_num in block_indices_data:
        block_indices_data[block_num] = block_indices_data[block_num][:min_samples]

    print(f"➡️ 모든 데이터 길이를 {min_samples}개 샘플로 통일하여 분석합니다.")

    # --- 3. 블록별 일치율 계산 ---
    block_match_rates = {}
    for block_num, attn_indices in block_indices_data.items():
        # MLP 인덱스와 현재 블록의 인덱스가 같은 경우 True(1), 다른 경우 False(0)
        matches = (mlp_indices == attn_indices)
        # 일치율 계산 (일치하는 샘플 수 / 전체 샘플 수)
        match_rate = matches.mean()
        block_match_rates[block_num] = match_rate

    # 결과를 보기 쉽게 DataFrame으로 변환
    results_df = pd.Series(block_match_rates, name="Match Rate").sort_index()

    print("\n--- 분석 결과 ---")
    print("블록별 일치율 (%):")
    print((results_df * 100).round(2))

    # --- 4. 시각화 ---
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.figure(figsize=(10, 6))

    # Bar Chart 생성
    bar_plot = sns.barplot(x=results_df.index, y=results_df.values, palette="viridis")

    # 각 막대 위에 일치율(%) 텍스트 표시
    for index, value in enumerate(results_df):
        plt.text(index, value + 0.005, f'{value:.2%}', ha='center', va='bottom', fontsize=11)

    plt.title('Matching Rate between MLP Max-Norm and Aggregated Attention Min-Norm Indices', fontsize=15)
    plt.xlabel('Transformer Block', fontsize=12)
    plt.ylabel('Match Rate', fontsize=12)
    plt.ylim(0, max(0.1, results_df.max() * 1.2)) # Y축 범위 자동 조절
    plt.xticks(ticks=results_df.index, labels=[f'Block {i}' for i in results_df.index])

    plt.tight_layout()

    # 시각화 결과 저장 및 출력
    save_path = os.path.join(OUTPUT_DIR, 'aggregated_match_rate_visualization.png')
    plt.savefig(save_path)
    print(f"\n📈 시각화 결과가 '{save_path}'에 저장되었습니다.")
    plt.show()


if __name__ == '__main__':
    analyze_aggregated_matching()