"""Common stuff for BERT selective ablation and similar stuff."""
import dataclasses
from typing import Optional, Sequence

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
from tqdm import tqdm

from em.fishers import diagonal
from em.fishers import per_example
from em.tools.nmf import nmf_common
from em.util import flat_pack

from em.experimental import selective_ablation1


###############################################################################


@dataclasses.dataclass
class SelectiveAblationUtility:
    decomp: nmf_common.NmfDecomposition
    pe_fishers_data: per_example.PerExampleFlatFishers
    batch_fisher: Sequence[tf.Tensor]

    # Subset of original model's variables that will participate in the merge.
    original_variables: Sequence[tf.Variable]

    # Must be compiled with accuracy then cross-entropy metrics.
    output_model: TFAutoModelForSequenceClassification
    output_variables: Sequence[tf.Variable]

    flat_packer: flat_pack.FlatPacker

    tokenizer: AutoTokenizer

    full_eval_ds: tf.data.Dataset

    permutation: Optional[np.ndarray] = None

    eval_batch_size: int = 256

    def __post_init__(self):
        assert len(self.output_variables) == len(self.original_variables)
        assert len(self.batch_fisher) == len(self.original_variables)
        # assert len(self.pe_fishers_data.fishers) == len(self.original_variables)

    def _get_component_variables(self, component_variables):
        if component_variables == 'negative_og':
            return [- v for v in self.original_variables]

        else:
            raise ValueError('Unrecognized value for component_variables')

    def compute_ablation_selectivity_for_component(
        self,
        component_index: int,
        n_component_examples_eval: int,
        coefficients_set,
        fisher_floor=1e-8,
        normalize_fishers=True,
        component_variables='negative_og'
    ):
        component_variables = self._get_component_variables(component_variables)

        # This will take the role of the "fishers" for the component in the merge.
        component_fishers = self.flat_packer.decode_tf(self.decomp.H[component_index])

        # Make the dataset for evaluating the component-specific examples.
        comp_eval_ds = selective_ablation1.make_dataset_for_top_component_examples(
            W=self.decomp.W,
            component=component_index,
            n_examples=n_component_examples_eval,
            tokenizer=self.tokenizer,
            pe_fishers_data=self.pe_fishers_data,
        )
        comp_eval_ds = comp_eval_ds.batch(self.eval_batch_size)

        # Perform the merge over the coefficients set.
        _gen = selective_ablation1.generate_merged_for_coeffs_set(
            output_variables=self.output_variables,
            variables_to_merge=[self.original_variables, component_variables],
            fishers=[self.batch_fisher, component_fishers],
            coefficients_set=coefficients_set,
            fisher_floor=fisher_floor,
            normalize_fishers=normalize_fishers,
        )

        first_coeffs = []
        full_accs, full_losses = [], []
        comp_accs, comp_losses = [], []

        for coefficients in tqdm(_gen, total=len(coefficients_set)):
            _, full_acc, full_loss = self.output_model.evaluate(self.full_eval_ds, verbose=0)
            _, comp_acc, comp_loss = self.output_model.evaluate(comp_eval_ds, verbose=0)
            first_coeffs.append(coefficients[0])
            full_accs.append(full_acc)
            full_losses.append(full_loss)
            comp_accs.append(comp_acc)
            comp_losses.append(comp_loss)

        first_coeffs = np.array(first_coeffs)
        full_accs, full_losses = np.array(full_accs), np.array(full_losses)
        comp_accs, comp_losses = np.array(comp_accs), np.array(comp_losses)

        # # Do the plotting.
        # plt.plot(full_accs)
        # plt.plot(comp_accs)
        # plt.show()

        # plt.plot(full_losses)
        # plt.plot(comp_losses)
        # plt.show()

        return first_coeffs, (full_accs, full_losses), (comp_accs, comp_losses)


def _RENAME_make_plots(
    component_index: int,
    decomp,
    pe_fishers_data,
    tokenizer,
    full_eval_ds: tf.data.Dataset,
    n_component_examples_eval: int,
    permutation: Optional[np.ndarray] = None,
    eval_batch_size: int = 256,
):
    comp_eval_ds = selective_ablation1.make_dataset_for_top_component_examples(
        W=decomp.W,
        component=component_index,
        n_examples=n_component_examples_eval,
        tokenizer=tokenizer,
        pe_fishers_data=pe_fishers_data,
    )
    comp_eval_ds = comp_eval_ds.batch(eval_batch_size)
 
