R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i local_scripts/imagenet/resnet_explore_01.py

"""
import numpy as np

import tensorflow as tf
import tensorflow_datasets as tfds

from em import datasets as em_datasets


ds = tfds.load('imagenet2012', split='train')
for entry in ds:
    # list(entry.keys()) = ['file_name', 'image', 'label']
    break

x = tf.cast(entry['image'], tf.float32)
y = entry['label']

###########################################################

resnet = tf.keras.applications.resnet50.ResNet50(
    include_top=True,
    weights='imagenet',
    classes=1000,
    classifier_activation=None,
)
# # The logger stuff is to prevent a lot of annoying error messages.
# old_level = tf.get_logger().getEffectiveLevel()
# tf.get_logger().setLevel('ERROR')
# resnet(tf.zeros([1, 224, 224, 3]))
# tf.get_logger().setLevel(old_level)

###########################################################
# NOTE: produces a lot of annoying warnings when running resnet first time.

# tf.keras.applications.resnet50.preprocess_input
x2 = x
# antialias?
x2 = tf.image.resize(x2, [224, 224])
x2 = tf.keras.applications.resnet50.preprocess_input(x2)
x2 = x2[None, ...]

###########################################################

with tf.GradientTape(watch_accessed_variables=False) as tape:
    tape.watch(resnet.trainable_variables)
    logits = resnet(x2)
    blah = logits[0]

grads = tape.gradient(blah, resnet.trainable_variables)

classes_order = np.argsort(-logits[0].numpy())

probs = tf.math.softmax(logits)
sorted_probs = np.sort(probs.numpy())

###########################################################

resnet.summary()

###########################################################

ds = em_datasets.load('imagenet/resnet', split='train', tokenizer=None, sequence_length=224)
