#%% import and function module

import numpy as np
import sys
import os
import time
from sklearn.metrics import mean_absolute_error

run_path = 'D://研究生学习//论文部分//Long-Short Term Identification//ICLR//程序//try_code'
if os.path.abspath('.') != run_path:
    os.chdir(run_path)
sys.path.append(run_path)

from Methods.Preprocessing import Pre_Processing
import Utils.DFT as DFT
import Utils.Optimize as Opt

#%% Global Predict

def Global_Predict(dataset, PP, predict_len_new):

    max_k_global = PP.max_k_global
    foriour_len_global = PP.foriour_len_global
    
    
    # Foriour roll paraments
    
    max_k_roll = PP.max_k_roll
    roll_step = predict_len_new
    Fourier_list = []
    Fourier_basis = []
    
    
    # paraments of net
    Net_MAE = np.zeros(PP.variables_len)
    
    
    # global paraments
    SI_Global_MAE = np.zeros(PP.variables_len)
    Basis_Global = np.zeros((max_k_global, PP.variables_len))
    x_global_predict = np.zeros_like(PP.x)
    
    
    start_time = time.time()
    
    for i in range(PP.variables_len):
        
        '''
        var_x = x[:, np.arange(value_need.shape[1]) != i].T
        x_net_pred = lxy_lasso(var_x, x[:,i], train_rate=train_rate, def_alpha=5e-5, it=10000)
        Net_MAE[i] = mean_absolute_error(x_net_pred[train_len:predict_point], x[train_len:predict_point, i])
        '''
    
        
        # global prediction
        Basis_Global[:, i] = DFT.SI_foriour(PP.t[:foriour_len_global], 
                                         PP.x[:foriour_len_global, i], max_k=max_k_global)
        
        basis_need = DFT.SI_basis(PP.t, Basis_Global[:, i], max_k=max_k_global)
        x_global_pred = Opt.SI_optimize(basis_need, PP.x[:, i], train_rate=PP.train_rate, def_alpha=PP.global_opti, it=PP.opti_iter)
        SI_Global_MAE[i] = mean_absolute_error(x_global_pred[PP.train_len:PP.predict_point], PP.x[PP.train_len:PP.predict_point, i])
        x_global_predict[:, i] = x_global_pred
        
        # paraments of Foriour basis
        fourier_cov = set()
        for j in range(0, PP.predict_point, roll_step):
            num = DFT.SI_foriour(PP.t[j:predict_len_new+j],PP.x[j:predict_len_new+j, i], max_k=max_k_roll)
            for k in range(max_k_roll):
                fourier_cov.add(num[k])
        fourier_cov = np.array(list(fourier_cov))
        Fourier_list.append(fourier_cov)
        Fourier_basis.append(DFT.SI_basis(PP.t, fourier_cov, max_k=len(fourier_cov)))

    return x_global_predict, basis_need, Fourier_basis, SI_Global_MAE, start_time

    
    
    