import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scipy import stats

# ==========================================
# Config
# ==========================================
# generalization gap results file path
GEN_RESULTS_FILE = "generalization_results.jsonl"

# weighted degree results file path
WD_RESULTS_FILE = "results_wd/wd_softmax_False_norm_True_results_pca_10_val.jsonl" 

# ==========================================
# Data Processing
# ==========================================

def load_and_merge_data(gen_file, wd_file):
    if not os.path.exists(gen_file):
        print(f"Error: {gen_file}")
        return None
    if not os.path.exists(wd_file):
        print(f"Error: {wd_file}")
        return None

    # 1. Load data
    print("Loading data...")
    df_gen = pd.read_json(gen_file, lines=True)
    df_wd = pd.read_json(wd_file, lines=True)

    # 2. Standardize model names (remove .pt suffix)
    df_wd['merge_key'] = df_wd['model_name'].apply(lambda x: os.path.splitext(x)[0])
    df_gen['merge_key'] = df_gen['model_name'] 

    # 3. Merge DataFrames
    # Use inner join to keep only models present in both files
    # If both files have a split field, they will automatically become split_gen and split_wd
    df_merged = pd.merge(df_gen, df_wd, on='merge_key', suffixes=('_gen', '_wd'))
    
    print(f"Merged data: {len(df_merged)} models matched.")
    return df_merged

# ==========================================
# Plotting Logic (Modified)
# ==========================================

def plot_correlation(df):
    if df is None or len(df) == 0:
        print("No data to plot.")
        return

    # -------------------------------------------
    # 1. Extract metadata (PCA Dim and Split info)
    # -------------------------------------------
    
    # Extract PCA dimension
    pca_dim = df['pca_dim'].iloc[0] if 'pca_dim' in df.columns else 'Unknown'
    
    # Extract Split info (compatible with split or split_wd)
    if 'split' in df.columns:
        split_val = df['split'].iloc[0]
    elif 'split_wd' in df.columns:
        # If merge produced suffixes
        split_val = df['split_wd'].iloc[0]
    else:
        split_val = 'Unknown'

    # -------------------------------------------
    # 2. Plot settings
    # -------------------------------------------
    sns.set(style="whitegrid")
    plt.figure(figsize=(10, 8))

    # Calculate correlation coefficient
    r_val, p_val = stats.pearsonr(df['abd'], df['gap'])
    
    # Plot scatter and regression line
    ax = sns.regplot(
        data=df, 
        x='abd', 
        y='gap', 
        scatter_kws={'alpha': 0.6, 's': 60, 'edgecolor': 'w'},
        line_kws={'color': 'red', 'label': f'Linear Fit (r={r_val:.3f})'}
    )

    # Label each point
    for index, row in df.iterrows():
        try:
            label_text = row['merge_key'].split('_')[-1]
        except:
            label_text = row['merge_key']
        
        ax.annotate(
            label_text, 
            xy=(row['abd'], row['gap']),
            xytext=(5, 5),          
            textcoords='offset points', 
            fontsize=9,             
            color='black',          
            alpha=0.8               
        )

    # -------------------------------------------
    # 3. Update title and save filename
    # -------------------------------------------
    
    # Title includes Split information
    plt.title(f'Generalization Gap vs Weighted Degree\n(PCA Dim: {pca_dim}, Split: {split_val})', fontsize=15)
    
    plt.xlabel('Weighted Degree', fontsize=12)
    plt.ylabel('Generalization Gap', fontsize=12)
    
    # Display statistical data text
    text_str = '\n'.join((
        f'Pearson r = {r_val:.3f}',
        f'p-value = {p_val:.2e}',
        f'Sample N = {len(df)}',
        f'Split = {split_val}' # You can also add a line in the text box
    ))
    
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax.text(0.05, 0.95, text_str, transform=ax.transAxes, fontsize=12,
            verticalalignment='top', bbox=props)

    plt.legend(loc='lower right')
    plt.tight_layout()
    
    # Filename includes Split information to avoid overwriting results from different splits
    output_img = f"correlation_plot_pca_{pca_dim}_{split_val}_logits+norm.png"
    
    plt.savefig(output_img, dpi=300)
    print(f"Plot saved to {output_img}")
    plt.show()

if __name__ == "__main__":
    merged_df = load_and_merge_data(GEN_RESULTS_FILE, WD_RESULTS_FILE)
    plot_correlation(merged_df)