#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 12 09:15:13 2021

@author: pooya
"""

######################## set parameters
dataset_name = 'pyrim'

K_y=2
KK_x=[2,4,5,6,7,8,9,10,16,32]
######################## load libraries

import scipy.io as sio
from model import SRNN_init
from model import NN_model
import numpy as np
from model import SRNN_Model
from scipy.io import savemat
from sklearn import preprocessing
######################## load dataset
mat_contents = sio.loadmat(dataset_name+'_mat.mat')
Xtr = mat_contents['Xtr']
Ytr = mat_contents['Ytr']
#Ytr = Ytr.reshape((Ytr.shape[0]),1)

#Xv = mat_contents['Xv']
#Yv = mat_contents['Yv']
#Yv = Yv.reshape(Yv.shape[0])

X_test = mat_contents['X_test']
Y_test = mat_contents['Y_test']
#Y_test = Y_test.reshape((Y_test.shape[0],1))

##N=np.shape(Xtr)
##select = np.random.permutation(N[0])
##Ytr = Ytr-1
##N01=int(np.ceil(N[0]/100))

##Xv = Xtr[select[Data_set_portion*N01+1:(Data_set_portion*N01+1)+CV_portion*N01],:]
##Yv = Ytr[select[Data_set_portion*N01+1:(Data_set_portion*N01+1)+CV_portion*N01]]
##Xtr = Xtr[select[0:Data_set_portion*N01],:]
##Ytr = Ytr[select[0:Data_set_portion*N01]]

#scaler = preprocessing.StandardScaler().fit(Xtr)
#Xtr = scaler.transform(Xtr)
#X_test = scaler.transform(X_test)
######################## run init
for K_x in KK_x:
    X_C,Y_C = SRNN_init(Xtr,Ytr,K_y,K_x)
    
    NN = NN_model(X_C,Y_C)
    Y_predict = NN.predict(Xtr)
    Error_train_init = np.sum(np.sum(np.power((Ytr-Y_predict),2),axis=1)**1)/Xtr.shape[0]
    print(Error_train_init)
    Y_predict = NN.predict(X_test)
    Error_test_init = np.sum(np.sum(np.power((Y_test-Y_predict),2),axis=1)**1)/X_test.shape[0]
    print(Error_test_init)
    
    init_=[np.copy(X_C),np.copy(Y_C)]
    
    SRNN = SRNN_Model(X_C,Y_C)
    
    SRNN.fit(Xtr,Ytr,lr=1e-4, X_test = X_test, Y_test = Y_test, iter = -8)
    #SRNN.fit(Xtr,Ytr,lr=1e-5, X_test = X_test, Y_test = Y_test, iter = 8)
    #SRNN.fit(Xtr,Ytr,lr=1e-6, X_test = X_test, Y_test = Y_test, iter = 8)
    #SRNN.fit(Xtr,Ytr,lr=1e-7, X_test = X_test, Y_test = Y_test, iter = 8)
    #SRNN.fit(Xtr,Ytr,lr=1e-8, X_test = X_test, Y_test = Y_test, iter = 8)
    #SRNN.fit(Xtr,Ytr,lr=1e-9, X_test = X_test, Y_test = Y_test, iter = 8)
    
    print('SRNN Model:')
    Y_predict = SRNN.predict(Xtr)
    Error_train = np.sum(np.sum(np.power((Ytr-Y_predict),2),axis=1)**1)/Xtr.shape[0]
    print(Error_train)
    Y_predict = SRNN.predict(X_test)
    Error_test = np.sum(np.sum(np.power((Y_test-Y_predict),2),axis=1)**1)/X_test.shape[0]
    print(Error_test)
    loc='/home/pooya/Desktop/Projects_codes/Project10_(Reg_SRNN)/Experiments/'+dataset_name+'/SRNN/'+'K='+str(K_x)
    file_to_save = [SRNN,Error_train,Error_test,Error_train_init,Error_test_init]
    np.save(loc,file_to_save)
    savemat(loc,{'Error_train':Error_train,'Error_test':Error_test,'Error_test_init':Error_test_init,'Error_train_init':Error_train_init,
                 'X_C':X_C,'Y_C':Y_C,'K_y':K_y})