#!/usr/bin/env python
# coding: utf-8

# **Packages**
import numpy as np
#from sklearn.tree import export_text
from tqdm import tqdm
import time
import sys

import my_tree as mt
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import LeaveOneGroupOut
#from sklearn.model_selection import cross_val_score # Cross validation 
import os
import functools
import operator
from random import shuffle
import copy
import my_forest as mf
from lime.lime_tabular import LimeTabularExplainer

import dask
from multiprocessing import freeze_support

sys.path.append('build')
#sys.path.append('/usr/lib64/boost169/')
import proto_suff
from timeout import timeout


import json

def data_lime(instance, X_train, rf) :
    explainer = LimeTabularExplainer(X_train, class_names=rf.classes_)
    num_f = len(instance)
    exp = explainer.explain_instance(instance, rf.predict_proba, num_features= num_f, top_labels = int(rf.predict(instance.reshape(1,-1))[0])==1 )
    l = exp.as_list()
    dico = {}
    worst_score = 0
    for e in l:
        if float(e[1]) < 0 :
            worst_score += float(e[1])
    for e in l :
        info = e[0].split(' ')
        if float(e[1]) != 0 :
            if len(info) == 3  :
                if info[1][0] == '<' and instance[int(info[0])] <= float(info[2]) :
                    dico[(int(info[0]), float(info[2]), '-')] = float(e[1])
                elif instance[int(info[0])] >= float(info[2])  :
                    dico[(int(info[0]), float(info[2]), '+')] = float(e[1])
            elif len(info) == 5 and instance[int(info[2])] <= float(info[4]) and instance[int(info[2])] >= float(info[0]) :
                dico[(int(info[2]), float(info[0]), '+')] = float(e[1])/2
                dico[(int(info[2]), float(info[4]), '-')] = float(e[1])/2
    good_score = 0
    output = 0
    i=0
    v = list(dico.values())
    triage = [(v[i], i) for i in range(len(dico.keys()))]
    triage.sort(reverse=True)
    keys = [list(dico.keys())[i[1]] for i in triage]
    while i < len(keys) and  good_score < abs(worst_score) :
        if dico[keys[i]] > 0 :
            good_score += dico[keys[i]]
            output += 1
        i += 1
    return dico, output

def fusion_dict(list_dict) :
    output = list_dict[0]
    for d in list_dict[1::] :
        for k in d.keys() :
            output[k] += d[k]
    for k in output.keys() :
        if type(output[k]) != list :
            output[k] /= len(list_dict)
    return output
 

def rf_cross_validation(X, Y, nb_tree, cv=10, max_depth = None, groups = None, nb_forest = None) :
    nb_instance = len(Y)
    quotient = nb_instance // cv
    reste = nb_instance % cv
    if nb_forest is None :
        nb_forest = cv
    if groups is None :
        groups = [quotient*[i] for i in range(1,cv+1)]
        groups = functools.reduce(operator.iconcat, groups, [])
        groups += [i for i in range(1,reste+1)]
        shuffle(groups)
    loo = LeaveOneGroupOut()
    score = 0
    forests = []
    i = 0
    for ind_train, ind_test in loo.split(X, Y, groups=groups):
        if i < nb_forest :
            i += 1
            x_train = [X[x] for x in ind_train]
            y_train = [Y[x] for x in ind_train]
            x_test = [X[x] for x in ind_test]
            y_test = [Y[x] for x in ind_test]
            rf = RandomForestClassifier(max_depth=max_depth, n_estimators=nb_tree)
            rf.fit(x_train, y_train)
            y_predict = rf.predict(x_test)
            accuracy = (np.sum(y_predict == y_test)/len(y_test))*100
            score += accuracy
            forests.append((copy.deepcopy(rf),ind_train,ind_test))
    score /= nb_forest
    return score, forests, groups    


#Set parameter for testing

def work_on_forest(forest, d , l, acc, name = None, Lime = True) :
    
    dataset = name.split('_')[0]
    
    
    os.chdir(f"../cnf_files/{dataset}")
    
    
    total_record = {}
    indice = 0
    
    for f in forest :
    
        record = {}
            
        #Creation of a my_tree
        my_forest = []
        for t in range(len(f[0].estimators_)) :
            my_tree = mt.decision_tree()
            my_tree.from_DecisionTreeClassifier(f[0].estimators_[t])
            my_forest.append(my_tree)
        my_forest = mf.decision_forest(my_forest)
        my_forest.labels = f[0].classes_
        
        X_train = np.array([d[i] for i in f[1]])
        
        x_test = np.array([d[i] for i in f[2]])
        y_test = np.array([l[i] for i in f[2]])
        
        nb_echantillon = np.min((25,len(f[2]))) # len(y_test)
        
        os.chdir(f"{dataset}_rf")
        tps_base = -time.time()
        explainer = proto_suff.MinimumMajoritaryExplainer()
        fc = my_forest.compileForest(0, write_file=True, name_file=f"{name}_rf.txt")
        os.chdir("..")
        
        for tree in fc :
            explainer.addTree(tree)
        tps_base += time.time()

        
        #Creation of list to store some results
        len_suff_reason = np.zeros((nb_echantillon, 2))
        len_min_reason = np.zeros((nb_echantillon, 2))
        len_dir_reason = np.zeros((nb_echantillon, 2))
        len_suff_proto_reason = np.zeros((nb_echantillon, 2))
        len_lime_s = np.zeros((nb_echantillon, 2))
        len_10s = np.zeros(nb_echantillon)
        len_60s = np.zeros(nb_echantillon)
        len_600s = np.zeros(nb_echantillon)
        
        list_reason = []
        
        list_instance = []
        
        nb_tricky_instance = 0
        
        list_ok = []
        
        os.chdir(f"{dataset}_instances")
        file = open(f"{name}_instances.txt", "w")
        os.chdir("..")
        
        for j in tqdm(range(0,nb_echantillon)) : #x_test.shape[0]) 
            
        # assert forests[0][0].predict(x_test[j].reshape(1,-1)) == my_forest.predict(x_test[j])[0]
            
            if j <= 9 :
                name_f = name+"-0" +str(j)
            else :
                name_f = name+"-" +str(j)
            
            #my_forest.compileForest(target=my_forest.take_decision(x_test[j])[0], write_file=True, name_file=f"{name}.txt")
            
            '''os.chdir(f"{dataset}_wcnf")
            try :
                tps = -time.time()
                min_reason = []
                min_reason = my_forest.find_pseudo_min_reason(x_test[j], time_out = 600, name = name_f, writing_mode = True , compute=True)
                tps +=time.time()
                assert my_forest.is_a_majority_implicant(min_reason, my_forest.take_decision(x_test[j])[0])
                len_min_reason[j] = [len(min_reason), tps]
            except TimeoutError :
                print("Time out Min Maj")
            os.chdir("..")'''
            
            os.chdir(f"{dataset}_wcnf")
            approx10s_reason = my_forest.find_approx_proto_min_reason(x_test[j], time_out = 10, name = name_f, writing_mode = True , compute=True)
            len_10s[j] = len(approx10s_reason)
            '''approx60s_reason = my_forest.approx_len_proto_min_reason(x_test[j], time_out = 60, name = name_f, writing_mode = True , compute=True)
            len_60s[j] = len(approx60s_reason)
            approx600s_reason = my_forest.find_approx_proto_min_reason(x_test[j], time_out = 600, name = name_f, writing_mode = True , compute=True)
            len_600s[j] = len(approx600s_reason)'''
            os.chdir("..")
        
            if my_forest.predict(x_test[j])[0]==y_test[j] :
                if type(y_test[j]) == np.int64 :
                    list_ok.append([True, int(y_test[j])])
                else :
                    list_ok.append([True, y_test[j]])
            else :
                if type(y_test[j]) == np.int64 :
                    list_ok.append([False, int(y_test[j])])
                else :
                    list_ok.append([False, y_test[j]])
                
            list_instance.append(x_test[j].tolist())
            
            os.chdir(f"{dataset}_gcnf")
            try : 
                tps = -time.time()
                time_out = 600
                with timeout(seconds=time_out) :
                    sufficient_reason = []
                    sufficient_reason = my_forest.find_sufficient_reason(x_test[j], name = name_f, compute=True)
                tps +=time.time()
                len_suff_reason[j] = [len(sufficient_reason), tps]
            except TimeoutError :
                print("Time out MUS")
                
            os.chdir("..")
            
            tps = -time.time()
            direct_reason = my_forest.find_direct_reason(x_test[j])
            tps += time.time()  
            len_dir_reason[j] = [len(direct_reason), tps]
            
            if Lime : 
                tps = -time.time()
                len_lime = data_lime(x_test[j], X_train, f[0])[1]
                tps += time.time()
                len_lime_s[j] = [len_lime, tps]
            
            tps = -time.time()
            x_bin = my_forest.binarized_instance(x_test[j])
            file.write(str(x_bin)[1:-1].replace(",","") + "\n")
            suff_proto_reason = explainer.explain(x_bin)
            #suff_proto_reason = my_forest.find_proto_sufficient_reason(x_test[j])[0]
            tps +=time.time() + tps_base
            len_suff_proto_reason[j] = [len(suff_proto_reason), tps]
            
            list_reason.append([direct_reason, sufficient_reason, suff_proto_reason, approx10s_reason])
        
        file.close()
        
        record["acc"] = acc
        record["instance"] = list_instance
        record["classified"] = list_ok
        record["len_bin"] = [len(my_forest.bina.keys()), x_test.shape[1]]
        
        
        record["lime"] = len_lime_s.tolist()
        
        #record["minimal majoritary"] = len_min_reason.tolist()
        record["sufficient"] = len_suff_reason.tolist()
        record["direct"] = len_dir_reason.tolist()
        record["majoritary"] = len_suff_proto_reason.tolist()
        record["10s"] = len_10s.tolist()
        record["60s"] = len_60s.tolist()
        record["600s"] = len_600s.tolist()
        
        record["reason"] = list_reason
        
        record["hashmap"] = [str(my_forest.bina)]
        
        total_record[f"forest_{indice}"] = record
        indice += 1
        
        os.chdir("..")
    
    #record["dir_proto_r"] = len_dir_proto_reason.tolist()
    
        
    
    return record 

def main() :
    freeze_support()
    
    os.chdir("../dataset")
    fichiers = os.listdir()
    
    database = {}
    
    fichiers = [arg +".csv" for arg in sys.argv[1:]]
    
    for f in fichiers :
        database[f.split(".")[0]] = pd.read_csv(f)
    
    #fichiers = os.listdir()
    #fichiers = fichiers[76::]
    
    os.chdir("../script")
    
    info = "info_data_RF.json"
    with open(info) as json_data:
        
        info_nb_tree = json.load(json_data)
        
    for dataset in database.keys() :
        
        print(f"Work on {dataset}")
        
        nb_tree = info_nb_tree[dataset]
        
        cv = 10
        nb_forest = cv
        nb_w = 1
        
        data = database[dataset].copy()
        size = data.shape
        labels = data[data.columns[size[1]-1]]
        labels = np.array(labels)
        data = data.drop(columns=[data.columns[size[1]-1]])
        data = data.values
        
        score, forests, groups = rf_cross_validation(data, labels, nb_tree, cv=cv, max_depth = 8, nb_forest = nb_forest)
        
        os.chdir("../cnf_files")
        os.mkdir(dataset)
        os.mkdir(f"{dataset}/{dataset}_rf")
        os.mkdir(f"{dataset}/{dataset}_instances")
        os.mkdir(f"{dataset}/{dataset}_gcnf")
        os.mkdir(f"{dataset}/{dataset}_wcnf")
        os.chdir("../script")
        
        list_record = []
        num_rf = 0
        print(score)
        #Generate a place to save data
        for f in range(0,cv) : 
            record = dask.delayed(work_on_forest)([forests[f]], data, labels, score, name = dataset + "_" + str(num_rf))
            num_rf += 1
            list_record.append(record)
            
        final_record = dask.delayed(fusion_dict)(list_record)  
        final_results = final_record.compute(scheduler="processes", num_workers = nb_w )
        
        os.chdir("../result_RF")
        
                
        with open(f'{dataset}_RF_stat.json', 'w') as file:
            file.write(json.dumps(final_results, indent = 4))
        
if __name__ == '__main__':
    main()
            
        
