import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from pado.core import PadoModule
from pado.nn.parameter import ParameterModule

__all__ = ["Linear", "Identity"]


class Linear(PadoModule):

    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 *, init_type="xavier"):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.weight = ParameterModule(torch.empty(out_features, in_features))
        if bias:
            self.bias = ParameterModule(torch.zeros(out_features))
        else:
            self.bias = None

        init_type = init_type.lower()
        if init_type not in ("xavier", "attn", "proj"):
            raise ValueError(f"Linear unsupported init_type {init_type}.")
        self._initialize_parameters(init_type)

    def _initialize_parameters(self, init_type: str):
        if init_type == "attn":  # for QKV
            nn.init.uniform_(self.weight.data, -math.sqrt(1.0 / self.in_features), math.sqrt(1.0 / self.in_features))
        elif init_type == "proj":  # for Proj, if out_dim is very small (final FC)
            nn.init.uniform_(self.weight.data, -math.sqrt(3.0 / self.in_features), math.sqrt(3.0 / self.in_features))
        else:  # init_type == "xavier":  # default, for dense FC
            nn.init.xavier_uniform_(self.weight.data, gain=1.0)
        if self.bias is not None:
            nn.init.zeros_(self.bias.data)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        :param x:       (batch_size, in_features)
        :return:
                        (batch_size, out_features)
        """
        weight = self.weight()
        bias = self.bias() if (self.bias is not None) else None
        return F.linear(x, weight, bias)


class Identity(PadoModule):

    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x
