"""Stuff to put common code in for the ablation experiments."""
import dataclasses
import random
import time
from typing import Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
import tensorflow as tf
from transformers import TFPreTrainedModel

from em import datasets as em_datasets
from em.fishers import diagonal
from em.merging import merging
from em.util import flat_pack
from em.util import hf_util

from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC


class SoftmaxKlDivergenceLoss(tf.keras.losses.KLDivergence):

    def call(self, y_true, y_pred):
        return super().call(
            tf.math.softmax(y_true, axis=-1),
            tf.math.softmax(y_pred, axis=-1),
        )


@dataclasses.dataclass
class Experiment1:
    mc: QCC.QqpModelContext
    retaining_fisher: Sequence[tf.Tensor]

    def __post_init__(self):
        self.special_processing = self.mc.special_processing

        self.tokenizer = self.mc.tokenizer
        self.sign_guider = QMC.SignGuider(mc=self.mc)

        self.pef = self.mc.pef
        self.nmf = self.mc.nmf
        self.W = self.nmf.W

        self.retaining_variables = list(self.mc.variables)

        self.output_model = self.mc.load_model()
        self.output_variables = hf_util.get_all_variables(self.output_model)

        self.predicted_logits = self.mc.container.predicted_logits
        self.predictions = np.argmax(self.predicted_logits, axis=-1)
        self.labels = self.mc.container.labels

        self.corrects_indicator = self.predictions == self.labels
        self.corrects_indices, = np.nonzero(self.corrects_indicator)

        self.incorrects_indicator = self.predictions != self.labels
        self.incorrects_indices, = np.nonzero(self.incorrects_indicator)

        self._dummy_ablating_fisher = None

    #################################################################

    def create_eval_ctx(
        self,
        task_ri: str,
        sequence_length: int,
        n_examples=None,
        split='validation',
    ):
        ds = em_datasets.load(
            task_ri,
            split=split,
            sequence_length=sequence_length,
            tokenizer=self.mc.tokenizer,
        )
        if n_examples is not None:
            ds = ds.take(n_examples)

        if self.special_processing == 'HF_MNLI':
            ds = em_datasets.glue.fix_text_attack_mnli_labeling(ds)

        return QCC.EvaluationContext2.create_from_ds(
            ds=ds.cache(),
            model=self.mc.model,
            special_processing=self.special_processing,
        )

    #################################################################

    def get_top_example_indices(self, component_index: int) -> np.ndarray:
        return self.mc.sort_example_indices_for_component(component_index)

    def apply_sign_guide(self, sign_guide, delta: float):
        return self.sign_guider.apply_sign_guide(self.retaining_variables, sign_guide, delta)

    def apply_gradient(self, gradient, delta: float):
        return self.sign_guider.apply_gradient(self.retaining_variables, gradient, delta)

    #################################################################

    def get_dummy_ablating_fisher(self):
        if self._dummy_ablating_fisher is None:
            self._dummy_ablating_fisher = [tf.ones_like(f) for f in self.retaining_fisher]
        return self._dummy_ablating_fisher

    def remove_correct_example_indices(self, example_indices: Sequence[int]) -> np.ndarray:
        incorrects_indices = set(self.incorrects_indices)
        return np.array([i for i in example_indices if i in incorrects_indices], dtype=np.int32)

    def get_top_incorrect_example_indices(self, component_index: int) -> np.ndarray:
        inds = self.get_top_example_indices(component_index)
        return self.remove_correct_example_indices(inds)

    #################################################################

    def get_components_order_by_coeff_magnitude(self, n_examples: int):
        # Returns component indices in descending order of top coeffs magnitude.
        return np.argsort(np.sort(-self.W, axis=0)[:n_examples].sum(axis=0))

    #################################################################

    def random_example_indices(self, n: int):
        return np.random.permutation(np.arange(self.labels.shape[0]))[:n]

    def random_example_indices_by_correctness(self, n_correct: int, n_incorrect: int):
        corrects = np.random.permutation(np.nonzero(self.corrects_indicator)[0])[:n_correct]
        incorrects = np.random.permutation(np.nonzero(self.incorrects_indicator)[0])[:n_incorrect]
        return np.random.permutation(np.concatenate([corrects, incorrects], axis=0))

    #################################################################

    def compute_loss_gradient(
        self,
        example_indices: Sequence[int],
        coefficients: Optional[Sequence[float]] = None,
        normalize_by_example: bool = False,
        batch_size: int = 32,
    ):
        example_indices = np.array(example_indices, dtype=np.int32)

        coeff_ds = None
        if coefficients is not None:
            coefficients = np.array(coefficients, dtype=np.float32)
            assert coefficients.shape == example_indices.shape
            coeff_ds = tf.data.Dataset.from_tensor_slices(coefficients).batch(batch_size)

        ds = self.sign_guider.get_ds_for_examples(example_indices).batch(batch_size)

        return self.sign_guider.compute_loss_gradient(
            ds,
            coeff_ds,
            normalize_by_example=normalize_by_example,
        )

    def compute_kl_gradient(
        self,
        example_indices: Sequence[int],
        coefficients: Optional[Sequence[float]] = None,
        batch_size: int = 32,
        *,
        allow_recompile: bool = False
    ):
        if not isinstance(self.mc.model.loss, SoftmaxKlDivergenceLoss):
            if not allow_recompile:
                raise ValueError
            self.mc.model.compile(
                loss=SoftmaxKlDivergenceLoss(),
                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
            )

        # Returns the gradient of the KL-divergence with the original predictions.
        example_indices = np.array(example_indices, dtype=np.int32)

        coeff_ds = None
        if coefficients is not None:
            coefficients = np.array(coefficients, dtype=np.float32)
            assert coefficients.shape == example_indices.shape
            coeff_ds = tf.data.Dataset.from_tensor_slices(coefficients).batch(batch_size)

        ds = self.sign_guider.get_ds_for_examples(example_indices)

        og_logits = self.predicted_logits[example_indices]
        ds = tf.data.Dataset.zip((ds, tf.data.Dataset.from_tensor_slices(og_logits)))
        ds = ds.map(lambda batch, logits: (batch[0], logits))
        ds = ds.batch(batch_size)

        return self.sign_guider.compute_loss_gradient(
            ds,
            coeff_ds,
        )
        # - Make ds have og_logits as labels.
        # - Recompile the model to have kl-divergence as its loss [or maybe check that is has been recompiled]

    def compute_fisher(
        self,
        example_indices: Sequence[int],
        coefficients: Optional[Sequence[float]] = None,
        normalize_by_example: bool = False,
        batch_size: int = 4,
    ):
        example_indices = np.array(example_indices, dtype=np.int32)

        coeff_ds = None
        if coefficients is not None:
            coefficients = np.array(coefficients, dtype=np.float32)
            assert coefficients.shape == example_indices.shape
            coeff_ds = tf.data.Dataset.from_tensor_slices(coefficients).batch(batch_size)

        ds = self.sign_guider.get_ds_for_examples(example_indices).batch(batch_size)

        return self.sign_guider.compute_fisher(
            ds,
            coeff_ds,
            normalize_by_example=normalize_by_example,
        )

    def l2_normalize(self, gradients: Sequence[tf.Tensor]) -> Sequence[tf.Tensor]:
        norm = merging._l2_norm_of_fisher(gradients)
        return [g / norm for g in gradients]

    #################################################################
    
    def stream_merge(
        self,
        ablating_variables: Sequence[tf.Tensor],
        ablating_fisher: Sequence[tf.Tensor],
        coefficients: Union[int, Sequence[Tuple[float, float]]],
        fisher_floor: float = 1e-9,
    ):
        variables_to_merge = [self.retaining_variables, ablating_variables]
        fishers_to_merge = [self.retaining_fisher, ablating_fisher]

        norm_constants = [merging._l2_norm_of_fisher(f) for f in fishers_to_merge]

        if isinstance(coefficients, int):
            coefficients = merging.create_pairwise_grid_coeffs(coefficients)

        for c in coefficients:
            merging._merge_with_coeffs(
                self.output_variables,
                variables_to_merge,
                coefficients=c,
                fishers=fishers_to_merge,
                fisher_floor=fisher_floor,
                favor_target_model=True,
                normalization_constants=norm_constants,
            )
            yield c

    def create_model(self, ablating_fisher, sign_guide, delta: float, lmbda: float):
        ablating_variables = self.apply_sign_guide(sign_guide, delta)
        gen = self.stream_merge(
            ablating_variables=ablating_variables,
            ablating_fisher=ablating_fisher,
            coefficients=[(1 - lmbda, lmbda)],
        )
        for coefficients in gen:
            return self.output_model


##########################################################################


@dataclasses.dataclass
class KlRangeTargeter:

    exp: Experiment1
    sign_guide: Sequence[tf.Tensor]

    kl_fn: Callable[[TFPreTrainedModel], float]

    kl_range: Tuple[float, float]
    delta_mag_range: Tuple[float, float]
    lmbda_working_range: Tuple[float, float] = (0.035, 1 - 0.035)

    backoff_factor: float = 2.0
    backoff_attempts: int = 10

    ablate_top_k_params: Optional[int] = None

    def __post_init__(self):
        assert len(self.kl_range) == 2
        assert 0 < self.min_kl < self.max_kl

        assert len(self.delta_mag_range) == 2
        assert 0 < self.min_delta_mag < self.max_delta_mag

        assert len(self.lmbda_working_range) == 2
        assert 0 < self.min_lmbda < self.max_lmbda < 1

        assert self.backoff_factor > 1

        self._last_delta = None
        self._last_lmbda = None

    #################################################################

    @property
    def min_kl(self):
        return self.kl_range[0]

    @property
    def max_kl(self):
        return self.kl_range[1]

    @property
    def min_delta_mag(self):
        return self.delta_mag_range[0]

    @property
    def max_delta_mag(self):
        return self.delta_mag_range[1]

    @property
    def min_lmbda(self):
        return self.lmbda_working_range[0]

    @property
    def max_lmbda(self):
        return self.lmbda_working_range[1]

    #################################################################

    def _get_random_delta(self):
        # Log uniform distribution.
        log_delta = random.uniform(np.log(self.min_delta_mag), np.log(self.max_delta_mag))
        return np.exp(log_delta)

    def _get_random_lmbda(self):
        # Uniform distribution
        return random.uniform(self.min_lmbda, self.max_lmbda)

    #################################################################

    def _evaluate(self, ablating_fisher, delta, lmbda):
        self._last_delta = delta
        self._last_lmbda = lmbda

        ablating_variables = self.exp.apply_sign_guide(self.sign_guide, delta)
        gen = self.exp.stream_merge(
            ablating_variables=ablating_variables,
            ablating_fisher=ablating_fisher,
            coefficients=[(1 - lmbda, lmbda)],
        )
        for coefficients in gen:
            kl = self.kl_fn(self.exp.output_model)
            break

        if self.min_kl <= kl <= self.max_kl:
            return (0, kl)
        elif kl < self.min_kl:
            return (-1, kl)
        elif self.max_kl < kl:
            return (1, kl)
        else:
            raise ValueError('This condition should not be reachable.')

    def _kl_step_coeffs_gen(self, delta, lmbda, condition, i):
        assert condition != 0

        # if random.random() < 0.5:
        if i % 2:
            # Do delta.
            og_log_delta = np.log(delta)

            if condition < 0:
                log_diff = (np.log(self.max_delta_mag) - og_log_delta) / 2
            else:
                log_diff = (np.log(self.min_delta_mag) - og_log_delta) / 2

            for i in range(self.backoff_attempts):
                new_delta = np.exp(og_log_delta + log_diff)
                yield new_delta, lmbda
                log_diff /= self.backoff_factor

        else:
            # Do lmbda
            og_lmbda = lmbda

            if condition < 0:
                diff = (self.max_lmbda - og_lmbda) / 2
            else:
                diff = (self.min_lmbda - og_lmbda) / 2

            for i in range(self.backoff_attempts):
                new_lmbda = og_lmbda + diff
                yield delta, new_lmbda
                diff /= self.backoff_factor

    def _get_ablating_fisher(self, component_index: int):
        mc = self.exp.mc
        if self.ablate_top_k_params is None:
            return mc.make_fisher_for_components([component_index])
        else:
            assert self.ablate_top_k_params > 0
            spF = mc.make_sparse_fisher_vector_for_components([component_index])
            # print(spF.values.shape[0])
            top_values, _ = tf.math.top_k(spF.values, k=min(self.ablate_top_k_params, spF.values.shape[0]))
            spF = tf.sparse.retain(spF, spF.values >= top_values[-1])
            packer = flat_pack.FlatPacker([v.shape for v in mc.variables])
            return packer.decode_tf(tf.sparse.to_dense(spF))

    def search(
        self,
        component_index: int,
        max_iters: int,
        init_delta: Optional[float] = None,
        init_lmbda: Optional[float] = None,
    ):
        # NOTE: This is based on the assumption that the KL is monotonic in both
        # delta and lmbda.
        ablating_fisher = self._get_ablating_fisher(component_index)
        delta = init_delta or self._get_random_delta()
        lmbda = init_lmbda or self._get_random_lmbda()

        for i in range(max_iters):
            condition, kl0 = self._evaluate(ablating_fisher, delta, lmbda)
            if condition == 0:
                return self.exp.output_model

            coeffs = list(self._kl_step_coeffs_gen(delta, lmbda, condition, i))
            for j, (delta, lmbda) in enumerate(coeffs):
                cond, _ = self._evaluate(ablating_fisher, delta, lmbda)
                if cond == 0:
                    return self.exp.output_model
                elif cond == condition:
                    break


##########################################################################

@dataclasses.dataclass
class ExperimentHelper1:
    exp: Experiment1

    component_index: int
    n_evaluation_examples: int

    kl_target_range: Tuple[float, float]

    fixed_sign_guide: bool = False

    n_kl_range_targeter_examples: Optional[int] = None
    kl_range_targeter_ex_indices: Optional[np.ndarray] = None

    ablate_top_k_params: Optional[int] = None

    def __post_init__(self):
        self.run_results = []

        self.eval_ctx = self.exp.mc.get_evaluation_context()

        if self.kl_range_targeter_ex_indices is not None:
            if self.n_kl_range_targeter_examples is not None:
                assert self.n_kl_range_targeter_examples == len(self.kl_range_targeter_ex_indices)
            self.n_kl_range_targeter_examples = len(self.kl_range_targeter_ex_indices)

        elif self.n_kl_range_targeter_examples is not None:
            self.kl_range_targeter_ex_indices = self.exp.random_example_indices(self.n_kl_range_targeter_examples)

        else:
            raise ValueError

        self.evaluation_ex_indices = self.exp.random_example_indices(self.n_evaluation_examples)

        self._last_sign_guide = None

        self._last_delta = None
        self._last_lmbda = None

    def get_ablating_fisher(self):
        return self.exp.mc.make_fisher_for_components([self.component_index])

    def make_sign_guide(self):
        return [tf.random.normal(v.shape) for v in self.exp.retaining_variables]

    def do_run(self, kl_targeter_max_iters: int = 25):
        def kl_fn(model):
            return self.eval_ctx.evaluate(model, self.kl_range_targeter_ex_indices).kl()

        if not self.fixed_sign_guide or self._last_sign_guide is None:
            self._last_sign_guide = self.make_sign_guide()

        sign_guide = self._last_sign_guide

        targeter = KlRangeTargeter(
            exp=self.exp,
            sign_guide=sign_guide,
            kl_fn=kl_fn,
            kl_range=self.kl_target_range,
            delta_mag_range=[1e-5, 3],
            ablate_top_k_params=self.ablate_top_k_params,
        )

        start = time.time()
        kl_model = targeter.search(self.component_index, max_iters=kl_targeter_max_iters)
        print("KL search time:", time.time() - start)

        start = time.time()
        eval_results = self.eval_ctx.evaluate(kl_model, self.evaluation_ex_indices)
        self.run_results.append(eval_results)
        print("Eval time:", time.time() - start)

        self._last_delta = targeter._last_delta
        self._last_lmbda = targeter._last_lmbda

        return eval_results

    def get_evaluation_examples_coeffs(self, component_index: Optional[int] = None):
        if component_index is None:
            component_index = self.component_index
        return self.exp.W[:, component_index][self.evaluation_ex_indices]

    def get_example_kls_for_run(self, run_index: int):
        eval_results = self.run_results[run_index]
        return tf.keras.losses.kl_divergence(
            tf.math.softmax(eval_results.logits),
            tf.math.softmax(eval_results.og_logits)
        ).numpy()

    def get_examples_kl_matrix(self):
        # return.shape = [n_evaluation_examples, n_runs]
        return np.stack([
            self.get_example_kls_for_run(i)
            for i in range(len(self.run_results))
        ], axis=-1)

    def get_model_with_different_sign_guide(self):
        # Keeps everything the same except for the sign guide
        assert self._last_delta is not None
        assert self._last_lmbda is not None

        return self.exp.create_model(
            self.get_ablating_fisher(),
            self.make_sign_guide(),
            delta=self._last_delta,
            lmbda=self._last_lmbda,
        )

    def evaluate_with_different_sign_guide(self):
        kl_model = self.get_model_with_different_sign_guide()

        start = time.time()
        eval_results = self.eval_ctx.evaluate(kl_model, self.evaluation_ex_indices)
        self.run_results.append(eval_results)
        print("Eval time:", time.time() - start)

        return eval_results
