"""Some stuff I'm experimenting with for selective NMF-based ablation."""
import dataclasses
from typing import Sequence

import cvxpy as cp
import numpy as np
import tensorflow as tf

from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.util import hf_util


###############################################################################

def make_dataset_for_top_component_examples(
    W: np.ndarray,
    component: int,
    n_examples: int,
    tokenizer,
    pe_fishers_data: per_example.PerExampleFlatFishers,
) -> tf.data.Dataset:
    _, inds = tf.math.top_k(W[:, component], k=n_examples)
    inds = inds.numpy()

    input_ids = pe_fishers_data.input_ids[inds]
    # I think this is how the token_type_ids work.
    token_type_ids = tokenizer.pad_token_type_id * np.ones_like(input_ids)
    labels = pe_fishers_data.labels[inds]

    examples_ds = tf.data.Dataset.from_tensor_slices({
        'input_ids': tf.cast(input_ids, tf.int32),
        'token_type_ids': tf.cast(token_type_ids, tf.int32),
    })
    labels_ds = tf.data.Dataset.from_tensor_slices(labels)
    return tf.data.Dataset.zip((examples_ds, labels_ds))


###############################################################################

def generate_merged_for_coeffs_set(
    output_variables: Sequence[tf.Variable],
    variables_to_merge: Sequence[Sequence[tf.Tensor]],
    fishers: Sequence[Sequence[tf.Tensor]],
    coefficients_set: Sequence[Sequence[float]],
    # NOTE: The default fisher_floor is lower here than in the original
    # function in the merging module.
    fisher_floor: float = 1e-8,
    normalize_fishers: bool = True,
    *,
    favor_target_model: bool = True,
):
    # Create the model to yield, then handle the norm_constants
    if normalize_fishers and fishers is not None:
        norm_constants = [merging._l2_norm_of_fisher(f) for f in fishers]
    else:
        norm_constants = None

    # Make sure that all of the variable lists and fishers contain exactly the same number
    # of variables.
    assert len({len(output_variables)} | set(len(v) for v in variables_to_merge) | set(len(f) for f in fishers)) == 1

    for coefficients in coefficients_set:
        merging._merge_with_coeffs(
            output_variables,
            variables_to_merge,
            coefficients=coefficients,
            fishers=fishers,
            fisher_floor=fisher_floor,
            favor_target_model=favor_target_model,
            normalization_constants=norm_constants,
        )
        yield coefficients


###############################################################################

@dataclasses.dataclass
class RiemannianAblator:
    # shape = [n_merge_parameters]
    flat_full_fisher: np.ndarray

    # shape = [n_merge_parameters]
    flat_component_fisher: np.ndarray

    # Whether or not to normalize the passed fishers to have unit norm.
    normalize_fishers: bool = False

    def __post_init__(self):
        assert self.flat_full_fisher.shape == self.flat_component_fisher.shape

        self.dtype = self.flat_component_fisher.dtype
        self.n_params, = self.flat_component_fisher.shape

        if self.normalize_fishers:
            self._normalize_fishers()

        self.sq_delta = cp.Variable([self.n_params], name='sq_delta')
        self.alpha = cp.Parameter([], name='alpha')

        self._set_up_problem()

    def _normalize_fishers(self, eps=1e-12):
        self.flat_full_fisher = self.flat_full_fisher / (np.linalg.norm(self.flat_full_fisher) + eps)
        self.flat_component_fisher = self.flat_component_fisher / (np.linalg.norm(self.flat_component_fisher) + eps)

    def _set_up_problem(self):
        obj = self.sq_delta @ (self.alpha * self.flat_component_fisher - (1 - self.alpha) * self.flat_full_fisher)
        self.objective = cp.Maximize(obj)

        # These constraints ensures that sq_delta belongs to the n_params-simplex.
        constraints = [
            self.sq_delta >= 0,
            cp.sum(self.sq_delta) == 1,
        ]

        self.prob = cp.Problem(self.objective, constraints)

    def solve_for_sq_delta(self, alpha: float, **ecos_kwargs):
        # alpha = 0 => keep f, ignore c
        # alpha = 1 => ignore f
        if not (0 <= alpha <= 1):
            raise ValueError('alpha must be in range [0, 1]')

        self.alpha.value = alpha
        loss = self.prob.solve(warm_start=True, solver=cp.ECOS, **ecos_kwargs)

        # If problem could not be solved, self.prob.solve will return a string
        # indicating why the problem could not be solved.
        if not isinstance(loss, float):
            return None

        return self.sq_delta.value

    def randomly_unsquare_delta(self, sq_delta: np.ndarray):
        # More of a utitilty function, doesn't need to be attached to this class.
        signs = np.sign(np.random.normal(size=sq_delta.shape))
        # We take the abs since it appears that sometimes the solver can return
        # very low magnitude negative values.
        return signs * np.sqrt(np.abs(sq_delta))
