"""
Adapted from https://github.com/pytorch/vision/blob/main/torchvision/datasets/flickr.py
Thanks to the authors of torchvision
"""
from collections import defaultdict
import glob
import os
from collections import defaultdict
from html.parser import HTMLParser
from typing import Any, Callable, Dict, List, Optional, Tuple

from PIL import Image
from torchvision.datasets import VisionDataset

class Flickr(VisionDataset):

    def __init__(
        self,
        root: str,
        ann_file: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        self.ann_file = os.path.expanduser(ann_file)
        data = defaultdict(list)
        with open(ann_file) as fd:
            fd.readline()
            for line in fd:
                line = line.strip()
                if line:
                    # some lines have comma in the caption, se we make sure we do the split correctly
                    img, caption = line.strip().split(".jpg,")
                    img = img + ".jpg"
                    data[img].append(caption)
        self.data = list(data.items())
    
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is a list of captions for the image.
        """
        img, captions = self.data[index]

        # Image
        img = Image.open(os.path.join(self.root, img)).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)

        # Captions
        target =  captions
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


    def __len__(self) -> int:
        return len(self.data)