import os
import uuid
import torch
import shutil
from torcheeg.datasets import BCICIV2aDataset
from typing import Dict, Any

class CustomBCICIV2aDataset(BCICIV2aDataset):
    def __init__(self, skip_trial_with_artifacts=True, io_path=None, *args, **kwargs):
        if io_path is None:
            # Generate a unique path if not provided
            unique_id = f"{os.getpid()}_{uuid.uuid4().hex}"
            io_path = os.path.join('.torcheeg_cache', f'process_{unique_id}')
        
        # Ensure the directory exists
        os.makedirs(os.path.dirname(io_path), exist_ok=True)
        
        # If the directory already exists, try to remove it
        if os.path.exists(io_path):
            try:
                shutil.rmtree(io_path)
            except OSError:
                # If removal fails, use a new unique path
                unique_id = f"{os.getpid()}_{uuid.uuid4().hex}"
                io_path = os.path.join('.torcheeg_cache', f'process_{unique_id}')
        
        super().__init__(*args, skip_trial_with_artifacts=skip_trial_with_artifacts, io_path=io_path, **kwargs)

    def __getitem__(self, index):
        eeg_signal, metadata = super().__getitem__(index)
        
        # Convert EEG sample to tensor
        eeg_tensor = torch.from_numpy(eeg_signal).float()
        # Remove the first 470/250Hz = 1.88 seconds out of the signal out of length 7s*250Hz=1750 points
        # This still preserves the cue + motor imagery reaction from each subject
        eeg_tensor = eeg_tensor[:, 470:] 
        
        # zero pad along channel dimension (22 channels available, 23 needed)
        eeg_tensor = eeg_tensor.T
        eeg_tensor = torch.nn.functional.pad(eeg_tensor, (0, 1), "constant", 0.0).T
        
        # Extract the label from the metadata
        label = metadata['label'] - 1  # Subtract 1 to convert from 1-4 range to 0-3 range
        assert 0 <= label <= 3
        
        return {'input': eeg_tensor, 'label': label}

    def __del__(self):
        # Clean up the cache directory when the dataset object is destroyed
        if hasattr(self, 'io_path') and os.path.exists(self.io_path):
            try:
                shutil.rmtree(self.io_path)
            except OSError:
                pass
