from functools import partial
import torch.nn as nn
from timm.layers.helpers import to_2tuple

from . import modifiers


class Linear(nn.Linear):

    def __init__(self, n_features, out_features, bias=True, device=None, dtype=None):

        super(Linear, self).__init__(n_features, out_features, bias=bias, device=device, dtype=dtype)
    
    @modifiers.conditional_forward
    def forward(self, x):
        return super().forward(x)

    def __repr__(self):
        torch_repr = super().__repr__()
        return f"Conditional{torch_repr}"


class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.GELU,
            bias=True,
            drop=0.,
            use_conv=False,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)
        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear

        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

    @modifiers.conditional_forward
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x
    
    def __repr__(self):
        torch_repr = super().__repr__()
        return f"Conditional{torch_repr}"
