import json
import os
import random
from enum import Enum
from typing import Callable, Dict, List, Optional, Union
from .decoders import ImageDataDecoder, TargetDecoder
from .extended import ExtendedVisionDataset
class _Split(Enum):
    TRAIN = "train"
    VAL = "val"
def read_images_and_captions(root: str, split: _Split) -> List[Dict]:
    image_dir = None
    if _Split(split) == _Split.TRAIN:
        annotations_full_path = os.path.join(
            root, "annotations_trainval2014/annotations/captions_train2014.json"
        )
        image_dir = os.path.join(root, "train2014/train2014")
    else:
        annotations_full_path = os.path.join(
            root, "annotations_trainval2017/annotations/captions_train2017.json"
        )
        image_dir = os.path.join(root, "val2017/val2017")
    with open(annotations_full_path) as f:
        all_annotations = json.load(f)
    data = {}
    for item in all_annotations["images"]:
        id = item["id"]
        data[id] = {
            "id": None,
            "image": os.path.join(image_dir, item["file_name"]),
            "captions": [],
        }
    for item in all_annotations["annotations"]:
        data[item["image_id"]]["id"] = item["image_id"]
        data[item["image_id"]]["captions"].append(item["caption"])
    return list(data.values())
class CocoCaptions(ExtendedVisionDataset):
    Split = Union[_Split]
    def __init__(
        self,
        *,
        split: "CocoCaptions.Split",
        root: Optional[str] = None,
        transforms: Optional[Callable] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super().__init__(
            root=root,
            transforms=transforms,
            transform=transform,
            target_transform=target_transform,
            image_decoder=ImageDataDecoder,
            target_decoder=TargetDecoder,
        )
        self.image_captions = read_images_and_captions(root, split)
    def get_image_relpath(self, index: int) -> str:
        image_path = self.image_captions[index]["image"]
        return image_path
    def get_image_data(self, index: int) -> bytes:
        image_path = self.get_image_relpath(index)
        with open(image_path, mode="rb") as f:
            image_data = f.read()
        return image_data
    def get_target(self, index: int) -> str:
        return random.choice(self.image_captions[index]["captions"])
    def __len__(self) -> int:
        return len(self.image_captions)
