import functools
import torch.nn as nn

from .build import HEAD_REGISTRY


class MLP(nn.Module):

    def __init__(
        self,
        in_features=2048,
        hidden_layers=[],
        activation="relu",
        bn=True,
        dropout=0.0,
    ):
        super().__init__()
        if isinstance(hidden_layers, int):
            hidden_layers = [hidden_layers]

        assert len(hidden_layers) > 0
        self.out_features = hidden_layers[-1]

        mlp = []

        if activation == "relu":
            act_fn = functools.partial(nn.ReLU, inplace=True)
        elif activation == "leaky_relu":
            act_fn = functools.partial(nn.LeakyReLU, inplace=True)
        else:
            raise NotImplementedError

        for hidden_dim in hidden_layers:
            mlp += [nn.Linear(in_features, hidden_dim)]
            if bn:
                mlp += [nn.BatchNorm1d(hidden_dim)]
            mlp += [act_fn()]
            if dropout > 0:
                mlp += [nn.Dropout(dropout)]
            in_features = hidden_dim

        self.mlp = nn.Sequential(*mlp)

    def forward(self, x):
        return self.mlp(x)


@HEAD_REGISTRY.register()
def mlp(**kwargs):
    return MLP(**kwargs)
