from abc import ABC, abstractmethod
import torch
from enum import Enum
from PIL import Image

class DataSplit(Enum):
    TRAIN = "train"
    VAL = "val"
    TEST = "test"

class DatasetWrapper(ABC, torch.utils.data.Dataset):
    """
    Abstract base class for dataset wrappers.
    """

    @abstractmethod
    def __init__(self):
        """
        Initialize the dataset wrapper.
        This method should be implemented by subclasses to set up the dataset.
        """
        pass

    @abstractmethod
    def __format_annotations__(self):
        """
        Format annotations for the dataset.
        This method should be implemented by subclasses to process and format the dataset annotations.
        """
        pass

    def __get_annotation__(self, img_id):
        return self.annotations[img_id]
    
    def __get_image__(self, img_id, sub_id, get_bd=False):
        
        annotation = self.annotations[img_id]
        if get_bd: img_path = annotation[sub_id]['bd_img_path']
        else: img_path = annotation[sub_id]['clean_img_path']
        
        if img_path is None:
            raise ValueError('Image path is None')

        img = Image.open(img_path).convert('RGB')
        return img

    def __getitem_train__(self, index, get_bd=False):
        
        img_id = self.ids[index]
        annotation = self.__get_annotation__(img_id)
        
        if get_bd:
            
            # 1) Return the image and the annotations
            # Training bd images have 1 sub image ONLY
            img = self.__get_image__(img_id, 0, get_bd)
            ann = annotation[0]
            
        else:
            # 2) Return the image and the annotations
            img = self.__get_image__(img_id, 0, get_bd)
            ann = annotation[0]
        
        return img, ann, img_id
    
    def __getitem_test__(self, index, get_bd=False):
        
        img_id = self.ids[index]
        annotation = self.__get_annotation__(img_id)
        
        # 1) Determin how many sub_images are in the image
        num_sub_images = len(annotation)
        
        return_imgs = []
        return_annotations = []
        
        if get_bd:

            # 2) Return the image and the annotations
            # Training bd images have many sub images
            for i in range(num_sub_images):
                img = self.__get_image__(img_id, i, get_bd)
                ann = annotation[i]
                
                return_imgs.append(img)
                return_annotations.append(ann)
        else:
            # 2) Return the image and the annotations
            img = self.__get_image__(img_id, 0, get_bd)
            ann = annotation[0]

            return_imgs.append(img)
            return_annotations.append(ann)
        
        return return_imgs, return_annotations, img_id

    def __getitem__(self, index, get_bd=False):
    
        if self.data_split == DataSplit.TRAIN:
            return self.__getitem_train__(index, get_bd)
        else:
            return self.__getitem_test__(index, get_bd)

    def __len__(self):
        return len(self.ids)