from functools import partial
from itertools import repeat
import collections.abc

from torch import nn as nn


def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))
    return parse

to_2tuple = _ntuple(2)


class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks

    NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
    """
    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.GELU,
            norm_layer=None,
            bias=True,
            drop=0.,
            use_conv=False,
            final_activation=None,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)
        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
        self.final_activation_str = final_activation

        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])
        
        if final_activation == "sigmoid":
            self.final_activation = nn.Sigmoid()
        elif final_activation == "softmax":
            self.final_activation = nn.Softmax(dim=-1)
        elif final_activation == "sigmoid_symm":        # identify encoder / decoder when save model
            self.final_activation = nn.Sequential(
                nn.Sigmoid(),
                nn.Identity()
            )
        elif final_activation == "tanh":
            self.final_activation = nn.Tanh()
        else:
            self.final_activation = None

    def forward(self, x):
        x = x.permute(0, 2, 3, 4, 1)
        
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.norm(x)
        x = self.fc2(x)
        x = self.drop2(x)
        
        if self.final_activation is not None:
            x = self.final_activation(x)
            if self.final_activation_str == "sigmoid_symm":
                x = x * 2.0 - 1.0
        x = x.permute(0, 4, 1, 2, 3)
        return x