"""TCAV for BERT models."""
import dataclasses
import random
import os
from typing import List, Optional

import h5py
import numpy as np
from sklearn.linear_model import LogisticRegression
import tensorflow as tf
from transformers import PreTrainedTokenizer, TFBertForSequenceClassification


from em.projects.pi import qqp_components_context as QCC
from em.util import hdf5_util

###############################################################################

# Steps:
# - learn CAV
# - directional derivative thing


@dataclasses.dataclass
class LogisticRegressionParams:
    max_iter: int = 1000


@dataclasses.dataclass
class BertTcavExperiment:
    # shape = [n_examples, d_activations (should be 768)]
    activations: np.ndarray

    # shape = [n_examples]
    labels: np.ndarray
    
    model: TFBertForSequenceClassification

    n_negative_examples: int

    n_scoring_examples: int

    logistic_regression_params: LogisticRegressionParams = None

    ############################################################
    # I think needed to be hashable to work with tf.function
    
    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other

    ############################################################

    def __post_init__(self):
        self.n_examples = self.activations.shape[0]
        if self.logistic_regression_params is None:
            self.logistic_regression_params = LogisticRegressionParams()

        # Gradients of log probs with respect to activations for each example.
        # shape = [n_examples, d_activations, n_classes]
        self.log_prob_gradients = self._compute_log_prob_gradients()

    @tf.function
    def _compute_log_softmax_wrt_activations(self, example_activations):
        # Example activations should NOT have a dummy batch dimension.
        x = example_activations[None, :]

        x = self.model.bert.pooler.dense(x, training=False)
        x = self.model.classifier(x, training=False)

        x = tf.squeeze(x, axis=0)

        # NOTE: Unlike the original TCAV paper, we use the log_softmax instead of the logits.
        return tf.math.log_softmax(x)

    @tf.function
    def _compute_log_prob_gradient_for_example(self, example_activations):
        lp_grads = []

        with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape:
            tape.watch(example_activations)

            log_probs = self._compute_log_softmax_wrt_activations(example_activations)

            for i in range(self.model.num_labels):
                log_prob = log_probs[i]
                with tape.stop_recording():
                    grad = tape.gradient(log_prob, example_activations)
                    lp_grads.append(grad)

        return tf.stack(lp_grads, axis=-1)

    def _compute_log_prob_gradients(self):
        grads = [
            self._compute_log_prob_gradient_for_example(tf.cast(x, tf.float32)).numpy()
            for x in self.activations[:self.n_scoring_examples]
        ]
        return np.stack(grads, axis=0)

###############################################################################


@dataclasses.dataclass
class TcavScores:
    labels: np.ndarray
    scores: np.ndarray


# @dataclasses.dataclass
# class BertTcavForComponent:
#     exp: BertTcavExperiment

#     concept_example_indices: np.ndarray

#     def __post_init__(self):
#         self.logistic_regression_params = self.exp.logistic_regression_params
#         self.n_scoring_examples = self.exp.n_scoring_examples
#         self.activations = self.exp.activations
#         self.labels = self.exp.labels
#         self.model = self.exp.model

#         self.n_concept_examples = self.concept_example_indices.shape[0]

#         self.negative_example_indices = self._make_negative_example_indices()

#         self.cav = tf.Variable(tf.zeros([self.exp.activations.shape[-1]], dtype=tf.float32))

#     def _make_negative_example_indices(self):
#         choices = set(range(self.exp.n_examples)) - set(self.concept_example_indices)
#         return np.array(random.sample(choices, self.exp.n_negative_examples), dtype=np.int32)

#     ############################################################
#     # I think needed to be hashable to work with tf.function
    
#     def __hash__(self):
#         return id(self)

#     def __eq__(self, other):
#         return self is other

#     ############################################################

#     def learn_cav(self):
#         all_ex_inds = np.concatenate([self.concept_example_indices, self.negative_example_indices], axis=0)

#         labels = np.zeros([all_ex_inds.shape[0]], dtype=np.int32)
#         labels[:self.concept_example_indices.shape[0]] = 1

#         examples = self.activations[all_ex_inds]

#         clf = LogisticRegression(max_iter=self.logistic_regression_params.max_iter, fit_intercept=False).fit(examples, labels)
#         self.cav.assign(clf.coef_.reshape([-1]))

#     @tf.function
#     def _compute_log_softmax_wrt_activations(self, example_activations):
#         # Example activations should NOT have a dummy batch dimension.
#         x = example_activations[None, :]

#         x = self.model.bert.pooler.dense(x, training=False)
#         x = self.model.classifier(x, training=False)

#         x = tf.squeeze(x, axis=0)

#         # NOTE: Unlike the original TCAV paper, we use the log_softmax instead of the logits.
#         return tf.math.log_softmax(x)

#     @tf.function
#     def _compute_score_for_example(self, example_activations):
#         scores = []

#         with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape:
#             tape.watch(example_activations)

#             log_probs = self._compute_log_softmax_wrt_activations(example_activations)

#             for i in range(self.model.num_labels):
#                 log_prob = log_probs[i]
#                 with tape.stop_recording():
#                     grad = tape.gradient(log_prob, example_activations)
#                     scores.append(tf.einsum('i,i->', grad, self.cav))

#         return tf.stack(scores, axis=-1)

#     def compute_per_example_scores(self):
#         labels = self.labels[:self.n_scoring_examples]

#         scores = []
#         for i in range(self.n_scoring_examples):
#             scores.append(self._compute_score_for_example(tf.cast(self.activations[i], tf.float32)).numpy())

#         scores = np.stack(scores, axis=0)

#         return TcavScores(scores=scores, labels=labels)


###############################################################################


@dataclasses.dataclass
class BertTcavForComponent2:
    exp: BertTcavExperiment

    concept_example_indices: np.ndarray

    n_runs: int

    def __post_init__(self):
        self.logistic_regression_params = self.exp.logistic_regression_params
        self.n_scoring_examples = self.exp.n_scoring_examples
        self.activations = self.exp.activations
        self.labels = self.exp.labels
        self.model = self.exp.model

        self.n_concept_examples = self.concept_example_indices.shape[0]

        self.log_prob_gradients = self.exp.log_prob_gradients

        self.negative_example_indices = None
        self.cavs = None

    def _make_negative_example_indices(self):
        choices = set(range(self.exp.n_examples)) - set(self.concept_example_indices)
        return np.array(random.sample(choices, self.exp.n_negative_examples), dtype=np.int32)

    def _learn_cav(self, run_index: int):
        negative_example_indices = self.negative_example_indices[run_index]
        all_ex_inds = np.concatenate([self.concept_example_indices, negative_example_indices], axis=0)

        labels = np.zeros([all_ex_inds.shape[0]], dtype=np.int32)
        labels[:self.concept_example_indices.shape[0]] = 1

        examples = self.activations[all_ex_inds]

        clf = LogisticRegression(
            max_iter=self.logistic_regression_params.max_iter, fit_intercept=False).fit(examples, labels)
        return clf.coef_.reshape([-1])

    def learn_cavs(self):
        self.negative_example_indices = np.stack([
            self._make_negative_example_indices()
            for _ in range(self.n_runs)
        ], axis=0)

        self.cavs = np.stack([
            self._learn_cav(i)
            for i in range(self.n_runs)
        ], axis=0)

    def compute_per_run_scores(self):
        # Returned array has shape [n_runs, n_classes].
        scores = []
        for i in range(self.n_runs):
            cav = self.cavs[i]
            lp_grads = self.log_prob_gradients[:self.n_scoring_examples]
            labels = self.labels[:self.n_scoring_examples]

            per_ex_scores = np.einsum('ijk,j->ik', lp_grads, cav)

            run_ret = []
            for label in range(self.model.num_labels):
                mask = labels == label
                label_scores = per_ex_scores[mask, label]
                run_ret.append(float((label_scores > 0).sum()) / float(label_scores.shape[0]))

            scores.append(run_ret)

        return np.array(scores, dtype=np.float32)

###############################################################################


def load_run_scores(filepath: str):
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        component_indices = hdf5_util.load_h5_ds(f['data/component_indices'])
        return {
            i: hdf5_util.load_h5_ds(f[f'data/component_{i}_scores'])
            for i in component_indices
        }