#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from sklearn.ensemble import RandomForestRegressor
from sklearn.inspection import DecisionBoundaryDisplay
from python.pred_bakeoff_settings import estimator, max_depth, fname

import pandas as pd
import numpy as np
from sklearn import tree
import matplotlib.pyplot as plt
from matplotlib import colormaps 
from time import time
from tqdm import tqdm

resdf = pd.read_csv('sim_out/'+fname+'.csv')

datasets = sorted(list(set(resdf['Dataset'])))

tt = sorted(list(set(resdf['Trans'])))

folds = max(resdf['Fold'])+1

ncol = int(np.ceil(len(datasets)/2))

pretty_names = {'asm':'TBAS','id':'Id','pca':'PCA','rand':'Rand'}

fig = plt.figure(figsize=[8,3])
for di,ds in enumerate(datasets):
    plt.subplot(2,ncol,1+di)
    ddf = resdf.loc[resdf['Dataset']==ds,:]
    ddf = ddf.drop('Dataset',axis=1)
    #gdf = ddf.groupby('Trans').mean()
    toplot = [ddf.loc[ddf['Trans']==t,'RMSE'] for t in tt]
    plt.boxplot(toplot)
    #plt.xlabel(tt)
    plt.xticks(np.arange(1,len(tt)+1), [pretty_names[t] for t in tt], fontdict = {'size':7})
    plt.title(ds)
plt.tight_layout()
plt.savefig(fname+".pdf")
plt.close()

resdf.groupby(['Dataset','Trans']).mean()

mean = pd.pivot_table(resdf, values='RMSE', index = 'Trans', columns='Dataset', aggfunc='mean')
sd = pd.pivot_table(resdf, values='RMSE', index = 'Trans', columns='Dataset', aggfunc='std')

lb = mean - 2*sd/np.sqrt(folds)
ub = mean + 2*sd/np.sqrt(folds)

mean = np.round(mean,3).astype(str)

isbold = np.zeros_like(mean)
for i,v in enumerate(mean.columns):
    whomin = np.argmin(mean[v])
    for j,u in enumerate(mean.index):
        overlap = np.logical_and(ub.iloc[whomin,i] >= lb.iloc[j,i], lb.iloc[whomin,i] <= ub.iloc[j,i])
        isbold[j,i] = overlap

        if isbold[j,i]:
            mean.iloc[j,i] = "\\textbf{"+mean.iloc[j,i]+"}"

with open("tables/"+fname+".tex",'w') as f:
    mean.to_latex(f)
