import torch.nn as nn
import torch
import torch.nn.functional as F


class PMR(nn.Module):

    def __init__(self, dims, multiple=4, boundary=0.5):
        super(PMR, self).__init__()
        self.modalA = ElementSplit(dims[0], min(dims), multiple, boundary)
        self.modalB = ElementSplit(dims[1], min(dims), multiple, boundary)
        self.P_reconA = nn.Linear(dims[0] + min(dims), dims[0])
        self.P_reconB = nn.Linear(dims[1] + min(dims), dims[1])

    def forward(self, x1, x2):
        x1_inv, x1_spec = self.modalA(x1)
        x2_inv, x2_spec = self.modalB(x2)
        x1 = torch.concat([x2_inv, x1_spec], dim=1)
        x2 = torch.concat([x1_inv, x2_spec], dim=1)
        x1_re = self.P_reconA(x1)
        x2_re = self.P_reconB(x2)
        return x1, x2, x1_re, x2_re

class PMR_MM(nn.Module):

    def __init__(self, dims, multiple=4, boundary=0.5):
        super(PMR_MM, self).__init__()
        self.modalA = ElementSplit(dims[0], min(dims), multiple, boundary)
        self.modalB = ElementSplit(dims[1], min(dims), multiple, boundary)
        self.modalC = ElementSplit(dims[2], min(dims), multiple, boundary)
        self.P_reconA = nn.Linear(dims[0] + min(dims), dims[0])
        self.P_reconB = nn.Linear(dims[1] + min(dims), dims[1])
        self.P_reconC = nn.Linear(dims[2] + min(dims), dims[2])

    def forward(self, x1, x2, x3):
        x1_inv, x1_spec = self.modalA(x1)
        x2_inv, x2_spec = self.modalB(x2)
        x3_inv, x3_spec = self.modalC(x3)
        x1 = torch.concat([x3_inv, x1_spec], dim=1)
        x2 = torch.concat([x1_inv, x2_spec], dim=1)
        x3 = torch.concat([x2_inv, x3_spec], dim=1)
        x1_re = self.P_reconA(x1)
        x2_re = self.P_reconB(x2)
        x3_re = self.P_reconC(x3)
        return x1, x2, x3, x1_re, x2_re, x3_re
    
class ElementSplit(nn.Module):

    def __init__(self, dim, min_len, multiple=4, boundary=0.5):
        super(ElementSplit, self).__init__()
        if boundary > 1.:
            self.boundary = boundary
        else:
            self.boundary = int(boundary * multiple * dim)
        self.dislen = int(dim*multiple)
        self.dim = dim
        self.P_dis = nn.Linear(dim, multiple * dim)
        self.P_con_inv = nn.Linear(self.boundary, min_len)
        self.P_cov_spec = nn.Linear(self.dislen - self.boundary, dim)

    def forward(self, x):
        x = self.P_dis(x)
        x_inv = self.P_con_inv(x[:, :self.boundary])
        x_spec = self.P_cov_spec(x[:, self.boundary:self.dislen])
        return x_inv, x_spec


class ReconstructionLoss(nn.Module):
    def __init__(self):
        super(ReconstructionLoss, self).__init__()

    def forward(self, x_recon, x_original):
        return F.mse_loss(x_recon, x_original)


if __name__ == "__main__":
    inputs1 = torch.randn(1, 4096)
    inputs2 = torch.randn(1, 128)
    net = PMR(dims=[4096, 128], boundary=0.125)

    from thop import profile
    import time
    for i in range(11):
        since = time.time()
        a,b,c,d = net.forward(inputs1, inputs2)
        print(a.shape, b.shape, c.shape, d.shape)
        print(time.time() - since)
    flops, params = profile(net, inputs=(inputs1, inputs2, ))
    print('FLOPs = ' + str(2 * flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')

    assert False

    inputs1 = torch.randn(1, 4096)
    inputs2 = torch.randn(1, 128)
    inputs3 = torch.randn(1, 768)
    net = PMR_MM(dims=[4096, 128, 768], boundary=0.5)

    from thop import profile
    import time
    for i in range(11):
        since = time.time()
        a,b,c,d,e,f = net.forward(inputs1, inputs2, inputs3)
        print(a.shape, b.shape, c.shape, d.shape, e.shape, f.shape)
        print(time.time() - since)
    flops, params = profile(net, inputs=(inputs1, inputs2,inputs3, ))
    print('FLOPs = ' + str(2 * flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')
