#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from python.loadnhanes import _load
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from matplotlib import colormaps
from sklearn.neighbors import KNeighborsRegressor

from sklearn import tree
from python.width_lib import C_tree, G_tree
#from sklearn.ensemble import RandomForestRegressor

X,y = _load()

nonans = ~np.any(np.isnan(X), axis = 0)
X = X.loc[:,nonans]
X = X.astype(float)

fig = plt.figure()
plt.scatter(X['age'],y, alpha = 0.2)
plt.savefig("temp.pdf")
plt.close()

for v in X.columns:
    X[v] = (X[v]-np.min(X[v]))/(np.max(X[v])-np.min(X[v]))

dt_age = tree.DecisionTreeRegressor(max_depth=15, min_samples_leaf = 10)
#dt = tree.DecisionTreeRegressor(max_depth=4, min_samples_leaf = 10)
dt_age.fit(X.iloc[:,1:2],y)
age_pred = dt_age.predict(X.iloc[:,1:2])

dt = tree.DecisionTreeRegressor(max_depth=15, min_samples_leaf = 10)
#dt = tree.DecisionTreeRegressor(max_depth=4, min_samples_leaf = 10)
dt.fit(X,y)

yres = y - age_pred
ynorm = (yres-np.min(y))/(np.max(y)-np.min(y))

fig = plt.figure()
plt.hist(ynorm)
plt.savefig("temp.pdf")
plt.close()

Ch = C_tree(dt, X.shape[1])

ed = np.linalg.eigh(Ch)
evals = np.flip(ed[0])

v1 = ed[1][:,-1]
v2 = ed[1][:,-2]
v3 = ed[1][:,-3]

#V = pd.DataFrame([v1,v2])
V = pd.DataFrame([v2,v3])
V.columns = X.columns

Z = X @ V.T

Nrand = 1000000
Xrand = np.random.uniform(size=[Nrand,X.shape[1]])
pred = dt.predict(Xrand)
prednorm = (pred - np.min(y))/(np.max(y)-np.min(y))
Zrand = Xrand @ V.T

nsub = 15000
sub=np.random.choice(Nrand,nsub,replace=False)
Zplot = Zrand.iloc[sub]

onegood = np.logical_and(Zplot.iloc[:,0] >= np.min(Z.iloc[:,0]), Zplot.iloc[:,0] <= np.max(Z.iloc[:,0]))
twogood = np.logical_and(Zplot.iloc[:,1] >= np.min(Z.iloc[:,1]), Zplot.iloc[:,1] <= np.max(Z.iloc[:,1]))
isgood = np.logical_and(onegood, twogood)
Zplot = Zplot.loc[isgood,:]

knn = KNeighborsRegressor(n_neighbors=100)
knn.fit(Zrand,prednorm)
knnpred = knn.predict(Zplot)
knnpred = (knnpred-np.min(knnpred))/(np.max(knnpred)-np.min(knnpred)) 

fig = plt.figure(figsize=[9,3])

plt.subplot(1,3,1)
startat = 1
gotil = 10
plt.scatter(np.arange(startat, gotil), evals[startat:gotil]/np.max(evals[startat:]))
plt.yscale('log')
plt.ylabel("Relative Magnitude")
plt.xlabel("Index")
plt.title("Spectrum")

plt.subplot(1,3,2)
plt.scatter(Z.iloc[:,0],Z.iloc[:,1], c=colormaps['cool'](ynorm))
plt.title("Data Projection")
plt.xlabel("First Component")
plt.ylabel("Second Component")

plt.subplot(1,3,3)
plt.title("Prediction Projection")
plt.scatter(Zplot.iloc[:,0],Zplot.iloc[:,1], c=colormaps['cool'](knnpred))
plt.xlabel("First Component")
plt.ylabel("Second Component")

plt.tight_layout()
plt.savefig('eig.png')
plt.close()

top_K = 4
np.abs(V.iloc[0,:]).sort_values(ascending=False)[:top_K]
np.abs(V.iloc[1,:]).sort_values(ascending=False)[:top_K]


## Vectors to LaTeX table.
vdf = pd.DataFrame([v1,v2,v3])
vdf.columns = X.columns

top_K = 10
top_vals = np.sum(np.square(vdf), axis = 0).sort_values(ascending=False).index[:top_K]
sdf = vdf.loc[:,top_vals]
sdf.index = [1,2,3]

sdf.columns = [x.split('_')[0] for x in sdf.columns]

sdf = sdf.style.format(precision=2)

with open("tables/eig.tex",'w') as f:
    sdf.to_latex(f)

