from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from torch import Tensor

from torch.nn import Module, Linear, ReLU, Softmax, Sequential
from torch.nn import init
from torch import tensor

class MLPMFP(Module):
    """Implement a basic single layer MLP using mean-field parameterization."""

    def __init__(
            self,
            in_dim: int,
            out_dim: int,
            width: int,
            *,
            is_bias: bool=True,
    ) -> None:
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.width = width

        self.ih = Linear(
            in_features=in_dim,
            out_features=width,
            bias=is_bias,
        )
        self.relu_ih = ReLU()
        self.ho = Linear(
            in_features=width,
            out_features=out_dim,
            bias=is_bias,
        )

        if is_bias:
            self.alpha = tensor(1/(width+1))
        else:
            self.alpha = tensor(1/width)

        self.reset_parameters()

    def forward(
            self,
            x: Tensor,
    ) -> Tensor:
        if len(x.shape) > 2:
            x = x.squeeze(dim=1)
            x = x.view((x.shape[0], x.shape[1] * x.shape[2]))
        y = self.ih(x)
        y = self.relu_ih(y)
        y = self.ho(y)
        y *= self.alpha
        return y
    
    def reset_parameters(
            self,
    ) -> None:
        init.normal_(
            self.ih.weight,
        )
        init.normal_(
            self.ho.weight,
        )