import os
import json
import glob
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from pathlib import Path

RESULTS_DIR = "exp_dir"
OUTPUT_DIR = "./plots/noise_robustness"
DATASETS = ["CharacterTrajectories", "SpokenArabicDigits", "UWaveGestureLibrary"]

MODEL_DISPLAY_NAMES = {
    'baseline_cubic': "Cubic",
    'baseline_linear': "Linear",
    #'odernn': "ODE-RNN",
    #'log_ncde': "Log-NCDE",
    'kernel_gaussian': "Gaussian",
    'gp_gp': "GP",
    #'qformer_gaussian': "MV Gaussian",
    #'qformer_gp': "MV GP",
    #'conv_gaussian': "MVC Gaussian",
    'conv_gp': "MVC GP"
}

PALETTE = sns.color_palette("tab10", n_colors=10)
MARKERS = ['o', 's', '^', 'v', 'D', 'X', 'P', '*', 'h', '>']
STYLE_MAP = {name: {'color': PALETTE[i], 'marker': MARKERS[i]} 
             for i, name in enumerate(MODEL_DISPLAY_NAMES.values())}

def compare_lists(list1, list2):
    if not list1 or not list2: return False
    if len(list1) != len(list2): return False
    return np.allclose(list1, list2, atol=1e-4)

def is_target_config(dataset, model_type, params):
    tol = float(params.get('tol', 0))
    bw_raw = params.get('bw_raw', None)
    ls_raw = params.get('ls_raw', None)
    bws_raw = params.get('bws_raw', [])
    
    # --- CharacterTrajectories best res---
    if dataset == "CharacterTrajectories":
        if model_type == 'baseline_linear': return np.isclose(tol, 0.001)
        if model_type == 'baseline_cubic':  return np.isclose(tol, 0.001)
        if model_type == 'odernn':          return np.isclose(tol, 0.001)
        if model_type == 'log_ncde':        return int(params.get('step_size', 0)) == 10
        if model_type == 'kernel_gaussian': return bw_raw and np.isclose(bw_raw, 0.05)
        if model_type == 'gp_gp':           return ls_raw and np.isclose(ls_raw, 0.6)
        
        target_mv = [0.03, 0.1, 0.4, 1.4]
        target_mvc_gp = [1.4, 1.4, 1.4, 1.4]
        
        if model_type in ['qformer_gaussian', 'qformer_gp', 'conv_gaussian']: 
            return compare_lists(bws_raw, target_mv)
        if model_type == 'conv_gp':          
            return compare_lists(bws_raw, target_mvc_gp) or compare_lists(bws_raw, target_mv)

    # --- SpokenArabicDigits besdt res ---
    elif dataset == "SpokenArabicDigits":
        if model_type == 'baseline_linear': return np.isclose(tol, 0.001)
        if model_type == 'baseline_cubic':  return np.isclose(tol, 0.001)
        if model_type == 'odernn':          return np.isclose(tol, 0.001)
        if model_type == 'log_ncde':        return int(params.get('step_size', 0)) == 20
        if model_type == 'kernel_gaussian': return bw_raw and np.isclose(bw_raw, 0.05)
        if model_type == 'gp_gp':           return ls_raw and np.isclose(ls_raw, 0.6)
        
        target_mv_gauss = [0.03, 0.1, 0.4, 1.4]
        target_mv_gp    = [0.1, 0.4, 1.4]
        target_mvc_gp   = [1.4, 1.4, 1.4]
        
        if model_type == 'qformer_gaussian': return compare_lists(bws_raw, target_mv_gauss)
        if model_type == 'qformer_gp':       return compare_lists(bws_raw, target_mv_gp)
        if model_type == 'conv_gaussian':    return compare_lists(bws_raw, target_mv_gauss)
        if model_type == 'conv_gp':          return compare_lists(bws_raw, target_mvc_gp) or compare_lists(bws_raw, target_mv_gp)

    # --- UWaveGestureLibrary best res---
    elif dataset == "UWaveGestureLibrary":
        if model_type == 'baseline_linear': return np.isclose(tol, 0.0001)
        if model_type == 'baseline_cubic':  return np.isclose(tol, 0.0001)
        if model_type == 'odernn':          return np.isclose(tol, 0.01)
        if model_type == 'log_ncde':        return int(params.get('step_size', 0)) == 30
        if model_type == 'kernel_gaussian': return bw_raw and np.isclose(bw_raw, 0.05)
        if model_type == 'gp_gp':           return ls_raw and np.isclose(ls_raw, 1.4)
        
        target_mv_gauss = [0.03, 0.1, 0.4, 1.4]
        target_mv_gp    = [0.05, 0.2, 0.6]
        target_mvc_gauss= [0.05, 0.05, 0.05, 0.05]
        target_mvc_gp   = [0.1, 0.2, 0.4]
        
        if model_type == 'qformer_gaussian': return compare_lists(bws_raw, target_mv_gauss)
        if model_type == 'qformer_gp':       return compare_lists(bws_raw, target_mv_gp)
        if model_type == 'conv_gaussian':    return compare_lists(bws_raw, target_mvc_gauss)
        if model_type == 'conv_gp':          return compare_lists(bws_raw, target_mvc_gp) or compare_lists(bws_raw, target_mv_gp)

    return False

def get_model_key(params):
    exp_type = params.get('type', '').lower()
    if exp_type == 'baseline':
        return f"baseline_{params.get('interpolation', 'linear')}"
    elif exp_type == 'q-former':
        return f"qformer_{params.get('kernel', 'gaussian')}"
    elif exp_type == 'conv':
        k = params.get('kernel', 'gaussian')
        return f"conv_{k}"
    elif exp_type == 'kernel':
        return f"kernel_{params.get('kernel', 'gaussian')}"
    elif exp_type == 'gp':
        return "gp_gp"
    elif exp_type == 'log_ncde':
        return "log_ncde"
    elif exp_type == 'odernn':
        return "odernn"
    return "unknown"

def load_filtered_data(root_dir):
    data = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: {'nfe': [], 'acc': []})))
    files = glob.glob(os.path.join(root_dir, "**/*.json"), recursive=True)
    print(f"Scanning {len(files)} files...")
    
    for fpath in files:
        path_parts = Path(fpath).parts
        dataset = None
        for d in DATASETS:
            if d in path_parts:
                dataset = d
                break
        if not dataset: continue

        try:
            with open(fpath, 'r') as f: content = json.load(f)
        except: continue

        params = content.get('params', content.get('model_params', {}))
        model_key = get_model_key(params)
        
        if 'grud' in model_key or params.get('type') == 'grud': continue
        if not is_target_config(dataset, model_key, params): continue
        if model_key not in MODEL_DISPLAY_NAMES: continue
        
        display_name = MODEL_DISPLAY_NAMES[model_key]
        noise_res = content.get('noise_results', {})
        
        for noise_val_str, metrics in noise_res.items():
            noise_val = float(noise_val_str)
            data[dataset][display_name][noise_val]['nfe'].append(metrics['nfe'])
            data[dataset][display_name][noise_val]['acc'].append(metrics['accuracy'])
            
    return data

def plot_noise_robustness():
    data = load_filtered_data(RESULTS_DIR)
    if not data:
        print("No valid data found matching the criteria.")
        return

    sns.set_style("whitegrid")
    fig, axes = plt.subplots(2, 3, figsize=(18, 9), sharex=True)
    
    fig.suptitle("Noise Robustness Analysis", 
                 fontsize=22, fontweight='bold', y=1)
    
    all_models_found = set()
    for d in data: all_models_found.update(data[d].keys())
    sorted_models = sorted(list(all_models_found))

    for col, dataset in enumerate(DATASETS):
        if dataset not in data:
            axes[0, col].set_visible(False)
            axes[1, col].set_visible(False)
            continue
            
        ds_data = data[dataset]
        ax_nfe = axes[0, col]
        ax_err = axes[1, col]
        
        ax_nfe.set_title(dataset, fontsize=16, fontweight='bold', pad=15)
        
        for model_name in sorted_models:
            if model_name not in ds_data: continue
                
            noise_levels = sorted(ds_data[model_name].keys())
            if not noise_levels: continue
            noise_levels = noise_levels[1:] 
            if 0.0 in ds_data[model_name]:
                base_nfe_vals = np.array(ds_data[model_name][0.0]['nfe'])
                base_nfe_mean = np.mean(base_nfe_vals)
                if base_nfe_mean < 1e-6: base_nfe_mean = 1.0 
            else:
                base_nfe_vals = np.array(ds_data[model_name][noise_levels[0]]['nfe'])
                base_nfe_mean = np.mean(base_nfe_vals)
            
            mean_nfe_rel = []
            mean_err = []
            
            for nl in noise_levels:
                nfes = np.array(ds_data[model_name][nl]['nfe'])
                accs = np.array(ds_data[model_name][nl]['acc'])
                
                mean_nfe_rel.append(np.mean(nfes) / base_nfe_mean)
                
                errs = 1.0 - (accs / 100.0)
                mean_err.append(np.mean(errs))
                
            x = np.array(noise_levels)
            style = STYLE_MAP.get(model_name, {'color': 'gray', 'marker': 'o'})
            
            ax_nfe.plot(x, mean_nfe_rel, label=model_name, **style, linewidth=2, markersize=5, alpha=0.7)
            
            # lower_nfe = np.maximum(np.array(mean_nfe_rel) - np.array(std_nfe_rel), 0)
            # upper_nfe = np.array(mean_nfe_rel) + np.array(std_nfe_rel)
            # ax_nfe.fill_between(x, lower_nfe, upper_nfe, color=style['color'], alpha=0.15)
            
            ax_err.plot(x, mean_err, label=model_name, **style, linewidth=2, markersize=5, alpha=0.7)
            # lower_err = np.maximum(np.array(mean_err) - np.array(std_err), 0)
            # upper_err = np.array(mean_err) + np.array(std_err)
            # ax_err.fill_between(x, lower_err, upper_err, color=style['color'], alpha=0.15)

        ax_nfe.grid(True, which='both', linestyle='--', alpha=0.5)
        ax_err.grid(True, which='both', linestyle='--', alpha=0.5)

        # #og Scale Config (Commented out)
        # ax_nfe.set_yscale('log')


        # ax_nfe.set_xscale('symlog', linthresh=1e-3)
        # ax_err.set_xscale('symlog', linthresh=1e-3)

        if col == 0:
            ax_nfe.set_ylabel("Relative NFE", fontsize=14, fontweight='bold')
            ax_err.set_ylabel("Error Rate", fontsize=14, fontweight='bold')

    
    fig.supxlabel("Noise Level", fontsize=16, fontweight='bold', y=0.12)
    
    handles = []
    labels = []
    for model in sorted_models:
        style = STYLE_MAP.get(model, {'color': 'gray', 'marker': 'o'})
        h = plt.Line2D([0], [0], color=style['color'], marker=style['marker'], 
                       linestyle='-', linewidth=2, markersize=8)
        handles.append(h)
        labels.append(model)
        
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 0.05),
               ncol=5, frameon=False, fontsize=14)

    plt.tight_layout()
    plt.subplots_adjust(top=0.90, bottom=0.20)
    
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    save_path = os.path.join(OUTPUT_DIR, "noise_robustness.pdf")
    plt.savefig(save_path, format='pdf', bbox_inches='tight')
    print(f"Saved plot to {save_path}")

if __name__ == "__main__":
    plot_noise_robustness()