import os
import uuid
import torch
from torcheeg.datasets import SEEDDataset
from torcheeg import transforms 

class CustomSEEDDataset(SEEDDataset):
    def __init__(self, root_path, num_workers, io_path=None, new_cache_path=None, *args, **kwargs):
        '''
        Initialize CustomSEEDDataset class, a wrapper for torcheeg.datasets.SEEDDataset with custom minmax transform.
        
        Args:
            root_path (string): path to preprocessed data
            num_workers (int): number of workers to function in parallel
            io_path (string, optional): on a first run, torcheeg creates cache data. 
                    If cache data is already available (i.e. running for the second time), 
                    provide the path here to speed up the initialization process. Default None.
            new_cache_path (string): when running this file for the first time, torcheeg will create 
                    cache data. Use this parameter to specify where the cache gets stored. 
                    Default None (will be created in the current working directory)
        '''
        
        # Set up a unique path for caching if none is provided
        if io_path is None:
            unique_id = f"{os.getpid()}_{uuid.uuid4().hex}"
            if new_cache_path is None:
                new_cache_path = ''
            io_path = os.path.join(new_cache_path, f'process_{unique_id}')
        
        os.makedirs(io_path, exist_ok=True)

    
        super().__init__(
            root_path=root_path,
            offline_transform=None,
            num_worker=num_workers,
            online_transform=None,
            label_transform=transforms.Select('emotion'),  # Replace if necessary
            io_path=io_path,
            *args, **kwargs
        )

    def __len__(self):
        return super().__len__()

    def __getitem__(self, index):
        '''
        Retrieve the data sample associated with the given index.
        In the process, min-max normalization is applied.
        
        Args:
            index (int): index of sample to be retrieved
        '''
        eeg_signal, label = super().__getitem__(index)

        eeg_signal = torch.from_numpy(eeg_signal).to(torch.float32)
        
        # Compute min and max per channel
        min_vals, _ = torch.min(eeg_signal, dim=1, keepdim=True)
        max_vals, _ = torch.max(eeg_signal, dim=1, keepdim=True)

        # Min-max normalization to [0, 1] range
        normalized_signal = (eeg_signal - min_vals) / (max_vals - min_vals + 1e-8)

        # Rescale to [-1, 1] range
        rescaled_signal = 2 * normalized_signal - 1

        # Rescale labels from (-1, 0, 1) to (0, 1, 2)
        label = label + 1
        
        return {'input': rescaled_signal, 'label': label}
