from collections import OrderedDict

import numpy as np

import theano
import theano.tensor as T
from lasagne import utils
from lasagne.updates import get_or_compute_grads, apply_momentum

def tos(loss_or_grads, params, learning_rate=0.001, mu=0.001):
    """TOS updates

    Parameters
    ----------
    loss_or_grads : symbolic expression or list of expressions
        A scalar loss expression, or a list of gradient expressions
    params : list of shared variables
        The variables to generate update expressions for
    learning_rate : float or symbolic scalar
        Learning rate
    mu : float or symbolic scalar
        regularization parameter.

    Returns
    -------
    OrderedDict
        A dictionary mapping each parameter to its update expression

    References
    ----------
    .. [1] Yurtsever, Gu, and Sra (2021):
           Three Operator Splitting with Subgradients, 
           Stochastic Gradients, and Adaptive Learning Rates.
           Accepted to NeurIPS 2021.
    """
    all_grads = get_or_compute_grads(loss_or_grads, params)
    t_prev = theano.shared(utils.floatX(0.))
    updates = OrderedDict()

    one = T.constant(1)

    t = t_prev + one
    a_t = learning_rate/T.sqrt(t)
    mu_t = mu * a_t

    for param, g_t in zip(params, all_grads):
        value = param.get_value(borrow=True)
        y_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
                               broadcastable=param.broadcastable)

        x_t = 2.0 * param - y_prev - a_t * g_t
        x_t = T.sgn(x_t) * T.maximum(T.abs_(x_t) - mu_t, np.zeros(value.shape, dtype=value.dtype))
        y_t = y_prev - param + x_t
        if len(value.shape) == 2:
            alpha_t = mu_t * T.sqrt(value.shape[1])
            group_norms = T.sqrt(T.sum(y_t**2, axis=1)) # (n,)
            shrink = T.fill(group_norms, alpha_t) # (n,)
            factor = T.switch(
                T.gt(group_norms, shrink), # (n,)
                T.ones_like(shrink) - shrink / group_norms, # (n,)
                T.zeros_like(shrink)) # (n, )
            factor = T.transpose(T.tile(factor, (value.shape[1], 1))) # (n, m)
        elif len(value.shape) == 4:
            alpha_t = mu_t * T.sqrt(value.shape[1]*value.shape[2]*value.shape[3])
            group_norms = T.sqrt(T.sum(y_t**2, axis=(1,2,3)))
            shrink = T.fill(group_norms, alpha_t)
            factor = T.switch(
                T.gt(group_norms, shrink),
                T.ones_like(shrink) - shrink / group_norms,
                T.zeros_like(shrink)) # (n, )
            factor = T.tile(factor, (value.shape[1], value.shape[2], value.shape[3], 1)).transpose(3,0,1,2)
        else:
            alpha_t = mu_t
            shrink = T.fill(y_t, alpha_t) # (n,)
            factor = T.switch(
                T.gt(y_t - shrink, np.zeros(value.shape[0], dtype=value.dtype)), 
                T.ones_like(shrink) - shrink / y_t,
                T.zeros_like(shrink))
        z_t = y_t * factor
        updates[y_prev] = y_t
        updates[param] = z_t

    updates[t_prev] = t
    return updates

def adaptos(loss_or_grads, params, learning_rate=0.001, mu=0.001, alpha=0.001, epsilon=1e-6):
    """AdapTOS updates

    Parameters
    ----------
    loss_or_grads : symbolic expression or list of expressions
        A scalar loss expression, or a list of gradient expressions
    params : list of shared variables
        The variables to generate update expressions for
    learning_rate : float or symbolic scalar
        Learning rate
    mu : float or symbolic scalar
        regularization parameter.
    epsilon : float or symbolic scalar
        Constant for numerical stability.

    Returns
    -------
    OrderedDict
        A dictionary mapping each parameter to its update expression

    References
    ----------
    .. [1] Yurtsever, Gu, and Sra (2021):
           Three Operator Splitting with Subgradients, 
           Stochastic Gradients, and Adaptive Learning Rates.
           Accepted to NeurIPS 2021.
    """
    all_grads = get_or_compute_grads(loss_or_grads, params)
    updates = OrderedDict()

    for param, g_t in zip(params, all_grads):
        value = param.get_value(borrow=True)
        y_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
                               broadcastable=param.broadcastable)
        accu = theano.shared(np.zeros(value.shape, dtype=value.dtype),
                               broadcastable=param.broadcastable)

        accu = theano.shared(utils.floatX(0.))
        accu_new = accu + T.sum(g_t ** 2)
        updates[accu] = accu_new
        a_t = learning_rate / T.sqrt(accu_new)

        ## We can also try entrywise rate (like AdaGrad used in Neural Nets)
        # accu_new = accu + g_t ** 2
        # updates[accu] = accu_new
        # a_t = learning_rate / T.sqrt(accu_new + epsilon)

        mu_t = mu*a_t
        
        x_t = 2.0*param - y_prev - a_t*g_t
        x_t = T.sgn(x_t)*T.maximum(T.abs_(x_t)-mu_t, np.zeros(value.shape, dtype=value.dtype))
        y_t = y_prev - param + x_t
        
        if len(value.shape) == 2: # fc layer
            alpha_t = mu_t * T.sqrt(value.shape[1])
            group_norms = T.sqrt(T.sum(y_t**2, axis=1)) # (n,)
            shrink = T.fill(group_norms, alpha_t) # (n,)
            factor = T.switch(
                T.gt(group_norms, shrink), # (n,)
                T.ones_like(shrink) - shrink / group_norms, # (n,)
                T.zeros_like(shrink)) # (n, )
            factor = T.transpose(T.tile(factor, (value.shape[1], 1))) # (n, m)
        elif len(value.shape) == 4: # conv layer
            alpha_t = mu_t * T.sqrt(value.shape[1]*value.shape[2]*value.shape[3])
            group_norms = T.sqrt(T.sum(y_t**2, axis=(1,2,3)))
            shrink = T.fill(group_norms, alpha_t)
            factor = T.switch(
                T.gt(group_norms, shrink),
                T.ones_like(shrink) - shrink / group_norms,
                T.zeros_like(shrink)) # (n, )
            factor = T.tile(factor, (value.shape[1], value.shape[2], value.shape[3], 1)).transpose(3,0,1,2)
        else:
            alpha_t = mu_t
            shrink = T.fill(y_t, alpha_t) # (n,)
            factor = T.switch(
                T.gt(y_t - shrink, np.zeros(value.shape[0], dtype=value.dtype)), 
                T.ones_like(shrink) - shrink / y_t,
                T.zeros_like(shrink))
        z_t = y_t * factor

        updates[y_prev] = y_t
        updates[param] = z_t

    return updates

def adaptos_with_momentum(loss_or_grads, params, learning_rate=0.001, mu=0.001, alpha=0.001, epsilon=1e-6, momentum=0.9):
    updates = adaptos(loss_or_grads, params, learning_rate, mu, alpha, epsilon)
    return apply_momentum(updates, momentum=momentum)

def tos_with_momentum(loss_or_grads, params, learning_rate=0.001, mu=0.001, momentum=0.9):
    updates = tos(loss_or_grads, params, learning_rate=0.001, mu=0.001)
    return apply_momentum(updates, momentum=momentum)