#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
from sklearn import tree

def C_tree(dt, P, mode='3', normG = False, weighwidth = False, weighsamp = False, precweigh = False):
    #lb, ub, lw, G, depth = tree_grads(dt, P)
    lb, ub, var, G = G_tree(dt, P, mode = mode)
    lw = ub-lb

    dtd = dt.tree_

    nsamp = dtd.n_node_samples
    isleaf = dtd.feature < 0
    if weighsamp:
        samp_weight = nsamp[isleaf]
    else:
        samp_weight = np.ones(G.shape[0])

    if precweigh:
        assert var is not None
        vv = np.mean(var, axis = 1)
        pw = 1/vv
        pw = pw/np.sum(pw)*pw.shape[0]
    else:
        pw = np.ones(G.shape[0])

    if normG:
        G = G / (1e-12+np.sqrt(np.sum(np.square(G), axis = 1))[:,np.newaxis])

    G = G * samp_weight[:,np.newaxis]
    G = G * pw[:,np.newaxis]

    if weighwidth:
        ww = np.prod(lw, axis = 1)
        G = G * ww[:,np.newaxis]

    # Drop any nan rows.
    isgood = np.all(np.isfinite(G), axis = 1)
    G = G[isgood,:]
    if not np.all(isgood):
        print("Dropped bad row in C_tree.")
        print("Hey, understand why this happens dummy.")

    Ch0 = G.T @ G

    return Ch0

def G_tree(dt, P, mode = '3'):

    dtd = dt.tree_
    L = len(dt.tree_.feature)

    # Generate some of my own linear quantities.
    depth = dt.tree_.compute_node_depths()
    parent = np.zeros(len(dt.tree_.feature)).astype(int)-1
    leftchild = np.zeros(len(dt.tree_.feature)).astype(bool) # Is this node a left child?
    shift_value = np.nan*np.zeros(len(dt.tree_.feature))
    shift_value[0] = dt.tree_.value[0]
    lb = np.zeros([L,P])
    ub = np.ones([L,P])
    est1 = np.zeros([L,P])
    est2 = np.zeros([L,P])
    est3 = np.zeros([L,P])
    var3 = np.inf*np.ones([L,P])
    for i in range(len(dt.tree_.feature)):
        isleaf = dtd.feature[i] < 0
        if not isleaf:
            p = dtd.feature[i]
            lc = dtd.children_left[i]
            rc = dtd.children_right[i]
            # Propagate widths
            lb[lc,:] = lb[i,:]
            ub[lc,:] = ub[i,:]
            lb[rc,:] = lb[i,:]
            ub[rc,:] = ub[i,:]
            # Update with new split
            thresh = dtd.threshold[i]
            ub[lc,p] = lb[rc,p] = thresh

            vrc = dtd.value[rc]
            vlc = dtd.value[lc]
            vme = dtd.value[i]

            # Update Left child estimate
            vdiff = (vme-vlc)
            o_width = ub[i,p] - thresh
            est1[i,p] = 2*vdiff/o_width

            # Update Right child estimate
            vdiff = (vrc-vme)
            o_width = thresh - lb[i,p]
            est2[i,p] = 2*vdiff/o_width

            # Update child difference estimate.
            vdiff = (vrc-vlc)
            o_width = ub[i,p] - lb[i,p]
            est3[i,p] = 2*vdiff/o_width

            # Propogate gradient info.
            est1[lc,:] = est1[rc,:] = est1[i,:]
            est2[lc,:] = est2[rc,:] = est2[i,:]
            est3[lc,:] = est3[rc,:] = est3[i,:]

            ## Calculate variance of gradient estimator.
            const = 4./np.square(ub[i,p]-lb[i,p]) 
            varl = dtd.impurity[lc]/(dtd.n_node_samples[lc]-1)
            varr = dtd.impurity[rc]/(dtd.n_node_samples[rc]-1)
            var3[i,p] = const*(varl+varr)

            # Propagate variance.
            var3[lc,:] = var3[rc,:] = var3[i,:]

    isleaf = dtd.feature < 0
    #md = np.max(depth)
    #depth_weight = 1/np.power(2.,depth-md) * (depth>=md-1)
    #usegrad = np.logical_and(~isleaf,pud)
    #pud = depth==md-1
    usegrad = isleaf

    if mode=='1+2':
        G = (est1 + est2)[usegrad,:] / 2
        var = None
    elif mode=='3':
        G = est3[usegrad,:]
        var = var3
    else:
        raise Exception("I don't like it :( [the mode]")

    lb = lb[usegrad,:]
    ub = ub[usegrad,:]
    if var is not None:
        var = var[usegrad,:]

    return lb, ub, var, G

