#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from sklearn.ensemble import RandomForestRegressor
from sklearn.inspection import DecisionBoundaryDisplay

from python.width_lib import C_tree, G_tree

import pickle
import numpy as np
from sklearn import tree
import matplotlib.pyplot as plt
from matplotlib import colormaps 
from time import time
from tqdm import tqdm
import jax
import jax.numpy as jnp

from python.width_lib import G_tree

#earth_x = 0.5
#earth_y = 0.5
#earth = np.array([0.5, 0.5])

#f = lambda x: 1./np.sum(np.square(x-earth))
##f = lambda x: jnp.square(x[0])*jnp.exp(x[0]*x[1])+jnp.square(x[1])
##f = lambda x: jnp.square(x[0]-0.5)*jnp.exp((x[0]-0.5)*(x[1]-0.5))+jnp.square(x[1]-0.5)
#f = lambda x: jnp.square(x[0]-0.5)*jnp.exp((x[0]-0.5)*(x[1]-0.5))+jnp.square(x[1]-0.5)
#f = lambda x: jnp.cos(2*2*jnp.pi*jnp.sqrt(jnp.sum(jnp.square(x-0.5))))
#f = lambda x: 1/jnp.square(1+jnp.sum(jnp.square(x-0.8)) + jnp.sum(jnp.square(x-0.2)))
#def f(x):
#    t1 = jnp.cos(2*jnp.pi*jnp.sqrt(jnp.sum(jnp.square(x[1]-0.5))))
#    t2 = jnp.cos(2*2*jnp.pi*jnp.sqrt(jnp.sum(jnp.square(x[0]-0.5))))
#    return t1+t2

#def f(x):
#    t1 = jnp.cos(0.8*2*jnp.pi*jnp.sqrt(jnp.sum(jnp.square(x[1]-0.5))))
#    t2 = jnp.cos(0.8*2*2*jnp.pi*jnp.sqrt(jnp.sum(jnp.square(x[0]-0.5))))
#    return t1+t2

def f(x):
    #lam = np.sqrt(2)*jnp.sum(jnp.square(x-0.5))
    lam = jnp.sum(jnp.abs(x-0.5))
    t1 = jnp.cos(lam*2*jnp.pi*jnp.sqrt(jnp.sum(jnp.square(x[1]-0.5))))
    t2 = jnp.cos((1-lam)*0.8*2*jnp.pi*jnp.sqrt(jnp.sum(jnp.square(x[0]-0.5))))
    #return -(t1+t2)
    return jnp.log(3.-(t1+t2))

#an = np.array([1., -0.2])
#def f(x):
#    z = jnp.sum(an*(x-0.5))
#    #z = jnp.sum(jnp.square(x-0.5))
#    return jnp.cos(5*2*np.pi*z)*jnp.exp(-z)
g = jax.grad(f)

estimator = 'tree'
#estimator = 'rf'

#NN = 10000
NN = 1000

xx, yy = np.meshgrid(np.linspace(0, 1, NN), np.linspace(0,1,NN))
Xest = np.hstack((xx.reshape(-1,1), yy.reshape(-1,1)))
#yest = np.apply_along_axis(f, 1, Xest)
yest = jax.vmap(f)(Xest)

# Viz pred surface
NN = 40
xx, yy = np.meshgrid(np.linspace(0, 1, NN), np.linspace(0,1,NN))
Xplot = np.hstack((xx.reshape(-1,1), yy.reshape(-1,1)))
act = np.apply_along_axis(f, 1, Xplot)

NN = 10
xx, yy = np.meshgrid(np.linspace(0, 1, NN), np.linspace(0,1,NN))
Xvplot = np.hstack((xx.reshape(-1,1), yy.reshape(-1,1)))

#for max_depth in [6,9,12,15]:
#for max_depth in range(1,15):
for max_depth in range(1,10):
    print(max_depth)

    # Fit a tree to those data.
    if estimator=='tree':
        #dt = tree.DecisionTreeRegressor(max_depth = 15, random_state = 0)
        dt = tree.DecisionTreeRegressor(max_depth = max_depth, random_state = 0)
        dt.fit(Xest,yest)
        lb, ub, var, G = G_tree(dt, P=2)
        pred = dt.predict(Xplot)

        Gest = np.zeros([Xvplot.shape[0],2])
        for i in range(Xvplot.shape[0]):
            xi = Xvplot[i,:]
            inbox = np.all(np.logical_and(lb<=xi, ub>=xi), axis = 1)
            assert np.sum(inbox)==1
            boxind = np.where(inbox)[0][0]
            Gest[i,:] = G[boxind,:]
        #Gest = Gest / np.sqrt(np.sum(np.square(Gest),axis=1))[:,np.newaxis]
    elif estimator=='rf':
        P = 2
        obs_per_leaf = 5
        rf = RandomForestRegressor(max_depth = 15, min_samples_leaf = obs_per_leaf)
        rf.fit(Xest,yest)
        pred = rf.predict(Xvplot)

        precweigh = False

        assert False # Change xplot t xvplot

        ## Try averaging over all nodes its in with depth-weighting? or no, don't need depth weighting if estimating at single point...
        M = Xplot.shape[0]
        Gest = np.zeros([Xplot.shape[0],P])
        W = np.zeros([Xplot.shape[0],P])
        for dt in rf.estimators_:
            print(dt)
            ## Get grads.
            lb, ub, var, G = G_tree(dt,P)
            prec = 1/var

            for m in range(Xplot.shape[0]):
                xi = Xplot[m,:]
                inbox = np.all(np.logical_and(lb<=xi, ub>=xi), axis = 1)
                assert np.sum(inbox)==1
                boxind = np.where(inbox)[0][0]
                if precweigh:
                    Gest[m,:] += G[boxind,:] * prec[boxind,:]
                    W[m,:] += prec[boxind,:]
                else:
                    Gest[m,:] += G[boxind,:] 

        if precweigh:
            for m in range(M):
                Gest[m,:] = Gest[m,:] / W[m,:]
        #Gest = Gest / np.sqrt(np.sum(np.square(Gest),axis=1))[:,np.newaxis]

    ## Get true vector field
    G = np.apply_along_axis(g, 1, Xvplot)
    #G = G / np.sqrt(np.sum(np.square(G),axis=1))[:,np.newaxis]

    Gest
    G

    prednorm = (pred-np.min(pred))/(np.max(pred)-np.min(pred))
    actnorm = (act-np.min(act))/(np.max(act)-np.min(act))

    fig = plt.figure(figsize=[3,3])
    plt.scatter(Xplot[:,0], Xplot[:,1], c=colormaps['cool'](actnorm))
    plt.title('Function')
    plt.tight_layout()
    plt.savefig(f"true_func.pdf")
    plt.close()

    fig = plt.figure(figsize=[3,3])
    plt.quiver(Xvplot[:,0], Xvplot[:,1], G[:,0], G[:,1])
    plt.title('True Vector Field')
    plt.tight_layout()
    plt.savefig(f"true_field.pdf")
    plt.close()

    fig = plt.figure(figsize=[3,3])
    plt.quiver(Xvplot[:,0], Xvplot[:,1], Gest[:,0], Gest[:,1])
    #plt.title('Tree-Est Vector Field')
    plt.title(f"Depth {max_depth} estimate")

    plt.tight_layout()
    plt.savefig(f"vec_field_{max_depth}.pdf")
    plt.close()

