#!/usr/bin/env python
"""human_categories.py

Code to define the class that deals with the specifics
of the 16 categories used in Robert's human and DNN
experiments.

"""

import numpy as np
import os

import helper_200_only.wordnet_functions as wf


def compute_imagenet_indices_for_category(category):
    """Return list of ImageNet indices that correspond to category.

    'category' is part of the 16 classes.
    """
    assert category in get_human_object_recognition_categories()

    categories = HumanCategories()

    indices = []
    for i in range(0, 1000):
        WNID = wf.get_WNID_from_index(i)
        if categories.get_human_category_from_WNID(WNID) == category:
            indices.append(i)
    return indices


def get_human_object_recognition_categories():
    """Return the 16 categories that are used for the human experiment.

    To be more precise, return the categories that Robert uses in his
    object recognition experiment.
    """

    return sorted(["knife", "keyboard", "elephant", "bicycle", "airplane",
            "clock", "oven", "chair", "bear", "boat", "cat",
            "bottle", "truck", "car", "bird", "dog"])


def get_num_human_categories():
    """Return number of categories used in the object recogn. experiment."""

    return len(get_human_object_recognition_categories())


class HumanCategories(object):

    #Note: Some WNIDs may not be part of the ilsvrc2012 database.

    # Those WNIDs were generated with:
    # wordnet_functions.get_ilsvrc2012_training_WNID("knife") etc.
    # Since this takes some time, they were collected here for
    # a massive speed-up in computation time.
    # Excluded categories were then removed manually.

    knife =    []

    keyboard = ['n03085013']

    elephant = ['n02504458']

    bicycle =  []

    airplane = []

    clock =    []

    oven =     []

    chair =    ['n04099969']

    bear =     ['n02132136']

    boat =     ['n03662601']

    cat =      [ "n02123045", "n02123394", "n02124075",
                "n02125311"]

    bottle =   ['n02823428', 'n03937543', 'n03983396',
                'n04560804']

    truck =    ['n03796401', 'n03977966']

    car =      ['n02814533', 'n03100240', 'n04285008']

    bird =     [ 'n01855672', 'n02002724', 'n02056570']

    dog =      ['n02094433', 'n02099601','n02099712',
                'n02106662', 'n02113799']

    airplane_indices = []
    bear_indices = [14]
    bicycle_indices = []
    bird_indices = [41, 67]
    boat_indices = [196]
    bottle_indices = [45, 50, 98, 118]
    car_indices = [117, 147, 157]
    cat_indices = [0, 66, 102, 131]
    chair_indices = [3]
    clock_indices = []
    dog_indices = [11, 39, 78, 135, 194]
    elephant_indices = [199]
    keyboard_indices = [26]
    knife_indices = []
    oven_indices = []
    truck_indices = [64, 90]


    def get_human_category_from_WNID(self, wnid):
        """Return the MS COCO category for a given WNID.

        Returns None if wnid is not part of the 16 human categories.

        parameters:
        - wnid: a string containing the wnid of an image, e.g. 'n03658185'

        """

        categories = get_human_object_recognition_categories()
        for c in categories:
            attr = getattr(self, c)
            if wnid in attr:
                return c

        return None

    def get_imagenet_indices_for_category(self, category):
        """Return ImageNet indices that correspond to an entry-level category.

        Returns error if 'category' is not part of the 16 human categories.

        parameters:
        - category: a string, e.g. "dog" or "knife"
        """

        assert category in get_human_object_recognition_categories()

        return getattr(self, category+"_indices")
