#%% import and function module

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

import sys
import os
import time
from sklearn.exceptions import ConvergenceWarning
from scipy import stats
import warnings
from scipy import interpolate

plt.close('all')
run_path = ''
if os.path.abspath('.') != run_path:
    os.chdir(run_path)
sys.path.append(run_path)


import Utils.DFT as DFT
import Utils.Optimize as Opt
from Methods.Preprocessing import Pre_Processing
import Methods.GlobalPredict as GP
import Methods.Validation as Valid
import Methods.LocalPredict as LP

warnings.filterwarnings("ignore", category=ConvergenceWarning)


#%% dataset

dataset = "ETTh2"

# rate = 1,2,4,7.5

rate = 7.5
predict_len_new = int(96*rate)
train_len_new = 96
data_len = train_len_new + predict_len_new
train_rate_new = 1 - predict_len_new/data_len

PP = Pre_Processing(dataset)
x_global_predict, basis_need, Fourier_basis, SI_Global_MAE, start_time = GP.Global_Predict(
    dataset, PP, predict_len_new)
  

max_k_new, Valid_Net_MAE, Valid_MAE = Valid.SI_Validation(
    dataset, PP, data_len, train_len_new, predict_len_new, 
    x_global_predict, basis_need, Fourier_basis, train_rate_new)


predict_MAE_all, predict_MSE_all, end_time = LP.Local_Predict(dataset, PP, SI_Global_MAE, Valid_MAE, Valid_Net_MAE, data_len,
                  x_global_predict, predict_len_new, max_k_new, basis_need, 
                  Fourier_basis, train_rate_new)




print('MSE:', predict_MSE_all, ' MAE:', predict_MAE_all)
print('time:', end_time - start_time)

