from numbers import Number
from typing import Callable, Dict, Tuple, List, Any, Optional
import os

from PIL.Image import Image

from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
from torchvision.datasets.folder import default_loader

from cv_lib.classification.data.imagenet import MEAN, STD


class QueryDataset(Dataset):
    def __init__(
        self,
        root: str,
        img_list_fp: str,
        resize: Optional[Tuple[int]] = (224, 224),
        augmentations: Callable[[Image, Dict[str, Any]], Tuple[Image, Dict[str, Any]]] = None,
        dataset_mean: List[float] = MEAN,
        dataset_std: List[float] = STD,
        img_channels: int = 3,
    ):
        """
        resize: (h, w)
        """
        self.root = os.path.expanduser(root)
        if isinstance(resize, list):
            resize = tuple(resize)
        elif isinstance(resize, Number):
            resize = (resize, resize)
        self.resize = resize
        self.augmentations = augmentations

        self.dataset_mean: List[float] = dataset_mean
        self.dataset_std: List[float] = dataset_std
        self.img_channels: int = img_channels

        self._init_dataset(img_list_fp)

    def __len__(self) -> int:
        return len(self.img_list)

    def _init_dataset(self, img_list_fp: str):
        with open(img_list_fp) as file:
            img_list = file.readlines()
        img_list = [line.strip("\r\n") for line in img_list]
        self.img_list = img_list

    def get_image(self, index: int) -> Image:
        fp = os.path.join(self.root, self.img_list[index])
        img = default_loader(fp)
        return img

    def get_img_info(self, index: int) -> Dict[str, Any]:
        ret = dict(
            image_id=index,
            img_fp=os.path.join(self.root, self.img_list[index])
        )
        return ret

    def __getitem__(self, index: int) -> Tuple[Tensor, Dict[str, Any]]:
        """
        Return image and target where target is a dictionary e.g.
            target: {
                image_id: str or int
                *OTHER_INFO*: other information
            }
        """
        img = self.get_image(index)

        target: Dict[str, Any] = self.get_img_info(index)

        if self.augmentations is not None:
            img, target = self.augmentations(img, target)

        if self.resize is not None:
            img = TF.resize(img, self.resize)
        img = TF.to_tensor(img)
        img = TF.normalize(img, self.dataset_mean, self.dataset_std, inplace=True)
        return img, target

