from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from torch import Tensor

from torch.nn import Module, Linear, ReLU, Softmax, Sequential

class MLP(Module):
    """Implement a basic multilayer perceptron."""
    def __init__(
            self,
            in_dim: int,
            out_dim: int,
            layers: int,
            width: int | list[int],
            *,
            is_classification: bool=True,
            add_bias: bool=True,
            logits_only=True,
    ) -> None:
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.layers = layers
        self.width = width
        self.is_class = is_classification
        self.add_bias = add_bias

        if isinstance(self.width, int):
            self.width = [self.width]

        modules = self._build()

        modules.append(
            Linear(
                in_features=self.width[-1],
                out_features=self.out_dim,
                bias=self.add_bias
            )
        )

        if is_classification and not logits_only:
            modules.append(Softmax(dim=1))

        self.net = Sequential(*modules)

    def forward(self, x: Tensor) -> Tensor:
        if self.is_class:
            x = x.view(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3])

        return self.net(x)

    def _build(self) -> list[Module]:

        modules = []
        modules.append(
            Linear(
                in_features=self.in_dim,
                out_features=self.width[0],
                bias=self.add_bias,
            ),
        )
        modules.append(ReLU())
        
        for layer in range(self.layers-1):
            modules.append(
                Linear(
                    in_features=self.width[layer],
                    out_features=self.width[layer+1],
                    bias=self.add_bias,
                )
            )
            modules.append(ReLU())

        return modules