"""Merging using sparse diagonal Fishers."""
from typing import List, Optional, Sequence, Union

import datasets as hfds
import tensorflow as tf
from transformers import TFPreTrainedModel

from em.evaluation import evaluation
from em.fishers import diagonal
from em.fishers import sparse_diagonal
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.util import hf_util

# Some typedefs
Fisher = Union[diagonal.DiagonalFisher, sparse_diagonal.SparseDiagonalFisher]
MaybeSparseTensor = Union[tf.Tensor, tf.sparse.SparseTensor]


def _l2_norm_of_fisher(fisher: Fisher):
    var_norms = []
    for f in fisher.fishers:
        if isinstance(f, tf.sparse.SparseTensor):
            vn = tf.sparse.reduce_sum(tf.sparse.map_values(tf.square, f))
        else:
            vn = tf.reduce_sum(tf.square(f))
        var_norms.append(vn)
    return tf.sqrt(tf.reduce_sum(var_norms))


def _to_dense(tensor: MaybeSparseTensor) -> tf.Tensor:
    if isinstance(tensor, tf.sparse.SparseTensor):
        return tf.sparse.to_dense(tensor)
    else:
        return tensor


def _floor(tensor: MaybeSparseTensor, fisher_floor: float) -> MaybeSparseTensor:
    if isinstance(tensor, tf.sparse.SparseTensor):
        return tf.sparse.map_values(tf.maximum, tensor, fisher_floor)
    else:
        return tf.maximum(tensor, fisher_floor)


def _merge_with_coeffs(
    output_variables: Sequence[tf.Variable],
    base_variables: Sequence[tf.Variable],
    base_model_fisher: Fisher,
    fishers: Sequence[Fisher],
    coefficients: Sequence[float],
    variable_filter: Optional[tmv.VariableFilter] = None,
    fisher_floor: float = 1e-6,
    favor_target_model=True,
    normalization_constants=None,
):
    n_models = len(fishers) + 1
    assert len(coefficients) == n_models

    if normalization_constants is not None:
        assert len(normalization_constants) == n_models
        coefficients = [w / n for w, n in zip(coefficients, normalization_constants)]

    base_coeff, *coefficients = coefficients

    for i, (out_var, base_var, base_fisher) in enumerate(
        zip(output_variables, base_variables, base_model_fisher.fishers)
    ):
        if variable_filter is not None and not variable_filter.does_variable_match(out_var):
            continue

        # TODO: Find better way of avoiding diving by zero.
        diag = _floor(_to_dense(base_fisher), fisher_floor)
        tmp = base_coeff * diag

        lhs = [tmp]
        rhs = [tmp * base_var]

        # TODO: When I am more familiar with manipulating sparse tensors, this
        # can probably be re-written to be far more time/space efficient.
        for fisher, coeff in zip(fishers, coefficients):
            diag = fisher.fishers[i]
            if not favor_target_model:
                diag = _floor(diag, fisher_floor)
            mvar = _to_dense(fisher.parameters[i])
            tmp = coeff * _to_dense(diag)
            lhs.append(tmp)
            rhs.append(tmp * mvar)

        rhs = tf.reduce_sum(rhs, axis=0)
        lhs = tf.reduce_sum(lhs, axis=0)
        out_var.assign(rhs / lhs)


def generate_merged_for_coeffs_set(
    base_model: TFPreTrainedModel,
    base_model_fisher: Fisher,
    fishers: Sequence[Fisher],
    coefficients_set: Sequence[Sequence[float]],
    variable_filter: Optional[tmv.VariableFilter] = None,
    fisher_floor: float = 1e-6,
    favor_target_model=True,
    normalize_fishers=True,
):
    if normalize_fishers:
        norm_constants = [_l2_norm_of_fisher(f) for f in [base_model_fisher, *fishers]]
    else:
        norm_constants = None

    base_variables = hf_util.get_mergeable_variables(base_model)

    output_model = hf_util.clone_model(base_model)
    output_variables = hf_util.get_mergeable_variables(output_model)

    # Make sure that all of the variable and fisher lists contain exactly the same number
    # of entries.
    fishers_lengths = {len(base_model_fisher.fishers)} | set(len(f.fishers) for f in fishers)
    parameters_lengths = {len(output_variables)} | set(len(f.parameters) for f in fishers)
    assert len(fishers_lengths | parameters_lengths) == 1

    for coefficients in coefficients_set:
        _merge_with_coeffs(
            output_variables=output_variables,
            base_variables=base_variables,
            base_model_fisher=base_model_fisher,
            fishers=fishers,
            coefficients=coefficients,
            variable_filter=variable_filter,
            fisher_floor=fisher_floor,
            favor_target_model=favor_target_model,
            normalization_constants=norm_constants,
        )
        yield coefficients, output_model


def merging_coefficients_search(
    base_model: TFPreTrainedModel,
    base_model_fisher: Fisher,
    fishers: Sequence[Fisher],
    coefficients_set: Sequence[Sequence[float]],
    dataset: tf.data.Dataset,
    metric: hfds.Metric,
    variable_filter: Optional[tmv.VariableFilter] = None,
    fisher_floor: float = 1e-6,
    favor_target_model=True,
    normalize_fishers=True,
    print_results=False,
) -> List[merging.MergeResult]:
    merged_models = generate_merged_for_coeffs_set(
        base_model=base_model,
        base_model_fisher=base_model_fisher,
        fishers=fishers,
        coefficients_set=coefficients_set,
        variable_filter=variable_filter,
        fisher_floor=fisher_floor,
        favor_target_model=favor_target_model,
        normalize_fishers=normalize_fishers,
    )
    results = []
    for coeffs, merged_model in merged_models:
        score = evaluation.evaluate_model(merged_model, dataset, metric)
        result = merging.MergeResult(coefficients=coeffs, score=score)
        results.append(result)
        if print_results:
            merging.print_merge_result(result)
    return results
