#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.inspection import DecisionBoundaryDisplay
from python.width_lib import C_tree
from python.class_lib_beta import C_tree_class
from python.pred_bakeoff_settings import estimator, max_depth, fname, cp

import pandas as pd
import numpy as np
from sklearn import tree
from time import time
from tqdm import tqdm

exec(open("python/sim_settings.py").read())

np.random.seed(123)

#n_folds = 100
n_folds = 10
print("Few folds!")

trans = ['id','pca','asm','rand']

D = len(datasets_to_use)

#errs = np.zeros([D,len(trans),n_folds])
errs = []
dataset = []
folds = []
trs = []

resdfs = []

normGs = []
precweighs = []

obs_per_leaf = 5

as_estimator = 'tree'
#as_estimator = 'forest'

resdfs = []

factorial = False
if factorial:
    ngs = [True,False]
    pws = [True,False]
else:
    ngs = [False]
    #print("NormGrad on!")
    #ngs = [True]
    pws = [False]

for normG in tqdm(ngs, leave=False):
    for precweigh in tqdm(pws,leave=False):

        #normG = True
        ##weighwidth = True
        #weighwidth = False
        #weighsamp = True

        ##normG = True
        #normG = False
        #weighwidth = True
        weighwidth = False
        #weighwidth = True
        #weighsamp = True
        weighsamp = False
        #precweigh = False

        for di,ds in enumerate(tqdm(datasets_to_use, leave = False)):
            outname = data_dir+ds+'.csv'
            df = pd.read_csv(outname)

            # For infra.
            df = df.loc[~np.any(df.isnull(), axis = 1),:]
            print(ds)
            print(df.shape)

            if cp:
                df.iloc[:,-1] = (df.iloc[:,-1]>np.median(df.iloc[:,-1])).astype(int)

            N = df.shape[0]
            P = df.shape[1]-1
            inds = np.arange(N)
            np.random.shuffle(inds)
            test_inds = np.array_split(inds, n_folds)
            for f in tqdm(range(n_folds)):
                test_ind = test_inds[f]
                train_ind = np.setdiff1d(np.arange(N), test_ind)

                X_test = np.array(df.iloc[test_ind,:-1])
                y_test = np.array(df.iloc[test_ind,-1])
                X_train = np.array(df.iloc[train_ind,:-1])
                y_train = np.array(df.iloc[train_ind,-1])

                P = X_train.shape[1]
                keep_K = int(np.ceil(np.sqrt(P)))
                #keep_K = P

                for ti,tr in enumerate(tqdm(trans, leave = False)):
                    if tr=='id':
                        T = np.eye(P)
                    elif tr=='rand':
                        T = np.linalg.qr(np.random.normal(size=[P,P])).Q
                    elif tr=='pca':
                        mux = np.mean(X_train, axis = 0)
                        sigx = np.std(X_train, axis = 0)
                        Xn_train = (X_train - mux[np.newaxis,:]) / (1e-8+sigx[np.newaxis,:])
                        dc = np.linalg.svd(Xn_train, full_matrices = False)
                        #Xn_train - dc[0] @ np.diag(dc[1]) @ dc[2]
                        Vt = dc[2]
                        T = Vt.T
                    elif tr=='asm':
                        Xn_train = (X_train - np.min(X_train, axis = 0)[np.newaxis,:]) / (np.max(X_train,axis=0)-np.min(X_train,axis=0))[np.newaxis,:]

                        if as_estimator=='forest':
                            assert not cp
                            rf = RandomForestRegressor(max_depth = max_depth, min_samples_leaf = obs_per_leaf, n_jobs = -1)
                            rf.fit(Xn_train, y_train)
                            Ch0 = np.zeros([P,P])
                            T = len(rf.estimators_)
                            for dt in tqdm(rf.estimators_, leave = False):
                                Ct = C_tree(dt, P, normG = normG, weighwidth = weighwidth, weighsamp=weighsamp, mode = '3', precweigh = precweigh)
                                if np.any(np.isnan(Ct)):
                                    raise Exception()
                                Ch0 += Ct/T
                        elif as_estimator=='tree':
                            if cp:
                                dt = tree.DecisionTreeClassifier(max_depth = max_depth, random_state = 0, min_samples_leaf = obs_per_leaf)
                            else:
                                dt = tree.DecisionTreeRegressor(max_depth = max_depth, random_state = 0, min_samples_leaf = obs_per_leaf)
                            dt.fit(Xn_train, y_train)
                            if cp:
                                Ch0 = C_tree_class(dt, P, normG = normG, weighwidth = weighwidth, weighsamp=weighsamp, mode = '3', precweigh = precweigh)
                            else:
                                Ch0 = C_tree(dt, P, normG = normG, weighwidth = weighwidth, weighsamp=weighsamp, mode = '3', precweigh = precweigh)
                        else:
                            raise Exception()

                        dc = np.linalg.eigh(Ch0)
                        #Ch0 - dc[1] @ np.diag(dc[0]) @ dc[1].T
                        U = dc[1]
                        T = U
                    else:
                        raise Exception('bad tr')

                    if tr=='id':
                        Xt_train = X_train
                        Xt_test = X_test
                    else:
                        Xx_train = X_train @ T[:,-keep_K:]
                        Xx_test = X_test @ T[:,-keep_K:]

                        Xt_train = np.concatenate([X_train, Xx_train], axis = 1)
                        Xt_test = np.concatenate([X_test, Xx_test], axis = 1)

                    if estimator=='forest':
                        if cp:
                            mod = RandomForestClassifier(max_depth = max_depth, min_samples_leaf = obs_per_leaf)
                        else:
                            mod = RandomForestRegressor(max_depth = max_depth, min_samples_leaf = obs_per_leaf)
                    elif estimator=='tree':
                        if cp:
                            mod = tree.DecisionTreeClassifier(max_depth = max_depth, random_state = 0, min_samples_leaf = obs_per_leaf)
                        else:
                            mod = tree.DecisionTreeRegressor(max_depth = max_depth, random_state = 0, min_samples_leaf = obs_per_leaf)
                    else:
                        raise Exception()
                    mod.fit(Xt_train, y_train)
                    pred = mod.predict(Xt_test)

                    rmse = np.sqrt(np.mean(np.square(pred - y_test)))
                    errs.append(rmse)
                    dataset.append(ds)
                    folds.append(f)
                    trs.append(tr)
                    
        resdf = pd.DataFrame([trs,dataset,folds,errs]).T
        resdf.columns = ['Trans','Dataset','Fold','RMSE']

        resdfs.append(resdf.groupby(['Dataset','Trans']).mean())
        normGs.append(normG)
        precweighs.append(precweigh)
        #weighsamps.append(weighsamp)

if not factorial:
    resdf.to_csv('sim_out/'+fname+'.csv', index = False)

