#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec  7 17:27:18 2020

@author: pooya
"""

######################## set parameters
dataset_name = 'cadata'

K_y=3
K_x=9
######################## load libraries

import scipy.io as sio
from model import SRNN_init
from model import NN_model
from model import SRNN_Model
import numpy as np

######################## load dataset
mat_contents = sio.loadmat(dataset_name+'_mat.mat')
Xtr = mat_contents['Xtr']
Ytr = mat_contents['Ytr']
#Ytr = Ytr.reshape(Ytr.shape[0])

#Xv = mat_contents['Xv']
#Yv = mat_contents['Yv']
#Yv = Yv.reshape(Yv.shape[0])

X_test = mat_contents['X_test']
Y_test = mat_contents['Y_test']
#Y_test = Y_test.reshape(Y_test.shape[0])

##N=np.shape(Xtr)
##select = np.random.permutation(N[0])
##Ytr = Ytr-1
##N01=int(np.ceil(N[0]/100))

##Xv = Xtr[select[Data_set_portion*N01+1:(Data_set_portion*N01+1)+CV_portion*N01],:]
##Yv = Ytr[select[Data_set_portion*N01+1:(Data_set_portion*N01+1)+CV_portion*N01]]
##Xtr = Xtr[select[0:Data_set_portion*N01],:]
##Ytr = Ytr[select[0:Data_set_portion*N01]]


######################## run init

X_C,Y_C = SRNN_init(Xtr,Ytr,K_y,K_x)

NN = NN_model(X_C,Y_C)
Y_predict = NN.predict(Xtr)
Error_train = np.sum(np.sum(np.power((Ytr-Y_predict),2),axis=1)**1)/Xtr.shape[0]
print(Error_train)
Y_predict = NN.predict(X_test)
Error_test = np.sum(np.sum(np.power((Y_test-Y_predict),2),axis=1)**1)/X_test.shape[0]
print(Error_test)

SRNN = SRNN_Model(X_C,Y_C)

SRNN.fit(Xtr,Ytr,lr=1e-4, X_test = X_test, Y_test = Y_test,iter=10)
SRNN.fit(Xtr,Ytr,lr=1e-5, X_test = X_test, Y_test = Y_test,iter=10)
SRNN.fit(Xtr,Ytr,lr=1e-6, X_test = X_test, Y_test = Y_test,iter=10)
SRNN.fit(Xtr,Ytr,lr=1e-7, X_test = X_test, Y_test = Y_test,iter=10)
SRNN.fit(Xtr,Ytr,lr=1e-8, X_test = X_test, Y_test = Y_test,iter=10)
SRNN.fit(Xtr,Ytr,lr=1e-9, X_test = X_test, Y_test = Y_test,iter=10)
print('SRNN Model:')
Y_predict = SRNN.predict(Xtr)
Error_train = np.sum(np.sum(np.power((Ytr-Y_predict),2),axis=1)**1)/Xtr.shape[0]
print(Error_train)
Y_predict = SRNN.predict(X_test)
Error_test = np.sum(np.sum(np.power((Y_test-Y_predict),2),axis=1)**1)/X_test.shape[0]
print(Error_test)