#!/usr/bin/env python
# coding: utf-8

# **Packages**

# Some import
import numpy as np
from tqdm import tqdm

import my_tree as mt
from sklearn.model_selection import train_test_split, LeaveOneGroupOut
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import os
import functools
import operator
from random import shuffle
import copy
import time
import sys
from collections import OrderedDict

database = {}

instances = []
attributes = []


if len(sys.argv) == 1:
    print("Give name of the dataset as argument: for example: python3 generate_data_DT.py dataset/compas.csv")
    exit(1)
file = sys.argv[1]
database[file] = pd.read_csv(file)
instances.append(database[file].shape[0])
attributes.append(database[file].shape[1])



def tree_cross_validation(X, Y, cv=10, max_depth=None, groups=None, nb_tree=None):
    nb_instance = len(Y)
    quotient = nb_instance // cv
    reste = nb_instance % cv
    if nb_tree is None:
        nb_tree = cv
    if groups is None:
        groups = [quotient * [i] for i in range(1, cv + 1)]
        groups = functools.reduce(operator.iconcat, groups, [])
        groups += [i for i in range(1, reste + 1)]
        shuffle(groups)
    loo = LeaveOneGroupOut()
    score = 0
    trees = []
    loo.split(X, Y, groups=groups)
    for ind_train, ind_test in loo.split(X, Y, groups=groups):
        x_train = [X[x] for x in ind_train]
        y_train = [Y[x] for x in ind_train]
        x_test = [X[x] for x in ind_test]
        y_test = [Y[x] for x in ind_test]
        dt = DecisionTreeClassifier(max_depth=max_depth)
        dt.fit(x_train, y_train)
        y_predict = dt.predict(x_test)
        accuracy = (np.sum(y_predict == y_test) / len(y_test)) * 100
        score += accuracy
        trees.append((copy.deepcopy(dt), ind_train, ind_test))
    score /= cv
    return score, trees, groups


# Set parameter for testing
random_state = None
cv = 4
nb_tree = 1
ratio_train_test = 0.1

features = ["Number_of_Priors", "score_factor", "Age_Above_FourtyFive", "Age_Below_TwentyFive",
            "African_American", "Asian", "Hispanic", "Native_American", "Other", "Female", "Misdemeanor",
            "Two_yr_Recidivism"]


def analyze_compas(instance):
    tmp = my_tree.unbinarized_instance(instance, need_detail=True)
    tmp2 = []
    for i, val in enumerate(tmp):
        if val[1] == 0.5:
            tmp2.append(("-" if val[2] == "-" else "") + features[val[0] - 1])
        else:
            tmp2.append(features[val[0] - 1] + ("<=" if val[2] == "-" else ">") + str(val[1]))
    return tmp2


dataset = file
print(f"Work on {dataset}")

rec = {}

data = database[dataset].copy()
size = data.shape
other_names = {}
u = 0
for y in list(data.columns):
    other_names[y] = str(u)
    u += 1
data.rename(columns=other_names, inplace=True)
labels = data[str(size[1] - 1)]
labels = labels.to_numpy()
data = data.drop(columns=[str(size[1] - 1)])
data = data.to_numpy()

score, trees, groups = tree_cross_validation(data, labels, cv=cv, max_depth=None)
i = 1
my_tree = mt.decision_tree()
my_tree.from_DecisionTreeClassifier(trees[i][0])

x_test = np.array([data[i] for i in trees[i][2]])
y_test = np.array([labels[i] for i in trees[i][2]])

p = 4
nb_echantillon = 10

for j in range(nb_echantillon):
    print("classified", (trees[i][0].predict(x_test[j].reshape(1, -1)) == y_test[j])[0])
    print("class=", y_test[j])
    direct_reason = my_tree.find_direct_reason(x_test[j])
    print("direct=", direct_reason)

    sufficient_reason = my_tree.find_sufficient_reason(x_test[j], implicant=direct_reason)
    print("sufficient=", sufficient_reason, "--", len(sufficient_reason))

    tps = time.time()
    print("\nCompute max 1000 minimal")
    minimal_reason = my_tree.find_min_reason(x_test[j], nb = 1000)
    l = len(minimal_reason[0])
    for m in minimal_reason:
        assert len(m) == l
    print("size minimal:", l, "nbmin=", len(minimal_reason), "time : ", time.time() - tps)

    # s_minimal_reason = my_tree.unredundant_binarized_instance(minimal_reason)
    # t_minimal_reason = my_tree.unbinarized_instance(minimal_reason)

    necessary_features, relevant_features = my_tree.find_necessary_features(x_test[j])
    print("necessary", necessary_features)
    print("relevant", relevant_features)

    print("compute max 10000 sufficients")
    tps = time.time()
    sufficients = my_tree.enumerate_all_sufficient(x_test[j], max=10000)
    print("nb s=", len(sufficients), "time:", time.time() - tps)

    print("heat")
    heat = my_tree.heatmap_sufficients(x_test[j])
    print(heat)
    print('nb sufficients: ', max(heat.values()))


    my_tree.heatmap_contrastive(x_test[j])
    print("-------------------------------------------------------------------------------------")
