import torch
from torch.utils.data import Dataset
import numpy as np
from scipy.signal import find_peaks, peak_prominences, savgol_filter
from gluonts.dataset.split import TrainingDataset

BASE_SEASON = 24.

def get_fixed_factor(freq: str, domain=None) -> float:
    has_weekly = domain in ['Transport' , 'Healthcare', 'Sales'] # only human-rythm dependent domains have weekly cylces (not for example nature)
    # make different groups, and normalize relative to group
    if freq == '4S': ############### sub any reasonable cycle group --> Hourly
        factor = BASE_SEASON / (3600. / 4)
    elif freq == '10S':
        factor = BASE_SEASON / 360
    elif freq == 'T': ############ sub day group
        factor = BASE_SEASON / (24.*60) # [24*60]
    elif freq[-1] == 'T':
        n_min = int(freq[:-1])
        factor = BASE_SEASON / (24*60 / n_min)
    elif freq == 'H':
        factor = BASE_SEASON / 24
    elif freq == '6H':
        factor = BASE_SEASON / 4
    elif freq == 'D':
        if has_weekly: 
            factor = BASE_SEASON / 7
        else:
            factor = BASE_SEASON / 365 ############## sub year group
    elif freq[-1] == 'D' and 'WED' not in freq:
        n = int(freq[:-1])
        if has_weekly: 
            factor = BASE_SEASON / 7
        else:
            factor = BASE_SEASON / 365 ############## sub year group
        factor *= n
    elif freq == 'W' or 'W-' in freq:
        factor =  BASE_SEASON / (365./7)
    elif freq == 'M' or 'M-' in freq: factor = BASE_SEASON / 12
    elif 'Q' in freq: factor = BASE_SEASON / 4.0 # 'Q' or 'Q-Month'
    elif 'A' in freq: factor = BASE_SEASON / 4.  # leap year ??
    else:
        raise NotImplementedError('{freq} not implemented. Add {freq} option to this method')
    return factor

class ValidWrapper(Dataset):
    def __init__(
            self, 
            gluon_dataset: TrainingDataset,
            target_length: int,
            min_seq: int = 10,
            pretrain_context = 2048,
            max_time_points = 500 * 2048
    ):
        """Wrapper to create Gift Test Dataset from Gift Valid dataset.

        Args:
            gluon_dataset (TrainingDataset): Validation dataset.
            target_length (int): Target length.
        """
        self.input = []
        self.label = []
        gl = list(gluon_dataset)
        f = gl[0]['freq']
        sf = get_fixed_factor(f)
        n_timepoints = sum([min(len(e['target']), pretrain_context / sf) for e in gl]) - target_length
        every_nth = max(1, int(n_timepoints / max_time_points)) # to save some time don't evaluate every series if there is a ton of data
    
        for el in gl[::every_nth]:
            if len(el['target']) - target_length > min_seq:
                if el['target'].ndim == 1:
                    self.input.append({'item_id': el['item_id'], 'start': el['start'], 'freq': el['freq'], 'target': el['target'][:-target_length]})
                    self.label.append({'item_id': el['item_id'], 'start': el['start'], 'freq': el['freq'], 'target': el['target'][-target_length:]})
                else:
                    raise NotImplementedError("Assumes that to_univariate=True for now.")
        if len(self.input) == 0:
            raise ValueError("Invalid Dataset")
    def __len__(self):
        return len(self.input)  # = num of series

    def __getitem__(self, idx):
        return [self.input[idx], self.label[idx]]