import torch
import os
import json
from utils.wordnet_helpers import *

INFO_DIR = 'imagenet_info/imagenet_class_index.json'

imagenet_data = json.load(open(INFO_DIR))
imagenet_wnids = [imagenet_data[str(i)][0] for i in range(len(imagenet_data))]


def get_lookuptable(priorities=None, table_type='list', wnids=None):
    if wnids is not None:
        imagenet_hyper_hypo_map_subset = {wnid: imagenet_hyper_hypo_map[wnid] for wnid in wnids}
    else:
        imagenet_hyper_hypo_map_subset = imagenet_hyper_hypo_map
    if priorities is None:
        priorities = {1: {'interval': [1, 1000], 'pick_lowest': True}}

    lookuptable = dict()
    key_hypers = []
    for hypo in imagenet_hypo_hyper_map.keys():
        best_hyper = None
        for priority in priorities.values():
            best_len = priority['interval'][1 if priority['pick_lowest'] else 0]
            for hyper in imagenet_hypo_hyper_map[hypo]:
                if not hyper in imagenet_hyper_hypo_map_subset.keys():
                    continue
                hyper_set = imagenet_hyper_hypo_map_subset[hyper]
                if priority['interval'][0] <= len(hyper_set) <= priority['interval'][1] and ((priority['pick_lowest'] and len(hyper_set) <= best_len) or ((not priority['pick_lowest']) and len(hyper_set) >= best_len)):
                    best_len = len(hyper_set)
                    best_hyper = hyper
            if best_hyper is not None:
                break
        lookuptable[hypo] = imagenet_hyper_hypo_map_subset[best_hyper]
        key_hypers.append(best_hyper)
    if table_type == 'list': return lookuptable, key_hypers
    elif table_type == 'tensor':
        lookuptable_tensor = torch.zeros(1000, 1000)
        for i in range(1000):
            lookuptable_tensor[i, lookuptable[i]] = 1
        return lookuptable_tensor, key_hypers


def display_lookuptable(key_hypers):
    for hypo, hyper in enumerate(key_hypers):
        hypo_str = f'{get_name(imagenet_wnids[hypo])}'
        hyper_str = f'{get_name(hyper)}'
        print(hypo_str + ' ' * (35 - len(hypo_str)) + hyper_str + ' ' * (30 - len(hyper_str))+ f'used by {key_hypers.count(hyper):3.0f} classes\t\tcoarsity: {len(imagenet_hyper_hypo_map[hyper]):3.0f}')


def display_hierarchy():
    for hyper in imagenet_hyper_hypo_map.keys():
        print(f'{hyper}: {get_name(hyper )} ({len(imagenet_hyper_hypo_map[hyper])})')
        print(f'{[get_name(imagenet_wnids[hypo]) for hypo in imagenet_hyper_hypo_map[hyper]]}\n')

imagenet_hyper_hypo_map, imagenet_hypo_hyper_map = get_hyper_hypo_maps(imagenet_wnids)
