"""TCAV for ResNet models."""
import dataclasses
import random
from typing import List, Optional

import numpy as np
from sklearn.linear_model import LogisticRegression
import tensorflow as tf

from em.projects.pi import qqp_components_context as QCC

from em.models import em_models
from em.util import monkey_patching


from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.keras import layers as keras_layers


###############################################################################


def create_activations_to_logits_model(model_ri: str):
    activations = None

    def override_fn(og_fn, *args, **kwargs):
        nonlocal activations
        ret = og_fn(*args, **kwargs)
        activations = ret
        return ret

    mctx = monkey_patching.MonkeyPatcherContext()
    mctx.patch_method(keras_layers.GlobalAveragePooling2D, '__call__', override_fn)

    with mctx:
        og_model = em_models.from_pretrained(model_ri)

    return keras_training.Model(
        activations,
        og_model.output,
        name='resnet_activations_to_logits_model')

###############################################################################


@dataclasses.dataclass
class LogisticRegressionParams:
    max_iter: int = 1000


@dataclasses.dataclass
class ResnetTcavExperiment:
    # shape = [n_examples, d_activations]
    activations: np.ndarray

    # shape = [n_examples]
    labels: np.ndarray
    
    # Must be an activations_to_logits_model.
    model: tf.keras.Model

    n_negative_examples: int

    n_scoring_examples: int

    logistic_regression_params: LogisticRegressionParams = None

    ############################################################
    # I think needed to be hashable to work with tf.function
    
    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other

    ############################################################

    def __post_init__(self):
        self.n_examples = self.activations.shape[0]
        if self.logistic_regression_params is None:
            self.logistic_regression_params = LogisticRegressionParams()

        # Gradients of log probs with respect to activations for each example.
        # shape = [n_examples, d_activations, n_classes]
        self.log_prob_gradients = self._compute_log_prob_gradients()

    @tf.function
    def _compute_log_softmax_wrt_activations(self, example_activations):
        # Example activations should NOT have a dummy batch dimension.
        x = example_activations[None, :]
        x = self.model(x, training=False)
        x = tf.squeeze(x, axis=0)

        # NOTE: Unlike the original TCAV paper, we use the log_softmax instead of the logits.
        return tf.math.log_softmax(x)

    @tf.function
    def _compute_log_prob_gradient_for_example(self, example_activations):
        lp_grads = []

        with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape:
            tape.watch(example_activations)

            log_probs = self._compute_log_softmax_wrt_activations(example_activations)

            for i in range(self.model.num_labels):
                log_prob = log_probs[i]
                with tape.stop_recording():
                    grad = tape.gradient(log_prob, example_activations)
                    lp_grads.append(grad)

        return tf.stack(lp_grads, axis=-1)

    def _compute_log_prob_gradients(self):
        grads = [
            self._compute_log_prob_gradient_for_example(tf.cast(x, tf.float32)).numpy()
            for x in self.activations[:self.n_scoring_examples]
        ]
        return np.stack(grads, axis=0)
