import torch.nn as nn
import torch
import torch.nn.functional as F


class GMF(nn.Module):

    def __init__(self, dims, multiple=4, boundary=0.5):
        super(GMF, self).__init__()
        self.modalA = ElementSplit(dims[0], min(dims), multiple, boundary)
        self.modalB = ElementSplit(dims[1], min(dims), multiple, boundary)
        self.P_reconA = nn.Linear(dims[0] + min(dims), dims[0])
        self.P_reconB = nn.Linear(dims[1] + min(dims), dims[1])

    def forward(self, x1, x2):
        x1_inv, x1_spec = self.modalA(x1)
        x2_inv, x2_spec = self.modalB(x2)
        x1 = torch.concat([x2_inv, x1_spec], dim=1)
        x2 = torch.concat([x1_inv, x2_spec], dim=1)
        x1_re = self.P_reconA(x1)
        x2_re = self.P_reconB(x2)
        return x1, x2, x1_re, x2_re


class ElementSplit(nn.Module):

    def __init__(self, dim, min_len, multiple=4, boundary=0.5):
        super(ElementSplit, self).__init__()
        self.boundary = int(boundary * multiple * dim)
        self.dislen = int(dim*multiple)
        self.dim = dim
        self.P_dis = nn.Linear(dim, multiple * dim)
        self.P_con_inv = nn.Linear(self.boundary, min_len)
        self.P_cov_spec = nn.Linear(self.boundary, dim)

    def forward(self, x):
        x = self.P_dis(x)
        x_inv = self.P_con_inv(x[:, :self.boundary])
        x_spec = self.P_cov_spec(x[:, self.boundary:self.dislen])
        return x_inv, x_spec


class ReconstructionLoss(nn.Module):
    def __init__(self):
        super(ReconstructionLoss, self).__init__()

    def forward(self, x_recon, x_original):
        return F.mse_loss(x_recon, x_original)


if __name__ == "__main__":
    inputs1 = torch.randn(1, 512)
    inputs2 = torch.randn(1, 512)
    net = GMF(dims=[512, 512])

    from thop import profile
    import time
    for i in range(11):
        since = time.time()
        net.forward(inputs1, inputs2)
        print(time.time() - since)
    flops, params = profile(net, inputs=(inputs1, inputs2,))
    print('FLOPs = ' + str(2 * flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')
