import torch
from torch.utils.data import Dataset
from abc import ABC
from typing import Any, Union, Dict

class BaseDataset(Dataset, ABC):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)



class CustomMaskDataset(BaseDataset, ABC):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


class CustomMask3DDataset(CustomMaskDataset, ABC):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


class Base3DDataset(BaseDataset, ABC):
    """
    Abstract class for 3D datasets.

    This class is intended to be used for 3D datasets where each sample is a 3D volume.
    It is recommended to use the smart collate function for batching this dataset.
    It will help with handling cases where the batch contains None values or dictionaries.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __getitem__(self, index) -> tuple[
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        Dict[str, Any]
    ]:
        """
        Returns a tuple containing:
        - Indices of the samples in the dataset (torch.Tensor)
        - Image data (torch.Tensor)
        - Predicted label (torch.Tensor)
        - True label (torch.Tensor)
        - Additional metadata (dict)
        """
        raise NotImplementedError("This method should be implemented by subclasses.")


OptTensor = Union[torch.Tensor, None]

class DatasetWithRegionsInAnOrgan(Base3DDataset, ABC):
    """
    Abstract class for region-based 3D segmentation datasets.

    This class is intended to be used for data which can be decomposed into:
    - Image data (e.g., CT, MRI)
    - Main organ mask (e.g., lung, brain, etc)
    - Region of interest (ROI) multi-label segmentation (e.g., tumor, lesion, etc)
    - Additional metadata (e.g., patient ID, scan date, etc)

    It is recommended to use the smart collate function for batching this dataset.
    It will help with handling cases where the batch contains None values or dictionaries.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


    def __getitem__(self, index) -> tuple[
            torch.Tensor, 
            torch.Tensor, 
            torch.Tensor, 
            torch.Tensor, 
            OptTensor, 
            OptTensor, 
            Dict[str, Any]
        ]:
        """
        Returns a tuple containing:
        - Indices of samples in the dataset (torch.Tensor)
        - Image data (torch.Tensor)
        - Predicted label (torch.Tensor)
        - True label (torch.Tensor)
        - Main organ mask (Optional torch.Tensor)
        - Region of interest segmentation (Optional torch.Tensor)
        - Additional metadata (dict)
        """
        raise NotImplementedError("This method should be implemented by subclasses.")
