"""Common code for merging, ablation, and analysis of HANS lone nmfs."""
import dataclasses
import os
from typing import Optional, Sequence

import numpy as np
import tensorflow as tf
from transformers import PreTrainedTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.fishers import per_example
from em.tools.nmf import nmf_common
from em.util import flat_pack

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.ll import hans_labeling
from em.projects.ll import hans_labeling_analysis as hla
from em.projects.ll import hans_util


# typdefs
SparseTensor = tf.sparse.SparseTensor


@dataclasses.dataclass
class HansLoneModelContext:
    components_context: 'HansLoneComponentContext'

    model_name: str
    pef: per_example.PerExampleFlatFishers
    nmf: nmf_common.SparseNmfDecomposition

    def __post_init__(self):
        self.tokenizer = self.components_context.tokenizer
        self.from_pt = self.components_context.from_pt

        self.model = self.load_model()
        self.container = self._make_container()

        self.n_components = self.nmf.W.shape[-1]
        self.variables = self.model.trainable_variables

    ############################################################

    def sort_example_indices_for_component(self, component_index: int):
        # The sort is in DESCENDING order.
        return np.argsort(-self.nmf.W[:, component_index])

    ############################################################

    def compute_tuning_info(self, selection_params: ncf.SelectionParameters):
        return hla.compute_hans_tuning_info(
            self.container, self.components_context.indicators, selection_params)

    def make_sparse_fisher_vector_for_components(self, component_indices: Sequence[int]) -> SparseTensor:
        component_indices = np.array(list(sorted(component_indices)), dtype=np.int32)
        
        nmf, = self.container.nmfs
        spH = nmf.get_full_sparse_H()
        spH = [spH[i] for i in component_indices]
        
        # W.shape = [n_examples, len(component_indices)]
        W = hla._get_denormalized_coefficients(self.container, 0, component_indices)
        coeffs = W.sum(axis=0)
        
        terms = [tf.cast(c, tf.float32) * h for c, h in zip(coeffs, spH)]
        ret = terms[0]
        for term in terms[1:]:
            ret = tf.sparse.add(ret, term)
        return ret

    def make_fisher_for_components(self, component_indices: Sequence[int]):
        spF = self.make_sparse_fisher_vector_for_components(component_indices)
        packer = flat_pack.FlatPacker([v.shape for v in self.variables])
        return packer.decode_tf(tf.sparse.to_dense(spF))

    ############################################################

    def load_model(self):
        model = TFAutoModelForSequenceClassification.from_pretrained(
            self.model_name, from_pt=self.from_pt)
        model.compile(
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
        )
        return model

    ############################################################

    def _make_container(self):
        container = am.PefNmfAnalysisContainer(
            pef=self.pef,
            nmfs=[self.nmf],
            tokenizer=self.tokenizer,
            shift_labels=False,
        )
        hans_util.fix_up_hans_container(container)
        return container

###############################################################################


@dataclasses.dataclass
class HansLoneComponentContext:
    model_name_pattern: str

    pef_filepath_pattern: str
    nmf_filepath_pattern: str

    with_flipped: bool

    tokenizer: PreTrainedTokenizer

    split: str = "validation"

    from_pt: bool = True

    def __post_init__(self):
        self.examples = self._load_examples()
        self.indicators = self._compute_indicators()

        self.model_contexts = {}

    ############################################################

    def make_model_context(self, /, key: str, **pattern_kwargs):
        if key in self.model_contexts:
            raise ValueError

        pef = per_example.PerExampleFlatFishers.load(
            self.pef_filepath_pattern.format(**pattern_kwargs),
            n_examples=None,
            # This leads to the Fishers not being loaded, which ends up being much faster.
            start_fisher_index=0,
            end_fisher_index=0,
        )

        nmf = nmf_common.SparseNmfDecomposition.load(self.nmf_filepath_pattern.format(**pattern_kwargs))
        nmf.normalize_components_to_unit_norm()

        self.model_contexts[key] = HansLoneModelContext(
            components_context=self,
            pef=pef,
            nmf=nmf,
            model_name=self.model_name_pattern.format(**pattern_kwargs),
        )
        return self.model_contexts[key]

    def get_evaluation_context(
        self,
        sequence_length: int = 64,
        batch_size: int = 128,
        og_logits: Optional[np.ndarray] = None,
    ) -> 'HansLoneEvaluationContext':
        return HansLoneEvaluationContext(
            components_context=self,
            sequence_length=sequence_length,
            batch_size=batch_size,
            og_logits=og_logits,
        )

    ############################################################

    def _load_examples(self):
        if self.with_flipped:
            return hans_util.get_hans_lone_with_flipped_examples(self.split)
        else:
            return hans_util.get_first_hans_examples(
                self.split,
                5000,
                lambda ds: ds.filter(lambda x: x['heuristic'] == 'lexical_overlap' and x['label'] == 1)
            )

    def _compute_indicators(self):
        return hans_labeling.compute_full_indicator(self.examples, 'ne')

###############################################################################


class EvaluationResultsMixIn:
    """
    Requires the attributes:
        - labels
        - logits
        - og_logits
    Attributes that can be set-up:
        - predictions
        - correct_predictions

    """

    def set_up_derived_attributes(self):
        self.predictions = np.argmax(self.logits, axis=-1)
        self.correct_predictions = self.predictions == self.labels

    ############################################################

    def acc(self) -> float:
        return self.correct_predictions.astype(np.float64).mean()

    def acc_for_examples(self, example_indices: Sequence[int]) -> float:
        example_indices = np.array(list(sorted(example_indices)), dtype=np.int32)
        return self.correct_predictions[example_indices].astype(np.float64).mean()

    ############################################################

    def _loss(self, labels, logits) -> float:
        return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True).numpy().mean()

    def loss(self) -> float:
        return self._loss(self.labels, self.logits)

    def loss_for_examples(self, example_indices: Sequence[int]) -> float:
        example_indices = np.array(list(sorted(example_indices)), dtype=np.int32)
        return self._loss(self.labels[example_indices], self.logits[example_indices])

    ############################################################

    def _kl(self, og_logits, logits) -> float:
        assert og_logits is not None
        return tf.keras.losses.kl_divergence(tf.math.softmax(logits), tf.math.softmax(og_logits)).numpy().mean()

    def kl(self) -> float:
        return self._kl(self.og_logits, self.logits)

    def kl_for_examples(self, example_indices: Sequence[int]) -> float:
        example_indices = np.array(list(sorted(example_indices)), dtype=np.int32)
        return self._kl(self.og_logits[example_indices], self.logits[example_indices])


@dataclasses.dataclass
class HansLoneEvaluationResults(EvaluationResultsMixIn):
    components_context: HansLoneComponentContext

    labels: np.ndarray
    logits: np.ndarray

    og_logits: Optional[np.ndarray] = None

    def __post_init__(self):
        self.set_up_derived_attributes()
        self.indicators = self.components_context.indicators


@dataclasses.dataclass
class HansLoneEvaluationContext:
    components_context: HansLoneComponentContext

    sequence_length: int = 64
    batch_size: int = 128

    og_logits: Optional[np.ndarray] = None

    def __post_init__(self):
        self.with_flipped = self.components_context.with_flipped
        self.tokenizer = self.components_context.tokenizer
        self.split = self.components_context.split

        ds_name = 'hans/lexical_overlap_ne'
        if self.with_flipped:
            ds_name += '_with_flipped'
        self.ds = em_datasets.load(ds_name, split=self.split,
                                   sequence_length=self.sequence_length, tokenizer=self.tokenizer)

    ############################################################

    def evaluate(self, model) -> HansLoneEvaluationResults:
        labels, logits = [], []
        for x, y in self.ds.batch(self.batch_size):
            labels.append(y.numpy())
            batch_logits = model(x, training=False).logits.numpy()
            batch_logits = hans_util.fix_up_hans_logits(batch_logits)
            logits.append(batch_logits)

        return HansLoneEvaluationResults(
            labels=np.concatenate(labels, axis=0),
            logits=np.concatenate(logits, axis=0),
            components_context=self.components_context,
            og_logits=self.og_logits,
        )
