

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init


class MCLinear(nn.Module):
    def __init__(self, input_dim=38, out_dim=64, in_channel=3, out_channel=3):
        super(MCLinear, self).__init__()
        self.input_dim = input_dim
        self.out_dim = out_dim
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.weight = nn.Parameter(torch.randn(in_channel, out_channel, input_dim, out_dim))
        self.bias = nn.Parameter(torch.zeros(out_channel, out_dim))

        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        x = torch.einsum('bqn, qpnm->bpm', x, self.weight) + self.bias
        return x


class MLPResBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout=0.):
        super(MLPResBlock, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            # nn.LayerNorm([hidden_dim]),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, input_dim),
            # nn.LayerNorm([input_dim]),
            nn.Dropout(dropout),
        )

        self.gamma = nn.Parameter(torch.ones(input_dim), requires_grad=True)

    def forward(self, x):
        out = self.net(x)
        x = x + self.gamma * out
        return F.gelu(x)


class ResMLP(nn.Module):
    def __init__(self, input_dim=38, hidden_dim=64, dropout=0.2):
        super(ResMLP, self).__init__()

        self.hidden_dim = hidden_dim
        self.input_dim = input_dim

        self.mc_fc_head = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                                        # nn.LayerNorm([64]),
                                        nn.GELU())
        self.res_layer = nn.Sequential(MLPResBlock(input_dim=hidden_dim, hidden_dim=hidden_dim,
                                                   dropout=dropout),
                                       MLPResBlock(input_dim=hidden_dim, hidden_dim=hidden_dim,
                                                   dropout=dropout),
                                       MLPResBlock(input_dim=hidden_dim, hidden_dim=hidden_dim,
                                                   dropout=dropout),
                                       MLPResBlock(input_dim=hidden_dim, hidden_dim=hidden_dim,
                                                   dropout=dropout),
                                       MLPResBlock(input_dim=hidden_dim, hidden_dim=hidden_dim,
                                                   dropout=dropout),
                                       )
        self.output_fc = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        x = self.mc_fc_head(x)
        feature = self.res_layer(x)
        x = self.output_fc(feature)
        return x


if __name__ == "__main__":
    net = ResMLP()
    x = torch.randn(256, 38)
    y = net(x)
    print(y.shape)