import torch.nn as nn
import torch
import torch.nn.functional as F


class MAP_IVR(nn.Module):

    def __init__(self, dims, latent_dim=1024):
        super(MAP_IVR, self).__init__()
        self.latent_dim = latent_dim
        self.motion = ElementMapper(dims[0], self.latent_dim)
        self.appear = ElementMapper(dims[0], self.latent_dim)
        self.image = ElementMapper(dims[1], self.latent_dim)
        self.recon = nn.Linear(2 * self.latent_dim, dims[1])

    def forward(self, x1, x2):
        motion = self.motion(x1)
        appear = self.appear(x1)
        image = self.image(x2)
        recon = self.recon(torch.cat((motion, image), dim=-1))
        return motion, appear, image, recon


class ElementMapper(nn.Module):

    def __init__(self, dim, latent_dim):
        super(ElementMapper, self).__init__()
        self.mapper = nn.Sequential(
            nn.Linear(dim, latent_dim),
            nn.BatchNorm1d(latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim),
            nn.BatchNorm1d(latent_dim),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.mapper(x)


class ReconstructionLoss(nn.Module):
    def __init__(self):
        super(ReconstructionLoss, self).__init__()

    def forward(self, x_recon, x_original):
        return F.mse_loss(x_recon, x_original)


if __name__ == "__main__":
    inputs1 = torch.randn(1, 4096)
    inputs2 = torch.randn(1, 128)
    net = MAP_IVR(dims=[4096, 128])
    print(net)
    from thop import profile

    flops, params = profile(net, inputs=(inputs1, inputs2,))
    print('FLOPs = ' + str(2 * flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')
