import json
import os
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt

fichiers = os.listdir("../result_RF")
#fichiers = ["placement_RF_stat.json"] # os.listdir("../result_RF") #["mnist49_RF_stat.json", "mnist38_RF_stat.json"] #, "mnist38_RF_stat.json", "farm-ads_RF_stat.json", "bank_marketing_RF_stat.json", "dorothea_RF_stat.json", "adult_RF_stat.json", "mnist_RF_stat.json", "p53_RF_stat.json"]

global_result = {}

print_timeout = False

list_acc = []

for f in tqdm(fichiers) :
    
    os.chdir("../result_RF")
    
    with open(f) as json_data:
        data_dict = json.load(json_data)
    
    os.chdir("../plot_rf")
    if not f.split(".")[0] in os.listdir() :
        os.mkdir(f.split(".")[0])
    os.chdir(f.split(".")[0])
    
    
    dataset = f.split(".")[0].split("_")[0]
    
    if dataset == "ad" :
        dataset = "ad_data"
    
    #For each tree
    
    len_values = {}
    time_values = {}
    
    colors = ['r.', 'gx', 'b*', 'yd', 'm+', 'kX','g.','bv','r<']
    colors_bis = ['r+', 'gx', 'b*', 'yd', 'm+', 'kX']
    
    ind_wanted = 0
    
    
    
    for k in data_dict.keys() :
        
        if k == "acc" :
            list_acc.append((f,data_dict[k]))
            
        if k == "10s" :
            len_values["10s"] = data_dict[k]
            time_values["10s"] = [10]*len(data_dict[k])
            
        if k == "60s" :
            len_values["60s"] = data_dict[k]
            time_values["60s"] = [60]*len(data_dict[k])
            
        if k == "600s" :
            len_values["600s"] = data_dict[k]
            time_values["600s"] = [600]*len(data_dict[k])
        
        #Box plot
        if k == "minimal majoritary" :
            len_r = []
            time_r = []
            for i in data_dict[k] :
                if i[ind_wanted] != 0 or print_timeout :
                    len_r.append(i[ind_wanted])
                    time_r.append(i[1])
            len_values["Minimal majoritary"] = len_r
            time_values["Minimal majoritary"] = time_r
            
        elif k == "direct" :
            len_r = []
            time_r = []
            for i in data_dict[k] :
                if i[ind_wanted] != 0 or print_timeout :
                    len_r.append(i[ind_wanted])
                    time_r.append(i[1])
            len_values["Direct"] = len_r
            time_values["Direct"] = time_r
        elif k == "sufficient" :
            len_r = []
            time_r = []
            for i in data_dict[k] :
                if i[ind_wanted] != 0 or print_timeout :
                    len_r.append(i[ind_wanted])
                    time_r.append(i[1])
            len_values["Sufficient"] = len_r
            time_values["Sufficient"] = time_r
        
        elif k == "majoritary" :
            len_r = []
            time_r = []
            for i in data_dict[k] :
                if i[ind_wanted] != 0 or print_timeout :
                    len_r.append(i[ind_wanted])
                    time_r.append(i[1])
            len_values["Majoritary"] = len_r
            time_values["Majoritary"] = time_r
        elif k == "lime" :
            len_r = []
            time_r = []
            for i in data_dict[k] :
                if i[ind_wanted] != 0 or print_timeout :
                    len_r.append(i[0])
                    time_r.append(i[1])
            len_values["Lime"] = len_r
            time_values["Lime"] = time_r
            
     
    #global plot
    plt.figure()
    i = 0
    labels = []
    for k in len_values.keys() :
        if k[-1] != "s" and k != "Minimal majoritary" : 
            plt.semilogx(time_values[k], len_values[k], colors[i], data=k)
            labels.append(k)
            i += 1
            
    plt.xlabel("time (s)")
    plt.ylabel("Size of reason")
    #plt.title(dataset)
    plt.grid()
    plt.legend(labels)
    plt.savefig('general_plot.pdf')
    
    #plot VS len
    plt.figure()
    plt.plot(len_values["Majoritary"],len_values["Sufficient"],"r+")
    mini = int(np.min((np.min(len_values["Majoritary"]),np.min(len_values["Sufficient"])))) - 1
    maxi = int(np.max((np.max(len_values["Majoritary"]),np.max(len_values["Sufficient"])))) + 3
    values = [i for i in range(mini, np.max((2,maxi)), np.max((1,(maxi-mini)//10)))]
    plt.plot(values,values)
    plt.xticks(values)
    plt.yticks(values)
    plt.grid()
    plt.xlabel("Size of majoritary reason")
    plt.ylabel("Size of sufficient reason")
    #plt.title(dataset + " : majority-sufficient VS sufficient reasons Size" )
    plt.savefig('MSR_VS_SR_Length.pdf')
    
    if "10s" in data_dict.keys():
        #plot VS len
        plt.figure()
        plt.plot(len_values["10s"],len_values["Sufficient"],"r+")
        mini = int(np.min((np.min(len_values["10s"]),np.min(len_values["Sufficient"])))) - 1
        maxi = int(np.max((np.max(len_values["10s"]),np.max(len_values["Sufficient"])))) + 3
        values = [i for i in range(mini, np.max((2,maxi)), np.max((1,(maxi-mini)//10)))]
        plt.plot(values,values)
        plt.xticks(values)
        plt.yticks(values)
        plt.grid()
        plt.xlabel("Size of an approximation of a minimal majoritary reason (after 10s)")
        plt.ylabel("Size of sufficient reason")
        #plt.title(dataset + " : majority-sufficient VS sufficient reasons Size" )
        plt.savefig('MSR_VS_10s_Length.pdf')
    
    
    fig = plt.figure()
    #plt.title(dataset + " : Size of reason boxplot ")
    labs = [k for k in len_values.keys() if k != "Lime" and k != "Direct" and k != "Minimal majoritary" and k[-1] != "s"] + [k for k in len_values.keys() if k[-1] == "s" and len_values[k][0] != 0 ]
    vals = []
    for key in labs :
        vals.append(len_values[key])
    plt.boxplot(vals, labels = labs)
    plt.grid()
    plt.ylabel("Size of reason")
    plt.savefig('Boxplot.pdf')
    
    os.chdir("..")
        