import torch
from my_utils.utils import load_pkl, DATA_ROOT,DATA_ROOT_TSC
from torch.utils.data import Dataset
from os.path import join
import numpy as np
from scipy import signal

class MyDataset(Dataset):
    def __init__(self,dataset_name,mode,context_length,prediction_length,count_max,use_filter,window,order) -> None:
        '''
        mode: string, train or test
        count_max: int, number of samples
        '''
        super().__init__()
        data_dir = DATA_ROOT
        data = load_pkl(join(data_dir,f'{dataset_name}_{mode}.pkl')) if not use_filter else load_pkl(join(data_dir,f'{dataset_name}_{mode}_window_{window}_order_{order}.pkl'))
        self.data = []
        self.context_length = context_length
        self.prediction_length = prediction_length
        length = context_length + prediction_length
        count = 0
        for item in data:
            start = 0
            while start + length <= item.shape[0] and count < count_max:
                self.data.append(torch.from_numpy(item[start:start+length]).float())
                start += 1
                count += 1
        self.length = len(self.data)
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, index):
        return self.data[index][0:self.context_length], self.data[index][self.context_length:]

    def normalize(self,context,target,eps=1e-12):
        """
        Calculate mean and std only on context

        Instance normalization

        context.shape = [B,T]

        target.shape = [B,T]

        return:

            context_normalized, shape = [B,T]

            target_normalized, shape = [B,T]

            mean, shape = [B,1]
            
            std, shape =[B,1]
        """
        assert len(context.shape) == 2 and context.shape[-1] == self.context_length
        assert len(target.shape) == 2 and target.shape[-1] == self.prediction_length

        mean = context.mean(dim=1,keepdim=True)
        std = context.std(dim=1,keepdim=True)

        xy = torch.cat([context, target],dim=1)
        xy_n = (xy - mean) / (std+eps)
        context_normalized = xy_n[:,0:self.context_length]
        target_normalized = xy_n[:,self.context_length:]
        return context_normalized, target_normalized, mean, std
    
    def de_normalize(self,context_normalized,target_normalized,prediction_normalized,mean,std,eps=1e-12):
        """
        Input:
            context_normalized.shape = BT
            target_normalized.shape = BT
            prediction_normalized.shape = BT
            mean, std = B
        Output:
            context.shape = BT
            target.shape = BT
            prediction.shape = BT
        """
        assert context_normalized.shape[-1] == self.context_length
        assert target_normalized.shape[-1] == self.prediction_length
        assert prediction_normalized.shape[-1] == self.prediction_length

        context = context_normalized * (std+eps) + mean
        target = target_normalized * (std+eps) + mean
        prediction = prediction_normalized * (std+eps) + mean
        return context, target, prediction
    


    def get_mask(self,context,length_mask,step=1):
        """
        add the shifting step for mask

        mask = 1 means preservation

        mask = 0 means drop corresponding value

        Input:

            context: shape = [T]

        Output:

            context: shape = [num_mask, T]

            mask: shape = [num_mask, T]
        """
        assert len(context.shape) == 1 and context.shape[0] == self.context_length
        num_mask = (self.context_length - length_mask)//step + 1
        context = context.unsqueeze(0).repeat(num_mask,1)
        mask = np.ones(context.shape)

        for mask_index in range(num_mask):
            start = mask_index * step
            mask[mask_index,start:start+length_mask] = 0
        mask = torch.from_numpy(mask).to(dtype=torch.float32)
        return context, mask



class DatasetTSC(Dataset):
    def __init__(
        self,
        data_root = DATA_ROOT_TSC,
        dataset_name = None,
        mode = 'train',
        normalization_method = 'instance',
        use_filter = True,
        window = 15,
        order = 5
    ) -> None:
        '''
        Dataset for Time Series Classification
        mode: train or test
        '''
        super().__init__()
        # Univariate_arff/DiatomSizeReduction/DiatomSizeReduction_TRAIN.txt
        mode = mode.upper()
        path = f'{dataset_name}_{mode}.txt'
        path = join(data_root,dataset_name,path)
        data = np.loadtxt(path)
        label = data[:,0].astype(np.int64)
        context = data[:,1:].astype(np.float32)

        if use_filter:
            context_filter = []
            for series in context:
                context_filter.append(signal.savgol_filter(series,window,order))
            context = np.stack(context_filter,axis=0)

        label_unique = np.unique(label)
        mapping = {
            old:new for new, old in enumerate(label_unique)
        }

        for i in range(label.shape[0]):
            label[i] = mapping[label[i]]
        
        self.label = torch.from_numpy(label).to(torch.int64)
        self.context = torch.from_numpy(context).to(torch.float32)
        self.length = label.shape[0]

        self.context_length = self.context.shape[-1]
        self.num_class = len(set(self.label.tolist()))

        if normalization_method == 'instance':
            mean = self.context.mean(dim=-1,keepdim=True)
            std = self.context.std(dim=-1,keepdim=True)
            self.context = (self.context - mean) / (std + 1e-12)
        elif normalization_method == 'none':
            pass
        elif normalization_method == 'global':
            raise NotImplementedError(normalization_method)
        else:
            raise NotImplementedError(normalization_method)
    
    def get_info(self):
        '''
        return:
            ctx_len
            n_class
        '''
        return self.context_length, self.num_class
    
    def __len__(self):
        return self.length

    def __getitem__(self, index):
        '''
        Normalized series
        Label
        '''
        return self.context[index], self.label[index]