"""Code for finding perturbed parameters."""
import dataclasses
from typing import Optional

# import cvxpy as cp
import numpy as np
import tensorflow as tf


class PerturbationFinder:
    # TODO: Maybe change of some of the stuff here to use tensorflow instead of numpy.

    def __init__(self, f: np.ndarray, g: np.ndarray):
        assert f.shape == g.shape
        self.n_params = f.shape[0]

        # A diagonal fisher.
        self.f = f

        # The vector corresponding to the rank 1 PSD component.
        self.g = g

    def solve(self, delta: float, min_fisher_value: float = 1e-6):
        z = self.g / np.maximum(self.f, min_fisher_value)
        z *= np.sqrt(delta / np.sum(self.f * z**2))
        return z

    def compute_constraint_value(self, z: np.ndarray):
        return np.sum(self.f * z**2)

    def compute_objective_value(self, z: np.ndarray):
        return np.sum(self.g * z)


class RegularizedPerturbationFinder:
    """
    Corresponds to:
        argmax_x  x^t g - alpha * x^t x
        subject to
            f^t x**2 <= delta
    """
    def __init__(self, f: np.ndarray, g: np.ndarray, max_iters: int, epsilon: float = 1e-6):
        assert f.shape == g.shape
        self.n_params = f.shape[0]

        # A diagonal fisher.
        self.f = tf.cast(f, tf.float32)

        # The vector corresponding to the rank 1 PSD component.
        self.g = tf.cast(g, tf.float32)

        # Used to determine how close two values need to be to
        # considered equal.
        self.epsilon = epsilon

        # Maximum number of iterations when doing the binary search for mu.
        self.max_iters = max_iters

        # Bounds of Langrange multiplier used during optimization.
        self.mu1 = tf.Variable(0.0, dtype=tf.float32)
        self.mu2 = tf.Variable(0.0, dtype=tf.float32)

    # @tf.function
    def _compute_G_mu(self, mu: tf.Tensor, alpha: tf.Tensor):
        x = self.g / (2 * (mu * self.f + alpha))
        return tf.tensordot(x**2, self.f, 1)

    # @tf.function
    def _solve(self, delta: tf.Tensor, alpha: tf.Tensor):
        # Find the mu = 0 solution.
        x_mu0 = self.g / (2.0 * alpha)
        obj_mu0 = tf.tensordot(x_mu0, self.g, 1) - alpha * tf.tensordot(x_mu0, x_mu0, 1)

        # Find the mu \neq 0 solution.
        #
        # If G(mu) = delta is the constraint, this binary-search method is based on the fact
        # that dG/dmu < 0 for all mu > 0.
        self.mu1.assign(0.0)
        self.mu2.assign(1.0)

        # Find an upper bound on mu.
        for _ in tf.range(self.max_iters):
            if self._compute_G_mu(self.mu2, alpha) > delta:
                self.mu2.assign(2.0 * self.mu2)
            else:
                break
        # tf.print(self._compute_G_mu(self.mu2, alpha))

        # Do binary search.
        for _ in tf.range(self.max_iters):
            mu_mid = (self.mu1 + self.mu2) / 2.0
            G_mid = self._compute_G_mu(mu_mid, alpha)
            if tf.abs(G_mid - delta) <= self.epsilon:
                break
            elif G_mid > delta:
                self.mu1.assign(mu_mid)
            else:
                self.mu2.assign(mu_mid)

        mu = (self.mu1 + self.mu2) / 2.0
        x = self.g / (2 * (mu * self.f + alpha))
        obj = tf.tensordot(x, self.g, 1) - alpha * tf.tensordot(x, x, 1)

        # tf.print(self._compute_G_mu(mu, alpha))
        # tf.print(self._compute_G_mu(0.0, alpha))

        # Return the solution with the larger objective value. Recall that we are
        # maximizing here.
        if obj_mu0 > obj:
            return x_mu0
        else:
            return x

    def solve(self, delta: float, alpha: float):
        return self._solve(tf.cast(delta, tf.float32), tf.cast(alpha, tf.float32))


class PerturbationFinder2:
    # TODO: Maybe change of some of the stuff here to use tensorflow instead of numpy.
    """
    Solves the following unconstrained optimization problem:
        argmin_x -x^t g + alpha * f^t x**2 + beta * x^t x

    """

    def __init__(self, f: np.ndarray, g: np.ndarray):
        assert f.shape == g.shape
        self.n_params = f.shape[0]

        # A diagonal fisher.
        self.f = f

        # The vector corresponding to the rank 1 PSD component.
        self.g = g

    def solve(self, alpha: float, beta: float = 0.0):
        return self.g / (2 * alpha * self.f + 2 * beta)

    def compute_constraint_value(self, z: np.ndarray):
        return np.sum(self.f * z**2)

    def compute_objective_value(self, z: np.ndarray):
        return np.sum(self.g * z)

# class PerturbationFinder:

#     def __init__(self, f: np.ndarray, g: np.ndarray):
#         assert f.shape == g.shape
#         self.n_params = f.shape[0]

#         # A diagonal fisher.
#         self.f = f

#         # The vector corresponding to the rank 1 PSD component.
#         self.g = g

#         # Controls the size of the perturbation.
#         self.delta = cp.Parameter([], name='delta')

#         # The perturbation. This should be added/subtracted to the model parametetrs.
#         self.z = cp.Variable([self.n_params], name='z')

#         self._set_up_problem()

#     def _set_up_problem(self):
#         self.obj = self.z[None, :] @ self.g[:, None]
#         self.objective = cp.Maximize(self.obj)

#         self.constraint_value = (self.z**2)[None, :] @ self.f[:, None]
#         constraint = self.constraint_value <= self.delta

#         self.prob = cp.Problem(self.objective, [constraint])

#     def solve(self, delta: float):
#         self.delta.value = delta
#         self.prob.solve(warm_start=True, solver=cp.ECOS)


class PerturbationFinder3:
    """
    Does the offset by g with weighting according.

    """

    def __init__(self, f: np.ndarray, g: np.ndarray):
        assert f.shape == g.shape
        self.n_params = f.shape[0]

        # A diagonal fisher.
        self.f = f

        # The vector corresponding to the rank 1 PSD component.
        self.g = g

        # Something like a diagonal fisher associated with g?
        self.gg = g * g
        self.gg /= np.sqrt(np.sum(self.gg**2))

    def solve(self, multiplier: float, lmbda: Optional[float] = None, min_fisher_value: float = 1e-9):
        # lmbda is weight associated to the comonent's fisher. A lmbda of
        # None is treated specially and means just do an offset by g, which I
        # think is exactly the same as a lmbda of 0.
        if lmbda is None:
            return multiplier * self.g

        assert 0 <= lmbda <= 1

        f = np.maximum(self.f, min_fisher_value)
        gg = np.maximum(self.gg, min_fisher_value)

        # th1 = 0
        c1 = lmbda * f

        th2 = multiplier * self.g
        c2 = (1 - lmbda) * gg

        return (c2 * th2) / (c1 + c2)
