import torch.nn as nn

from . import register_component
from .utils import CONV_TYPES

@register_component("ProjectionLayer")
class ProjectionLayer(nn.Module):
    """
    Projection layer. Projects in_channels to out_channels.

    Args:
        dimension: Dimension for convolution operations (1, 2, or 3)
        in_channels: Number of input channels
        out_channels: Number of output channels
        kernel_size: Kernel size of the convolution
        bias: Whether to include bias in convolution
    """
    def __init__(
            self, 
            dimension: int,
            in_channels: int, 
            out_channels: int, 
            kernel_size: int = 1,
            bias: bool = True
        ):
        super().__init__()

        assert dimension in CONV_TYPES, "Dimension must be 1, 2, or 3"
        Conv = CONV_TYPES[dimension]

        self.conv = Conv(
            in_channels, 
            out_channels, 
            kernel_size, 
            bias=bias
        )
    
    def forward(self, x):
        return self.conv(x)