""" MLP module w/ dropout and configurable activation layer

Hacked together by / Copyright 2020 Ross Wightman
"""
from functools import partial

import torch
from torch import nn as nn

from .grn import GlobalResponseNorm
from .helpers import to_2tuple


class InhibitoryMlp(nn.Module):
    """ Inhibitory enhanced Mlp based on MLP as used in Vision Transformer
    """
    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.GELU,
            norm_layer=None,
            bias=True,
            drop=0.,
            use_conv=False,
            inhibitory_ratio=0.125,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)
        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear

        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
        # self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
        # assert (out_features // 8) != 0, "inhibitory connection now only implemented for 8 multiplied number of features"
        assert inhibitory_ratio <= 0.5, "inhibitory connection now only implemented for less than half of features"
        # self.fc2 = linear_layer(hidden_features, out_features // 2, bias=bias[1])
        self.inhibitory = int(out_features*inhibitory_ratio)
        self.fc2 = linear_layer(hidden_features, out_features - self.inhibitory, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.norm(x)
        x = self.fc2(x)

        x = torch.cat([x, -x[:, :, -self.inhibitory:]], dim=2)

        x = self.drop2(x)
        return x
