from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.framework import ops
from tensorflow.python.training import optimizer
import tensorflow as tf
class DemonSGDM(optimizer.Optimizer):
    """
    Implementation of DemonCM.
    iterations: total number of iterations.
    
    """
    def __init__(self, iterations, learning_rate=0.1, momentum=0.9, use_locking=False, name="DemonCM"):#beta=0.5,
        super(DemonSGDM, self).__init__(use_locking, name)
        self._lr = learning_rate
        self._momentum = momentum
        self._iterations = iterations
        self.t = tf.Variable(1.0, trainable=False)
        
        self._lr_t = None
        self._momentum_t = None

    def _prepare(self):
        self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
        self._momentum_t = ops.convert_to_tensor(self._momentum, name="momentum_t")

    def _create_slots(self, var_list):
        # Create slots for the first and second moments.
        for v in var_list:
            self._zeros_slot(v, "m", self._name)

    def _apply_resource_dense(self, grad, var):
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        t = self.t
        momentum_t = math_ops.cast(self._momentum_t, var.dtype.base_dtype)

        eps = 1e-7

        m = self.get_slot(var, "m")
        z = (self._iterations -t) /(self._iterations)
        cur_momentum = self._momentum * (z / ( 1 - self._momentum + self._momentum * z)) 

        m_t = m.assign(m * cur_momentum + lr_t * grad)
        
        var_update = state_ops.assign_sub(var, m_t)

        return control_flow_ops.group(*[var_update, m_t])

    def _apply_sparse(self, grad, var):
        raise NotImplementedError("Sparse gradient updates are not supported.")
        
    def _finish(self, update_ops, name_scope):
        """Do what is needed to finish the update.
        This is called with the `name_scope` using the "name" that
        users have chosen for the application of gradients.
        Args:
          update_ops: List of `Operation` objects to update variables.  This list
            contains the values returned by the `_apply_dense()` and
            `_apply_sparse()` calls.
          name_scope: String.  Name to use for the returned operation.
        Returns:
          The operation to apply updates.
        """
        t = self.t.assign_add(1.0)
        
        return control_flow_ops.group(*update_ops + [t], name=name_scope)
      
