#!/usr/bin/env python3

"""JSON dataset: support CUB, NABrids, Flower, Dogs and Cars"""

import json
import os
from typing import Union

import torch.utils.data
import torchvision as tv


def read_json(filename: str) -> Union[list, dict]:
    """read json files"""
    with open(filename, "rb") as fin:
        data = json.load(fin, encoding="utf-8")
    return data


class JSONDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, percentage=1.0, split="train", transform=None):
        self._split = split
        self.data_dir = data_dir
        self.data_percentage = percentage
        self._construct_imdb()
        self.transform = transform

    def get_anno(self):
        anno_path = os.path.join(self.data_dir, "{}.json".format(self._split))
        if "train" in self._split:
            if self.data_percentage < 1.0:
                anno_path = os.path.join(
                    self.data_dir,
                    "{}_{}.json".format(self._split, self.data_percentage)
                )
        assert os.path.exists(anno_path), "{} dir not found".format(anno_path)
        return read_json(anno_path)

    def get_image_dir(self):
        raise NotImplementedError()

    def _construct_imdb(self):
        """Constructs the imdb."""

        img_dir = self.get_image_dir()
        assert os.path.exists(img_dir), "{} dir not found".format(img_dir)

        anno = self.get_anno()
        # Map class ids to contiguous ids
        self._class_ids = sorted(list(set(anno.values())))
        self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)}

        # Construct the image db
        self._imdb = []
        for img_name, cls_id in anno.items():
            cont_id = self._class_id_cont_id[cls_id]
            im_path = os.path.join(img_dir, img_name)
            self._imdb.append({"im_path": im_path, "class": cont_id})

    def __getitem__(self, index):
        # Load the image
        im = tv.datasets.folder.default_loader(self._imdb[index]["im_path"])
        label = self._imdb[index]["class"]
        if self.transform is not None:
            im = self.transform(im)
        return im, label

    def __len__(self):
        return len(self._imdb)


class CUB200Dataset(JSONDataset):
    """CUB_200 dataset."""

    def __init__(self, data_dir, percentage=1.0, split="train", transforms=None):
        super(CUB200Dataset, self).__init__(data_dir, percentage, split, transforms)

    def get_image_dir(self):
        return os.path.join(self.data_dir, "images")


class StanfordDogs(JSONDataset):
    """stanford-dogs dataset."""

    def __init__(self, data_dir, percentage=1.0, split="train", transforms=None):
        super(StanfordDogs, self).__init__(data_dir, percentage, split, transforms)

    def get_image_dir(self):
        return os.path.join(self.data_dir, "Images")


class NabirdsDataset(JSONDataset):
    """Nabirds dataset."""

    def __init__(self, data_dir, percentage=1.0, split="train", transforms=None):
        super(NabirdsDataset, self).__init__(data_dir, percentage, split, transforms)

    def get_image_dir(self):
        return os.path.join(self.data_dir, "images")

class OxfordFlowers(JSONDataset):
    """flowers dataset."""

    def __init__(self, data_dir, percentage=1.0, split="train", transforms=None):
        super(OxfordFlowers, self).__init__(data_dir, percentage, split, transforms)

    def get_image_dir(self):
        return self.data_dir

