#!/usr/bin/env python
# coding: utf-8

# **Packages**

# Some import
import numpy as np
from pysat.pb import PBEnc
from tqdm import tqdm

import my_tree as mt
from sklearn.model_selection import train_test_split, LeaveOneGroupOut
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import os
import functools
import operator
from random import shuffle
import copy
import matplotlib.pyplot as plt


database = {}

instances = []
attributes = []

file = "./dataset/mnist38"

database[file] = pd.read_csv(file + ".csv")
instances.append(database[file].shape[0])
attributes.append(database[file].shape[1])

def getImages(x_test, pos, data, *, treatdata = True):
    images = np.zeros((2, 28, 28))
    if treatdata :
        implicant = []
        for key in data:
            l = my_tree.unbinarized_instance([key], need_detail=True)[0]
            implicant.append(l[0] * (1 if l[2] == "+" else -1))
        print("sufficient: sz=", len(implicant), "is ", implicant)
    else:
        implicant = data

    for key in implicant:
        if key < 0:
            continue
        d = abs(key)
        num_col = (d - 1) // 28
        num_line = (d - 1) % 28
        images[0, num_col, num_line] = implicant[key] if isinstance(implicant, dict) else 100

    for key in implicant:
        if key > 0:
            continue
        d = abs(key)
        num_col = (d - 1) // 28
        num_line = (d - 1) % 28
        images[1, num_col, num_line] = implicant[key] if isinstance(implicant, dict) else 100
    return images

def display_images(x_test, pos, implicant, title):
    images = np.zeros((8, 28, 28))
    fig, axes = plt.subplots(1, 4, figsize=(20, 20))
    images[0] = np.reshape(x_test[pos], (28, 28))
    # Display
    for key in implicant:
        if key < 0:
            continue
        d = abs(key)
        num_col = (d - 1) // 28
        num_line = (d - 1) % 28
        images[1, num_col, num_line] = implicant[key] if isinstance(implicant, dict) else 100

    for key in implicant:
        if key > 0:
            continue
        d = abs(key)
        num_col = (d - 1) // 28
        num_line = (d - 1) % 28
        images[2, num_col, num_line] = implicant[key] if isinstance(implicant, dict) else 100

    titles = [f"Correct", f"{title} ligh on", f"{title} light off", f"{title} map both"]

    axes.flat[0].imshow(images[0])
    axes.flat[0].title.set_text(titles[0])

    axes.flat[1].imshow(images[1], cmap='Blues', interpolation="nearest")
    axes.flat[1].title.set_text(titles[1])

    axes.flat[2].imshow(images[2], cmap='Reds', interpolation="nearest")
    axes.flat[2].title.set_text(titles[2])
    axes.flat[3].imshow(images[1], cmap='Blues', interpolation="nearest")
    axes.flat[3].imshow(images[2], cmap='Reds', alpha=0.4, interpolation="nearest")

    axes.flat[3].imshow(images[0], alpha=0.2)
    axes.flat[3].title.set_text(titles[3])




def display_heat_map(h, label):

    heat2 = {}
    for key in h:
        l = my_tree.unbinarized_instance([key], need_detail=True)[0]
        heat2[l[0] * (1 if l[2] == "+" else -1)] = h[key]

    display_images(x_test, pos, heat2, label)


def tree_cross_validation(X, Y, cv=10, max_depth=None, groups=None, nb_tree=None):
    nb_instance = len(Y)
    quotient = nb_instance // cv
    reste = nb_instance % cv
    if nb_tree is None:
        nb_tree = cv
    if groups is None:
        groups = [quotient * [i] for i in range(1, cv + 1)]
        groups = functools.reduce(operator.iconcat, groups, [])
        groups += [i for i in range(1, reste + 1)]
        shuffle(groups)
    loo = LeaveOneGroupOut()
    score = 0
    trees = []
    loo.split(X, Y, groups=groups)
    for ind_train, ind_test in loo.split(X, Y, groups=groups):
        x_train = [X[x] for x in ind_train]
        y_train = [Y[x] for x in ind_train]
        x_test = [X[x] for x in ind_test]
        y_test = [Y[x] for x in ind_test]
        dt = DecisionTreeClassifier(max_depth=max_depth)
        dt.fit(x_train, y_train)
        y_predict = dt.predict(x_test)
        accuracy = (np.sum(y_predict == y_test) / len(y_test)) * 100
        score += accuracy
        trees.append((copy.deepcopy(dt), ind_train, ind_test))
    score /= cv
    return score, trees, groups


# Set parameter for testing
random_state = None
cv = 4
nb_tree = 1
ratio_train_test = 0.1

dataset = file

print(f"Work on {dataset}")

rec = {}

data = database[file].copy()
size = data.shape
other_names = {}
u = 0
for y in list(data.columns):
    other_names[y] = str(u)
    u += 1
data.rename(columns=other_names, inplace=True)
labels = data[str(size[1] - 1)]
labels = labels.to_numpy()
data = data.drop(columns=[str(size[1] - 1)])
data = data.to_numpy()

score, trees, groups = tree_cross_validation(data, labels, cv=cv, max_depth=None)

my_tree = mt.decision_tree()
my_tree.from_DecisionTreeClassifier(trees[1][0])

x_test = np.array([data[i] for i in trees[1][2]])
y_test = np.array([labels[i] for i in trees[1][2]])

p = 4
nb_echantillon = 10

pos = -1
for j in range(nb_echantillon):
    if (trees[1][0].predict(x_test[j].reshape(1, -1)) == y_test[j])[0]:
        print(y_test[j])
        pos = j
        break


heat = my_tree.heatmap_sufficients(x_test[pos])
print("nb suff=", max(heat.values()))
display_heat_map(heat, "Heat")

heat2 = {}
for key in heat:
    l = my_tree.unbinarized_instance([key], need_detail=True)[0]
    heat2[l[0] * (1 if l[2] == "+" else -1)] = heat[key]

fig2, axes2 = plt.subplots(1, 3, figsize=(20, 20))
instance2 =  np.zeros((2, 28, 28))
instance2[0] = np.reshape(x_test[pos], (28, 28))
axes2.flat[0].imshow(instance2[0])
axes2.flat[0].title.set_text("Instance")

s = getImages(x_test, pos, heat2, treatdata=False)
axes2.flat[1].imshow(s[0], cmap='Blues', interpolation="nearest")
axes2.flat[1].imshow(s[1], cmap='Reds', alpha=0.4, interpolation="nearest")
axes2.flat[1].title.set_text("HeatMap")

necessary, relevant = my_tree.find_necessary_features(x_test[pos])
nec = {}
print("nb necessary: ", len(necessary)," necessary: ", necessary)
print("nb relevant: ",len(relevant),"relevant: ",  relevant)
for key in necessary:
    l = my_tree.unbinarized_instance([key], need_detail=True)[0]
    nec[l[0] * (1 if l[2] == "+" else -1)] = 100

for key in relevant:
    l = my_tree.unbinarized_instance([key], need_detail=True)[0]
    nec[l[0] * (1 if l[2] == "+" else -1)] = 20

s = getImages(x_test, pos, nec, treatdata=False)
axes2.flat[2].imshow(s[1], cmap='Reds', interpolation="nearest")
axes2.flat[2].imshow(s[0], cmap='Blues', alpha=0.4, interpolation="nearest")
axes2.flat[2].title.set_text("Necessary + relevant")
######################################################################################

fig, axes = plt.subplots(1, 3, figsize=(20, 20))

instance =  np.zeros((2, 28, 28))
instance[0] = np.reshape(x_test[pos], (28, 28))


axes.flat[0].imshow(instance[0])
axes.flat[0].title.set_text("Instance")

# two sufficients

sufficients = my_tree.enumerate_all_sufficient(x_test[pos], max = 500)
s = getImages(x_test, pos, sufficients[0])
axes.flat[1].imshow(s[0], cmap='Blues', interpolation="nearest")
axes.flat[1].imshow(s[1], cmap='Reds', alpha=0.4, interpolation="nearest")
axes.flat[1].title.set_text("Sufficient")


s = getImages(x_test, pos, sufficients[len(sufficients)//4])
axes.flat[2].imshow(s[0], cmap='Blues', interpolation="nearest")
axes.flat[2].imshow(s[1], cmap='Reds', alpha=0.4, interpolation="nearest")
axes.flat[2].title.set_text("Sufficient")

plt.show()
