#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 21 11:05:23 2020

@author: pooya
"""

################## import files
import scipy.io as sio
from scipy.io import savemat
from sklearn.ensemble import RandomForestRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import AdaBoostRegressor
from sklearn.ensemble import BaggingRegressor
from sklearn.cluster import KMeans
from sklearn.linear_model import RidgeCV
from sklearn.linear_model import LinearRegression
from model import SRNN_init
from model import NN_model
from model_RBFs import RBF_lin_reg as RBF
import numpy as np
from scipy.io import savemat
################## set parameters

dataset_name = 'pyrim'

T_depth = [5,6,7,8]
K=32
#################
mat_contents = sio.loadmat(dataset_name+'_mat.mat')
Xtr = mat_contents['Xtr']
Ytr = mat_contents['Ytr']
#Ytr = Ytr.reshape((Ytr.shape[0]))
#Ytr = Ytr.reshape(Ytr.shape[0])

#Xv = mat_contents['Xv']
#Yv = mat_contents['Yv']
#Yv = Yv.reshape((Yv.shape[0]))
#Yv = Yv.reshape(Yv.shape[0])

X_test = mat_contents['X_test']
Y_test = mat_contents['Y_test']
#Y_test = Y_test.reshape((Y_test.shape[0]))
#Y_test = Y_test.reshape(Y_test.shape[0])

##N=np.shape(Xtr)
##select = np.random.permutation(N[0])
##Ytr = Ytr-1
##N01=int(np.ceil(N[0]/100))

##Xv = Xtr[select[Data_set_portion*N01+1:(Data_set_portion*N01+1)+CV_portion*N01],:]
##Yv = Ytr[select[Data_set_portion*N01+1:(Data_set_portion*N01+1)+CV_portion*N01]]
##Xtr = Xtr[select[0:Data_set_portion*N01],:]
##Ytr = Ytr[select[0:Data_set_portion*N01]]


######################## run init

print('training tree')
DTC = DecisionTreeRegressor(max_depth=np.log2(K))
RF = DTC.fit(Xtr,Ytr.ravel())
Y_pred = RF.predict(X_test)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_tree_test_no = np.sum(np.sum(np.power((Y_test-Y_pred),2),axis=1)**1)/X_test.shape[0]

print('tree test', Err_tree_test_no)

Y_pred = RF.predict(Xtr)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_tree_train_no = np.sum(np.sum(np.power((Ytr-Y_pred),2),axis=1)**1)/Xtr.shape[0]

print('tree train', Err_tree_train_no)

print('training')
DTC = RandomForestRegressor(max_depth=3, n_estimators=K, max_features=None)
RF = DTC.fit(Xtr,Ytr.ravel())
Y_pred = RF.predict(X_test)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_RF_test_no = np.sum(np.sum(np.power((Y_test-Y_pred),2),axis=1)**1)/X_test.shape[0]

print('RF test', Err_RF_test_no)

Y_pred = RF.predict(Xtr)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_RF_train_no = np.sum(np.sum(np.power((Ytr-Y_pred),2),axis=1)**1)/Xtr.shape[0]

print('RF train', Err_RF_train_no)

print('training regboost')
DTC = DecisionTreeRegressor(max_depth=3)
DTC = AdaBoostRegressor(base_estimator=DTC,n_estimators=K)
RF = DTC.fit(Xtr,Ytr.ravel())
Y_pred = RF.predict(X_test)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_boost_test_no = np.sum(np.sum(np.power((Y_test-Y_pred),2),axis=1)**1)/X_test.shape[0]

print('regboost test', Err_boost_test_no)

Y_pred = RF.predict(Xtr)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_boost_train_no = np.sum(np.sum(np.power((Ytr-Y_pred),2),axis=1)**1)/Xtr.shape[0]

print('regboost train', Err_boost_train_no)


print('training Bagging')
DTC = DecisionTreeRegressor(max_depth=3)
DTC = BaggingRegressor(base_estimator=DTC,n_estimators=K, max_features=1.0)
RF = DTC.fit(Xtr,Ytr.ravel())
Y_pred = RF.predict(X_test)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_bag_test_no = np.sum(np.sum(np.power((Y_test-Y_pred),2),axis=1)**1)/X_test.shape[0]

print('Bag test', Err_bag_test_no)

Y_pred = RF.predict(Xtr)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_bag_train_no = np.sum(np.sum(np.power((Ytr-Y_pred),2),axis=1)**1)/Xtr.shape[0]

print('Bag train', Err_bag_train_no)


print('training ridgeCV')
DTC = RidgeCV()
RF = DTC.fit(Xtr,Ytr.ravel())
Y_pred = RF.predict(X_test)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_ridge_test_no = np.sum(np.sum(np.power((Y_test-Y_pred),2),axis=1)**1)/X_test.shape[0]

print('ridgeCV test', Err_ridge_test_no)

Y_pred = RF.predict(Xtr)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_ridge_train_no = np.sum(np.sum(np.power((Ytr-Y_pred),2),axis=1)**1)/Xtr.shape[0]

print('ridgeCV train', Err_ridge_train_no)

print('training lin-reg')
DTC = LinearRegression()
RF = DTC.fit(Xtr,Ytr.ravel())
Y_pred = RF.predict(X_test)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_lin_test_no = np.sum(np.sum(np.power((Y_test-Y_pred),2),axis=1)**1)/X_test.shape[0]

print('lin-reg test', Err_lin_test_no)

Y_pred = RF.predict(Xtr)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_lin_train_no = np.sum(np.sum(np.power((Ytr-Y_pred),2),axis=1)**1)/Xtr.shape[0]

print('lin-reg train', Err_lin_train_no)

print('training k-means')
X_C,Y_C = SRNN_init(Xtr,Ytr,1,K)

NN = NN_model(X_C,Y_C)
Y_predict = NN.predict(X_test)
Err_Kmeans_test = np.sum(np.sum(np.power((Y_test-Y_predict),2),axis=1)**1)/X_test.shape[0]
print('k-means test',Err_Kmeans_test)

Y_predict = NN.predict(Xtr)
Err_Kmeans_train = np.sum(np.sum(np.power((Ytr-Y_predict),2),axis=1)**1)/Xtr.shape[0]
print('k-means train',Err_Kmeans_train)

print('training RBF-SVM')
RF = RBF(K)
RF.fit(Xtr,Ytr,Xv=X_test,Yv=Y_test)
Y_pred = RF.predict(X_test)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_RBF_test_no = np.sum(np.sum(np.power((Y_test-Y_pred),2),axis=1)**1)/X_test.shape[0]
print('RBF-SVM test', Err_RBF_test_no)

Y_pred = RF.predict(Xtr)
Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_RBF_train_no = np.sum(np.sum(np.power((Ytr-Y_pred),2),axis=1)**1)/Xtr.shape[0]

print('RBF-SVM train', Err_RBF_train_no)

loc='/home/pooya/Desktop/Projects_codes/Project10_(Reg_SRNN)/Experiments/'+dataset_name+'/Other_models/'+'BC='+str(K)
savemat(loc,{'Err_RBF_train_no':Err_RBF_train_no,'Err_RBF_test_no':Err_RBF_test_no,
             'Err_tree_train_no':Err_tree_train_no,'Err_tree_test_no':Err_tree_test_no,
             'Err_RF_train_no':Err_RF_train_no,'Err_RF_test_no':Err_RF_test_no,
             'Err_bag_train_no':Err_bag_train_no,'Err_bag_test_no':Err_bag_test_no,
             'Err_Kmeans_train':Err_Kmeans_train,'Err_Kmeans_test':Err_Kmeans_test,
             'Err_ridge_train_no':Err_ridge_train_no,'Err_ridge_test_no':Err_ridge_test_no,
             'Err_lin_train_no':Err_lin_train_no,'Err_lin_test_no':Err_lin_test_no,
             'Err_boost_train_no':Err_boost_train_no,'Err_boost_test_no':Err_boost_test_no})