
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init


class MCLinear(nn.Module):
    def __init__(self, input_dim=38, out_dim=64, in_channel=3, out_channel=3):
        super(MCLinear, self).__init__()
        self.input_dim = input_dim
        self.out_dim = out_dim
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.weight = nn.Parameter(torch.randn(in_channel, out_channel, input_dim, out_dim))
        self.bias = nn.Parameter(torch.zeros(out_channel, out_dim))

        # initialization
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        x = torch.einsum('bqn, qpnm->bpm', x, self.weight) + self.bias
        return x


class MLPBlock(nn.Module):
    def __init__(self, input_dim, input_channel, out_dim, out_channel, dropout=0.):
        super(MLPBlock, self).__init__()
        self.net = nn.Sequential(
            MCLinear(input_dim, input_dim, input_channel, input_channel),
            nn.LayerNorm([input_dim]),
            nn.GELU(),
            nn.Dropout(dropout),
            MCLinear(input_dim, out_dim, input_channel, out_channel),
            nn.LayerNorm([out_dim]),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        out = self.net(x)
        return F.gelu(out)


class UnetMLP(nn.Module):
    def __init__(self, input_dim=38, hidden_dim=64, hidden_channel=8, dropout=0.2):
        super(UnetMLP, self).__init__()

        self.hidden_dim = hidden_dim
        self.input_dim = input_dim

        self.mc_fc_head = nn.Sequential(MCLinear(input_dim=input_dim, out_dim=hidden_dim, in_channel=1, out_channel=3),
                                        nn.LayerNorm([64]),
                                        nn.GELU())
        self.down1 = MLPBlock(input_dim=hidden_dim, input_channel=3,
                              out_dim=hidden_dim, out_channel=hidden_channel, dropout=dropout)
        self.down2 = MLPBlock(input_dim=hidden_dim//2, input_channel=hidden_channel,
                              out_dim=hidden_dim//2, out_channel=hidden_channel, dropout=dropout)
        self.down3 = MLPBlock(input_dim=hidden_dim // 4, input_channel=hidden_channel,
                              out_dim=hidden_dim//4, out_channel=hidden_channel, dropout=dropout)
        self.up3 = MLPBlock(input_dim=hidden_dim // 4, input_channel=hidden_channel,
                              out_dim=hidden_dim//2, out_channel=hidden_channel, dropout=dropout)
        self.up2 = MLPBlock(input_dim=hidden_dim // 2, input_channel=hidden_channel,
                            out_dim=hidden_dim, out_channel=hidden_channel, dropout=dropout)
        self.up1 = MLPBlock(input_dim=hidden_dim, input_channel=hidden_channel,
                            out_dim=hidden_dim, out_channel=hidden_channel, dropout=dropout)

        self.output_fc = nn.Linear(hidden_dim, input_dim)

        self.gamma1 = nn.Parameter(torch.ones(hidden_channel, hidden_dim), requires_grad=True)
        self.gamma2 = nn.Parameter(torch.ones(hidden_channel, hidden_dim//2), requires_grad=True)

    def unet(self, x):
        x1 = self.down1(x)
        x = F.max_pool1d(x1, 2)
        x2 = self.down2(x)
        x = F.max_pool1d(x2, 2)
        x3 = self.down3(x)

        x_up2 = self.up3(x3)
        # x_cat2 = torch.cat([x_up2, x2], dim=1)
        x_cat2 = x_up2 + self.gamma2 * x2
        x_up1 = self.up2(x_cat2)
        # x_cat1 = torch.cat([x_up1, x1], dim=1)
        x_cat1 = x_up1 + self.gamma1 * x1
        x = self.up1(x_cat1)
        return x


    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.mc_fc_head(x)
        x = self.unet(x)
        x = torch.mean(x, dim=1)
        x = self.output_fc(x)
        return x


if __name__ == "__main__":
    net = UnetMLP()
    x = torch.randn(256, 38)
    y = net(x)
    print(y.shape)