import torch
from torch import Tensor
from torch.optim import Optimizer
from typing import List, Optional


class CSGO(Optimizer):
    """
        Implements clipped stochastic gradient descent (optionally with momentum).
    """

    def __init__(self, params, lr=None, friction=0.01, clip=4.00, decay=0.99):
        if lr is not None and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))

        defaults = dict(lr=lr, clip=clip, friction=friction, decay=decay)
        super(CSGO, self).__init__(params, defaults)

    def __setstate__(self, state):
        super().__setstate__(state)

    @torch.no_grad()
    def step(self):
        """
        Performs a single optimization step.
        """
        loss = None

        for group in self.param_groups:
            params_with_grad = []
            d_p_list = []
            momentum_buffer_list = []
            square_averages_list = []

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    d_p_list.append(p.grad)
                    if p.grad.is_sparse:
                        raise Exception("Sparse gradient not supported.")

                    state = self.state[p]

                    if 'momentum_buffer' not in state:
                        momentum_buffer_list.append(None)
                    else:
                        momentum_buffer_list.append(state['momentum_buffer'])

                    if 'square_buffer' not in state:
                        square_averages_list.append(None)
                    else:
                        square_averages_list.append(state['square_buffer'])

            sgd(params_with_grad,
                d_p_list,
                momentum_buffer_list,
                square_averages_list,
                lr = group['lr'],
                clip = group['clip'],
                friction = group['friction'],
                decay = group['decay'],
            )

            # update momentum_buffers in state
            for p, momentum_buffer, square_buffer in zip(params_with_grad, momentum_buffer_list, square_averages_list):
                state = self.state[p]
                state['momentum_buffer'] = momentum_buffer
                state['square_buffer'] = square_buffer
                # debugging...
                scaling = 1/(square_buffer+1e-5)
                print(f"{str(p.shape):<40} {float(momentum_buffer.mean()):<10.3f} {float(momentum_buffer.std()):<10.3f} {float(scaling.mean()):<10.2f} {float(scaling.std()):<10.2f}")

        return loss


def sgd(params: List[Tensor],
        d_p_list: List[Tensor],
        momentum_buffer_list: List[Optional[Tensor]],
        square_avgs: List[Optional[Tensor]],
        lr: float,
        clip: float,
        friction: float,
        decay: float,
    ):
    r"""Functional API that performs SGD algorithm computation.

    See :class:`~torch.optim.SGD` for details.
    """

    _single_tensor_sgd(
        params,
        d_p_list,
        momentum_buffer_list,
        square_avgs,
        lr=lr,
        clip=clip,
        friction=friction,
        decay=decay,
        )


def _single_tensor_sgd(params: List[Tensor],
                       d_p_list: List[Tensor],
                       momentum_buffer_list: List[Optional[Tensor]],
                       square_buffer_list: List[Tensor],
                       lr: float,
                       clip: float,
                       friction: float = 0.01,
                       decay: float = 1.0,
                       ):

    for i, param in enumerate(params):

        square_avg = square_buffer_list[i]
        if square_avg is None:
            square_avg = torch.zeros_like(param, requires_grad=False)
            square_buffer_list[i] = square_avg

        correction_step = d_p_list[i]

        beta = 0.99
        eps = 1e-5
        square_avg.mul_(beta).addcmul_(correction_step, correction_step, value=1-beta)
        avg = square_avg.sqrt().add_(eps)

        correction_step /= avg

        g_head = torch.clip(correction_step, -clip, clip)
        g_tail = correction_step - g_head

        accumulator = momentum_buffer_list[i]

        if accumulator is None:
            accumulator = torch.clone(g_tail).detach()
            momentum_buffer_list[i] = accumulator
        else:
            accumulator += g_tail

        # clipped correction step.
        param -= lr * g_head

        # would be better to transfer up to some fixed length? not 0.01
        # we already know next momentum step, so apply it now
        delta = accumulator * friction
        param -= delta * lr
        accumulator -= delta
        accumulator *= decay
