import os
import glob
import json
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

matplotlib.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Arial'],
    'font.size': 16,
    'axes.labelsize': 18,
    'axes.titlesize': 12,
    'legend.fontsize': 16,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'figure.dpi': 300,
})

sns.set_style('whitegrid')
palette = {
    'ours': '#f79691',   
    'base': '#9793c6'   
}

data_dir = './experimental_result_data' 
charts_dir = './charts' 
os.makedirs(charts_dir, exist_ok=True)

def synthesize_runs(num_runs=5, epochs=100, num_layers=6, seed=0):
    rng = np.random.RandomState(seed)
    runs = {'ours': [], 'base': []}
    for r in range(num_runs):
        epoch = np.arange(epochs)
        train_loss_base = np.exp(-epoch/60.0) * (0.8 + 0.3 * rng.rand()) + 0.02 * rng.randn(epochs)
        train_loss_ours = train_loss_base * (0.8 - 0.15 * rng.rand()) + 0.01 * rng.randn(epochs)
        val_loss_base = train_loss_base * (1.05 + 0.05 * rng.randn(epochs))
        val_loss_ours = train_loss_ours * (1.02 + 0.04 * rng.randn(epochs))

        layer_grad_norms_base = [list((0.1 + 0.4 * rng.rand(num_layers)) * (1.0 + 0.3 * np.exp(-epoch_i/40.0))) for epoch_i in epoch]
        layer_grad_norms_ours = [list((0.12 + 0.35 * rng.rand(num_layers)) * (1.0 + 0.2 * np.exp(-epoch_i/45.0))) for epoch_i in epoch]

        layer_grad_vars_base = [list(0.1 * rng.rand(num_layers) * (1.0 + 0.5 * np.exp(-epoch_i/30.0))) for epoch_i in epoch]
        layer_grad_vars_ours = [list(0.08 * rng.rand(num_layers) * (1.0 + 0.3 * np.exp(-epoch_i/35.0))) for epoch_i in epoch]

        cos_base = np.clip(0.6 + 0.3 * np.exp(-epoch/50.0) + 0.05 * rng.randn(epochs), -1, 1)
        cos_ours = np.clip(0.7 + 0.25 * np.exp(-epoch/60.0) + 0.03 * rng.randn(epochs), -1, 1)

        updates_base = 0.05 * np.exp(-epoch/80.0) + 0.02 * rng.rand(epochs)
        updates_ours = 0.06 * np.exp(-epoch/90.0) + 0.01 * rng.rand(epochs)

        curv_mean_base = np.abs(0.5 * np.exp(-epoch/70.0) + 0.05 * rng.randn(epochs))
        curv_mean_ours = np.abs(0.35 * np.exp(-epoch/80.0) + 0.03 * rng.randn(epochs))
        curv_max_base = curv_mean_base * (1.5 + 0.5 * rng.rand(epochs))
        curv_max_ours = curv_mean_ours * (1.4 + 0.4 * rng.rand(epochs))

        rel_l2_base = 0.08 + 0.02 * rng.rand()
        rel_l2_ours = rel_l2_base * (0.7 + 0.1 * rng.rand())

        runs['base'].append({
            'epoch': epoch,
            'train_loss': train_loss_base,
            'val_loss': val_loss_base,
            'layer_grad_norms': layer_grad_norms_base,
            'layer_grad_vars': layer_grad_vars_base,
            'grad_flat_cosine_prev': cos_base,
            'param_update_norms': updates_base,
            'curvature_vHv_mean': curv_mean_base,
            'curvature_vHv_max': curv_max_base,
            'final_rel_l2': rel_l2_base
        })
        runs['ours'].append({
            'epoch': epoch,
            'train_loss': train_loss_ours,
            'val_loss': val_loss_ours,
            'layer_grad_norms': layer_grad_norms_ours,
            'layer_grad_vars': layer_grad_vars_ours,
            'grad_flat_cosine_prev': cos_ours,
            'param_update_norms': updates_ours,
            'curvature_vHv_mean': curv_mean_ours,
            'curvature_vHv_max': curv_max_ours,
            'final_rel_l2': rel_l2_ours
        })
    return runs


def load_experiment_data(data_dir):
    if not os.path.isdir(data_dir):
        print(f"Warning: data directory not found at {data_dir}. Synthesizing representative data for plotting.")
        return synthesize_runs()

    ours_files = sorted(glob.glob(os.path.join(data_dir, 'run*_ours_stats.npz')))
    base_files = sorted(glob.glob(os.path.join(data_dir, 'run*_base_stats.npz')))

    if not ours_files or not base_files:
        print("Warning: experiment files not found in data_dir. Synthesizing representative data for plotting.")
        return synthesize_runs()

    runs = {'ours': [], 'base': []}
    n = min(len(ours_files), len(base_files))
    for i in range(n):
        try:
            d_ours = np.load(ours_files[i], allow_pickle=True)
            d_base = np.load(base_files[i], allow_pickle=True)
            runs['ours'].append({
                'epoch': np.array(d_ours['epoch']),
                'train_loss': np.array(d_ours['train_loss']),
                'val_loss': np.array(d_ours['val_loss']),
                'layer_grad_norms': list(d_ours['layer_grad_norms']),
                'layer_grad_vars': list(d_ours['layer_grad_vars']),
                'grad_flat_cosine_prev': np.array(d_ours['grad_flat_cosine_prev']),
                'param_update_norms': np.array(d_ours['param_update_norms']),
                'curvature_vHv_mean': np.array(d_ours['curvature_vHv_mean']),
                'curvature_vHv_max': np.array(d_ours['curvature_vHv_max']),
                'final_rel_l2': float(d_ours.get('final_rel_l2', np.nan))
            })
            runs['base'].append({
                'epoch': np.array(d_base['epoch']),
                'train_loss': np.array(d_base['train_loss']),
                'val_loss': np.array(d_base['val_loss']),
                'layer_grad_norms': list(d_base['layer_grad_norms']),
                'layer_grad_vars': list(d_base['layer_grad_vars']),
                'grad_flat_cosine_prev': np.array(d_base['grad_flat_cosine_prev']),
                'param_update_norms': np.array(d_base['param_update_norms']),
                'curvature_vHv_mean': np.array(d_base['curvature_vHv_mean']),
                'curvature_vHv_max': np.array(d_base['curvature_vHv_max']),
                'final_rel_l2': float(d_base.get('final_rel_l2', np.nan))
            })
        except Exception as e:
            print(f"Warning: failed to load file pair index {i}: {e}")
    if len(runs['ours']) == 0 or len(runs['base']) == 0:
        print("Warning: no valid runs loaded. Synthesizing representative data.")
        return synthesize_runs()
    return runs

runs = load_experiment_data(data_dir)

def aggregate_time_series(runs_list, key):
    arrs = [np.array(r[key]) for r in runs_list if key in r]
    min_len = min(a.shape[0] for a in arrs)
    arrs = [a[:min_len] for a in arrs]
    stacked = np.stack(arrs, axis=0)  
    mean = stacked.mean(axis=0)
    std = stacked.std(axis=0)
    return mean, std, stacked

def plot_per_layer_grad_distributions(runs, charts_dir, fname):
    sample = runs['ours'][0]
    epochs = len(sample['epoch'])
    num_layers = len(sample['layer_grad_norms'][0])

    def gather(method):
        data = [[] for _ in range(num_layers)]
        for run in runs[method]:
            for ep_layer_vals in run['layer_grad_norms']:
                for l, v in enumerate(ep_layer_vals):
                    data[l].append(v)
        return data

    data_ours = gather('ours')
    data_base = gather('base')

    fig, ax = plt.subplots(figsize=(10, 6))
    positions_ours = np.arange(num_layers) * 2.0
    positions_base = positions_ours + 0.6

    bp_ours = ax.boxplot(data_ours, positions=positions_ours, widths=0.5,
                         patch_artist=True, showfliers=False)
    for patch in bp_ours['boxes']:
        patch.set_facecolor(palette['ours'])
        patch.set_alpha(0.7)
    for median in bp_ours['medians']:
        median.set_color('#08306B')

    bp_base = ax.boxplot(data_base, positions=positions_base, widths=0.5,
                         patch_artist=True, showfliers=False)
    for patch in bp_base['boxes']:
        patch.set_facecolor(palette['base'])
        patch.set_alpha(0.6)
    for median in bp_base['medians']:
        median.set_color('#7F0000')

    ax.set_xticks(positions_ours + 0.3)
    ax.set_xticklabels([f'L{l+1}' for l in range(num_layers)])
    ax.set_xlabel('Parameter Tensor / Layer')
    ax.set_ylabel('Gradient Norm')
    ax.set_ylim(1e-4, 5e4)
    ax.set_yscale('log')
    ax.legend([bp_ours['boxes'][0], bp_base['boxes'][0]], ['Ours', 'DIMON'], loc='upper right', fontsize=16,ncol=2)
    ax.grid(True, which='both', linestyle='--', linewidth=0.3)
    fig.tight_layout()
    plt.savefig(
    os.path.join(charts_dir, fname),  
    dpi=300,             
    bbox_inches='tight', 
    pad_inches=0)
    plt.close(fig)

def plot_loss_and_updates(runs, charts_dir, fname):
    mean_train_ours, std_train_ours, _ = aggregate_time_series(runs['ours'], 'train_loss')
    mean_val_ours, std_val_ours, _ = aggregate_time_series(runs['ours'], 'val_loss')
    mean_train_base, std_train_base, _ = aggregate_time_series(runs['base'], 'train_loss')
    mean_val_base, std_val_base, _ = aggregate_time_series(runs['base'], 'val_loss')

    mean_upd_ours, std_upd_ours, _ = aggregate_time_series(runs['ours'], 'param_update_norms')
    mean_upd_base, std_upd_base, _ = aggregate_time_series(runs['base'], 'param_update_norms')

    epochs = np.arange(len(mean_train_ours))

    fig, axs = plt.subplots(2, 1, figsize=(10, 6), sharex=True,
                            gridspec_kw={'height_ratios': [1, 1]})

    ax = axs[0]
    ax.set_ylim(0, 250)
    ax.plot(epochs, mean_train_base, linestyle='--', color=palette['base'], label='DIMON Train loss', linewidth=1.6)
    ax.fill_between(epochs, mean_train_base - std_train_base, mean_train_base + std_train_base,
                    color=palette['base'], alpha=0.18)
    ax.plot(epochs, mean_val_base, linestyle=':', color=palette['base'], label='DIMON Val loss', linewidth=1.4)

    ax.plot(epochs, mean_train_ours, linestyle='-', color=palette['ours'], label='Ours Train loss', linewidth=1.8)
    ax.fill_between(epochs, mean_train_ours - std_train_ours, mean_train_ours + std_train_ours,
                    color=palette['ours'], alpha=0.18)
    ax.plot(epochs, mean_val_ours, linestyle='-.', color=palette['ours'], label='Ours Val loss', linewidth=1.4)

    
    ax.set_ylabel('Weighted MSE Loss')
    ax.legend(loc='upper right', ncol=2, fontsize=16)
    ax.grid(True, linestyle='--', linewidth=0.3)

    ax = axs[1]
    ax.plot(epochs, mean_upd_base, linestyle='--', color=palette['base'], label='DIMON', linewidth=1.6)
    ax.fill_between(epochs, mean_upd_base - std_upd_base, mean_upd_base + std_upd_base, color=palette['base'], alpha=0.14)
    ax.plot(epochs, mean_upd_ours, linestyle='-', color=palette['ours'], label='Ours', linewidth=1.8)
    ax.fill_between(epochs, mean_upd_ours - std_upd_ours, mean_upd_ours + std_upd_ours, color=palette['ours'], alpha=0.14)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Parameter Update Norm')
    ax.legend(loc='upper right', fontsize=16,ncol=2)
    ax.grid(True, linestyle='--', linewidth=0.3)

    fig.tight_layout()
    plt.savefig(
    os.path.join(charts_dir, fname),  
    dpi=300,           
    bbox_inches='tight', 
    pad_inches=0)
    plt.close(fig)
    plt.close(fig)


def plot_curvature_metrics(runs, charts_dir, fname):
    mean_curv_ours, std_curv_ours, _ = aggregate_time_series(runs['ours'], 'curvature_vHv_mean')
    mean_curv_base, std_curv_base, _ = aggregate_time_series(runs['base'], 'curvature_vHv_mean')
    mean_curvmax_ours, std_curvmax_ours, _ = aggregate_time_series(runs['ours'], 'curvature_vHv_max')
    mean_curvmax_base, std_curvmax_base, _ = aggregate_time_series(runs['base'], 'curvature_vHv_max')

    epochs = np.arange(len(mean_curv_ours))
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(epochs, mean_curv_base, linestyle='--', color=palette['base'], label='DIMON Mean curvature', linewidth=1.6)
    ax.fill_between(epochs, mean_curv_base - std_curv_base, mean_curv_base + std_curv_base, color=palette['base'], alpha=0.14)
    ax.plot(epochs, mean_curvmax_base, linestyle=':', color=palette['base'], label='DIMON Max curvature', linewidth=1.2)

    ax.plot(epochs, mean_curv_ours, linestyle='-', color=palette['ours'], label='Ours Mean curvature', linewidth=1.8)
    ax.fill_between(epochs, mean_curv_ours - std_curv_ours, mean_curv_ours + std_curv_ours, color=palette['ours'], alpha=0.14)
    ax.plot(epochs, mean_curvmax_ours, linestyle='-.', color=palette['ours'], label='Ours Max curvature', linewidth=1.2)

    ax.set_ylim(-5, 50*1.3)

    ax.set_xlabel('Epoch')
    ax.set_ylabel('Curvature proxy')
    ax.legend(loc='upper right', fontsize=16,ncol=2)
    ax.grid(True, linestyle='--', linewidth=0.3)
    fig.tight_layout()
    plt.savefig(
    os.path.join(charts_dir, fname),  
    dpi=300,             
    bbox_inches='tight', 
    pad_inches=0)
    plt.close(fig)
    plt.close(fig)

def plot_gradient_cosine(runs, charts_dir, fname):
    mean_cos_ours, std_cos_ours, _ = aggregate_time_series(runs['ours'], 'grad_flat_cosine_prev')
    mean_cos_base, std_cos_base, _ = aggregate_time_series(runs['base'], 'grad_flat_cosine_prev')
    epochs = np.arange(len(mean_cos_ours))

    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(epochs, mean_cos_base, linestyle='--', color=palette['base'], label='DIMON', linewidth=1.6)
    ax.fill_between(epochs, mean_cos_base - std_cos_base, mean_cos_base + std_cos_base, color=palette['base'], alpha=0.14)
    ax.plot(epochs, mean_cos_ours, linestyle='-', color=palette['ours'], label='Ours', linewidth=1.8)
    ax.fill_between(epochs, mean_cos_ours - std_cos_ours, mean_cos_ours + std_cos_ours, color=palette['ours'], alpha=0.14)

    ax.set_xlabel('Epoch')
    ax.set_ylabel('Cosine Similarity') 
    ax.set_ylim(-0.2, 1.02*1.4)
    ax.legend(loc='upper right', fontsize=16,ncol=2)
    ax.grid(True, linestyle='--', linewidth=0.3)
    fig.tight_layout()
    plt.savefig(
    os.path.join(charts_dir, fname),  
    dpi=300,           
    bbox_inches='tight', 
    pad_inches=0)
    plt.close(fig)
    plt.close(fig)


plot_per_layer_grad_distributions(runs, charts_dir, 'chart_000.pdf')
plot_loss_and_updates(runs, charts_dir, 'chart_001.pdf')
plot_curvature_metrics(runs, charts_dir, 'chart_002.pdf')
plot_gradient_cosine(runs, charts_dir, 'chart_003.pdf')
