import pandas as pd
import numpy as np
import pickle 

data_list = ['car_evaluation'
             , 'monks2'
             , 'monks1'
             , 'monks3'
             , 'bar7'
             , 'compas'
             , 'fico'
             , 'bcw_bin'
             , 'carryout_takeaway'
             , 'restaurant_20'
             , 'bar'
             , 'coffee_house'
            ]

data_names = ['Car Eval'
             , 'Monks 2'
             , 'Monks 1'
             , 'Monks 3'
             , 'Bar-7'
             , 'Compas'
             , 'FICO'
             , 'BCW'
             , 'Takeaway'
             , 'Restaurant'
             , 'Bar'
             , 'Coffee'
            ]

dic_res = {}
for file in data_list:
    read_file = './noise_trees_'+file
    with open(read_file, 'rb') as f:
        res = pickle.load(f)
    dic_res[file] = res
    
# Do two separate plots

import matplotlib as mpl
import matplotlib.pyplot as plt

titles = ['Number of trees', "Number of patterns", "Diversity"]#, r"Average accuracy over $R_{set}$"

for item_id in range(len(titles)):
    fig, ax = plt.subplots(2,1, figsize = (6, 9))
    rho_array = np.linspace(0,0.25, 40)
    for file_id, file in enumerate(data_list):

        to_plot = dic_res[file][item_id]

        idx = to_plot.shape[0]
        arr1 = np.argwhere(to_plot == -1)
        if len(arr1) > 0:
            idx = np.argwhere(to_plot == -1)[0][0]

        to_plot = to_plot[0:idx]

        mean = np.mean(to_plot, 1)
        std = np.std(to_plot, 1)
        
        if np.mean(dic_res[file][3][0]) > 0.7:
            
            ax[0].plot(rho_array[0:idx], mean, color = mpl.cm.tab20(file_id), label = data_names[file_id])
            if item_id == len(titles)-1:
                ax[0].legend(loc = (1.05,0.2), fontsize = 20, labelspacing = 0.1)
            ax[0].fill_between(rho_array[0:idx], mean-std, mean+std, color = mpl.cm.tab20(file_id),alpha = 0.1)
            ax[0].tick_params(axis='both',labelsize=20)
        else:
            
            ax[1].plot(rho_array[0:idx], mean, color = mpl.cm.tab20(file_id), label = data_names[file_id])
            if item_id == len(titles)-1:
                ax[1].legend(loc = (1.05,0.2), fontsize = 20, labelspacing = 0.1)
            ax[1].fill_between(rho_array[0:idx], mean-std, mean+std, color = mpl.cm.tab20(file_id),alpha = 0.1)
            ax[1].tick_params(axis='both',labelsize=20)


    #fig.supylabel(titles[item_id], size = 22)
    fig.text(-0.05, 0.5, titles[item_id], va='center', rotation='vertical', size = 26)
    plt.xlabel(r"Label noise, $\rho$", size = 26)


    #plt.show()
    plt.savefig(str(item_id)+'_split.pdf',bbox_inches = 'tight', dpi = 200)