import torch
from torch.optim.optimizer import Optimizer, required

class AGDMizer(Optimizer):
    r"""Implements my own version of accelerated GD for convex functions.

    Parameters
    ----------
        params (iterable): iterable of parameters to optimize or dicts
                defining parameter groups
        lr (float, required): learning rate
        momentum (float, optional): if positive, fixed size momentum will
                be used for acceleration and interpolating acceleration will
                not be used (default 0)
        prox_oper (function, optional): if specified, this function will be used
                to compute the proximal step. It is expected to take arguments of
                the form (current point, gradient, step size). It is expected to
                perform modifications in-place and not return a value
        mirror_oper (function, optional): if specified, this function will be used
                to compute the mirror step. Similar format of (current point,
                gradient, step size) and it is expected to perform operations in place

    """

    def __init__(self, params, lr = required, momentum = 0.0,
                 prox_oper = None, mirror_oper = None):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum: {}".format(momentum))

        if prox_oper is None:
            prox_oper = lambda t, grad_f, alpha : t.add_(grad_f, alpha=-alpha)
        if mirror_oper is None:
            mirror_oper = lambda t, grad_f, alpha : t.add_(grad_f, alpha=-alpha)

        defaults = dict(lr=lr, momentum=momentum,
                        prox_oper=prox_oper, mirror_oper=mirror_oper)
        super(AGDMizer, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(AGDMizer, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('prox_oper',
                             lambda t, grad_f, alpha : t.add_(grad_f, alpha=-alpha))
            group.setdefault('mirror_oper',
                             lambda t, grad_f, alpha : t.add_(grad_f, alpha=-alpha))
            group.setdefault('momentum', 0.0)

    def step(self, closure=None):
        r"""Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            p_op = group['prox_oper']
            m_op = group['mirror_oper']
            momentum = group['momentum']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("AGDMizer does not support sparse \
                    gradients, yet.")

                param_state = self.state[p]
                # Per parameter specification of step allows per parameter restarting
                if 'step' not in param_state:
                    param_state['step'] = 0

                # Stores the past auxiliary step. Either the mirror step or momentum
                if 'm_buffer' not in param_state:
                    param_state['m_buffer'] = \
                    torch.zeros_like(p, memory_format=torch.preserve_format)

                m_buffer = param_state['m_buffer']
                actual_param = p.data

                # If first step, initialize the auxiliary variable
                if param_state['step'] < 1:
                    m_buffer.copy_(p)

                # Perform current step
                step = param_state['step'] = param_state['step'] + 1
                if momentum != 0.0:
                    # Perform fixed, constant size momentum acceleration
                    m_buffer.add_(grad, alpha=-lr)
                    m_buffer.mul_(momentum)
                    # m_buffer stores momentum * (m_buffer + gradient step)
                    actual_param.add_(m_buffer)
                    actual_param.add_(grad, alpha=-lr)

                    # Functionally equivalent to
                    # y = x - step_size * grad(x)
                    # x = y + momentum * (y - yprev)
                    # yprev = y

                else:
                    # Perform prox-mirror interpolation-based acceleration
                    # Mirror step is calculated using past tau
                    m_op(m_buffer, grad, lr * (step + 1.0) / 2.0)
                    # The linear combination of proximal and mirror step is computed
                    tau = 2.0 / (step + 2.0)

                    # Proximal step and update the iterate with the linear combination
                    # of proximal and mirror steps
                    p_op(actual_param, grad, lr)
                    actual_param.mul_(1 - tau)
                    actual_param.add_(m_buffer, alpha=tau)

        return loss

    def set_final_iterate(self, closure):
        """Convenience function to set the final iterate (extra prox step).
           Will zero gradients and expects a closure to evaluate the loss.
        """
        self.zero_grad()
        closure().backward()
        self.exec_prox_step()

    def exec_prox_step(self, lr = None):
        """Convenience function to perform prox step"""
        for group in self.param_groups:
            for param in group['params']:
                if lr is None:
                    group['prox_oper'](param.data, param.grad.data, group['lr'])
                else:
                    group['prox_oper'](param.data, param.grad.data, lr)

    def exec_mirror_step(self, lr = None):
        """Convenience function to perform mirror step"""
        for group in self.param_groups:
            for param in group['params']:
                if lr is None:
                    group['mirror_oper'](param.data, param.grad.data, group['lr'])
                else:
                    group['mirror_oper'](param.data, param.grad.data, lr)

    def reset_all_steps(self):
        """Convenience function to reset all step for all parameters"""
        for group in self.param_groups:
            for param in group['params']:
                self.state[param]['step'] = 0
