#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.inspection import DecisionBoundaryDisplay

from python.class_lib_beta import G_tree_class

import pickle
import numpy as np
from sklearn import tree
import matplotlib.pyplot as plt
from matplotlib import colormaps 
import matplotlib
matplotlib.use('agg')
from time import time
from tqdm import tqdm

import numpy as np

X = np.load('raw_data/mnist.npz')['x_train']
y = np.load('raw_data/mnist.npz')['y_train']

#problem = '47'
problem = '08'

## Fours and sevens.
if problem=='47':
    fs = np.logical_or(y==4, y==7)
    y = y==7
    plotinds = [0,2]
elif problem=='08':
    fs = np.logical_or(y==0, y==8)
    y = y==8
    plotinds = [4,6]
else:
    raise Exception()
y = y[fs]
X = X[fs,:,:]
X = X.reshape([X.shape[0],-1]).astype(float)
X /=255
P = X.shape[1]

## Fit tree
max_depth = 4
#dt = tree.DecisionTreeClassifier(max_depth = max_depth, random_state = 0)
#dt.fit(X,y)
rf = RandomForestClassifier(max_depth = max_depth, random_state = 123)
rf.fit(X,y)

## Move from reference
xref = np.zeros(P)

fig = plt.figure(figsize=[12,3])

#fig = plt.figure(figsize=[16,4])
for isel in plotinds:
    xi = X[isel,:]
    xdiff = xi - xref
    ## Integrate from reference.
    Np = 500
    alphas = np.linspace(0,1,num=Np)[:,np.newaxis]
    path =  xi[np.newaxis,:]*alphas + (1-alphas)*xref

    Gest = np.zeros([path.shape[0],P])
    for dt in rf.estimators_:
        print(dt)
        ## Get grads.
        lb, ub, var, G = G_tree_class(dt,X.shape[1])

        for i in range(path.shape[0]):
            xi = path[i,:]
            inbox = np.all(np.logical_and(lb<=xi, ub>=xi), axis = 1)
            assert np.sum(inbox)==1
            boxind = np.where(inbox)[0][0]
            Gest[i,:] += G[boxind,:]
        #Gest = Gest / np.sqrt(np.sum(np.square(Gest),axis=1))[:,np.newaxis]

    pathint = np.mean(Gest, axis = 0)

    IG = xdiff * pathint

    ind = 1 if isel==plotinds[0] else 3
    plt.subplot(1,4,ind)
    plt.imshow(xi.reshape([28,28]), cmap = 'Grays')

    #plt.subplot(2,2,2)
    #plt.imshow(IG.reshape([28,28]), cmap = 'autumn')

    ax = plt.gca()
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticks([])

    ind = 2 if isel==plotinds[0] else 4
    sgn = -1. if isel==plotinds[0] else 1.
    plt.subplot(1,4,ind)
    plt.imshow(xi.reshape([28,28]), cmap = 'Grays')
    masked = np.ma.masked_where(IG == 0, IG)
    plt.imshow(sgn*masked.reshape([28,28]), cmap = 'autumn')
    #plt.colorbar()

    ax = plt.gca()
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticks([])


plt.tight_layout()
plt.savefig("ig"+problem+".pdf")
plt.close()

