"""Information about ImageNet class labels.

Most of this information is taken from ImageNet-X.
https://github.com/facebookresearch/imagenetx/
"""
import csv
import dataclasses
import os
from typing import List, Tuple

import numpy as np
import tensorflow as tf


###############################################################################

# _LABELS_CSV_PATH = os.path.join(__file__, 'assets', 'imagenet_labels.csv')
# _METACLASSES100_CSV_PATH = os.path.join(__file__, 'assets', 'imagenet_x_metaclasses100.csv')


# ###############################################################################

# class _GlobalState:
#     def __init__(self, labels_csv_path: str, metaclasses100_csv_path: str):
#         self.labels_csv_path = labels_csv_path
#         self.metaclasses100_csv_path = metaclasses100_csv_path

#     def _read_in_labels_csv(self):
#         names = []
#         wid_to_name = {}
#         with open(self.labels_csv_path, 'r') as f:   
#             reader = csv.reader(f)
#             for i, (wid, name) in reader:
#                 names.append(name)
#                 wid_to_name[wid] = name


# _state = _GlobalState(
#     labels_csv_path=_LABELS_CSV_PATH,
#     metaclasses100_csv_path=_METACLASSES100_CSV_PATH,
# )


# ###############################################################################

# def label_to_str(label: int):
#     pass
# ###############################################################################

# @dataclasses.dataclass
# class ImageNetClassPrediction:
#     name: str
#     description: str
#     score: float


def logits_to_top_classes(logits: np.ndarray, k: int) -> List[Tuple[str, float]]:
    # logits must be 2d array
    logits = tf.math.softmax(logits, axis=-1).numpy()
    decodes = tf.keras.applications.resnet50.decode_predictions(logits, k)
    return [
        [(description, score) for _, description, score in ex_decode]
        for ex_decode in decodes
    ]


def labels_to_classes(labels: np.ndarray) -> List[str]:
    n_ex = len(labels)
    fake_preds = np.zeros([n_ex, 1000])
    for i in range(n_ex):
        fake_preds[i, labels[i]] = 1
    decodes = tf.keras.applications.resnet50.decode_predictions(fake_preds, 1)
    return [d[0][1] for d in decodes]
