import os
from typing import Tuple, Any

from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import INaturalist


class iNaturalistSubset(Dataset):

    def __init__(self, dataset: INaturalist, category_type: str, category_id: int) -> None:
        # copy all relevant data from INaturalist
        self.transform = dataset.transform
        self.target_transform = dataset.target_transform
        self.root = dataset.root

        samples = []
        cur_label = None
        class_counter = -1
        orig_class_ids = []
        for _, (label, image_path) in enumerate(dataset.index):
            if dataset.categories_map[label][category_type] == category_id:
                if label != cur_label:
                    cur_label = label
                    class_counter += 1
                    orig_class_ids.append(label)
                category = dataset.all_categories[label]
                samples.append((os.path.join(dataset.root, category, image_path), class_counter))

        self.samples = samples
        self.image_paths = [s[0] for s in samples]
        self.targets = [s[1] for s in samples]
        self.orig_class_ids = orig_class_ids

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        with open(path, "rb") as f:
            sample = Image.open(f).convert("RGB")

        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)
