import scienceplots
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


def get_agg_stats(df, topic, bin_types = ['mean', 'median']):
    df['model'] = np.where(df['model']=='gpt_3.5', 'GPT3.5',\
                    np.where(df['model']=='gpt_4', 'GPT4',\
                        np.where(df['model']=='LLAMA2_7B', 'Llama2-7B',\
                            np.where(df['model']=='LLAMA2_13B', 'Llama2-13B',\
                                np.where(df['model']=='LLAMA2_70B', 'Llama2-70B', df['model'])))))
    df = df[df['year']!='']
    df['year'] = df['year'].astype(int)
    df['over_cutoff'] = np.where(df.year >2023, 1, 0)

    agg_df = df.groupby([topic, 'model']).agg({'over_cutoff':'sum',
                                                            'year':['mean', 'median','std', 'min', 'max']}).reset_index()
    agg_df['over_cutoff_perc'] = agg_df['over_cutoff']/10
    agg_df.columns = [topic, 'model', 'over_cutoff', 'mean_year', 'median_year', 'std_year', 'min_year', 'max_year', 'over_cutoff_perc']

    binned_df = {}
    for b in bin_types:
        agg_df[f'{b}_year_bin'] = agg_df[f'{b}_year'].round(1).astype(int)

        binned_df[b] = agg_df[[topic, 'model', f'{b}_year_bin', 'std_year','min_year', 'max_year']].groupby(['model', f'{b}_year_bin'])\
                                                                                                    .agg({topic:'count',
                                                                                                        'std_year': ['mean', 'median'],
                                                                                                        'min_year': ['mean', 'median'],
                                                                                                        'max_year': ['mean', 'median']}).reset_index()
        binned_df[b].columns = ['model', f'{b}_year_bin', 'count_in_bin', 'mean_std', 'median_std', 'mean_min', 'median_min', 'mean_max', 'median_max']
        
        for col in [f'{b}_year_bin', 'count_in_bin', 'mean_min', 'median_min', 'mean_max', 'median_max']:
            binned_df[b][col] = binned_df[b][col].astype(int)

        binned_df[b]['perc_in_bin'] = binned_df[b]['count_in_bin']/100
    
    return agg_df, binned_df

def plot_pdf(agg_df, avg_metric, title, xmin, xmax):
    models = ['Llama2-7B', 'Llama2-13B', 'Llama2-70B', 'GPT3.5', 'GPT4']

    plt.style.use(['science', 'bright'])

    fig, ax = plt.subplots(figsize=(8, 6))

    colors = ['#BBBBBB','#66CCEE', '#4477AA', '#CCBB44','#228833']
    x_range = np.linspace(xmin, xmax, int(xmax-xmin))
    for i, model in enumerate(models):
        data = agg_df[agg_df['model']==model]

        kde = gaussian_kde(data[f'{avg_metric}_year'])
        pdf_smooth = kde.evaluate(x_range)

        ax.plot(x_range, pdf_smooth, linewidth=2, color = colors[i], label=f"{model} ($\mu$={data[f'{avg_metric}_year'].mean():.0f}, $\sigma$={data[f'{avg_metric}_year'].std():.0f})")
        ax.legend()
    plt.xlim = (xmin, xmax)
    plt.title(title)
    plt.tight_layout()
    plt.show()