import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scipy import stats

# ==========================================
# Config
# ==========================================
# Sharpness results file path
SHARPNESS_RESULTS_FILE = "sharpness/results/sharpness_results_1.0.jsonl"

# Generalization gap results file path
GEN_RESULTS_FILE = "model-soups/generalization_results.jsonl"

# ==========================================
# Data Processing
# ==========================================

def load_and_merge_data(gen_file, sharp_file):
    if not os.path.exists(gen_file):
        print(f"Error: File not found {gen_file}")
        return None
    if not os.path.exists(sharp_file):
        print(f"Error: File not found {sharp_file}")
        return None

    print(f"Loading data from:\n 1. {gen_file}\n 2. {sharp_file}")
    df_gen = pd.read_json(gen_file, lines=True)
    df_sharp = pd.read_json(sharp_file, lines=True)


    df_sharp['merge_key'] = df_sharp['model_name'].apply(lambda x: os.path.splitext(x)[0])
    
    df_gen['merge_key'] = df_gen['model_name'].apply(lambda x: os.path.splitext(x)[0])


    df_merged = pd.merge(df_gen, df_sharp, on='merge_key', suffixes=('_gen', '_sharp'))
    

    if 'gap' not in df_merged.columns:
        print("Warning: 'gap' column not found in merged data. Trying to calculate or find alternative.")

        pass

    print(f"Merged data: {len(df_merged)} models matched.")
    return df_merged


def plot_single_metric(df, x_col, x_label, file_suffix):

    if df is None or len(df) == 0:
        print("No data to plot.")
        return
    
    if x_col not in df.columns:
        print(f"Error: Column {x_col} not found in data.")
        return

    rho_val = df['rho'].iloc[0] if 'rho' in df.columns else 'Unknown'

    sns.set(style="whitegrid")
    plt.figure(figsize=(10, 8))

    clean_df = df.dropna(subset=[x_col, 'gap'])
    if len(clean_df) < 2:
        print("Not enough data points for correlation.")
        return

    r_val, p_val = stats.pearsonr(clean_df[x_col], clean_df['gap'])
    
    ax = sns.regplot(
        data=clean_df, 
        x=x_col, 
        y='gap', 
        scatter_kws={'alpha': 0.6, 's': 60, 'edgecolor': 'w'},
        line_kws={'color': 'red', 'label': f'Linear Fit (r={r_val:.3f})'}
    )

    for index, row in clean_df.iterrows():
        try:
            # For example "models/fine_tune/model_1" -> "model_1"
            label_text = row['merge_key'].split('/')[-1] 
        except:
            label_text = row['merge_key']
        
        ax.annotate(
            label_text, 
            xy=(row[x_col], row['gap']),
            xytext=(5, 5),          
            textcoords='offset points', 
            fontsize=9,             
            color='black',          
            alpha=0.8               
        )

    plt.title(f'Generalization Gap vs {x_label}\n(rho = {rho_val})', fontsize=16)
    plt.xlabel(f'{x_label}', fontsize=12)
    plt.ylabel('Generalization Gap', fontsize=12)
    
    text_str = '\n'.join((
        f'Pearson r = {r_val:.3f}',
        f'p-value = {p_val:.2e}',
        f'Sample N = {len(clean_df)}',
        f'Rho = {rho_val}'
    ))
    
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax.text(0.05, 0.95, text_str, transform=ax.transAxes, fontsize=12,
            verticalalignment='top', bbox=props)

    plt.legend(loc='lower right')
    plt.tight_layout()
    

    output_filename = f"sharpness/results/correlation_gap_vs_{file_suffix}_rho_{rho_val}.png"
    plt.savefig(output_filename, dpi=300)
    print(f"Plot saved to {output_filename}")
    plt.show()

# ==========================================
# Main Program
# ==========================================

if __name__ == "__main__":
    # 1. Load and merge data
    merged_df = load_and_merge_data(GEN_RESULTS_FILE, SHARPNESS_RESULTS_FILE)
    
    if merged_df is not None:
        # 2. Plot the first graph: Generalization Gap vs Sharpness SAM
        plot_single_metric(
            merged_df, 
            x_col='sharpness_sam', 
            x_label='Sharpness (SAM)', 
            file_suffix='sam'
        )
        
        # 3. Generalization Gap vs Adaptive Sharpness
        plot_single_metric(
            merged_df, 
            x_col='sharpness_adaptive', 
            x_label='Adaptive Sharpness', 
            file_suffix='adaptive'
        )