"""Model-merging based perturbation method."""
import dataclasses
from typing import Sequence

import tensorflow as tf

from em.merging import merging


##########################################################################

@dataclasses.dataclass
class MmPerturber:
    variables: Sequence[tf.Tensor]
    retaining_fisher: Sequence[tf.Tensor]

    ablating_shift: Sequence[tf.Tensor]
    ablating_fisher: Sequence[tf.Tensor]

    output_variables: Sequence[tf.Variable]

    fisher_floor: float = 1e-9,

    def __post_init__(self):
        self._assert_lists_same_length()

    ############################################################

    def update_output_variables(self, delta: float, lmbda: float):
        ablating_variables = self._compute_ablating_variables(delta, lmbda)

        variables_to_merge = [self.variables, ablating_variables]
        fishers_to_merge = [self.retaining_fisher, self.ablating_fisher]

        norm_constants = [merging._l2_norm_of_fisher(f) for f in fishers_to_merge]

        merging._merge_with_coeffs(
            self.output_variables,
            variables_to_merge,
            coefficients=(1 - lmbda, lmbda),
            fishers=fishers_to_merge,
            fisher_floor=self.fisher_floor,
            favor_target_model=True,
            normalization_constants=norm_constants,
        )

    def _compute_ablating_variables(self, delta: float, lmbda: float):
        return [
            v + delta * s
            for v, s in zip(self.variables, self.ablating_shift)
        ]

    ############################################################

    def _assert_lists_same_length(self):
        assert len({
            len(self.variables),
            len(self.retaining_fisher),
            len(self.ablating_shift),
            len(self.ablating_fisher),
            len(self.output_variables),
        }) == 1
