#!/usr/bin/env python
# coding: utf-8

# **Packages**

# Some import
import sys
from pathlib import Path

import numpy as np
from sklearn.tree import export_text
from tqdm import tqdm
from pysat.solvers import Solver
import time
from math import ceil, floor
import random as rd
import matplotlib.pyplot as plt

import my_tree as mt
from sklearn import datasets
from sklearn import tree
from sklearn.model_selection import train_test_split, LeaveOneGroupOut
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
# from sklearn.model_selection import cross_val_score # Cross validation
import os
import functools
import operator
from random import shuffle
import copy
from joblib import dump
import json


def tree_cross_validation(X, Y, cv=10, max_depth=None, groups=None, nb_tree=None):
    nb_instance = len(Y)
    quotient = nb_instance // cv
    reste = nb_instance % cv
    if nb_tree is None:
        nb_tree = cv
    if groups is None:
        groups = [quotient * [i] for i in range(1, cv + 1)]
        groups = functools.reduce(operator.iconcat, groups, [])
        groups += [i for i in range(1, reste + 1)]
        shuffle(groups)
    loo = LeaveOneGroupOut()
    score = 0
    trees = []
    loo.split(X, Y, groups=groups)
    num_tree = 0
    for ind_train, ind_test in loo.split(X, Y, groups=groups):
        if num_tree < nb_tree:
            x_train = [X[x] for x in ind_train]
            y_train = [Y[x] for x in ind_train]
            x_test = [X[x] for x in ind_test]
            y_test = [Y[x] for x in ind_test]
            dt = DecisionTreeClassifier(max_depth=max_depth)
            dt.fit(x_train, y_train)
            y_predict = dt.predict(x_test)
            accuracy = (np.sum(y_predict == y_test) / len(y_test)) * 100
            score += accuracy
            trees.append((copy.deepcopy(dt), ind_train, ind_test))
            num_tree += 1
    score /= nb_tree
    return score, trees, groups

too_heavy = []
database = {}
instances = []
attributes = []

file = sys.argv[1]
if len(sys.argv) == 1:
    print("Give name of the dataset as argument: for example: python3 generate_data_DT.py dataset/compas.csv")
    exit(1)
database[file] = pd.read_csv(file)
instances.append(database[file].shape[0])
attributes.append(database[file].shape[1])

# Set parameter for testing
random_state = None
cv = 10
nb_tree = 1
ratio_train_test = 0.1
for dataset in database.keys():

    print(f"Work on {dataset}")

    rec = {}

    data = database[dataset].copy()
    size = data.shape
    other_names = {}
    u = 0
    for y in list(data.columns):
        other_names[y] = str(u)
        u += 1
    data.rename(columns=other_names, inplace=True)
    labels = data[str(size[1] - 1)]
    labels = labels.to_numpy()
    data = data.drop(columns=[str(size[1] - 1)])
    data = data.to_numpy()

    score, trees, groups = tree_cross_validation(data, labels, cv=cv, max_depth=None)


    # score7, trees7, _ = tree_cross_validation(data, labels, cv=cv, max_depth = 7, groups = groups)
    tree_data = {"glb_accuracy": score}

    # tree_ = {}

    for i in range(len(trees)):

        # Generate a place to save data

        record = {}
        # record_7 = {}

        # Creation of a my_tree
        my_tree = mt.decision_tree()
        my_tree.from_DecisionTreeClassifier(trees[i][0])

        record["nb_node"] = int(trees[i][0].tree_.node_count)
        record["nb_leaf"] = int(trees[i][0].tree_.n_leaves)
        record["max_depth"] = int(trees[i][0].tree_.max_depth)

        list_dr = my_tree.list_direct_reason()
        is_reductable = []
        nb_path = 0
        len_reductable = []
        len_all_dr = []
        liste_k = []
        n = len(my_tree.bina)
        for k in list_dr.keys():
            for dr in list_dr[k]:
                nb_path += 1
                sufficient_reason = my_tree.find_sufficient_reason(dr, target=k,
                                                                   implicant=dr)  # The first dr is useless
                k = len(dr) - len(sufficient_reason)
                if k > 0:
                    coeff_r = k / len(dr)
                    is_reductable.append(coeff_r)
                    len_reductable.append(len(dr))
                    liste_k.append(k * 2 ** (n - len(dr)))
                len_all_dr.append(len(dr))

        record["avg_length_dr_reductable"] = float(np.mean(len_reductable))
        record["avg_legnth_dr"] = float(np.mean(len_all_dr))
        record["weighted_reduction"] = float((np.sum(liste_k) / (2 ** n)) * 100)
        record["dr_reductable"] = float((len(is_reductable) / nb_path) * 100)
        record["avg_reduction_%"] = float(np.mean(is_reductable) * 100)
        record["std_reduction_%"] = float(np.std(is_reductable) * 100)

        # Creation of a my_tree
        # my_tree7 = mt.decision_tree()
        # my_tree7.from_DecisionTreeClassifier(trees7[i][0])

        x_test = np.array([data[i] for i in trees[i][2]])
        y_test = np.array([labels[i] for i in trees[i][2]])


        nb_instances = min(100, x_test.shape[0])

        p = 4

        # Creation of list to store some results
        classified = []
        len_direct_reason = np.zeros((nb_instances, p))
        len_suff_reason = np.zeros((nb_instances, p))
        len_suff_direct = np.zeros((nb_instances, p))
        len_min_reason = np.zeros((nb_instances, p))

        len_necessary_features = np.zeros((nb_instances, p))
        len_relevant_features = np.zeros((nb_instances, p))
        t_heatmap = np.zeros((nb_instances, 2))
        list_instance = []
        list_ok = []
        list_reason = []
        contrastives = np.zeros((nb_instances, 1))
        positives = np.zeros((nb_instances, 2))
        negatives = np.zeros((nb_instances, 2))
        nb_minimales = np.zeros((nb_instances, 2))

        for j in tqdm(range(nb_instances), desc=f"{dataset} analysis"):

            # assert trees[i][0].predict(x_test[j].reshape(1,-1)) == my_tree.predict(x_test[j])

            if my_tree.predict(x_test[j]) == y_test[j]:
                list_ok.append([True, int(y_test[j])])
            else:
                list_ok.append([False, int(y_test[j])])

            list_instance.append(x_test[j].tolist())

            classified.append(bool((trees[i][0].predict(x_test[j].reshape(1, -1)) == y_test[j])[0]))

            tps = -time.time()
            direct_reason = my_tree.find_direct_reason(x_test[j])
            s_direct_reason = my_tree.unredundant_binarized_instance(direct_reason)
            t_direct_reason = my_tree.unbinarized_instance(direct_reason)
            tps += time.time()
            len_direct_reason[j] = [len(direct_reason), len(s_direct_reason), len(t_direct_reason), tps]
            direct_reason.reverse()

            tps = -time.time()
            sufficient_reason = my_tree.find_sufficient_reason(x_test[j], implicant=direct_reason)
            s_sufficient_reason = my_tree.unredundant_binarized_instance(sufficient_reason)
            t_sufficient_reason = my_tree.unbinarized_instance(sufficient_reason)
            tps += time.time()
            len_suff_direct[j] = [len(sufficient_reason), len(s_sufficient_reason), len(t_sufficient_reason), tps]

            tps = -time.time()
            minimal_reason = my_tree.find_min_reason(x_test[j])
            s_minimal_reason = my_tree.unredundant_binarized_instance(minimal_reason)
            t_minimal_reason = my_tree.unbinarized_instance(minimal_reason)
            tps += time.time()
            len_min_reason[j] = [len(minimal_reason), len(s_minimal_reason), len(t_minimal_reason), tps]
            minimal_reason.reverse()

            tps = -time.time()
            necessary_features, relevant_features = my_tree.find_necessary_features(x_test[j])
            s_necessary_features = my_tree.unredundant_binarized_instance(necessary_features)
            t_necessary_features = my_tree.unbinarized_instance(necessary_features, need_detail=True)
            tps += time.time()
            len_necessary_features[j] = [len(necessary_features), len(s_necessary_features), len(t_necessary_features),
                                         tps]

            s_relevant_features = my_tree.unredundant_binarized_instance(relevant_features)
            t_relevant_features = my_tree.unbinarized_instance(relevant_features, need_detail=True)
            len_relevant_features[j] = [len(relevant_features), len(s_relevant_features), len(t_relevant_features), tps]

            tps = -time.time()
            heatmap = my_tree.heatmap_sufficients(x_test[j])
            tps += time.time()
            t_heatmap[j] = [max(heatmap.values()), tps]

            ct = my_tree.enumarate_all_contrastive(x_test[j])
            contrastives[j] = [len(ct)]

            positive = my_tree.reason_with_min_positive(x_test[j])
            positives[j] = [len(positive), len([l for l in positive if l > 0])]

            negative = my_tree.reason_with_min_negative(x_test[j])
            negatives[j] = [len(negative), len([l for l in negative if l > 0])]

            tps = time.time()
            nb_min = my_tree.find_min_reason(x_test[j], nb = 1000)
            nb_minimales[j] = [ len(nb_min), time.time() - tps]
            list_reason.append(
                [t_direct_reason, t_sufficient_reason, t_minimal_reason, t_necessary_features, t_relevant_features])

        record["acc"] = (np.sum(classified) / len(classified)) * 100

        record["len_bin"] = len(my_tree.bina.keys())
        record["Is well classified"] = classified
        record["direct_r"] = len_direct_reason.tolist()

        '''record["suff_r"] = len_suff_reason.tolist()'''
        record["suff_r_from_d"] = len_suff_direct.tolist()

        record["min_r"] = len_min_reason.tolist()

        record["necessary_f"] = len_necessary_features.tolist()
        record["relevant_f"] = len_relevant_features.tolist()

        record["nb_suff"] = t_heatmap.tolist()

        record["instance"] = list_instance
        record["classified"] = list_ok
        record["reason"] = list_reason
        record["nb_contrastives"] = contrastives.tolist()
        record["negatives"] = negatives.tolist()
        record["positives"] = positives.tolist()
        record["nb minimales"] = nb_minimales.tolist()
        tree_data[f"Tree_{i}"] = record.copy()

    mean_node = 0
    mean_leaf = 0
    mean_max_depth = 0
    mean_percent_reductable = 0
    mean_avg_reduction = 0
    mean_std_deviation = 0
    mean_len_bin = 0
    mean_avg_length_dr_reductable = 0
    mean_avg_legnth_dr = 0
    mean_weighted_reduction = 0

    for k in tree_data.keys():
        if k[0] == "T":
            mean_node += tree_data[k]["nb_node"]
            mean_leaf += tree_data[k]["nb_leaf"]
            mean_max_depth += tree_data[k]["max_depth"]
            mean_percent_reductable += tree_data[k]["dr_reductable"]
            mean_avg_reduction += tree_data[k]["avg_reduction_%"]
            mean_std_deviation += tree_data[k]["std_reduction_%"]
            mean_len_bin += tree_data[k]["len_bin"]
            mean_avg_length_dr_reductable += tree_data[k]["avg_length_dr_reductable"]
            mean_avg_legnth_dr += tree_data[k]["avg_legnth_dr"]
            mean_weighted_reduction += tree_data[k]["weighted_reduction"]

    tree_data["mean_nb_node"] = mean_node / cv
    tree_data["mean_nb_leaf"] = mean_leaf / cv
    tree_data["mean_max_depth"] = mean_max_depth / cv
    tree_data["mean_percent_reductable"] = mean_percent_reductable / cv
    tree_data["mean_avg_percent_reduction"] = mean_avg_reduction / cv
    tree_data["mean_std_percent_reduction"] = mean_std_deviation / cv
    tree_data["mean_len_bin"] = mean_len_bin / cv
    tree_data["mean_avg_length_dr"] = mean_avg_legnth_dr / cv
    tree_data["mean_avg_length_dr_reductable"] = mean_avg_length_dr_reductable / cv
    tree_data["mean_weighted_reduction"] = mean_weighted_reduction / cv

    d = file.split("/")[-1].split(".")[0]
    print(d)
    with open(f'./json/{d}_DT_global.json', 'w') as file:
        file.write(json.dumps(tree_data, indent=4))

    print(f"{dataset} done")
