# Plots number of trees, patterns and diversity for different values of the Rashomon parameter
# This script runs after compute characteristics and assumes that there are computed files in the folder.

import pandas as pd
import numpy as np
import pickle 
import matplotlib as mpl
import matplotlib.pyplot as plt

theta_array = [0.03, 0.05, 0.07, 0.09, 0.11]
dic_res = {}
for theta in theta_array:
    read_file = './monks3_' + str(theta)
    with open(read_file, 'rb') as f:
        res = pickle.load(f)
    dic_res[theta] = res

titles = ['Number of trees', "Number of patterns", "Diversity"]

for item_id in range(len(titles)):
    plt.figure()
    rho_array = np.linspace(0,0.25, 20)
    for theta_id, theta in enumerate(theta_array):

        to_plot = dic_res[theta][item_id]

        idx = to_plot.shape[0]
        arr1 = np.argwhere(to_plot == -1)
        if len(arr1) > 0:
            idx = np.argwhere(to_plot == -1)[0][0]

        to_plot = to_plot[0:idx]

        mean = np.mean(to_plot, 1)
        std = np.std(to_plot, 1)
        plt.plot(rho_array[0:idx], mean, label = r'$\theta=$'+str(theta))
        plt.fill_between(rho_array[0:idx], mean-std, mean+std, alpha = 0.1)
        if item_id == len(titles)-1:
            plt.legend(loc = (1.05,0), fontsize = 16, labelspacing = 0.1)

    plt.ylabel(titles[item_id], size = 20)
    plt.xlabel(r"Label noise, $\rho$", size = 20)

    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    
    plt.savefig(str(item_id)+'_combined.png',bbox_inches = 'tight')