

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init


class MCLinear(nn.Module):
    def __init__(self, input_dim=38, out_dim=64, in_channel=3, out_channel=3):
        super(MCLinear, self).__init__()
        self.input_dim = input_dim
        self.out_dim = out_dim
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.weight = nn.Parameter(torch.randn(in_channel, out_channel, input_dim, out_dim))
        self.bias = nn.Parameter(torch.zeros(out_channel, out_dim))

        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        x = torch.einsum('bqn, qpnm->bpm', x, self.weight) + self.bias
        return x


class MLPResBlock(nn.Module):
    def __init__(self, input_dim, input_channel, hidden_dim, hidden_channel, dropout=0.):
        super(MLPResBlock, self).__init__()
        self.net = nn.Sequential(
            MCLinear(input_dim, hidden_dim, input_channel, hidden_channel),
            nn.LayerNorm([hidden_dim]),
            nn.GELU(),
            nn.Dropout(dropout),
            MCLinear(hidden_dim, input_dim, hidden_channel, input_channel),
            nn.LayerNorm([input_dim]),
            nn.Dropout(dropout),
        )

        self.gamma = nn.Parameter(torch.ones(input_channel, input_dim), requires_grad=True)

    def forward(self, x):
        out = self.net(x)
        x = x + self.gamma * out
        return F.gelu(x)


class ResMLP(nn.Module):
    def __init__(self, input_dim=38, hidden_dim=64, hidden_channel=8, dropout=0.2):
        super(ResMLP, self).__init__()

        self.hidden_dim = hidden_dim
        self.input_dim = input_dim

        self.mc_fc_head = nn.Sequential(MCLinear(input_dim=input_dim, out_dim=hidden_dim, in_channel=1, out_channel=3),
                                        nn.LayerNorm([64]),
                                        nn.GELU())
        self.res_layer = nn.Sequential(MLPResBlock(input_dim=hidden_dim, input_channel=3, hidden_dim=hidden_dim,
                                                   hidden_channel=hidden_channel, dropout=dropout),
                                       MLPResBlock(input_dim=hidden_dim, input_channel=3, hidden_dim=hidden_dim,
                                                   hidden_channel=hidden_channel, dropout=dropout),
                                       MLPResBlock(input_dim=hidden_dim, input_channel=3, hidden_dim=hidden_dim,
                                                   hidden_channel=hidden_channel, dropout=dropout),
                                       MLPResBlock(input_dim=hidden_dim, input_channel=3, hidden_dim=hidden_dim,
                                                   hidden_channel=hidden_channel, dropout=dropout),
                                       MLPResBlock(input_dim=hidden_dim, input_channel=3, hidden_dim=hidden_dim,
                                                   hidden_channel=hidden_channel, dropout=dropout)
                                       )
        self.output_fc = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.mc_fc_head(x)
        x = self.res_layer(x)
        # avg pooling
        x = torch.mean(x, dim=1)
        x = self.output_fc(x)
        return x


if __name__ == "__main__":
    net = ResMLP()
    x = torch.randn(256, 38)
    y = net(x)
    print(y.shape)