#%% import and function module

import numpy as np
import sys
import os
import time
from sklearn.metrics import mean_absolute_error, mean_squared_error

run_path = ''
if os.path.abspath('.') != run_path:
    os.chdir(run_path)
sys.path.append(run_path)

from Methods.Preprocessing import Pre_Processing
import Utils.DFT as DFT
import Utils.Optimize as Opt

#%% Local predict

def Local_Predict(dataset, PP, SI_Global_MAE, Valid_MAE, Valid_Net_MAE, data_len,
                  x_global_predict, predict_len_new, max_k_new, basis_need, 
                  Fourier_basis, train_rate_new):

    is_global = (SI_Global_MAE / Valid_MAE < PP.local_judge_1) & (SI_Global_MAE < PP.local_judge_2)
    is_net = Valid_Net_MAE < Valid_MAE

    predict_roll_len = PP.predict_len - data_len + 1
    predict_MAE_all = 0
    predict_MSE_all = 0
    predict_MAE = np.zeros(PP.variables_len)
    predict_MSE = np.zeros(PP.variables_len)
    predict_step = PP.predict_step
    predict_turn_len = len(range(0, predict_roll_len, predict_step))
    x_predict = np.zeros((predict_turn_len, PP.variables_len, data_len))
    x_global = np.zeros((predict_turn_len, PP.variables_len, data_len))
    x_real = np.zeros((predict_turn_len, PP.variables_len, data_len))
    t_real = np.zeros((predict_turn_len, data_len))
    temp_i = 0


    for i in range(0, predict_roll_len, predict_step):
        t_new = PP.t[PP.predict_point+i:PP.predict_point+i+data_len]
        x_local_predict_step = np.zeros((PP.variables_len, data_len))
        x_temp_step = PP.x[PP.predict_point+i:PP.predict_point+i+data_len,:]
        for j in range(PP.variables_len):
            x_temp = PP.x[PP.predict_point+i:PP.predict_point+i+data_len,j]
            
            if is_global[j]:
                x_local_predict = x_global_predict[PP.predict_point+i:PP.predict_point+i+data_len,j]
            else:
                x_new = x_global_predict[PP.predict_point+i:PP.predict_point+i+data_len,j]
                
                x_temp_new = np.zeros_like(x_temp)
                a = PP.a
                b = PP.b
                c = PP.c
                sep = PP.sep
                sep_len = int(sep*len(x_temp))
                x_new[:sep_len] = a*x_new[:sep_len] + (1-a)*x_new[:sep_len]
                
                weights = (np.linspace(a**(1/b), c**(1/b), len(x_temp)-sep_len-predict_len_new))**(b)
                x_temp_new[sep_len:-predict_len_new] = np.multiply(weights, x_temp[sep_len:-predict_len_new]) + np.multiply(
                    1-weights, x_new[sep_len:-predict_len_new])
                x_temp_new[-predict_len_new:] = x_new[-predict_len_new:]
                x_new = x_temp_new
                
                
                basis_need_new = DFT.SI_basis(t_new, DFT.SI_foriour(t_new[:-predict_len_new], 
                                           x_new[:-predict_len_new], max_k=max_k_new),
                                           max_k=max_k_new)
                x_local_predict = Opt.SI_optimize(
                    np.vstack((basis_need[:, PP.predict_point+i:PP.predict_point+i+data_len], 
                               basis_need_new, Fourier_basis[j][:, PP.predict_point+i:PP.predict_point+i+data_len])), 
                    x_new, train_rate=train_rate_new,def_alpha=PP.local_opti, it=PP.opti_iter)
            x_local_predict_step[j, :] = x_local_predict
            
        x_net_predict_step = np.zeros((PP.variables_len, data_len))
        
        for j in range(PP.variables_len):
            if not is_global[j] and is_net[j]:
                var_x_local = x_local_predict_step[np.arange(PP.variables_len) != j, :]
                x_net_predict_step[j, :] = Opt.SI_optimize(var_x_local, 
                                                     x_local_predict_step[j, :], 
                                                     train_rate=1, def_alpha=PP.local_opti_2, it=PP.opti_iter)
        x_local_predict_step[is_net&(~is_global)] = x_net_predict_step[is_net&(~is_global)]
        
        for j in range(PP.variables_len):
            predict_MAE_all = predict_MAE_all + mean_absolute_error(x_local_predict_step[j, -predict_len_new:], 
                                                                    x_temp_step[-predict_len_new:, j])
            predict_MAE[j] = predict_MAE[j] + mean_absolute_error(x_local_predict_step[j, -predict_len_new:], 
                                                                  x_temp_step[-predict_len_new:, j])
            predict_MSE_all = predict_MSE_all + mean_squared_error(x_local_predict_step[j, -predict_len_new:], 
                                                                   x_temp_step[-predict_len_new:, j])
            predict_MSE[j] = predict_MSE[j] + mean_squared_error(x_local_predict_step[j, -predict_len_new:], 
                                                                  x_temp_step[-predict_len_new:, j])
        t_real[temp_i, :] = t_new
        x_predict[temp_i, :, :] = x_local_predict_step
        x_global[temp_i, :, :] = x_global_predict[PP.predict_point+i:PP.predict_point+i+data_len,:].T
        x_real[temp_i, :, :] = PP.x[PP.predict_point+i:PP.predict_point+i+data_len, :].T
        temp_i = temp_i + 1
        

    predict_MAE_all =  predict_MAE_all / PP.variables_len / predict_turn_len
    predict_MSE_all =  predict_MSE_all / PP.variables_len / predict_turn_len
    predict_MAE = predict_MAE / predict_turn_len
    predict_MSE = predict_MSE / predict_turn_len
    end_time = time.time()
        
    return predict_MAE_all, predict_MSE_all, end_time

