import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

def set_size(width, fraction=1, subplots=(3, 3)):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float or string
            Document width in points, or string of predined document type
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy
    subplots: array-like, optional
            The number of rows and columns of subplots.
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    if width == 'thesis':
        width_pt = 426.79135
    elif width == 'beamer':
        width_pt = 307.28987
    else:
        width_pt = width

    # Width of figure (in pts)
    fig_width_pt = width_pt * fraction
    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    # https://disq.us/p/2940ij3
    golden_ratio = (5**.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])

    return (fig_width_in, fig_height_in)

if __name__ == '__main__':

    n = 500 
    att = 5

    data = pd.read_csv(f'sim_results_{n}.csv')
    data['bias'] = data['estimate'] - att 
    data['rmse'] = np.abs(data['estimate'] - att)
 
    data = data[data['augmentation_eta'] == True]
    data = data[data['oracle'] == False]

    # create new column summarizing scenarios, use $\mu^\star$ when correct, and $\Bar{\mu}$ otherise
    data.loc[data['mu'], 'mu'] = '$\mu^\star$' 
    data.loc[data['mu'] == False, 'mu'] = '$\overline{\mu}$'
    data.loc[data['gamma'], 'gamma'] = '$\gamma^\star$'
    data.loc[data['gamma'] == False, 'gamma'] = '$\overline{\gamma}$'
    data.loc[data['pi'], 'pi'] = '$\pi^\star$'
    data.loc[data['pi'] == False, 'pi'] = '$\overline{\pi}$'
    data.loc[data['eta'], 'eta'] = '$\eta^\star$' 
    data.loc[data['eta'] == False, 'eta'] = '$\overline{\eta}$'

    data23 = data[data['ass'] == 2.3]
    data24 = data[data['ass'] == 2.4]

    bias23 = data23.groupby(['mu', 'gamma', 'pi'])['bias'].mean().abs()
    bias24 = data24.groupby(['mu', 'gamma', 'pi', 'eta'])['bias'].mean().abs()

    rmse23 = data23.groupby(['mu', 'gamma', 'pi'])['rmse'].mean()
    rmse24 = data24.groupby(['mu', 'gamma', 'pi', 'eta'])['rmse'].mean()

    coverage23 = data23.groupby(['mu', 'gamma', 'pi'])['coverage'].mean()
    coverage24 = data24.groupby(['mu', 'gamma', 'pi', 'eta'])['coverage'].mean()
 
    sns.set_theme(style="whitegrid", palette="pastel", font_scale=0.6)
    width = 396

    #
    # Performance boxplots
    #

    if n!=2000:
        axd = plt.figure(figsize=set_size(width, subplots=(3,2))).subplot_mosaic(
            [['bias3', 'bias4'],
            ['rmse3', 'rmse4'],
            ['coverage3', 'coverage4']],
            gridspec_kw = {'wspace':0.05, 'hspace':0.05}
            )
    else:
        axd = plt.figure(figsize=set_size(width, subplots=(2,2))).subplot_mosaic(
            [['bias3', 'bias4'],
            ['rmse3', 'rmse4']],
            gridspec_kw = {'wspace':0.05, 'hspace':0.05}
            )
    
    data23['scenario'] = data23['mu'].astype(str) + ',' + data23['gamma'].astype(str) + ',' + data23['pi'].astype(str)
    data24['scenario'] = data24['mu'].astype(str) + ',' + data24['gamma'].astype(str) + ',' + data24['pi'].astype(str) + ',' + data24['eta'].astype(str)

    sns.boxplot(data=data23, x='scenario', y='bias', ax=axd['bias3'], showfliers=False)
    axd['bias3'].set_title('Assumption 2.3')
    axd['bias3'].set_xlabel('')
    axd['bias3'].set_ylabel('Bias')
    # axd['bias3'].set_yscale('log')
    axd['bias3'].set_xticklabels('')
    if n == 2000:
        axd['bias3'].set_ylim(-10, 8)
    else:
        axd['bias3'].set_ylim(-14, 10)
    axd['bias3'].tick_params(pad=-3) 

    sns.boxplot(data=data24, x='scenario', y='bias', ax=axd['bias4'], showfliers=False)
    axd['bias4'].set_title('Assumption 2.4')
    axd['bias4'].set_xlabel('')
    axd['bias4'].set_ylabel('')
    # axd['bias4'].set_yscale('log')
    axd['bias4'].set_xticklabels('')
    axd['bias4'].set_yticklabels('')
    axd['bias4'].set_yticks(axd['bias3'].get_yticks())
    axd['bias4'].set_ylim(axd['bias3'].get_ylim())
    axd['bias4'].tick_params(pad=-3)

    sns.boxplot(data=data23, x='scenario', y='rmse', ax=axd['rmse3'], showfliers=False)
    axd['rmse3'].set_title('')
    axd['rmse3'].set_xlabel('Scenario')
    axd['rmse3'].set_ylabel('RMSE')
    # axd['rmse3'].set_yscale('log')
    if n!=2000:
        axd['rmse3'].set_xticklabels('')
    else:
        axd['rmse3'].set_xticklabels(axd['rmse3'].get_xticklabels(), rotation=90)
    
    if n!= 2000:
        axd['rmse3'].set_ylim(-0.1, 13)
    # axd['rmse3'].set_ylim(1e-7, 2*1e2)
    axd['rmse3'].tick_params(pad=-3)

    sns.boxplot(data=data24, x='scenario', y='rmse', ax=axd['rmse4'], showfliers=False)
    axd['rmse4'].set_title('')
    axd['rmse4'].set_xlabel('Scenario')
    axd['rmse4'].set_ylabel('')
    # axd['rmse4'].set_yscale('log')
    if n!=2000:
        axd['rmse4'].set_xticklabels('')
    else:
        axd['rmse4'].set_xticklabels(axd['rmse4'].get_xticklabels(), rotation=90)
    axd['rmse4'].set_yticklabels('')
    axd['rmse4'].set_yticks(axd['rmse3'].get_yticks())
    axd['rmse4'].set_ylim(axd['rmse3'].get_ylim())
    axd['rmse4'].tick_params(pad=-3)

    if n!=2000:
        sns.barplot(data=data23, x='scenario', y='coverage', ax=axd['coverage3'])
        axd['coverage3'].set_title('')
        axd['coverage3'].set_xlabel('Scenario')
        axd['coverage3'].set_ylabel('Coverage Probability')
        axd['coverage3'].set_xticklabels(axd['coverage3'].get_xticklabels(), rotation=90)
        axd['coverage3'].set_ylim(0, 1.05)
        axd['coverage3'].tick_params(pad=-3) 

        sns.barplot(data=data24, x='scenario', y='coverage', ax=axd['coverage4'])  
        axd['coverage4'].set_title('')
        axd['coverage4'].set_xlabel('Scenario')
        axd['coverage4'].set_ylabel('')
        axd['coverage4'].set_xticklabels(axd['coverage4'].get_xticklabels(), rotation=90)
        axd['coverage4'].set_yticklabels('')
        axd['coverage4'].set_yticks(axd['coverage3'].get_yticks())
        axd['coverage4'].set_ylim(axd['coverage3'].get_ylim())
        axd['coverage4'].tick_params(pad=-3) 
        
        plt.savefig(f'performance_boxplots_{n}.pdf', bbox_inches='tight')

    else:
        plt.savefig(f'performance_boxplots_{n}.pdf', bbox_inches='tight')

        axd2 = plt.figure(figsize=set_size(width, subplots=(1,2))).subplot_mosaic(
            [['coverage3', 'coverage4']],
            gridspec_kw = {'wspace':0.05, 'hspace':0.05}
            )
        
        sns.barplot(data=data23, x='scenario', y='coverage', ax=axd2['coverage3'])
        axd2['coverage3'].set_title('Assumption 2.3')
        axd2['coverage3'].set_xlabel('Scenario')
        axd2['coverage3'].set_ylabel('Coverage Probability')
        axd2['coverage3'].set_xticklabels(axd2['coverage3'].get_xticklabels(), rotation=90)
        axd2['coverage3'].set_ylim(0, 1.05)
        axd2['coverage3'].tick_params(pad=-3) 

        sns.barplot(data=data24, x='scenario', y='coverage', ax=axd2['coverage4'])  
        axd2['coverage4'].set_title('Assumption 2.4')
        axd2['coverage4'].set_xlabel('Scenario')
        axd2['coverage4'].set_ylabel('')
        axd2['coverage4'].set_xticklabels(axd2['coverage4'].get_xticklabels(), rotation=90)
        axd2['coverage4'].set_yticklabels('')
        axd2['coverage4'].set_yticks(axd2['coverage3'].get_yticks())
        axd2['coverage4'].set_ylim(axd2['coverage3'].get_ylim())
        axd2['coverage4'].tick_params(pad=-3)

        plt.savefig(f'performance_boxplots_{n}_coverage.pdf', bbox_inches='tight')


# a = data.loc[(data['ass'] == 2.3) &(data['gamma'] == True) & (data['pi'] == False) & (data['mu'] == False) &  (data['eta'] == False), :]

# plt.hist(a['bias'], bins=30)
# plt.title('Assumption 2.3')
# plt.xlabel('Estimate')
# plt.ylabel('Frequency')
# plt.show()

# # trim 0.05 and 0.95 quantiles
# q_high = a['bias'].quantile(0.90)
# plt.hist(a[a['bias'] <= q_high]['bias'], bins=30)
# plt.title('Assumption 2.3')
# plt.xlabel('Estimate')
# plt.ylabel('Frequency')
# plt.show()

# a[a['bias'] <= q_high]['bias'].describe()