import os
import torch
import numpy as np
from torch.utils.data import Dataset
from sktime.datasets import load_from_tsfile

def Reg(tensor):
    scale = 1.0 / (tensor.max(dim=1, keepdim=True)[0] - tensor.min(dim=1, keepdim=True)[0]) 
    tensor.mul_(scale).sub_(tensor.min(dim=1, keepdim=True)[0])
    return tensor

class TSDataset(Dataset):
    def __init__(self, datasetname, type = 'TRAIN', path = 'Data/UCR'):
        self.path = os.path.join(path, datasetname)
        self.label = []
        self.data = []
        with open(os.path.join(self.path, datasetname + '_' + type + '.tsv')) as f:
            for line in f:
                line = line.split('\t')
                self.label.append(float(line[0]))
                self.data.append([float(t) for t in line[1:]])
        self.label = torch.tensor(self.label)
        self.data = Reg(torch.tensor(self.data))
        if self.label.min() < 0:
            self.label = (self.label + 1) / 2
        elif self.label.min() == 1:
            self.label -= 1

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

class TSMDataset(Dataset):
    def __init__(self, datasetname, type = 'TRAIN', path = 'Data/UEA'):

        self.path = os.path.join(path, datasetname)
        self.data, self.label = load_from_tsfile(os.path.join(self.path, datasetname + '_' + type + '.ts'), return_data_type="numpy3d")
        # self.data = convert(self.data)
        # print(self.label)
        self.label = torch.tensor(np.unique(self.label, return_inverse=True)[1])
        self.data = Reg(torch.tensor(self.data, dtype=torch.float32))
        if self.label.min() < 0:
            self.label = (self.label + 1) / 2
        elif self.label.min() == 1:
            self.label -= 1

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]