import os
import torch
from torch.utils.data import Dataset
import pickle


class TUABDataset(Dataset):
    def __init__(self, data_path, transform=None, squeeze=False, scale=False):
        self.data_path = data_path
        self.files = os.listdir(self.data_path)
        self.transform = transform
        self.squeeze = squeeze
        self.scale = scale

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        sample = pickle.load(open(os.path.join(self.data_path, self.files[index]), "rb"))
        X = sample["X"]
        Y = sample["y"]
        X = torch.FloatTensor(X)
        # Unsqueeze the input tensor to add a channel dimension
        #(C, T) -> (1, C, T)
        if(self.squeeze):
            X = X.unsqueeze(0)
        if self.transform is not None:
            X = self.transform(X)
        if self.scale:
            # min-max to [-1, 1]
            max_X = X.max()
            min_X = X.min()
            X = (X - min_X) / (max_X - min_X) # [0, 1]
            X = (X - 0.5) * 2 # [-0.5, 0.5] -> [-1, 1]
        return X, Y
    
