#%% import and function module

import numpy as np
import sys
import os
import time
from sklearn.metrics import mean_absolute_error, mean_squared_error

run_path = ''
if os.path.abspath('.') != run_path:
    os.chdir(run_path)
sys.path.append(run_path)

from Methods.Preprocessing import Pre_Processing
import Utils.DFT as DFT
import Utils.Optimize as Opt

#%% Validation module

def SI_Validation(dataset, PP, data_len, train_len_new, predict_len_new, x_global_predict, 
                basis_need, Fourier_basis, train_rate_new):

    vaild_roll_len = PP.vaild_len - data_len + 1
    max_k_new = train_len_new // PP.valid_div
    Valid_MAE = np.zeros(PP.variables_len)
    Valid_MSE_all = 0
    valid_step = PP.valid_step
    Valid_Net_MAE = np.zeros(PP.variables_len)

    for i in range(0, vaild_roll_len, valid_step):
        
        t_new = PP.t[PP.train_len+i:PP.train_len+i+data_len]
        x_local_predict_step = np.zeros((PP.variables_len, data_len))
        x_net_predict_step = np.zeros((PP.variables_len, data_len))
        x_temp_step = PP.x[PP.train_len+i:PP.train_len+i+data_len,:]
        
        for j in range(PP.variables_len):
            # Local Predict Validation
            
            x_temp = PP.x[PP.train_len+i:PP.train_len+i+data_len,j]
            x_new = x_global_predict[PP.train_len+i:PP.train_len+i+data_len,j]
            
            
            x_temp_new = np.zeros_like(x_temp)
            a = PP.a
            b = PP.b
            c = PP.c
            sep = PP.sep
            sep_len = int(sep*len(x_temp))
            x_new[:sep_len] = a*x_new[:sep_len] + (1-a)*x_new[:sep_len]
            
            weights = (np.linspace(a**(1/b), c**(1/b), len(x_temp)-sep_len-predict_len_new))**(b)
            x_temp_new[sep_len:-predict_len_new] = np.multiply(weights, x_temp[sep_len:-predict_len_new]) + np.multiply(
                1-weights, x_new[sep_len:-predict_len_new])
            x_temp_new[-predict_len_new:] = x_new[-predict_len_new:]
            x_new = x_temp_new
            
            
            basis_need_new = DFT.SI_basis(t_new, DFT.SI_foriour(t_new[:-predict_len_new], 
                                       x_new[:-predict_len_new], max_k=max_k_new),
                                       max_k=max_k_new)
            x_local_predict_vaild = Opt.SI_optimize(
                np.vstack((basis_need[:, PP.train_len+i:PP.train_len+i+data_len], 
                           basis_need_new, Fourier_basis[j][:, PP.train_len+i:PP.train_len+i+data_len])), 
                x_new, train_rate=train_rate_new,def_alpha=PP.valid_opti, it=PP.opti_iter)
            x_local_predict_step[j, :] = x_local_predict_vaild
            Valid_MAE[j] = Valid_MAE[j] + mean_absolute_error(x_local_predict_vaild[-predict_len_new:], 
                                                              x_temp[-predict_len_new:])
            
            Valid_MSE_all = Valid_MSE_all + mean_squared_error(x_local_predict_vaild[-predict_len_new:], 
                                                               x_temp[-predict_len_new:])
        
        # Net Improve Validation
        for j in range(PP.variables_len):
            var_x_local = x_local_predict_step[np.arange(PP.variables_len) != j, :]
            # var_x_local = np.vstack((var_x_local, np.ones((1, var_x_local.shape[1]))))
            # x_net_predict_step[j, :] = Net_model[j].predict(var_x_local.T)
            x_net_predict_step[j, :] = Opt.SI_optimize(var_x_local, 
                                                 x_local_predict_step[j, :], 
                                                 train_rate=1, def_alpha=PP.valid_opti, it=PP.opti_iter)
            
            '''
            plt.figure(figsize=(15,8))
            plt.plot(x_temp_step[-predict_len_new:, j])
            plt.plot(x_net_predict_step[j, -predict_len_new:])
            '''

            Valid_Net_MAE[j] = Valid_Net_MAE[j] + mean_absolute_error(
                x_net_predict_step[j, -predict_len_new:], x_temp_step[-predict_len_new:, j])

            
        
            
    valid_turn_len = len(range(0, vaild_roll_len, valid_step))
    Valid_MAE = Valid_MAE / valid_turn_len
    Valid_Net_MAE = Valid_Net_MAE / valid_turn_len
    Valid_MAE_all = np.mean(Valid_MAE)
    Valid_MSE_all = Valid_MSE_all / valid_turn_len / PP.variables_len
    
    return max_k_new, Valid_Net_MAE, Valid_MAE

