import torch
from torch import nn

from .functional import init_truncnormal_zero_bias


class LinearProjection(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        ndim: None | int = None,
        bias: bool = True,
        optional: bool = False,
        init_weights: str = "torch",
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.ndim = ndim
        self.bias = bias
        self.init_weights = init_weights

        self.proj: nn.Module
        if optional and input_dim == output_dim:
            self.proj = nn.Identity()
        elif ndim is None:
            self.proj = nn.Linear(input_dim, output_dim, bias=bias)
        elif ndim == 1:
            self.proj = nn.Conv1d(input_dim, output_dim, kernel_size=1, bias=bias)
        elif ndim == 2:
            self.proj = nn.Conv2d(input_dim, output_dim, kernel_size=1, bias=bias)
        elif ndim == 3:
            self.proj = nn.Conv3d(input_dim, output_dim, kernel_size=1, bias=bias)
        else:
            raise NotImplementedError

        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.init_weights == "torch":
            pass
        elif self.init_weights in ["truncnormal", "truncnormal002"]:
            init_truncnormal_zero_bias(self.proj)
        else:
            raise NotImplementedError

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.proj(x)
