import torch.utils.data as data
from PIL import Image
import os
import json
from torchvision import transforms

import numpy as np

def default_loader(path):
    return Image.open(path).convert('RGB')

def load_taxonomy(ann_data, tax_levels, classes):
    # loads the taxonomy data and converts to ints
    taxonomy = {}

    if 'categories' in ann_data.keys():
        num_classes = len(ann_data['categories'])
        for tt in tax_levels:
            tax_data = [aa[tt] for aa in ann_data['categories']]
            _, tax_id = np.unique(tax_data, return_inverse=True)
            taxonomy[tt] = dict(zip(range(num_classes), list(tax_id)))
    else:
        # set up dummy data
        for tt in tax_levels:
            taxonomy[tt] = dict(zip([0], [0]))

    # create a dictionary of lists containing taxonomic labels
    classes_taxonomic = {}
    for cc in np.unique(classes):
        tax_ids = [0]*len(tax_levels)
        for ii, tt in enumerate(tax_levels):
            tax_ids[ii] = taxonomy[tt][cc]
        classes_taxonomic[cc] = tax_ids

    return taxonomy, classes_taxonomic


class INAT(data.Dataset):
    def __init__(self, root, ann_file, is_train=True):

        # load annotations
        print('Loading annotations from: ' + os.path.basename(ann_file))
        with open(ann_file) as data_file:
            ann_data = json.load(data_file)

        # set up the filenames and annotations
        self.imgs = [aa['file_name'] for aa in ann_data['images']]
        self.ids = [aa['id'] for aa in ann_data['images']]

        # if we dont have class labels set them to '0'
        if 'annotations' in ann_data.keys():
            self.classes = [aa['category_id'] for aa in ann_data['annotations']]
        else:
            self.classes = [0]*len(self.imgs)

        # load taxonomy
        self.tax_levels = ['id', 'genus', 'family', 'order', 'class', 'phylum', 'kingdom']
        #8142, 4412,    1120,     273,     57,      25,       6
        self.taxonomy, self.classes_taxonomic = load_taxonomy(ann_data, self.tax_levels, self.classes)

        # print out some stats
        print('\t' + str(len(self.imgs)) + ' images')
        print('\t' + str(len(set(self.classes))) + ' classes')

        self.root = root
        self.is_train = is_train
        self.loader = default_loader

        # augmentation params
        self.im_size = [224, 224]  # can change this to train on higher res
        self.mu_data = [0.485, 0.456, 0.406]
        self.std_data = [0.229, 0.224, 0.225]
        self.brightness = 0.4
        self.contrast = 0.4
        self.saturation = 0.4
        self.hue = 0.25

        # augmentations
        self.resize = transforms.Resize(256)
        self.center_crop = transforms.CenterCrop((self.im_size[0], self.im_size[1]))
        self.scale_aug = transforms.RandomResizedCrop(size=self.im_size[0])
        self.flip_aug = transforms.RandomHorizontalFlip()
        self.color_aug = transforms.ColorJitter(self.brightness, self.contrast, self.saturation, self.hue)
        self.tensor_aug = transforms.ToTensor()
        self.norm_aug = transforms.Normalize(mean=self.mu_data, std=self.std_data)

    def __getitem__(self, index):
        path = self.root + self.imgs[index]
        im_id = self.ids[index]
        img = self.loader(path)
        species_id = self.classes[index]
        tax_ids = self.classes_taxonomic[species_id]

        if self.is_train:
            img = self.scale_aug(img)
            img = self.flip_aug(img)
            img = self.color_aug(img)
        else:
            img = self.resize(img)
            img = self.center_crop(img)

        img = self.tensor_aug(img)
        img = self.norm_aug(img)

        return img, species_id # img, im_id, species_id, tax_ids

    def __len__(self):
        return len(self.imgs)

def get_objnet_mappings(val_loader):
    mappings_folder = '/datasets/objectnet-1.0/mappings/'
    with open(os.path.join(mappings_folder, "objectnet_to_imagenet_1k.json")) as file_handle:
        o_label_to_all_i_labels = json.load(file_handle)

    # now remove double i labels to avoid confusion
    o_label_to_i_labels = {
        o_label: all_i_label.split("; ")
        for o_label, all_i_label in o_label_to_all_i_labels.items()
    }

    # some in-between mappings ...
    o_folder_to_o_idx = val_loader.dataset.class_to_idx
    with open(os.path.join(mappings_folder, "folder_to_objectnet_label.json")) as file_handle:
        o_folder_o_label = json.load(file_handle)

    # now get mapping from o_label to o_idx
    o_label_to_o_idx = {
        o_label: o_folder_to_o_idx[o_folder]
        for o_folder, o_label in o_folder_o_label.items()
    }

    # some in-between mappings ...
    with open(os.path.join(mappings_folder, "pytorch_to_imagenet_2012_id.json")) as file_handle:
        i_idx_to_i_line = json.load(file_handle)
    with open(os.path.join(mappings_folder, "imagenet_to_label_2012_v2")) as file_handle:
        i_line_to_i_label = file_handle.readlines()

    i_line_to_i_label = {
        i_line: i_label[:-1]
        for i_line, i_label in enumerate(i_line_to_i_label)
    }

    # now get mapping from i_label to i_idx
    i_label_to_i_idx = {
        i_line_to_i_label[i_line]: int(i_idx)
        for i_idx, i_line in i_idx_to_i_line.items()
    }

    # now get the final mapping of interest!!!
    o_idx_to_i_idxs = {
        o_label_to_o_idx[o_label]: [
            i_label_to_i_idx[i_label] for i_label in i_labels
        ]
        for o_label, i_labels in o_label_to_i_labels.items()
    }
    i_idx_to_o_idxs = {}
    for k,v in o_idx_to_i_idxs.items():
        for v2 in v:
            i_idx_to_o_idxs[int(v2)] = int(k)
    return i_idx_to_o_idxs, o_idx_to_i_idxs