import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import ops
from tensorflow.python.training import optimizer


class RGD(optimizer.Optimizer):
    def __init__(
        self,
        learning_rate,
        momentum,
        delta,
        integrator="leapfrog",
        alpha=1,
        use_locking=False,
        name="RGD",
    ):
        super().__init__(use_locking, name)
        self._lr = learning_rate
        self._momentum = momentum
        self._delta = delta
        self._alpha = alpha

        if integrator not in ["symplectic_euler", "leapfrog"]:
            raise ValueError(
                "`integrator` must be either 'symplectic_euler' or 'leapfrog'"
            )
        self._integrator = integrator

    def _prepare(self):
        self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
        self._momentum_t = ops.convert_to_tensor(self._momentum, name="momentum")
        self._delta_t = ops.convert_to_tensor(self._delta, name="delta")
        self._alpha_t = ops.convert_to_tensor(self._alpha, name="alpha")

    def _create_slots(self, var_list):
        for v in var_list:
            self._zeros_slot(v, "v_k", self._name)
            self._get_or_make_slot(v, tf.identity(v), "x_k", self._name)

    def _apply_dense(self, grad, var_ref):
        v_k_ref = self.get_slot(var_ref, "v_k")
        x_k_ref = self.get_slot(var_ref, "x_k")
        var, v_k, x_k = var_ref, v_k_ref, x_k_ref

        lr = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        momentum = math_ops.cast(self._momentum_t, var.dtype.base_dtype)
        delta = math_ops.cast(self._delta_t, var.dtype.base_dtype)
        alpha = math_ops.cast(self._alpha_t, var.dtype.base_dtype)

        if self._integrator == "symplectic_euler":
            # v_{k+1} = momentum * v_k - lr * g_k
            v_k = momentum * v_k - lr * grad

            # x_{k+1} = x_k + v_{k+1} / sqrt(delta * ||v_k||^2 + 1)
            norm_factor = tf.sqrt(delta * tf.math.reduce_sum(tf.square(v_k)) + 1)
            var = var + v_k / norm_factor

        else:  # integrator == "leapfrog"
            # v_{k+1/2} = sqrt(momentum) * v_k - lr * g_k
            v_k = tf.sqrt(momentum) * v_k - lr * grad

            # x_{k+1} = alpha * x_{k+1/2} + (1-alpha) * x_k + v_{k+1/2} / sqrt(delta *||v_{k+1/2}||^2 + 1)
            if alpha != 1:
                var = alpha * var + (1 - alpha) * x_k

            norm_factor = tf.sqrt(delta * tf.math.reduce_sum(tf.square(v_k)) + 1)
            var = var + v_k / norm_factor

            if alpha != 1:
                x_k = tf.identity(var)

            # v_{k+1} = sqrt(momentum) * v_{k+1/2}
            v_k = tf.sqrt(momentum) * v_k

            # x_{k+3/2} = x_{k+1} + sqrt(momentum) *
            #       v_{k+1} / sqrt(momentum * delta * ||v_{k+1}||^2 + 1)
            norm_factor = tf.sqrt(
                momentum * delta * tf.math.reduce_sum(tf.square(v_k)) + 1
            )
            var = var + v_k * tf.sqrt(momentum) / norm_factor

        # Create an op that groups multiple operations.
        # When this op finishes, all ops in input have finished
        var_ref = var_ref.assign(var)
        v_k_ref = v_k_ref.assign(v_k)
        x_k_ref = x_k_ref.assign(x_k)
        return control_flow_ops.group(*[var_ref, v_k_ref, x_k_ref])


class PowerDescent(optimizer.Optimizer):
    def __init__(
        self,
        learning_rate,
        momentum,
        delta,
        little_a,
        big_a,
        use_locking=False,
        name="PowerDescent",
    ):
        super().__init__(use_locking, name)
        self._lr = learning_rate
        self._momentum = momentum
        self._delta = delta
        self._little_a = little_a
        self._big_a = big_a

    def _prepare(self):
        self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
        self._momentum_t = ops.convert_to_tensor(self._momentum, name="momentum")
        self._delta_t = ops.convert_to_tensor(self._delta, name="delta")
        self._little_a_t = ops.convert_to_tensor(self._little_a, name="little_a")
        self._big_a_t = ops.convert_to_tensor(self._big_a, name="big_a")

    def _create_slots(self, var_list):
        for v in var_list:
            self._zeros_slot(v, "v_k", self._name)
            self._get_or_make_slot(v, tf.identity(v), "x_k", self._name)

    def _apply_dense(self, grad, var_ref):
        v_k_ref = self.get_slot(var_ref, "v_k")
        x_k_ref = self.get_slot(var_ref, "x_k")
        var, v_k, x_k = var_ref, v_k_ref, x_k_ref

        lr = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        momentum = math_ops.cast(self._momentum_t, var.dtype.base_dtype)
        delta = math_ops.cast(self._delta_t, var.dtype.base_dtype)
        little_a = math_ops.cast(self._little_a_t, var.dtype.base_dtype)
        big_a = math_ops.cast(self._big_a_t, var.dtype.base_dtype)

        # v_{k+1} = momentum * v_k - lr * g_k
        v_k = momentum * v_k - lr * grad

        # x_{k+1} = x_k + v_{k+1} * ||v_{k+1}||^(a-2) * (delta * ||v_{k+1}||^a + 1) ^ (A/a-1)
        norm_v_k = tf.norm(v_k)
        norm_factor = (delta * norm_v_k ** little_a + 1) ** (
            big_a / little_a - 1
        ) * norm_v_k ** (little_a - 2)
        var = var + v_k / norm_factor

        # Create an op that groups multiple operations.
        # When this op finishes, all ops in input have finished
        var_ref = var_ref.assign(var)
        v_k_ref = v_k_ref.assign(v_k)
        x_k_ref = x_k_ref.assign(x_k)
        return control_flow_ops.group(*[var_ref, v_k_ref, x_k_ref])