"""Common code for merging, ablation, and analysis of HANS lone nmfs."""
import dataclasses
import os
from typing import Any, Dict, Optional, Sequence, Tuple

import numpy as np
import tensorflow as tf
from transformers import PreTrainedTokenizer, TFAutoModelForSequenceClassification, TFPreTrainedModel

from em import datasets as em_datasets
from em.datasets import common_processing
from em.fishers import per_example
from em.models import em_models
from em.tools.nmf import nmf_common
from em.util import flat_pack

from em.projects.anli import anli_misc1 as am
from em.projects.ll import hans_labeling_analysis as hla
from em.projects.ll import hans_util
from em.projects.wino import nmf_components_fisher as ncf


# typdefs
SparseTensor = tf.sparse.SparseTensor

###############################################################################

SPECIAL_PROCESSING_TYPES = ('HANS', 'HF_MNLI')

###############################################################################


@dataclasses.dataclass
class QqpComponentContext:
    model_name_pattern: str

    pef_filepath_pattern: str
    nmf_filepath_pattern: str

    tokenizer: PreTrainedTokenizer

    from_pt: bool = True

    special_processing: Optional[str] = None

    def __post_init__(self):
        assert self.special_processing is None or self.special_processing in SPECIAL_PROCESSING_TYPES

        self.model_contexts = {}

    ############################################################

    @property
    def n_classes(self):
        if self.special_processing == 'HF_MNLI':
            return 3
        else:
            return 2

    def make_model_context(self, /, key: str, **pattern_kwargs) -> 'QqpModelContext':
        if key in self.model_contexts:
            raise ValueError

        pef = per_example.PerExampleFlatFishers.load(
            self.pef_filepath_pattern.format(**pattern_kwargs),
            n_examples=None,
            # This leads to the Fishers not being loaded, which ends up being much faster.
            start_fisher_index=0,
            end_fisher_index=0,
        )

        if self.special_processing is None and pef.predicted_logits.shape[-1] == 2:
            pef.predicted_logits = np.concatenate([
                pef.predicted_logits,
                -1e9 * np.ones_like(pef.predicted_logits[:, :1])
            ], axis=-1)

        nmf = nmf_common.SparseNmfDecomposition.load(self.nmf_filepath_pattern.format(**pattern_kwargs))
        nmf.normalize_components_to_unit_norm()

        # TODO: Allow us to specify othe subsets of example indices.
        if pef.input_ids.shape[0] > nmf.W.shape[0]:
            pef = pef.create_for_subset(list(range(nmf.W.shape[0])))

        self.model_contexts[key] = QqpModelContext(
            components_context=self,
            pef=pef,
            nmf=nmf,
            model_name=self.model_name_pattern.format(**pattern_kwargs),
        )
        return self.model_contexts[key]

    
###############################################################################


@dataclasses.dataclass
class QqpModelContext:
    components_context: 'QqpComponentContext'

    model_name: str
    pef: per_example.PerExampleFlatFishers
    nmf: nmf_common.SparseNmfDecomposition

    def __post_init__(self):
        self.special_processing = self.components_context.special_processing

        self.tokenizer = self.components_context.tokenizer
        self.from_pt = self.components_context.from_pt

        self.model = self.load_model()
        self.container = self._make_container()

        self.n_components = self.nmf.W.shape[-1]
        self.variables = self.model.trainable_variables

    @property
    def n_classes(self):
        return self.components_context.n_classes

    ############################################################

    def sort_example_indices_for_component(self, component_index: int):
        # The sort is in DESCENDING order.
        return np.argsort(-self.nmf.W[:, component_index])

    ############################################################

    def make_sparse_fisher_vector_for_components(self, component_indices: Sequence[int]) -> SparseTensor:
        component_indices = np.array(list(sorted(component_indices)), dtype=np.int32)
        
        nmf, = self.container.nmfs
        spH = nmf.get_full_sparse_H()
        spH = [spH[i] for i in component_indices]
        
        # W.shape = [n_examples, len(component_indices)]
        W = hla._get_denormalized_coefficients(self.container, 0, component_indices)
        coeffs = W.sum(axis=0)
        
        terms = [tf.cast(c, tf.float32) * h for c, h in zip(coeffs, spH)]
        ret = terms[0]
        for term in terms[1:]:
            ret = tf.sparse.add(ret, term)
        return ret

    def make_fisher_for_components(self, component_indices: Sequence[int]):
        spF = self.make_sparse_fisher_vector_for_components(component_indices)
        packer = flat_pack.FlatPacker([v.shape for v in self.variables])
        return packer.decode_tf(tf.sparse.to_dense(spF))

    ############################################################

    def load_model(self):
        model = TFAutoModelForSequenceClassification.from_pretrained(
            self.model_name, from_pt=self.from_pt)
        model.compile(
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
        )
        return model

    ############################################################

    def _make_container(self):
        container = am.PefNmfAnalysisContainer(
            pef=self.pef,
            nmfs=[self.nmf],
            tokenizer=self.tokenizer,
            shift_labels=self.special_processing == 'HF_MNLI',
        )
        if self.special_processing == 'HANS':
            hans_util.fix_up_hans_container(container)
        container.predicted_logits = container.predicted_logits[:, :self.n_classes]
        self.pef.labels = container.labels
        return container

    ############################################################

    def get_evaluation_context(
        self,
        batch_size: int = 128,
    ) -> 'QqpEvaluationContext':
        return QqpEvaluationContext(
            mc=self,
            batch_size=batch_size,
        )


###############################################################################


class EvaluationResultsMixIn:
    """
    Requires the attributes:
        - labels
        - logits
        - og_logits
    Attributes that can be set-up:
        - predictions
        - correct_predictions

    """

    def set_up_derived_attributes(self):
        self.predictions = np.argmax(self.logits, axis=-1)
        self.correct_predictions = self.predictions == self.labels

    ############################################################

    def acc(self) -> float:
        return self.correct_predictions.astype(np.float64).mean()

    def acc_for_examples(self, example_indices: Sequence[int]) -> float:
        example_indices = np.array(list(sorted(example_indices)), dtype=np.int32)
        return self.correct_predictions[example_indices].astype(np.float64).mean()

    ############################################################

    def _loss(self, labels, logits) -> float:
        return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True).numpy().mean()

    def loss(self) -> float:
        return self._loss(self.labels, self.logits)

    def loss_for_examples(self, example_indices: Sequence[int]) -> float:
        example_indices = np.array(list(sorted(example_indices)), dtype=np.int32)
        return self._loss(self.labels[example_indices], self.logits[example_indices])

    ############################################################

    def _kl(self, og_logits, logits) -> float:
        assert og_logits is not None
        return tf.keras.losses.kl_divergence(tf.math.softmax(logits), tf.math.softmax(og_logits)).numpy().mean()

    def kl(self) -> float:
        return self._kl(self.og_logits, self.logits)

    def kl_for_examples(self, example_indices: Sequence[int]) -> float:
        example_indices = np.array(list(sorted(example_indices)), dtype=np.int32)
        return self._kl(self.og_logits[example_indices], self.logits[example_indices])

    ############################################################

    def indices_of_altered_predictions(self) -> Sequence[int]:
        # TODO: Maybe also add option to order by per-example KL instead of changd prediction
        og_predictions = np.argmax(self.og_logits, axis=-1)
        return np.where(self.predictions != og_predictions)[0]

    def indices_ordered_by_kl(self) -> Sequence[int]:
        # Indices are ordered in ascending order of KL-divergence.
        kls = tf.keras.losses.kl_divergence(tf.math.softmax(self.logits), tf.math.softmax(self.og_logits)).numpy()
        return np.argsort(-kls)


@dataclasses.dataclass
class QqpEvaluationResults(EvaluationResultsMixIn):
    labels: np.ndarray
    logits: np.ndarray

    og_logits: Optional[np.ndarray] = None

    def __post_init__(self):
        self.set_up_derived_attributes()


@dataclasses.dataclass
class QqpEvaluationContext:
    mc: QqpModelContext
    batch_size: int = 128

    def __post_init__(self):
        self.special_processing = self.mc.special_processing
        self.tokenizer = self.mc.tokenizer
        self.og_logits = self.mc.container.predicted_logits
        self.all_examples = self.mc.pef.get_full_examples(self.tokenizer, trim=True)

    ############################################################

    def get_ds(self, example_indices: Optional[np.ndarray] = None) -> tf.data.Dataset:
        if example_indices is None:
            return tf.data.Dataset.from_tensor_slices(self.all_examples)
        return tf.data.Dataset.from_tensor_slices(
            slice_examples_dict(self.all_examples, example_indices))

    def evaluate(self, model, example_indices: Optional[np.ndarray] = None) -> QqpEvaluationResults:
        labels, logits = [], []
        for x, y in self.get_ds(example_indices).batch(self.batch_size):
            labels.append(y.numpy())
            batch_logits = em_models.compute_logits(model, x, training=False).numpy()
            if self.special_processing == 'HANS':
                batch_logits = hans_util.fix_up_hans_logits(batch_logits)
            logits.append(batch_logits)

        return QqpEvaluationResults(
            labels=np.concatenate(labels, axis=0),
            logits=np.concatenate(logits, axis=0)[:, :self.og_logits.shape[-1]],
            og_logits=self.og_logits[example_indices] if example_indices is not None else self.og_logits,
        )


@dataclasses.dataclass
class EvaluationContext2:
    all_examples: Tuple[Dict[str, np.ndarray], np.ndarray]
    og_logits: np.ndarray
    batch_size: int = 128

    special_processing: Optional[str] = None

    def __post_init__(self):
        assert self.special_processing is None or self.special_processing in SPECIAL_PROCESSING_TYPES

    def get_ds(self, example_indices: Optional[np.ndarray] = None) -> tf.data.Dataset:
        if example_indices is None:
            return tf.data.Dataset.from_tensor_slices(self.all_examples)
        return tf.data.Dataset.from_tensor_slices(
            slice_examples_dict(self.all_examples, example_indices))

    def evaluate(self, model, example_indices: Optional[np.ndarray] = None) -> QqpEvaluationResults:
        labels, logits = [], []
        for x, y in self.get_ds(example_indices).batch(self.batch_size):
            labels.append(y.numpy())
            batch_logits = em_models.compute_logits(model, x, training=False).numpy()
            if self.special_processing == 'HANS':
                batch_logits = hans_util.fix_up_hans_logits(batch_logits)
            logits.append(batch_logits)

        return QqpEvaluationResults(
            labels=np.concatenate(labels, axis=0),
            logits=np.concatenate(logits, axis=0)[:, :self.og_logits.shape[-1]],
            og_logits=self.og_logits[example_indices] if example_indices is not None else self.og_logits,
        )

    @classmethod
    def create_from_ds(cls, ds: tf.data.Dataset, model: TFPreTrainedModel, batch_size: int = 128, special_processing: Optional[str] = None):
        # The ds should NOT be batched. It should also be finite.
        n_examples = len(ds)

        for all_examples in ds.batch(n_examples).as_numpy_iterator():
            break

        logits = []
        for x, _ in ds.batch(batch_size):
            batch_logits = em_models.compute_logits(model, x, training=False).numpy()
            if special_processing == 'HANS':
                batch_logits = hans_util.fix_up_hans_logits(batch_logits)
            logits.append(batch_logits)

        return cls(
            batch_size=batch_size,
            all_examples=all_examples,
            og_logits=np.concatenate(logits, axis=0),
            special_processing=special_processing,
        )

    @classmethod
    def create_from_pefs(cls, pef: per_example.PerExampleFlatFishers, tokenizer: PreTrainedTokenizer, **kwargs):
        return cls(
            all_examples=pef.get_full_examples(tokenizer, trim=True),
            og_logits=pef.predicted_logits,
            **kwargs,
        )

    @classmethod
    def create_from_ds_and_logits(cls, ds: tf.data.Dataset, logits: np.ndarray, **kwargs):
        # The ds should NOT be batched.
        special_processing = kwargs.get('special_processing', None)
        if special_processing is not None:
            raise NotImplementedError('TODO: Support special_processing, might not require any changes.')

        n_examples = logits.shape[0]
        for all_examples in ds.take(n_examples).batch(n_examples).as_numpy_iterator():
            break
        return cls(
            all_examples=all_examples,
            og_logits=logits,
            **kwargs,
        )


def slice_examples_dict(examples: Tuple[Dict[str, np.ndarray], np.ndarray], example_indices: np.ndarray):
    examples, labels = examples
    if isinstance(examples, np.ndarray):
        examples_slice = examples[example_indices]
    else:
        examples_slice = {
            k: v[example_indices]
            for k, v in examples.items()
        }
    return examples_slice, labels[example_indices]

