from .dataset_basic import BasicDataset
from utils import GlobalConfig
import numpy as np
import os
import pandas as pd
import torch
from dataclasses import dataclass
from sklearn.preprocessing import StandardScaler

@dataclass()
class M4Dataset:
    ids: np.ndarray
    groups: np.ndarray
    frequencies: np.ndarray
    horizons: np.ndarray
    values: np.ndarray

    def load(training: bool = True, dataset_file: str = '../dataset/M4') -> 'M4Dataset':
        """
        load cached dataset
        
        :param training: Load training part if training is True, test part otherwise
        """
        info_file = os.path.join(dataset_file, 'M4-info.csv')
        train_cache_file = os.path.join(dataset_file, 'training.npz')
        test_cache_file = os.path.join(dataset_file, 'test.npz')
        m4_info = pd.read_csv(info_file)
        return M4Dataset(ids=m4_info.M4id.values,
                         groups=m4_info.SP.values,
                         frequencies=m4_info.Frequency.values,
                         horizons=m4_info.Horizon.values,
                         values=np.load(train_cache_file if training else test_cache_file,
                                        allow_pickle=True))


@dataclass()
class M4Meta:
    seasonal_patterns = ['Yearly', 'Quarterly', 'Monthly', 'Weekly', 'Daily', 'Hourly']
    horizons = [6, 8, 18, 13, 14, 48]
    frequencies = [1, 4, 12, 1, 1, 24]
    horizons_map = {
        'Yearly': 6,
        'Quarterly': 8,
        'Monthly': 18,
        'Weekly': 13,
        'Daily': 14,
        'Hourly': 48
    } # different predict length

    frequency_map = {
        'Yearly': 1,
        'Quarterly': 4,
        'Monthly': 12,
        'Weekly': 1,
        'Daily': 1,
        'Hourly': 24
    }

    history_size = {
        'Yearly': 1.5,
        'Quarterly': 1.5,
        'Monthly': 1.5,
        'Weekly': 10,
        'Daily': 10,
        'Hourly': 10
    }


class Dataset_M4(BasicDataset):
    """
    Return:
    X: [B, 1, seq_len]
    Y: [B, 1, pred_len]
    F: [B, seq_len, 1] # encoder time mark
    mask: [B, Llabel+pred_len, 1] # decoder time mark
    """

    def __init__(self, config:GlobalConfig, flag='TRAIN'):
        self.root_path = config.args.dataset_root
        self.data_path = getattr(config.args, "dataset", "M4")
        self.seasonal_patterns = getattr(config.args, "seasonal_patterns", "Monthly")
        assert self.seasonal_patterns in M4Meta.seasonal_patterns, \
            f"seasonal_patterns should be {M4Meta.seasonal_patterns}, got {self.seasonal_patterns}"
        
        self.features = getattr(config.args, "features", "S") # "S" | "M" | "MS"
        self.target = getattr(config.args, "target", "OT")
        self.inverse = getattr(config.args, "inverse", False)
        self.timeenc = 1 if getattr(config.args, "embed", "timeF") == "timeF" else 0

        self.pred_len = M4Meta.horizons_map[self.seasonal_patterns]
        self.label_len = getattr(config.args, "label_len", self.pred_len)
        self.seq_len = getattr(config.args, "seq_len", self.pred_len * 2)
        self.scale = getattr(config.args, "scale", True)

        self.scale = getattr(config.args, "scale", False)
        self._scaler = None

        self.samples_per_series = getattr(config.args, "m4_samples_per_series", 1)
        
        
        super().__init__(config, flag)
    
    def _load_data(self, root_path, dataset_name_unused, flag):
        assert flag in {"TRAIN", "VALI", "TEST"}

        if flag == 'TRAIN':
            m4 = M4Dataset.load(training=True, dataset_file=os.path.join(self.root_path, self.data_path))
        else:
            m4 = M4Dataset.load(training=False, dataset_file=os.path.join(self.root_path, self.data_path))
        
        mask = (m4.groups == self.seasonal_patterns)
        values = np.array([v[~np.isnan(v)] for v in m4.values[mask]], dtype=object) # list of 1D np.array
        ids = np.array([i for i in m4.ids[mask]])


        seq_len, label_len, pred_len = self.seq_len, self.label_len, self.pred_len
        windows_X, windows_Y, marks_enc, marks_dec, id_list = [], [], [], [], []

        if flag == "TRAIN":
            history_size = M4Meta.history_size[self.seasonal_patterns]
            window_sampling_limit = int(history_size * pred_len)
            rng = np.random.default_rng(getattr(self.config.args, "seed", None))

            for sid, ts in zip(ids, values):
                L = len(ts)
                if L < 1:
                    continue
                n_try = self.samples_per_series
                for _ in range(n_try):
                    low_hist = max(1, L - window_sampling_limit)
                    low = max(label_len, low_hist)

                    high = L - pred_len
                    if high <= low:
                        continue

                    cut = int(rng.integers(low=low, high=high + 1)) # (low, L]
                    enc = np.zeros((seq_len, 1), dtype=np.float32)
                    enc_window = ts[max(0, cut-seq_len):cut]
                    enc[-len(enc_window):, 0] = enc_window
                    dec = np.zeros((label_len + pred_len, 1), dtype=np.float32)
                    # dec_window = ts[max(0, cut - label_len):min(L, cut + pred_len)]
                    dec_window = ts[cut - label_len:cut + pred_len]
                    dec[:len(dec_window), 0] = dec_window

                    enc_mark = np.ones((seq_len, 1), dtype=np.float32)
                    dec_mark = np.ones((label_len+pred_len, 1), dtype=np.float32)

                    windows_X.append(enc) # [Lx, 1]
                    windows_Y.append(dec[-pred_len:, :]) # [Ly, 1]
                    marks_enc.append(enc_mark) # [Lx, 1]
                    marks_dec.append(dec_mark) # [Llabel+Ly, 1]
                    id_list.append(sid)
        else:
            for sid, ts in zip(ids, values):
                L = len(ts)
                enc = np.zeros((seq_len, 1), dtype=np.float32)
                enc_window = ts[max(0, L - seq_len): L]
                enc[-len(enc_window):, 0] = enc_window

                dec = np.zeros((label_len + pred_len, 1), dtype=np.float32)
                dec_window = ts[max(0, L - (label_len + pred_len)): L]
                dec[: len(dec_window), 0] = dec_window

                enc_mark = np.ones((seq_len, 1), dtype=np.float32)
                dec_mark = np.ones((label_len + pred_len, 1), dtype=np.float32)

                windows_X.append(enc)
                windows_Y.append(dec[-pred_len:, :])
                marks_enc.append(enc_mark)
                marks_dec.append(dec_mark)
                id_list.append(sid)
        
        X = torch.from_numpy(np.stack(windows_X)).float().permute(0, 2, 1).contiguous()
        Y = torch.from_numpy(np.stack(windows_Y)).float().permute(0, 2, 1).contiguous()
        F = torch.from_numpy(np.stack(marks_enc)).float()
        mask = torch.from_numpy(np.stack(marks_dec)).float()


        self.n_channels = 1
        self.seq_len = seq_len
        self.pred_len = pred_len
        self._ids = id_list

        X, Y, F, mask = X.detach(), Y.detach(), F.detach(), mask.detach()

        return X, Y, F, mask
    
    def inverse_transform(self, data_np):
        return data_np


AVAILABLE_SHORT_TERM_FORECASTING_DATASETS = {
    "M4": Dataset_M4,
    "Default": Dataset_M4
}
