import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# compute denominator
def number_of_trees(max_depth, num_features):
    if max_depth <= 0:
        return 2
    return 2 + num_features * number_of_trees(max_depth-1, num_features-1) ** 2

#Read files with numerator:
with open('tree_farms_sets', 'rb') as f:
    res = pickle.load(f)

# compute number of features for the datasets
num_features = {}

for data in res.keys():
    df = pd.read_csv(data)
    num_features[data] = df.shape[1]-1
    print(data, df.shape[1]-1)
    
    
#Plot larger figure
plt.figure()
depth_arr = [1,2,3,4,5,6,7]
for i_data, key in enumerate(res.keys()):
    y = []
    for depth_id, depth in enumerate(depth_arr):
        y_val = np.log(res[key][depth_id] / number_of_trees(depth, num_features[key]))
        y += [y_val]
        
        new_value = (int(num_features[key]) - 10) / 12 * (1 - 0.3) + 0.3
        
        plt.scatter(depth,y_val, s= 100, alpha = new_value, zorder = 2, color = mpl.cm.tab10(depth_id))#color = color[i_data]
    plt.plot(depth_arr, y, alpha = new_value, c = 'k', linewidth = 0.5)
    
plt.xlabel("Tree depth", size = 18)
plt.ylabel("log Rashomon ratio, %", size = 18)  
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.savefig('7.png',bbox_inches = 'tight')

#Plot smaller figure
plt.figure()
depth_arr = [1,2,3]
for i_data, key in enumerate(res.keys()):
    y = []
    for depth_id, depth in enumerate(depth_arr):
        y_val = np.log(res[key][depth_id] / number_of_trees(depth, num_features[key]))
        y += [y_val]
        
        new_value = (int(num_features[key]) - 10) / 12 * (0.8 - 0.2) + 0.2
        
        plt.scatter(depth,y_val, s= 200, alpha = new_value, zorder = 2, color = mpl.cm.tab10(depth_id))#color = color[i_data]
    plt.plot(depth_arr, y, alpha = new_value, c = 'k', linewidth = 0.5)
    
plt.xlabel("Tree depth", size = 18)
plt.ylabel("log Rashomon ratio, %", size = 18)  
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.locator_params(axis='x', nbins=3)
plt.savefig('3.png',bbox_inches = 'tight')