#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from sklearn.ensemble import RandomForestRegressor
from sklearn.inspection import DecisionBoundaryDisplay

from python.width_lib import C_tree
from python.activegp_wrapper import C_GP
from python.dasm_lib import train_net
from python.active_bakeoff_setting import setting, fname

import pandas as pd
import numpy as np
from sklearn import tree
import matplotlib.pyplot as plt
from matplotlib import colormaps 
from time import time
from tqdm import tqdm

import psdr

np.random.seed(123)

#setting = 'small'

# Problem params
if setting=='big':
    #Ns = np.logspace(2,5,num=10).astype(int)
    #Ns = np.logspace(2,4,num=10).astype(int)
    Ns = np.logspace(2,5,num=10).astype(int)
    Ps = [10,50,100]
    subsparse = True
    estimators = ['tree','dasm']
elif setting=='small':
    ## Setting one: small problems to compare against other guys.
    Ns = np.logspace(1,4,num=10).astype(int)
    Ps = [2,3,4]
    subsparse = False
    estimators = ['tree','gp','pra','dasm']
else:
    raise Exception()

slow_method_max = 150
slow_methods = ['gp','pra']

## Tree params
estimator = 'tree'
#depth = 8
depth = 15
obs_per_leaf = 5

#reps = 20
reps = 20
#reps = 2

cols = ['est','IP','Time','N','P']
resdf = pd.DataFrame(np.zeros([len(Ns)*len(Ps)*len(estimators)*reps,len(cols)]))
resdf.columns = cols

trial = 0
for N in tqdm(Ns):
    for P in Ps:
        for rep in range(reps):
            if subsparse:
                #Rs = int(np.ceil(np.sqrt(P)))
                Rs = 3
                anz = np.random.normal(size=np.minimum(P,Rs))
                inz = np.random.choice(P,Rs,replace=False) if Rs < P else np.arange(P)
                a = np.zeros(P)
                a[inz] = anz
            else:
                a = np.random.normal(size=P)
            an = a / np.sqrt(np.sum(np.square(a)))

            def f(x):
                z = np.sum(an*(x-0.5))
                return np.cos(3*2*np.pi*z)

            X = np.random.uniform(size=[N,P])
            y = np.apply_along_axis(f, 1, X)

            ips = {}
            times = {}
            for estimator in estimators:
                if estimator in slow_methods and N > slow_method_max:
                    td = np.nan
                    est = np.nan
                else:
                    tt = time()
                    if estimator=='tree':
                        dt = tree.DecisionTreeRegressor(max_depth = depth, min_samples_leaf = obs_per_leaf)
                        dt.fit(X,y)
                        Ch = C_tree(dt, P, mode = '3')
                    elif estimator=='forest':
                        rf = RandomForestRegressor(max_depth = depth, min_samples_leaf = obs_per_leaf)
                        rf.fit(X,y)
                        Ch0 = np.zeros([P,P])
                        Cs = []
                        T = len(rf.estimators_)
                        for dt in rf.estimators_:
                            Ct = C_tree(dt, P)
                            Cs.append(Ct)
                            Ch0 += Ct/T
                        Ch = Ch0
                    elif estimator=='gp':
                        Ch = C_GP(X, y)
                    elif estimator=='pra':
                        #PSDR Est
                        domain = psdr.BoxDomain(np.zeros(P),np.ones(P))
                        func = psdr.Function(f, domain, fd_grad = True)
                        pra = psdr.PolynomialRidgeApproximation(degree = 5, subspace_dimension = 1, norm = 2, bound = 'upper')
                        pra.fit(X, y)
                        pra_est = pra.U
                        Ch = pra_est @ pra_est.T
                    elif estimator=='dasm':
                        params, costs = train_net(X, y)
                        nn_est = params[0][0]
                        nn_est = nn_est / np.sqrt(np.sum(np.square(nn_est)))
                        Ch = np.array(nn_est.T @ nn_est)
                        #fig = plt.figure()
                        #plt.plot(costs)
                        #plt.savefig("nnet_costs.pdf")
                        #plt.close()
                    else:
                        raise Exception("Unknown estimator.")
                #
                    td = time()-tt
                    est = np.linalg.eigh(Ch)[1][:,-1]

                times[estimator] = td
                ips[estimator] = np.abs(np.sum(est*an))

                resdf.iloc[trial,:] = [estimator, ips[estimator], times[estimator], N,P]
                trial += 1

resdf.to_csv(fname, index = False)
