import trackexp
from trackexp.utils import get_data, get_metadata, list_experiments
import pandas as pd
import matplotlib.pyplot as plt
import os
exp_dir = 'trackexp_out'
exp_names = sorted([
    name for name in os.listdir(exp_dir)
    if os.path.isdir(os.path.join(exp_dir, name)) and name.startswith('exp_')
])


import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

fw_data = []
other_data = []

for exp_id in range(len(exp_names)):
    df = get_data(exp_names[exp_id], "training")
    md = get_metadata(exp_names[exp_id])
    
    if md['loss_func'] == 'FW':
        fw_data.append(df)
    else:
        other_data.append(df)

def create_percentile_bands(data_list, time_col='wallclocktime', value_col='test_acc'):
    if not data_list:
        return None, None, None, None
    
    all_times = set()
    for df in data_list:
        all_times.update(df[time_col].values)
    all_times = sorted(list(all_times))
    
    percentile_20 = []
    percentile_50 = []
    percentile_80 = []
    time_points = []
    
    for t in all_times:
        values_at_t = []
        for df in data_list:
            # Find the closest time point in this experiment
            closest_idx = np.argmin(np.abs(df[time_col].values - t))
            values_at_t.append(df[value_col].iloc[closest_idx])
        
        if values_at_t:  # Only add if we have data
            time_points.append(t)
            percentile_20.append(np.percentile(values_at_t, 20))
            percentile_50.append(np.percentile(values_at_t, 50))
            percentile_80.append(np.percentile(values_at_t, 80))
    
    return np.array(time_points), np.array(percentile_20), np.array(percentile_50), np.array(percentile_80)

fw_times, fw_p20, fw_p50, fw_p80 = create_percentile_bands(fw_data)
other_times, other_p20, other_p50, other_p80 = create_percentile_bands(other_data)

plt.figure(figsize=(4,2.75))
plt.rcParams.update({'font.size': 12})
plt.grid(True)

if fw_times is not None:
    plt.fill_between(fw_times, fw_p20, fw_p80,
                     alpha=0.3, color='red', label='FW 20-80% band')
    
    plt.plot(fw_times, fw_p50,
             color='red', linewidth=2, 
             label='FW median')

if other_times is not None:
    # Plot confidence band
    plt.fill_between(other_times, other_p20, other_p80,
                     alpha=0.3, color='black', label='CE 20-80% band')
    
    # Plot median curve (thick, dashed)
    plt.plot(other_times, other_p50,
             color='black', linewidth=2, linestyle='--',
             label='CE median')

plt.legend()
plt.ylim([96, 98.1])
plt.xlabel('Wall Clock Time')
plt.ylabel('Test Accuracy')
plt.title('MNIST with GDTUO')
plt.grid(True, alpha=0.3)
plt.tight_layout()

print(f"FW group: {len(fw_data)} experiments")
print(f"Other group: {len(other_data)} experiments")

if fw_p50 is not None and len(fw_p50) > 0:
    print(f"FW final median accuracy: {fw_p50[-1]:.4f}")
    print(f"FW final 20-80% range: {fw_p20[-1]:.4f} - {fw_p80[-1]:.4f}")

if other_p50 is not None and len(other_p50) > 0:
    print(f"Other final median accuracy: {other_p50[-1]:.4f}")
    print(f"Other final 20-80% range: {other_p20[-1]:.4f} - {other_p80[-1]:.4f}")
plt.savefig('mnist_perf.pdf')
