from typing import List, Optional

import torch
from torch import Tensor
from torch.optim import Optimizer

class bort(Optimizer):
    """
    Implements restricted stochastic gradient descent.
    Supported PyTorch modules: nn.Linear / nn.Conv2d.
    :math: s.t. U^T U = I

    Args:
        params (iterable): parameters
        lr (float): learning rate
        momentum (float, optional): momentum factor (default: 0)
        weight_decay (float, optional): weight decay (default: 0)
        weight_constraint (float, optional): weight constraint (default: 0)
        dampening (float, optional): dampening for momentum (default: 0)
        nesterov (bool, optional): enables Nesterov momentum (default: False)
    """
    def __init__(
        self,
        params,
        lr: float,
        momentum: float = 0.0,
        weight_decay: float = 0.0,
        weight_constraint: float = 0.0,
        dampening: float = 0.0,
        nesterov: bool = False
    ):
        defaults = dict(lr=lr, momentum=momentum, 
                        weight_decay=weight_decay,
                        weight_constraint=weight_constraint,
                        dampening=dampening, nesterov=nesterov)
        
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(bort, self).__init__(params, defaults)
    
    def __setstate__(self, state: dict):
        super(bort, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault("nesterov", False)
    
    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            d_p_list = []
            momentum_buffer_list = []
            weight_decay = group["weight_decay"]
            weight_constraint = group['weight_constraint']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            lr = group['lr']

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    d_p_list.append(p.grad)

                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        momentum_buffer_list.append(None)
                    else:
                        momentum_buffer_list.append(state['momentum_buffer'])

            restricted_sgd(params_with_grad,
                        d_p_list,
                        momentum_buffer_list,
                        weight_decay=weight_decay,
                        weight_constraint=weight_constraint,
                        momentum=momentum,
                        lr=lr,
                        dampening=dampening,
                        nesterov=nesterov)

            # update momentum_buffers in state
            for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
                state = self.state[p]
                state['momentum_buffer'] = momentum_buffer


def restricted_sgd(
    params: List[Tensor],
    d_p_list: List[Tensor],
    momentum_buffer_list: List[Optional[Tensor]],
    *,
    weight_decay: float,
    weight_constraint: float,
    momentum: float,
    lr: float,
    dampening: float,
    nesterov: bool
):
    for i, param in enumerate(params):

        d_p = d_p_list[i]
        if weight_constraint != 0:
            # NOTE: restricted weight update!
            _param = param.view(param.size(0), -1)
            _param = _param.t() @ _param @ _param.t() - _param.t()
            _param = _param.t().view(param.size())
            d_p = d_p.add(_param, alpha=weight_constraint)
        if weight_decay != 0:
            d_p = d_p.add(param, alpha=weight_decay)

        if momentum != 0:
            buf = momentum_buffer_list[i]

            if buf is None:
                buf = torch.clone(d_p).detach()
                momentum_buffer_list[i] = buf
            else:
                buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

            if nesterov:
                d_p = d_p.add(buf, alpha=momentum)
            else:
                d_p = buf

        param.add_(d_p, alpha=-lr)
