import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import json
from scipy import stats

# ==========================================
# # Config
# ==========================================
# Sharpness results file path
SHARPNESS_RESULTS_FILE = "sharpness/results/sharpness_results_1.0.jsonl"

# Generalization gap results file path
GEN_RESULTS_FILE = "model-soups/generalization_results.jsonl"

# Hparam file path (used to get mix parameter)
HPARAM_FILE = "model-soups/hparam.json"

# ==========================================
# Data Processing
# ==========================================

def load_and_merge_data(gen_file, sharp_file, hparam_file):
    # 1. Check if files exist
    for f in [gen_file, sharp_file, hparam_file]:
        if not os.path.exists(f):
            print(f"Error:  {f}")
            return None

    print(f"Loading data from:\n 1. {gen_file}\n 2. {sharp_file}\n 3. {hparam_file}")

    # 2. Read jsonl data
    df_gen = pd.read_json(gen_file, lines=True)
    df_sharp = pd.read_json(sharp_file, lines=True)

    # 3. Read hparam.json
    try:
        with open(hparam_file, 'r') as f:
            hparam_data = json.load(f)
        # Convert dictionary to DataFrame, index is model name
        df_hparam = pd.DataFrame.from_dict(hparam_data, orient='index')
        df_hparam.reset_index(inplace=True)
        df_hparam.rename(columns={'index': 'merge_key'}, inplace=True)
    except Exception as e:
        print(f"Error loading hparam file: {e}")
        return None

    # 4. Standardize model names as merge key
    # Sharpness/Gen files usually have .pt suffix, remove it for consistency
    df_sharp['merge_key'] = df_sharp['model_name'].apply(lambda x: os.path.splitext(x)[0])
    df_gen['merge_key'] = df_gen['model_name'].apply(lambda x: os.path.splitext(x)[0])
    # Keys in Hparam are usually already without suffix, no need to process

    # 5. Merge DataFrames
    # Step A: Merge Gen and Sharpness
    df_merged_metrics = pd.merge(df_gen, df_sharp, on='merge_key', suffixes=('_gen', '_sharp'))
    
    # Step B: Merge Hparams (to get mix parameter)
    df_final = pd.merge(df_merged_metrics, df_hparam, on='merge_key', how='inner')
    
    # 6. Handle mix column (fill missing with 0)
    if 'mix' in df_final.columns:
        df_final['mix'] = df_final['mix'].fillna(0)
    else:
        # If 'mix' column is completely missing in hparam
        print("Warning: 'mix' column not found in hparam. Assuming mix=0 for all.")
        df_final['mix'] = 0

    # 7. Ensure 'gap' column exists
    if 'gap' not in df_final.columns:
        # Try to calculate based on train_acc and test_acc (depending on availability)
        if 'train_acc' in df_final.columns and 'test_acc' in df_final.columns:
             df_final['gap'] = df_final['train_acc'] - df_final['test_acc']
        else:
             print("Warning: 'gap' column missing and cannot be calculated.")

    print(f"Merged data: {len(df_final)} models matched.")
    return df_final


def plot_single_metric(df, x_col, x_label, file_suffix, group_label):

    if df is None or len(df) == 0:
        print(f"[{group_label}] No data to plot for {x_label}.")
        return
    
    if x_col not in df.columns:
        print(f"Error: Column {x_col} not found in data.")
        return

    rho_val = df['rho'].iloc[0] if 'rho' in df.columns else 'Unknown'

    sns.set(style="whitegrid")
    plt.figure(figsize=(10, 8))

    clean_df = df.dropna(subset=[x_col, 'gap'])
    if len(clean_df) < 2:
        print(f"[{group_label}] Not enough data points for correlation.")
        plt.close()
        return

    r_val, p_val = stats.pearsonr(clean_df[x_col], clean_df['gap'])
    
    plot_color = 'green' if 'No' not in group_label else 'blue'

    ax = sns.regplot(
        data=clean_df, 
        x=x_col, 
        y='gap', 
        scatter_kws={'alpha': 0.6, 's': 60, 'edgecolor': 'w'},
        line_kws={'color': 'red', 'label': f'Linear Fit (r={r_val:.3f})'},
        color=plot_color
    )

    for index, row in clean_df.iterrows():
        try:
            label_text = row['merge_key'].split('/')[-1] 
        except:
            label_text = row['merge_key']
        
        ax.annotate(
            label_text, 
            xy=(row[x_col], row['gap']),
            xytext=(5, 5),          
            textcoords='offset points', 
            fontsize=9,             
            color='black',          
            alpha=0.7               
        )

    plt.title(f'[{group_label}] Gap vs {x_label}\n(rho = {rho_val})', fontsize=16)
    plt.xlabel(f'{x_label}', fontsize=12)
    plt.ylabel('Generalization Gap', fontsize=12)
    
    # Statistics text
    text_str = '\n'.join((
        f'Group: {group_label}',
        f'Pearson r = {r_val:.3f}',
        f'p-value = {p_val:.2e}',
        f'Sample N = {len(clean_df)}',
        f'Rho = {rho_val}'
    ))
    
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax.text(0.05, 0.95, text_str, transform=ax.transAxes, fontsize=12,
            verticalalignment='top', bbox=props)

    plt.legend(loc='lower right')
    plt.tight_layout()
    

    output_filename = f"sharpness/results/correlation_gap_vs_{file_suffix}_{group_label}_rho_{rho_val}.png"
    
    os.makedirs(os.path.dirname(output_filename), exist_ok=True)
    
    plt.savefig(output_filename, dpi=300)
    print(f"Plot saved to {output_filename}")    
    plt.close()

# ==========================================
# Main Program
# ==========================================

if __name__ == "__main__":
    # 1. Load and merge data (including Gen, Sharpness, Hparam)
    merged_df = load_and_merge_data(GEN_RESULTS_FILE, SHARPNESS_RESULTS_FILE, HPARAM_FILE)
    
    if merged_df is not None:
        # 2. Data grouping
        # Group A: With Mixup (mix > 0)
        df_mix_active = merged_df[merged_df['mix'] > 0].copy()
        
        # Group B: Without Mixup (mix == 0)
        df_mix_zero = merged_df[merged_df['mix'] == 0].copy()
        
        print(f"\nData Split:")
        print(f"  - With Mixup (>0): {len(df_mix_active)} models")
        print(f"  - Without Mixup (=0): {len(df_mix_zero)} models")
        
        # 3. Loop through and plot different metrics
        # Define the list of metrics to plot
        metrics_to_plot = [
            {'col': 'sharpness_sam', 'label': 'Sharpness (SAM)', 'suffix': 'sam'},
            {'col': 'sharpness_adaptive', 'label': 'Adaptive Sharpness', 'suffix': 'adaptive'}
        ]

        for metric in metrics_to_plot:
            # Plot "With Mixup" group
            plot_single_metric(
                df_mix_active, 
                x_col=metric['col'], 
                x_label=metric['label'], 
                file_suffix=metric['suffix'],
                group_label="Mixup_Active"
            )
            
            # Plot "Without Mixup" group
            plot_single_metric(
                df_mix_zero, 
                x_col=metric['col'], 
                x_label=metric['label'], 
                file_suffix=metric['suffix'],
                group_label="Mixup_Zero"
            )
            
        print("\nAll tasks completed.")