import torch
import torch.nn as nn


class Mlp(nn.Module):
    def __init__(self,
            n_class: int,
            use_lm: bool,
            use_af: bool,
            lm_dim: int,
            af_single_dim: int,
            hidden_dim: int,
            no_hd_layers: int,
            dropout1: float,
            dropout2: float,
    ):
        """
        Inputs: protein sequence embeddings, shape = [len(seq), hidden dim].
        """
        super(Mlp, self).__init__()
        self.n_class = n_class
        if use_lm:
            self.input_dim = lm_dim
        if use_af:
            self.input_dim = af_single_dim
        self.linear = nn.Linear(self.input_dim, hidden_dim)
        self.hidden_layers = nn.ModuleList([LinearWithRelu(hidden_dim) for _ in range(no_hd_layers)])
        self.dropout1 = nn.Dropout(dropout1)
        self.dropout2 = nn.Dropout(dropout2)
        self.classifier = nn.Linear(hidden_dim, n_class)

    def forward(self, embeds: torch.Tensor) -> torch.Tensor:
        """Embeds: [*, N_res, embed_dim]"""

        # [*, N_res]
        mask = (embeds != 0).all(dim=-1)
        # [*]
        n_mask = mask.sum(dim=-1, keepdims=True)  # TODO: move mask to dataset
        # [*, N_res, hd]
        hidden = self.linear(embeds)
        hidden = self.dropout1(hidden)
        hidden = F.relu(hidden)

        for l in self.hidden_layers:
            hidden = l(hidden)

        # [*, N_res, hd]
        mask_expand = mask.unsqueeze(-1).expand(hidden.size())
        # [*, hd]
        output = (hidden * mask_expand).sum(dim=-2)  / n_mask
        output = self.dropout2(output)
        # [*, N_class]
        return self.classifier(output)



class LinearWithRelu(nn.Module):
    def __init__(self, dim: int):
        super(LinearWithRelu, self).__init__()
        self._layer = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor):
        return F.relu(self._layer(x))
