import torch

from torch.optim import Optimizer
from bitsandbytes.optim.optimizer import Optimizer1State

class LARS(Optimizer1State):
    def __init__(self, params, lr, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, optim_bits=32, args=None,
                 min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
        if momentum == 0:
            raise NotImplementError(f'LARS without momentum is not supported!')
        super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
                weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)

class LARS8bit(Optimizer1State):
    def __init__(self, params, lr, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, args=None,
                 min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
        if momentum == 0:
            raise NotImplementError(f'LARS without momentum is not supported!')
        super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
                weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)

class LARS32bit(Optimizer1State):
    def __init__(self, params, lr, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, args=None,
                 min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
        if momentum == 0:
            raise NotImplementError(f'LARS without momentum is not supported!')
        super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
                weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)


class PytorchLARS(Optimizer):
    def __init__(self, params, lr=0.01, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, max_unorm=0.02):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(PytorchLARS, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(PytorchLARS, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            d_p_list = []
            momentum_buffer_list = []
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            max_unorm = group['max_unorm']
            lr = group['lr']

            for p in group['params']:
                if p.grad is None: continue

                state = self.state[p]
                d_p = p.grad
                if weight_decay != 0:
                    d_p = d_p.add(param, alpha=weight_decay)

                if momentum != 0:
                    buf = state.get('momentum_buffer', None)

                    if buf is None:
                        buf = torch.clone(d_p).detach()
                        state['momentum_buffer']= buf
                    else:
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

                    if nesterov:
                        update = d_p + buf*momentum
                    else:
                        update = buf

                update_scale = 1.0
                if max_unorm > 0.0:
                    assert p.dtype == torch.float32
                    pnorm = torch.norm(p.detach())
                    unorm = torch.norm(update)
                    if unorm > max_unorm*pnorm:
                        update_scale = max_unorm*pnorm/unorm

                p.add_(update, alpha=-lr*update_scale)

        return loss
