R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 -i local_scripts/activations/resnet_activations_test01.py

"""
import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
from tensorflow.python.keras.applications import resnet as keras_resnet
from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.keras import layers as keras_layers

from em.models import em_models
from em.util import monkey_patching


# Needed for some reason to prevent BLAS fail to launch.
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

###############################################################################


def plot(x):
    if isinstance(x, list):
        x = tf.concat([tf.reshape(y, [-1]) for y in x], axis=0)
    x = tf.sort(tf.reshape(x, [-1]), direction='DESCENDING')
    plt.plot(x)
    plt.show()


IMAGE_SIZE = (224, 224, 3)

# dummy_batch = tf.random.uniform((4,) + IMAGE_SIZE)
dummy_batch = tf.random.uniform((1,) + IMAGE_SIZE)


v = tf.compat.v1.Variable(tf.zeros([7, 7, 2048]))
# 1 7 7 2048

all_activations = []


def override_fn(og_fn, *args, **kwargs):
    activations = og_fn(*args, **kwargs)
    all_activations.append(activations)
    # activations = tf.compat.v1.Print(activations, [tf.shape(activations)], summarize=100000)
    return activations


mctx = monkey_patching.MonkeyPatcherContext()
# mctx.patch_method(keras_resnet, 'block1', override_fn)
mctx.patch_method(keras_layers.GlobalAveragePooling2D, '__call__', override_fn)

with mctx:
    model = em_models.from_pretrained("resnet:resnet50_imagenet")
    # model(dummy_batch)


model2 = keras_training.Model(model.input, {'activations': all_activations, 'logits': model.output}, name='activations_model')

q = model2(dummy_batch)

# plot(q['activations'][4])

# [802816, 802816, 802816, 401408, 401408, 401408, 401408, 200704, 200704, 200704, 200704, 200704, 200704, 100352, 100352, 100352]

# 802_816
# 401_408
# 200_704
# 100_352
#
# 5_519_360 per example

# 64: 353_239_040
# 128: 706_478_080
