# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.utils import digit_version
from torch import Tensor

from mmdet.registry import MODELS

MODELS.register_module('Linear', module=nn.Linear)


@MODELS.register_module(name='NormedLinear')
class NormedLinear(nn.Linear):
    """Normalized Linear Layer.

    Args:
        tempeature (float, optional): Tempeature term. Defaults to 20.
        power (int, optional): Power term. Defaults to 1.0.
        eps (float, optional): The minimal value of divisor to
             keep numerical stability. Defaults to 1e-6.
    """

    def __init__(self,
                 *args,
                 tempearture: float = 20,
                 power: int = 1.0,
                 eps: float = 1e-6,
                 **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.tempearture = tempearture
        self.power = power
        self.eps = eps
        self.init_weights()

    def init_weights(self) -> None:
        """Initialize the weights."""
        nn.init.normal_(self.weight, mean=0, std=0.01)
        if self.bias is not None:
            nn.init.constant_(self.bias, 0)

    def forward(self, x: Tensor) -> Tensor:
        """Forward function for `NormedLinear`."""
        weight_ = self.weight / (
            self.weight.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
        x_ = x / (x.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
        x_ = x_ * self.tempearture

        return F.linear(x_, weight_, self.bias)


@MODELS.register_module(name='NormedConv2d')
class NormedConv2d(nn.Conv2d):
    """Normalized Conv2d Layer.

    Args:
        tempeature (float, optional): Tempeature term. Defaults to 20.
        power (int, optional): Power term. Defaults to 1.0.
        eps (float, optional): The minimal value of divisor to
             keep numerical stability. Defaults to 1e-6.
        norm_over_kernel (bool, optional): Normalize over kernel.
             Defaults to False.
    """

    def __init__(self,
                 *args,
                 tempearture: float = 20,
                 power: int = 1.0,
                 eps: float = 1e-6,
                 norm_over_kernel: bool = False,
                 **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.tempearture = tempearture
        self.power = power
        self.norm_over_kernel = norm_over_kernel
        self.eps = eps

    def forward(self, x: Tensor) -> Tensor:
        """Forward function for `NormedConv2d`."""
        if not self.norm_over_kernel:
            weight_ = self.weight / (
                self.weight.norm(dim=1, keepdim=True).pow(self.power) +
                self.eps)
        else:
            weight_ = self.weight / (
                self.weight.view(self.weight.size(0), -1).norm(
                    dim=1, keepdim=True).pow(self.power)[..., None, None] +
                self.eps)
        x_ = x / (x.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
        x_ = x_ * self.tempearture

        if hasattr(self, 'conv2d_forward'):
            x_ = self.conv2d_forward(x_, weight_)
        else:
            if digit_version(torch.__version__) >= digit_version('1.8'):
                x_ = self._conv_forward(x_, weight_, self.bias)
            else:
                x_ = self._conv_forward(x_, weight_)
        return x_
