import json
from pathlib import Path
from typing import Union

import PIL
import PIL.Image
from torch.utils.data import Dataset

class CIRHSDataset(Dataset):
    """
    Composed Image Retrieval on High-quality Systhetic triplets (CIRHS)
    """
    def __init__(self, data_path: Union[str, Path]='cir_datasets', preprocess: callable=None, annotation: Union[str, Path]='annotation.json'):
        """
        Args:
            data_path (Union[str, Path]): root path to the CIRHS dataset
            preprocess (callable): function which preprocesses the image
            annotation (Union[str, Path]): relative path to the annotaion json of CIRHS
        """
        # Set dataset paths and configurations
        data_path = Path(data_path) / 'cirhs'
        self.preprocess = preprocess
        self.data_path = data_path

        # Load Annotation Information
        with open(data_path / annotation, "r") as f:
            self.annotations = json.load(f)
        
        if len(self.annotations) == 0:
            raise IOError("The training data contains noting")
        
        print(f"CIRHSDataset is initialized!")

    def __getitem__(self, index) -> dict:
        """
        Returns a specific item from the dataset based on the index.
        """
        try:
            # Get the triplet id
            tid = self.annotations[index]['triplet_id:']

            # Get relative caption and shared concept
            relative_caption = self.annotations[index]['relative_caption']

            # Get the reference image
            reference_img_path = self.data_path / self.annotations[index]['reference_image_path']
            reference_img = self.preprocess(PIL.Image.open(reference_img_path))

            # Get the target image and ground truth images
            target_img_path = self.data_path / self.annotations[index]['target_image_path']
            target_img = self.preprocess(PIL.Image.open(target_img_path))

            return reference_img, target_img, relative_caption, tid
        except Exception as e:
            print(f"Exception: {e}")

    def __len__(self):
        """
        Returns the length of the dataset.
        """
        return len(self.annotations)