import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import json
from scipy import stats

# ==========================================
# Config
# ==========================================
# generalization gap results file path
GEN_RESULTS_FILE = "generalization_results.jsonl"

# weighted degree results file path
WD_RESULTS_FILE = "results_wd/wd_softmax_False_norm_True_results_pca_10_train.jsonl" 

# hyperparameter file path (new)
HPARAM_FILE = "hparam.json"

# ==========================================
# Data Processing
# ==========================================

def load_and_merge_data(gen_file, wd_file, hparam_file):
    # 1. Check if files exist
    for f in [gen_file, wd_file, hparam_file]:
        if not os.path.exists(f):
            print(f"Error:  {f}")
            return None

    print("Loading data...")
    
    # 2. Read jsonl data
    df_gen = pd.read_json(gen_file, lines=True)
    df_wd = pd.read_json(wd_file, lines=True)

    # 3. Read hparam.json (standard JSON format, keys are model names)
    # Use orient='index' to convert {"model_1": {...}} to DataFrame rows
    with open(hparam_file, 'r') as f:
        hparam_data = json.load(f)
    df_hparam = pd.DataFrame.from_dict(hparam_data, orient='index')
    
    # Reset index to column and rename to merge_key for merging
    df_hparam.reset_index(inplace=True)
    df_hparam.rename(columns={'index': 'merge_key'}, inplace=True)

    # 4. Standardize model names (remove .pt suffix)
    # Result files usually have .pt, hparam files usually don't
    df_wd['merge_key'] = df_wd['model_name'].apply(lambda x: os.path.splitext(x)[0])
    df_gen['merge_key'] = df_gen['model_name'] # Assuming gen file also doesn't have suffix, or you may need to handle similarly

    # 5. Merge DataFrames
    # Step 1: Merge Gen and WD
    df_merged = pd.merge(df_gen, df_wd, on='merge_key', suffixes=('_gen', '_wd'))
    
    # Step 2: Merge Hparams
    # Note: model_0 etc. may not have mix field, merged mix column may contain NaN
    df_final = pd.merge(df_merged, df_hparam, on='merge_key', how='inner')
    
    # Fill NaN values: for example, model_0 only has info field, mix is empty, we treat it as 0
    if 'mix' in df_final.columns:
        df_final['mix'] = df_final['mix'].fillna(0)
    else:
        # If no models have the mix field (extreme case), create a column of all 0s
        df_final['mix'] = 0

    print(f"Merged data: {len(df_final)} models matched with hyperparameters.")
    return df_final



def plot_correlation(df, condition_name):

    if df is None or len(df) == 0:
        print(f"No data to plot for condition: {condition_name}")
        return

    # -------------------------------------------
    # 1. Extract metadata
    # -------------------------------------------
    pca_dim = df['pca_dim'].iloc[0] if 'pca_dim' in df.columns else 'Unknown'
    
    if 'split' in df.columns:
        split_val = df['split'].iloc[0]
    elif 'split_wd' in df.columns:
        split_val = df['split_wd'].iloc[0]
    else:
        split_val = 'Unknown'

    # -------------------------------------------
    # 2. Plot settings
    # -------------------------------------------
    sns.set(style="whitegrid")
    plt.figure(figsize=(10, 8))

    # Calculate correlation coefficient
    # Ensure no NaN values (although handled during merge, better to confirm before stats calculation)
    valid_df = df.dropna(subset=['abd', 'gap'])
    if len(valid_df) < 2:
        print(f"Not enough data points for regression in {condition_name}")
        plt.close()
        return

    r_val, p_val = stats.pearsonr(valid_df['abd'], valid_df['gap'])
    
    # Plot scatter and regression line
    ax = sns.regplot(
        data=valid_df, 
        x='abd', 
        y='gap', 
        scatter_kws={'alpha': 0.6, 's': 60, 'edgecolor': 'w'},
        line_kws={'color': 'red', 'label': f'Linear Fit (r={r_val:.3f})'}
    )

    # Label each point
    for index, row in valid_df.iterrows():
        try:
            label_text = row['merge_key'].split('_')[-1]
        except:
            label_text = row['merge_key']
        
        ax.annotate(
            label_text, 
            xy=(row['abd'], row['gap']),
            xytext=(5, 5),          
            textcoords='offset points', 
            fontsize=9,             
            color='black',          
            alpha=0.8               
        )

    # -------------------------------------------
    # 3. Update title and save filename
    # -------------------------------------------
    
    plt.title(f'[{condition_name}] Gap vs Weighted Degree\n(PCA: {pca_dim}, Split: {split_val})', fontsize=15)
    plt.xlabel('Weighted Degree', fontsize=12)
    plt.ylabel('Generalization Gap', fontsize=12)
    
    # Statistical data text
    text_str = '\n'.join((
        f'Condition: {condition_name}',
        f'Pearson r = {r_val:.3f}',
        f'p-value = {p_val:.2e}',
        f'Sample N = {len(valid_df)}'
    ))
    
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax.text(0.05, 0.95, text_str, transform=ax.transAxes, fontsize=12,
            verticalalignment='top', bbox=props)

    plt.legend(loc='lower right')
    plt.tight_layout()
    
    # Filename includes condition_name
    output_img = f"correlation_{condition_name}_pca_{pca_dim}_{split_val}_logits.png"
    
    plt.savefig(output_img, dpi=300)
    print(f"Plot saved to {output_img}")
    # plt.show() # If running batch processing, you can comment out show to prevent blocking
    plt.close() # Close the figure to free memory

if __name__ == "__main__":
    # 1. Load and merge all data
    full_df = load_and_merge_data(GEN_RESULTS_FILE, WD_RESULTS_FILE, HPARAM_FILE)
    
    if full_df is not None:
        # 2. Filter data
        
        # Filter data where mix != 0 (i.e., mix > 0)
        df_mix_active = full_df[full_df['mix'] > 0].copy()
        
        # Filter data where mix == 0 (including model_0 which defaults to 0 if no mix field)
        df_mix_zero = full_df[full_df['mix'] == 0].copy()
        
        print(f"Models with Mixup > 0: {len(df_mix_active)}")
        print(f"Models with Mixup == 0: {len(df_mix_zero)}")

        # 3. Plot separately
        plot_correlation(df_mix_active, condition_name="Mixup_Active")
        plot_correlation(df_mix_zero, condition_name="Mixup_Zero")
        
        print("Done.")