import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import glob
import os



if __name__ == '__main__':

    dataset_order = ['Boston', 'Concrete', 'Energy', 'Kin8nm', 'Naval', 'Combined', 'Protein', 'Wine', 'Yacht', 'Year']
    dataset_order_padded = sum([['',d,''] for d in dataset_order], [])
    method_mapping = {
        'SVGD': '$k_2 = k_1$',
        'h-SVGD log s': '$k_2 = \log(d) \cdot k_1$',
        'h-SVGD sqrt s': '$k_2 = \sqrt{d} \cdot k_1$'
    }

    os.makedirs('bnn/img', exist_ok=True)
    path = 'bnn/out/*.csv'
    dfs = []
    for f in glob.glob(path):
        dataset = f.replace('bnn/out/', '').replace('.csv', '').replace('.arff', '').capitalize()
        dataset_index = dataset_order.index(dataset)
        dfs.append(pd.read_csv(f).assign(dataset=dataset, dataset_index=dataset_index))
    df = pd.concat(dfs).drop(columns=['Unnamed: 0'])
    df = df[df.method.isin(method_mapping.keys())]
    df['dataset_method'] = df['dataset_index'] + df.method.apply(lambda x: ( list(method_mapping.keys()).index(x) - 1) / 6)
    df['method_latex'] = df.method.replace(method_mapping)
    df = df.sort_values(by='dataset_method')


    # Create plot
    plt.rcParams.update({'text.usetex': True})
    sns.set_style('white')
    fig, ax = plt.subplots(figsize=(12, 3))

    # Add standard error bars
    ax.vlines(
        x=df.dataset_method,
        ymin=df.DAMV_MEAN-df.DAMV_STD,
        ymax=df.DAMV_MEAN+df.DAMV_STD,
        linewidth=1.5,
        colors=['blue', 'orange', 'green']*10,
        alpha=0.5
    )

    # Add mean DAMV for each dataset
    g = sns.scatterplot(
        data=df,
        x='dataset_method',
        y='DAMV_MEAN',
        hue='method_latex',
        style='method_latex',
        # palette=['blue', 'orange', 'green']*10,
        size='method_latex',
        sizes={'$k_2 = k_1$': 80, '$k_2 = \log(d) \cdot k_1$': 100, '$k_2 = \sqrt{d} \cdot k_1$': 120},
        # markers=[',', 'o', '^']*10,
        ax=ax
    )
    g.set_xticks(range(10))
    g.set_xticklabels(dataset_order)

    # Adjust axes and legend
    ylim = ax.get_ylim()
    ax.vlines(
        x=[x+0.5 for x in range(10)],
        ymin=ylim[0],
        ymax=ylim[1],
        linewidth=1,
        colors='grey',
        alpha=0.5
    )
    ax.set_ylim(ylim)
    ax.set_xlim(-0.5, 9.5)
    ax.set_xlabel('')
    ax.set_ylabel('DAMV')
    ax.get_legend().set_title('')
    plt.legend(markerscale=0.8)

    # Save figure
    plt.savefig('bnn/img/damv.png', bbox_inches='tight', dpi=300)