import torch
from torch.optim.optimizer import Optimizer, required
from math import sqrt

class MDMizer(Optimizer):
    r"""Implements my own version of mirror descent.

    Parameters
    ----------
        params (iterable): iterable of parameters to optimize or dicts
                defining parameter groups
        lr (float, required): learning rate
        m_oper (function, optional): if specified, this function will be used
                to compute the mirror step
    """

    def __init__(self, params, lr = required, mirror_oper = None):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))

        if mirror_oper is None:
            mirror_oper = lambda t, grad_f, alpha : t.add_(grad_f, alpha=-alpha)

        defaults = dict(lr=lr, mirror_oper=mirror_oper)
        super(MDMizer, self).__init__(params, defaults)

        self.total_num_steps = None

    def __setstate__(self, state):
        super(MDMizer, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('mirror_oper',
                             lambda t, grad_f, alpha : t.add_(grad_f, alpha=-alpha))

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            m_op = group['mirror_oper']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("AGDMizer does not support sparse \
                    gradients, yet.")

                param_state = self.state[p]
                # Per parameter specification of step allows per parameter restarting
                if 'step' not in param_state:
                    param_state['step'] = 0

                # Stores the average of past steps.
                if 'm_buffer' not in param_state:
                    param_state['m_buffer'] = \
                    torch.zeros_like(p, memory_format=torch.preserve_format)
                    param_state['m_buffer'].copy_(p)

                m_buffer = param_state['m_buffer']

                # If first step, initialize the auxiliary variable
                if param_state['step'] < 1:
                    m_buffer.copy_(p)

                # Perform current step
                actual_param = p.data
                param_state['step'] += 1
                # Calculate mirror step
                m_op(actual_param, grad, lr)
                # Calculate running mean
                tau = 1.0 / (param_state['step'] + 1.0)
                m_buffer.mul_(1.0 - tau)
                m_buffer.add_(actual_param, alpha=tau)

        return loss

    def set_total_num_steps(self, num_steps = None):
        """Set the total number of steps"""
        self.total_num_steps = num_steps
        if num_steps is not None:
            for group in self.param_groups:
                group['lr'] *= 1.0 / sqrt(num_steps + 1)

    def set_final_iterate(self, closure=None):
        """Convenience function to set the final iterate (running average)"""
        for group in self.param_groups:
            for param in group['params']:
                param.data.copy_(self.state[param]['m_buffer'])

    def exec_mirror_step(self, lr = None):
        """Convenience function to perform the mirror step"""
        for group in self.param_groups:
            for param in group['params']:
                if lr is None:
                    group['mirror_oper'](param.data, param.grad.data, group['lr'])
                else:
                    group['mirror_oper'](param.data, param.grad.data, lr)

    def reset_all_steps(self):
        """Convenience function to reset all step for all parameters"""
        for group in self.param_groups:
            for param in group['params']:
                self.state[param]['step'] = 0
