import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import interp1d
from matplotlib.font_manager import FontProperties

interp=True

style_dict = {
    'shine-big-rank-opa': {'color': 'green', 'linestyle': '-', 'linewidth': 4},
    'shine-big-rank-foa': {'color': '#D02020', 'linestyle': '-', 'linewidth': 5.5},  
    'sr1': {'color': '#0000CC', 'linestyle': '-', 'linewidth': 4},  
    'saba':{'color': 'purple', 'linestyle': '-.', 'linewidth': 1},
    'f2sa': {'color': 'C0', 'linestyle': '--', 'linewidth': 1 },
    'Bome': {'color': 'C9', 'linestyle': '-', 'linewidth': 3.2, 'alpha': 0.6},
    'Bsg1': {'color': 'orange', 'linestyle': 'dotted' ,'linewidth': 4, 'alpha': 0.8}
}



legend_names= {
    'shine-big-rank-opa': 'SHINE-OPA',
    'shine-big-rank-foa': 'qNBO(BFGS)',
    'saba': 'SABA',
    'Bsg1': 'BSG1',
    'f2sa': 'F2SA',
    'Bome': 'BOME',
    'sr1': 'qNBO(SR1)',   
}

included_schemes = ['shine-big-rank-foa','sr1','Bome' ,'saba','f2sa','shine-big-rank-opa', 'Bsg1']  


for dataset in [ 'real-sim']:
    plt.figure(figsize=(6,5))  

    max_time_to_plot = 30 

    results_name = f'{dataset}.csv'
    big_df_res = pd.read_csv(results_name, low_memory=False)
    

    for scheme_label in included_schemes:
        if scheme_label in style_dict and scheme_label in big_df_res['scheme_label'].unique():
            
            df_scheme = big_df_res[big_df_res['scheme_label'] == scheme_label].copy()

            df_scheme.reset_index(drop=True, inplace=True)  

        if not df_scheme.empty:
            min_per_seed = df_scheme.groupby('seed')['val_loss'].transform('min')-10
            df_scheme['val_loss_adjusted'] = df_scheme['val_loss']

            if interp:
                t = np.logspace(-2, 3, 50)
                curve = df_scheme.groupby('seed').apply(
                    lambda x: pd.DataFrame({
                        't': t,
                        'v': interp1d(x['time'], x['val_loss_adjusted'], kind='linear', bounds_error=False,
                                    fill_value=(x['val_loss_adjusted'].iloc[0], x['val_loss_adjusted'].iloc[-1]))(t)
                    })
                ).reset_index(level=1, drop=True)

                median_curve = curve.groupby('t')['v'].median()
                q1_curve = curve.groupby('t')['v'].quantile(0.1 )
                q3_curve = curve.groupby('t')['v'].quantile(1 - 0.1)

                plt.semilogy(median_curve.index, median_curve.values, label=scheme_label, **style_dict[scheme_label])
                plt.fill_between(median_curve.index, q1_curve, q3_curve, alpha=0.3, color=style_dict[scheme_label]['color'])
            else:
                
                grouped = df_scheme.groupby('i_iter')
                median_curve = grouped['time'].median().to_frame(name='time')
                median_curve['val_loss_adjusted'] = grouped['val_loss_adjusted'].median()

                q1_curve = grouped['val_loss_adjusted'].quantile(0.1)
                q3_curve = grouped['val_loss_adjusted'].quantile(0.9)

                
                plt.semilogy(median_curve['time'], median_curve['val_loss_adjusted'], label=scheme_label, **style_dict[scheme_label])

                plt.fill_between(median_curve['time'], q1_curve, q3_curve, alpha=0.3, color=style_dict[scheme_label]['color'])


    plt.xlabel('Running time (s)', fontsize=20)
    plt.ylabel('Test loss', fontsize=20)
    plt.xticks(fontsize=20)
    
    plt.xticks(range(0, 21, 5))
    if dataset=='20news':
       plt.xlim(-0.2,20)
       plt.xticks(range(0, 21, 5))
    
    else:
       plt.xlim(-0.1,10)
       plt.xticks(range(0, 11, 5))
    
    plt.yticks(fontsize=20)
    plt.yscale('log') 
    
    plt.savefig(f'lr{dataset}.pdf', dpi=300, bbox_inches='tight')
    plt.close()  

legend_font = FontProperties(weight='heavy',size=20)
plt.figure(figsize=(4,1))
handles = [
    plt.Line2D([], [], color=style_dict[scheme]['color'],linewidth=7, linestyle=style_dict[scheme]['linestyle'], 
               label=legend_names.get(scheme, scheme)) 
    for scheme in included_schemes if scheme in style_dict
]

plt.legend(handles=handles, ncol=1, fontsize=8, prop=legend_font, borderpad=1, 
           handlelength=2, handletextpad=2)

plt.axis('off')

plt.savefig('lrlegend1.pdf', dpi=300, bbox_inches='tight')