import torch
from torch.utils.data import Dataset
import numpy as np

class CustomEEGDataset(Dataset):
    def __init__(self, data_path, labels_path, transform=None):
        """
        Initializes the dataset.
        """
        # Load data and labels
        self.eegs = np.load(data_path)
        self.labels = np.load(labels_path)
        self.transform = transform


    def __len__(self):
        """
        Returns the total number of samples.
        """
        return self.eegs.shape[2]

    def __getitem__(self, index):
        """
        Retrieves a single sample and its label.
        """
        # Debugging information before accessing a sample
        #print(f"Accessing sample at index: {index}")
        #print(f"EEG data shape before indexing: {self.eegs.shape}")

        # Extract the sample (transposing to make channels last)
        X = self.eegs[:, :, index] #.transpose((1, 0))  # Now shape is (2000, 8)

        # Clip the data between -1024 and 1024
        X = np.clip(X, -1024, 1024)
        X = np.nan_to_num(X, nan=0) / 32.0

        # Apply any transformations
        # if self.transform:
        #     X = self.transform(X)

        # Fetch the label
        y = self.labels[index]

        # Convert to torch tensors
        X = torch.tensor(X, dtype=torch.float32)
        
        y = torch.tensor(y, dtype=torch.long)  # Assuming classification

        # Debugging information after processing the sample
        #print(f"Sample shape after processing: {X.shape}")
        #print(f"Label shape: {y}")

        return {'input': X, 'label': y}
