"""
Code adapted from https://github.com/mlfoundations/wise-ft/blob/master/src/datasets/objectnet.py
Thanks to the authors of wise-ft
"""

import os
import json
from pathlib import Path
import PIL

import numpy as np

import torch
from torchvision import datasets
from torchvision.transforms import Compose

from pathlib import Path

def get_metadata(folder):
    metadata = Path(folder)

    with open(metadata / 'folder_to_objectnet_label.json', 'r') as f:
        folder_map = json.load(f)
        folder_map = {v: k for k, v in folder_map.items()}
    with open(metadata / 'objectnet_to_imagenet_1k.json', 'r') as f:
        objectnet_map = json.load(f)

    with open(metadata / 'pytorch_to_imagenet_2012_id.json', 'r') as f:
        pytorch_map = json.load(f)
        pytorch_map = {v: k for k, v in pytorch_map.items()}

    with open(metadata / 'imagenet_to_label_2012_v2', 'r') as f:
        imagenet_map = {v.strip(): str(pytorch_map[i]) for i, v in enumerate(f)}

    folder_to_ids, class_sublist = {}, []
    classnames = []
    for objectnet_name, imagenet_names in objectnet_map.items():
        imagenet_names = imagenet_names.split('; ')
        imagenet_ids = [int(imagenet_map[imagenet_name]) for imagenet_name in imagenet_names]
        class_sublist.extend(imagenet_ids)
        folder_to_ids[folder_map[objectnet_name]] = imagenet_ids

    class_sublist = sorted(class_sublist)
    class_sublist_mask = [(i in class_sublist) for i in range(1000)]
    classname_map = {v: k for k, v in folder_map.items()}
    return class_sublist, class_sublist_mask, folder_to_ids, classname_map

class ObjectNetDataset(datasets.ImageFolder):

    def __init__(self, root, transform):
        (self._class_sublist,
         self.class_sublist_mask,
         self.folders_to_ids,
         self.classname_map) = get_metadata(root)
        subdir = os.path.join(root, "objectnet-1.0", "images")
        label_map = {name: idx for idx, name in enumerate(sorted(list(self.folders_to_ids.keys())))}
        self.label_map = label_map
        super().__init__(subdir, transform=transform)
        self.samples = [
            d for d in self.samples
            if os.path.basename(os.path.dirname(d[0])) in self.label_map
        ]
        self.imgs = self.samples
        self.classes = sorted(list(self.folders_to_ids.keys()))
        self.classes = [self.classname_map[c].lower() for c in self.classes]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        label = os.path.basename(os.path.dirname(path))
        return sample, self.label_map[label]