import argparse
import importlib
import os

import numpy as np
import torch
from torch.utils.data import DataLoader

from easytsf.data.dataset import data_provider
from easytsf.util.metrics import eval_metrics
from tqdm import tqdm

def load_config(dataset_name, exp_type, **kwargs):
    # 加载 exp_conf
    exp_conf_module = importlib.import_module('config.base_conf.{}'.format(exp_type))
    exp_conf = exp_conf_module.exp_conf

    # 加载 data_conf
    data_conf_module = importlib.import_module('config.data_config')
    data_conf = eval('data_conf_module.{}_conf'.format(dataset_name))

    final_config = {**data_conf, **exp_conf}
    for k in kwargs:
        final_config[k] = kwargs[k]

    return final_config


class HI:
    """
    Paper: Historical Inertia: A Neglected but Powerful Baseline for Long Sequence Time-series Forecasting
    Link: https://arxiv.org/abs/2103.16349
    """

    def __init__(self, hist_len: int, pred_len: int, reverse=False):
        """
        Args:
            hist_len (int): input time series length
            pred_len (int): prediction time series length
            reverse (bool, optional): if reverse the prediction of HI. Defaults to False.
        """

        super(HI, self).__init__()
        assert hist_len >= pred_len, "HI model requires input length > output length"
        self.hist_len = hist_len
        self.pred_len = pred_len
        self.reverse = reverse

    def forward(self, history_data: torch.Tensor) -> torch.Tensor:
        """Forward function of HI.

        Args:
            history_data (torch.Tensor): shape = [L_hist, N]

        Returns:
            torch.Tensor: model prediction [L_pred, N].
        """
        prediction = history_data[ -self.pred_len:, :]
        if self.reverse:
            prediction = prediction.flip(dims=[1])
        return prediction


def inverse_transform(data, mean, std):
    return (data * std) + mean


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--dataset_name", default='ShenZhenETF', type=str)
    parser.add_argument("-e", "--exp_type", default='traffic_forecasting', type=str)
    parser.add_argument("--reverse", default=0, type=int)
    parser.add_argument("--data_root", default="E:\\time-series forecasting\\easytsf\\dataset", type=str, help="data root")
    args = parser.parse_args()

    config = load_config(args.dataset_name, args.exp_type, data_root=args.data_root, reverse=args.reverse)
    test_set = data_provider(config, mode='test')
    test_loader = DataLoader(test_set, batch_size=1, num_workers=2, shuffle=False)
    model = HI(hist_len=config['hist_len'], pred_len=config['pred_len'], reverse=config['reverse'])

    stat = np.load(os.path.join(args.data_root, args.dataset_name, 'var_scaler_info.npz'))
    mean = torch.tensor(stat['mean']).float()
    std = torch.tensor(stat['std']).float()

    prediction, label = [], []
    for batch in tqdm(test_loader):
        raw_x = inverse_transform(batch[0].float()[..., 0], mean, std)
        raw_y = inverse_transform(batch[1].float()[..., 0], mean, std)
        prediction.append(model.forward(raw_x))
        label.append(raw_y)
    prediction = torch.cat(prediction)
    label = torch.cat(label)
    mae, rmse, mape, wape = eval_metrics(prediction, label, null_val=0.0)

    print('mae: {}, rmse: {}, mape: {}'.format(mae, rmse, mape))
