import torch

from typing import List, Optional, Dict
from mmengine.registry import OPTIMIZERS, OPTIM_WRAPPERS
from mmengine.optim.optimizer import OptimWrapper


def register_custom_optimizers() -> List[str]:
    custom_optimizers = []
    
    from optimizers import Sps
    OPTIMIZERS.register_module(module=Sps)
    custom_optimizers.append('Sps')
    
    from optimizers import DoG
    OPTIMIZERS.register_module(module=DoG)
    custom_optimizers.append('DoG')
    
    from optimizers import DAdaptSGD
    OPTIMIZERS.register_module(module=DAdaptSGD)
    custom_optimizers.append('DAdaptSGD')
    
    from optimizers import COCOB as COCOB_Backprop
    OPTIMIZERS.register_module(module=COCOB_Backprop)
    custom_optimizers.append('COCOB_Backprop')
    
    from optimizers import DAdaptAdam
    OPTIMIZERS.register_module(module=DAdaptAdam)
    custom_optimizers.append('DAdaptAdam')
    
    from optimizers import LDoG
    OPTIMIZERS.register_module(module=LDoG)
    custom_optimizers.append('LDoG')
    
    from optimizers import Prodigy
    OPTIMIZERS.register_module(module=Prodigy)
    custom_optimizers.append('Prodigy')
    
    from optimizers import PSSps
    OPTIMIZERS.register_module(module=PSSps)
    custom_optimizers.append('PSSps')
    
    from optimizers import PSDASGD
    OPTIMIZERS.register_module(module=PSDASGD)
    custom_optimizers.append('PSDASGD')
    
    
    return custom_optimizers

CUSTOM_OPTIMIZERS = register_custom_optimizers()


class PassLossOptimWrapper(OptimWrapper):
    def update_params(  # type: ignore
            self,
            loss: torch.Tensor,
            step_kwargs: Optional[Dict] = None,
            zero_kwargs: Optional[Dict] = None) -> None:
        """Update parameters in :attr:`optimizer`.

        Args:
            loss (torch.Tensor): A tensor for back propagation.
            step_kwargs (dict): Arguments for optimizer.step.
                Defaults to None.
                New in version v0.4.0.
            zero_kwargs (dict): Arguments for optimizer.zero_grad.
                Defaults to None.
                New in version v0.4.0.
        """
        if step_kwargs is None:
            step_kwargs = {}
        if zero_kwargs is None:
            zero_kwargs = {}
        loss = self.scale_loss(loss)
        self.backward(loss)
        # Update parameters only if `self._inner_count` is divisible by
        # `self._accumulative_counts` or `self._inner_count` equals to
        # `self._max_counts`
        if self.should_update():
            self.step(loss=loss,**step_kwargs)
            self.zero_grad(**zero_kwargs)


def register_custom_optim_wrapper() -> List[str]:
    custom_optim_wrappers = []
    
    OPTIM_WRAPPERS.register_module(module=PassLossOptimWrapper)
    custom_optim_wrappers.append('PassLossOptimWrapper')
    
    return custom_optim_wrappers
    
CUSTOM_OPTIM_WRAPPERS = register_custom_optim_wrapper()
