import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import logging

RESULTS_DIR = "./exp_dir"  
OUTPUT_DIR = "./plots/time_analysis"    
DATASET_FILTER = None                   

plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'Times', 'DejaVu Serif'],
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 14,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'legend.fontsize': 12,
    'figure.titlesize': 18, 
    'figure.autolayout': False,
})

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

def get_model_display_name(row):
    m_type = row.get('type')
    interp = row.get('interpolation')
    kernel = row.get('kernel', 'N/A')
    
    if m_type == 'baseline':
        if interp == 'cubic': return 'Cubic'
        if interp == 'linear': return 'Linear'
        return f'Baseline ({interp})'
        
    if m_type == 'odernn': return 'ODE-RNN'
    if m_type == 'grud': return 'GRU-D'

    if m_type == 'kernel':
        return 'Gaussian CDE'
    
    if m_type == 'gp':
        return 'GP CDE'
    
    if m_type == 'q-former':
        if kernel == 'gp': return 'MV-CDE (GP)'
        return 'MV-CDE (Gaussian)'
        
    if m_type == 'conv':
        if kernel == 'gp': return 'MVC-CDE (GP)'
        return 'MVC-CDE (Gaussian)'
    
    return f"{m_type}"

def load_data(base_dir):
    data = []
    path = Path(base_dir)
    if not path.exists(): return pd.DataFrame()

    json_files = list(path.rglob("*.json"))
    for jf in json_files:
        try:
            with open(jf, 'r') as f: d = json.load(f)
            
            if 'trajectories_fit_time' not in d: continue
            if DATASET_FILTER and d.get('dataset_name') != DATASET_FILTER: continue

            row = {
                'dataset': d.get('dataset_name'),
                'fit': d.get('trajectories_fit_time', 0.0),
                'train': d.get('training_time', 0.0),
                'eval': d.get('evaluation_time', 0.0),
                'type': d.get('type'),
                'interpolation': d.get('interpolation'),
                'kernel': d.get('kernel'),
            }
            row['display_name'] = get_model_display_name(row)
            data.append(row)
        except: pass
    return pd.DataFrame(data)

def plot_split_time_analysis(df, dataset_name):
    if df.empty: return

    df_agg = df.groupby('display_name')[['fit', 'train', 'eval']].mean()

    desired_order = [
        'Linear', 'Cubic', 'ODE-RNN', 'GRU-D',
        'Gaussian CDE', 'GP CDE',
        'MV-CDE (Gaussian)', 'MV-CDE (GP)',
        'MVC-CDE (Gaussian)', 'MVC-CDE (GP)'
    ]
    final_order = [m for m in desired_order if m in df_agg.index]
    final_order += [m for m in df_agg.index if m not in final_order]
    
    df_agg = df_agg.reindex(final_order)

    models = df_agg.index
    fit = df_agg['fit'].values
    train = df_agg['train'].values
    eval_t = df_agg['eval'].values

    c_train = '#1f77b4'  # Blue
    c_fit = '#ff7f0e'    # Orange
    c_eval = '#2ca02c'   # Green

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5.5), gridspec_kw={'width_ratios': [1, 1]})
    
    fig.suptitle(f"{dataset_name}", fontsize=18, y=0.96, fontweight='bold')

    bars1 = ax1.bar(models, train, color=c_train, alpha=0.8, edgecolor='black', width=0.65)
    
    ax1.set_title("Dominant Component: Pure Training Time", fontweight='bold', pad=10)
    ax1.set_ylabel("Time (seconds)", fontweight='bold')
    ax1.grid(axis='y', linestyle='--', alpha=0.4, color='gray')
    
    for rect in bars1:
        height = rect.get_height()
        ax1.text(rect.get_x() + rect.get_width()/2., height,
                 f'{height:.0f}s', ha='center', va='bottom', fontsize=10)

    plt.setp(ax1.get_xticklabels(), rotation=30, ha="right")


    bars_fit = ax2.bar(models, fit, color=c_fit, label='Trajectory Fit', alpha=0.85, edgecolor='black', width=0.65)
    bars_eval = ax2.bar(models, eval_t, bottom=fit, color=c_eval, label='Evaluation (Inference)', alpha=0.85, edgecolor='black', width=0.65)

    ax2.set_title("Overhead: Fit & Inference Time", fontweight='bold', pad=10)
    ax2.set_ylabel("") 
    ax2.grid(axis='y', linestyle='--', alpha=0.4, color='gray')
    ax2.legend(frameon=True, fancybox=False, edgecolor='black')

    totals_overhead = fit + eval_t
    for i, rect in enumerate(bars_eval):
        total_h = totals_overhead[i]
        ax2.text(rect.get_x() + rect.get_width()/2., total_h,
                 f'{total_h:.2f}s', ha='center', va='bottom', fontsize=10,)
        
        

    plt.setp(ax2.get_xticklabels(), rotation=30, ha="right")

    plt.tight_layout()
    plt.subplots_adjust(top=0.85, bottom=0.22, wspace=0.15) 
    
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    save_path = os.path.join(OUTPUT_DIR, f"time_split_{dataset_name}.pdf")
    plt.savefig(save_path, format='pdf', bbox_inches='tight')
    
    plt.savefig(save_path.replace('.pdf', '.png'), format='png', dpi=300, bbox_inches='tight')
    print(f"Saved split plot to {save_path}")
    plt.close()

if __name__ == "__main__":
    print("Loading data...")
    df = load_data(RESULTS_DIR)
    
    if not df.empty:
        datasets = df['dataset'].unique()
        for ds in datasets:
            print(f"Processing {ds}...")
            df_ds = df[df['dataset'] == ds]
            plot_split_time_analysis(df_ds, ds)
    else:
        print("No data found containing 'trajectories_fit_time'. Ensure you re-ran experiments with new code.")