import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import kl_loss


def positional_encoding(tensor):
    """
    Add positional encoding to the input tensor.

    Parameters:
    - tensor (torch.Tensor): Input tensor with shape [batch, time_len, feature_size].

    Returns:
    - torch.Tensor: Tensor with added positional encodings.
    """
    batch_size, time_len, feature_size = tensor.size()

    # Create positional encodings
    encoding = torch.zeros((time_len, feature_size), device=tensor.device)
    t = 1 / 10000 ** (torch.arange(0, feature_size, 2) / feature_size)
    k = torch.arange(time_len)
    v = torch.outer(k, t)
    encoding[:, 0::2] = v.sin()
    encoding[:, 1::2] = v.cos()[:, :encoding.shape[1]//2]

    # Add positional encodings to the input tensor
    tensor_with_pos_enc = tensor + encoding

    return tensor_with_pos_enc


class InterModalAttention(nn.TransformerEncoderLayer):

    def forward(
            self,
            src, src_2,
            src_mask=None,
            src_key_padding_mask=None,
            is_causal=False):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
            is_causal: If specified, applies a causal mask as ``src mask``.
                Default: ``False``.
                Warning:
                ``is_causal`` provides a hint that ``src_mask`` is the
                causal mask. Providing incorrect hints can result in
                incorrect execution, including forward and backward
                compatibility.

        Shape:
            see the docs in Transformer class.
        """
        src_key_padding_mask = F._canonical_mask(
            mask=src_key_padding_mask,
            mask_name="src_key_padding_mask",
            other_type=F._none_or_dtype(src_mask),
            other_name="src_mask",
            target_type=src.dtype
        )

        src_mask = F._canonical_mask(
            mask=src_mask,
            mask_name="src_mask",
            other_type=None,
            other_name="",
            target_type=src.dtype,
            check_other=False,
        )

        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
        why_not_sparsity_fast_path = ''
        if not src.dim() == 3:
            why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
        elif self.training:
            why_not_sparsity_fast_path = "training is enabled"
        elif not self.self_attn.batch_first :
            why_not_sparsity_fast_path = "self_attn.batch_first was not True"
        elif not self.self_attn._qkv_same_embed_dim :
            why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
        elif not self.activation_relu_or_gelu:
            why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
        elif not (self.norm1.eps == self.norm2.eps):
            why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
        elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
            why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
        elif self.self_attn.num_heads % 2 == 1:
            why_not_sparsity_fast_path = "num_head is odd"
        elif torch.is_autocast_enabled():
            why_not_sparsity_fast_path = "autocast is enabled"
        if not why_not_sparsity_fast_path:
            tensor_args = (
                src,
                self.self_attn.in_proj_weight,
                self.self_attn.in_proj_bias,
                self.self_attn.out_proj.weight,
                self.self_attn.out_proj.bias,
                self.norm1.weight,
                self.norm1.bias,
                self.norm2.weight,
                self.norm2.bias,
                self.linear1.weight,
                self.linear1.bias,
                self.linear2.weight,
                self.linear2.bias,
            )

            # We have to use list comprehensions below because TorchScript does not support
            # generator expressions.
            _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
            if torch.overrides.has_torch_function(tensor_args):
                why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
            elif not all((x.device.type in _supported_device_type) for x in tensor_args):
                why_not_sparsity_fast_path = ("some Tensor argument's device is neither one of "
                                              f"{_supported_device_type}")
            elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
                why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
                                              "input/output projection weights or biases requires_grad")

            if not why_not_sparsity_fast_path:
                merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
                return torch._transformer_encoder_layer_fwd(
                    src,
                    self.self_attn.embed_dim,
                    self.self_attn.num_heads,
                    self.self_attn.in_proj_weight,
                    self.self_attn.in_proj_bias,
                    self.self_attn.out_proj.weight,
                    self.self_attn.out_proj.bias,
                    self.activation_relu_or_gelu == 2,
                    self.norm_first,
                    self.norm1.eps,
                    self.norm1.weight,
                    self.norm1.bias,
                    self.norm2.weight,
                    self.norm2.bias,
                    self.linear1.weight,
                    self.linear1.bias,
                    self.linear2.weight,
                    self.linear2.bias,
                    merged_mask,
                    mask_type,
                )


        x = src
        y = src_2
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), self.norm1(y), src_mask, src_key_padding_mask, is_causal=is_causal)
            x = x + self._ff_block(self.norm2(x))
        else:
            x = self.norm1(x + self._sa_block(x, y, src_mask, src_key_padding_mask, is_causal=is_causal))
            x = self.norm2(x + self._ff_block(x))

        return x

    # self-attention block
    def _sa_block(self, x, y, attn_mask, key_padding_mask, is_causal=False):
        x = self.self_attn(y, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False, is_causal=is_causal)[0]
        return self.dropout1(x)


def reparameterization(mu, log_var):
    sigma = torch.exp(log_var * 0.5)
    eps = torch.randn_like(sigma)
    return mu + sigma * eps


class HOGNet(nn.Module):
    def __init__(self):
        super(HOGNet, self).__init__()
        self.norm1 = nn.BatchNorm1d(4464)
        self.conv1 = nn.Conv2d(31, 4, 3, 1, 0, bias=False)
        self.act = nn.LeakyReLU(inplace=True)
        self.conv2 = nn.Conv2d(4, 2, 3, 1, 0, bias=False)

    def forward(self, x):
        x = self.norm1(x)
        x = x.view(-1, 12, 12, 31)
        x = x.transpose(1, 3)
        x = self.conv1(x)
        x = self.act(x)
        x = self.conv2(x)
        return x


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.pose_squeeze = nn.Sequential(
            nn.LayerNorm(6),
            nn.Linear(6, 128, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(128, 128, bias=False)
        )

        self.gaze_squeeze = nn.Sequential(
            nn.LayerNorm(6),
            nn.Linear(6, 128, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(128, 128, bias=False)
        )

        self.au_squeeze = nn.Sequential(
            nn.LayerNorm(35),
            nn.Linear(35, 128, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(128, 128, bias=False)
        )

        self.lm_squeeze = nn.Sequential(
            nn.LayerNorm(136),
            nn.Linear(136, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )

        self.hog_squeeze = HOGNet()

        self.mfcc_squeeze = nn.Sequential(
            nn.LayerNorm(39),
            nn.Linear(39, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )

        self.egemaps_squeeze = nn.Sequential(
            nn.LayerNorm(88),
            nn.Linear(88, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )

        self.covarep_squeeze = nn.Sequential(
            nn.LayerNorm(81),
            nn.Linear(81, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )

        self.bert_squeeze = nn.Sequential(
            nn.LayerNorm(768),
            nn.Linear(768, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )

        encoder_layer_homo = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_encoder_homo = nn.TransformerEncoder(encoder_layer_homo, num_layers=8)
        encoder_layer_v = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_encoder_hetero_v = nn.TransformerEncoder(encoder_layer_v, num_layers=8)
        encoder_layer_a = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_encoder_hetero_a = nn.TransformerEncoder(encoder_layer_a, num_layers=8)
        encoder_layer_t = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_encoder_hetero_t = nn.TransformerEncoder(encoder_layer_t, num_layers=8)

        noise_layer_v = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_noise_v = nn.TransformerEncoder(noise_layer_v, num_layers=8)
        noise_layer_a = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_noise_a = nn.TransformerEncoder(noise_layer_a, num_layers=8)
        noise_layer_t = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_noise_t = nn.TransformerEncoder(noise_layer_t, num_layers=8)

        self.transformer_intermodal = InterModalAttention(d_model=128, nhead=4, batch_first=True)

        self.score = nn.Sequential(
            nn.Linear(128, 1024),
            nn.LeakyReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

        self.vec_homo = nn.Parameter(torch.ones(1, 3, 1))
        self.vec_hete = nn.Parameter(torch.ones(1, 3, 1))

        self.decoder_homo_a = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )
        self.decoder_homo_v = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )
        self.decoder_homo_t = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )
        self.decoder_hete_a = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )
        self.decoder_hete_v = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )
        self.decoder_hete_t = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )

        self.mae = nn.L1Loss()

    def forward(self, x_v, x_a, x_t):
        pose_s = self.pose_squeeze(x_v[:, :, :6].view(-1, 6)).view(x_v.shape[0], x_v.shape[1], -1)  # [b, t, 128]
        gaze_s = self.gaze_squeeze(x_v[:, :, 6:12].view(-1, 6)).view(x_v.shape[0], x_v.shape[1], -1)  # 128
        au_s = self.au_squeeze(x_v[:, :, 12:47].view(-1, 35)).view(x_v.shape[0], x_v.shape[1], -1)  # 128
        lm_s = self.lm_squeeze(x_v[:, :, 47:183].view(-1, 136)).view(x_v.shape[0], x_v.shape[1], -1)  # 128
        hog_s = self.hog_squeeze(x_v[:, :, 183:].view(-1, 4464)).view(x_v.shape[0], x_v.shape[1], -1)  # 128
        x_v = pose_s + gaze_s + au_s + lm_s + hog_s

        mfcc_s = self.mfcc_squeeze(x_a[:, :, :39].view(-1, 39)).view(x_a.shape[0], x_a.shape[1], -1)  # 128
        egemaps_s = self.egemaps_squeeze(x_a[:, :, 39:127].view(-1, 88)).view(x_a.shape[0], x_a.shape[1], -1)  # 128
        covarep_s = self.covarep_squeeze(x_a[:, :, 127:].view(-1, 81)).view(x_a.shape[0], x_a.shape[1], -1)  # 128
        x_a = mfcc_s + egemaps_s + covarep_s

        bert_s = self.bert_squeeze(x_t.view(-1, 768)).view(x_t.shape[0], x_t.shape[1], -1)  # 128
        x_t = bert_s

        # encode the modality features with temporal info
        # and get the first dim as the new modality feature
        x_v = positional_encoding(x_v)
        x_a = positional_encoding(x_a)
        x_t = positional_encoding(x_t)

        # modality homo
        x_v_homo = self.transformer_encoder_homo(x_v +
                                                 torch.mean(x_a, dim=1, keepdim=True) +
                                                 torch.mean(x_t, dim=1, keepdim=True)) + x_v  # [b, t_len, 128]
        x_a_homo = self.transformer_encoder_homo(x_a +
                                                 torch.mean(x_v, dim=1, keepdim=True) +
                                                 torch.mean(x_t, dim=1, keepdim=True)) + x_a  # [b, t_len * 3.33, 128]
        x_t_homo = self.transformer_encoder_homo(x_t +
                                                 torch.mean(x_v, dim=1, keepdim=True) +
                                                 torch.mean(x_a, dim=1, keepdim=True)) + x_t  # [b, t_len_text, 128]

        # modality hetero
        x_v_hete = self.transformer_encoder_hetero_v(x_v +
                                                     torch.mean(x_a, dim=1, keepdim=True) +
                                                     torch.mean(x_t, dim=1, keepdim=True) +
                                                     x_v_homo) + x_v  # [b, t_len, 128]
        x_a_hete = self.transformer_encoder_hetero_a(x_a +
                                                     torch.mean(x_v, dim=1, keepdim=True) +
                                                     torch.mean(x_t, dim=1, keepdim=True) +
                                                     x_a_homo) + x_a  # [b, t_len * 3.33, 128]
        x_t_hete = self.transformer_encoder_hetero_t(x_t +
                                                     torch.mean(x_v, dim=1, keepdim=True) +
                                                     torch.mean(x_a, dim=1, keepdim=True) +
                                                     x_t_homo) + x_t  # [b, t_len_text, 128]

        # modality noise
        x_v_noise = self.transformer_noise_v(x_v)  # [b, t_len, 128]
        x_a_noise = self.transformer_noise_a(x_a)  # [b, t_len * 3.33, 128]
        x_t_noise = self.transformer_noise_t(x_t)  # [b, t_len_text, 128]

        # make cross modal input
        x_v = x_v_homo[:, 0:1, :] + x_v_hete[:, 0:1, :] - x_v_noise[:, 0:1, :]
        x_a = x_a_homo[:, 0:1, :] + x_a_hete[:, 0:1, :] - x_a_noise[:, 0:1, :]
        x_t = x_t_homo[:, 0:1, :] + x_t_hete[:, 0:1, :] - x_t_noise[:, 0:1, :]

        # get vae latent
        mu = (self.transformer_intermodal(x_t, x_v)
              + self.transformer_intermodal(x_a, x_t)
              + self.transformer_intermodal(x_v, x_a))
        sigma = (self.transformer_intermodal(x_v, x_t)
                 + self.transformer_intermodal(x_t, x_a)
                 + self.transformer_intermodal(x_a, x_v))
        latents = reparameterization(mu, sigma)  # [b, 1, 128]

        vae_kl = kl_loss(mu.squeeze(1), sigma.squeeze(1))

        s = self.score(latents).view(-1)  # [b]

        x_homo_dec = torch.matmul(self.vec_homo, latents)  # [b, 3, 128]
        x_hete_dec = torch.matmul(self.vec_hete, latents)

        # decode hetero
        x_v_hete_dec = self.decoder_hete_v(x_hete_dec[:, 0, :]) + x_hete_dec[:, 0, :]  # [b, 128]
        x_a_hete_dec = self.decoder_hete_a(x_hete_dec[:, 1, :]) + x_hete_dec[:, 1, :]  # [b, 128]
        x_t_hete_dec = self.decoder_hete_t(x_hete_dec[:, 2, :]) + x_hete_dec[:, 2, :]  # [b, 128]

        # decode homo
        x_v_homo_dec = self.decoder_homo_v(x_homo_dec[:, 0, :]) + x_homo_dec[:, 0, :]  # [b, 128]
        x_a_homo_dec = self.decoder_homo_a(x_homo_dec[:, 1, :]) + x_homo_dec[:, 1, :]  # [b, 128]
        x_t_homo_dec = self.decoder_homo_t(x_homo_dec[:, 2, :]) + x_homo_dec[:, 2, :]  # [b, 128]

        vae_diff = (self.mae(x_v_hete_dec, x_v_hete[:, 0, :])
                    + self.mae(x_a_hete_dec, x_a_hete[:, 0, :])
                    + self.mae(x_t_hete_dec, x_t_hete[:, 0, :])
                    + self.mae(x_v_homo_dec, x_v_homo[:, 0, :])
                    + self.mae(x_a_homo_dec, x_a_homo[:, 0, :])
                    + self.mae(x_t_homo_dec, x_t_homo[:, 0, :])) / 6.

        output = {
            'score': s,
            'l_kl': vae_kl,
            'l_rec': vae_diff,
            'x_v_homo': x_v_homo[:, 0, :],
            'x_a_homo': x_a_homo[:, 0, :],
            'x_t_homo': x_t_homo[:, 0, :],
            'x_v_hete': x_v_hete[:, 0, :],
            'x_a_hete': x_a_hete[:, 0, :],
            'x_t_hete': x_t_hete[:, 0, :],
            'x_v_noise': x_v_noise[:, 0, :],
            'x_a_noise': x_a_noise[:, 0, :],
            'x_t_noise': x_t_noise[:, 0, :]
        }

        return output


class NetTransformer(nn.Module):
    def __init__(self):
        super(NetTransformer, self).__init__()
        self.pose_squeeze = nn.Sequential(
            nn.LayerNorm(6),
            nn.Linear(6, 128, bias=False),
            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True),
                                  num_layers=4)
        )

        self.gaze_squeeze = nn.Sequential(
            nn.LayerNorm(6),
            nn.Linear(6, 128, bias=False),
            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True),
                                  num_layers=4)
        )

        self.au_squeeze = nn.Sequential(
            nn.LayerNorm(35),
            nn.Linear(35, 128, bias=False),
            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True),
                                  num_layers=4)
        )

        self.lm_squeeze = nn.Sequential(
            nn.LayerNorm(136),
            nn.Linear(136, 128, bias=False),
            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True),
                                  num_layers=4)
        )

        self.hog_squeeze = HOGNet()

        self.mfcc_squeeze = nn.Sequential(
            nn.LayerNorm(39),
            nn.Linear(39, 128, bias=False),
            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True),
                                  num_layers=4)
        )

        self.egemaps_squeeze = nn.Sequential(
            nn.LayerNorm(88),
            nn.Linear(88, 128, bias=False),
            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True),
                                  num_layers=4)
        )

        self.covarep_squeeze = nn.Sequential(
            nn.LayerNorm(81),
            nn.Linear(81, 128, bias=False),
            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True),
                                  num_layers=4)
        )

        self.bert_squeeze = nn.Sequential(
            nn.LayerNorm(768),
            nn.Linear(768, 128, bias=False),
            nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True),
                                  num_layers=4)
        )

        encoder_layer_homo = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_encoder_homo = nn.TransformerEncoder(encoder_layer_homo, num_layers=8)
        encoder_layer_v = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_encoder_hetero_v = nn.TransformerEncoder(encoder_layer_v, num_layers=8)
        encoder_layer_a = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_encoder_hetero_a = nn.TransformerEncoder(encoder_layer_a, num_layers=8)
        encoder_layer_t = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_encoder_hetero_t = nn.TransformerEncoder(encoder_layer_t, num_layers=8)

        noise_layer_v = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_noise_v = nn.TransformerEncoder(noise_layer_v, num_layers=8)
        noise_layer_a = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_noise_a = nn.TransformerEncoder(noise_layer_a, num_layers=8)
        noise_layer_t = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True)
        self.transformer_noise_t = nn.TransformerEncoder(noise_layer_t, num_layers=8)

        self.transformer_intermodal = InterModalAttention(d_model=128, nhead=4, batch_first=True)

        self.score = nn.Sequential(
            nn.Linear(128, 1024),
            nn.LeakyReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

        self.vec_homo = nn.Parameter(torch.ones(1, 3, 1))
        self.vec_hete = nn.Parameter(torch.ones(1, 3, 1))
        self.vec_noise = nn.Parameter(torch.ones(1, 3, 1))

        self.decoder_homo_a = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )
        self.decoder_homo_v = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )
        self.decoder_homo_t = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )
        self.decoder_hete_a = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )
        self.decoder_hete_v = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )
        self.decoder_hete_t = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )
        self.decoder_n_a = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )
        self.decoder_n_v = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )
        self.decoder_n_t = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 256, bias=False),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 128, bias=False)
        )

        self.mae = nn.L1Loss()

    def forward(self, x_v, x_a, x_t):

        hog_s = self.hog_squeeze(x_v[:, :, 183:].view(-1, 4464)).view(x_v.shape[0], x_v.shape[1], -1)  # [b, t_v, 128]

        x_v = positional_encoding(x_v)
        x_a = positional_encoding(x_a)
        x_t = positional_encoding(x_t)

        pose_s = self.pose_squeeze(x_v[:, :, :6])  # [b, t_v, 128]
        gaze_s = self.gaze_squeeze(x_v[:, :, 6:12])  # [b, t_v, 128]
        au_s = self.au_squeeze(x_v[:, :, 12:47])  # [b, t_v, 128]
        lm_s = self.lm_squeeze(x_v[:, :, 47:183])  # [b, t_v, 128]
        x_v = pose_s + gaze_s + au_s + lm_s + hog_s  # [b, t_v, 128]

        mfcc_s = self.mfcc_squeeze(x_a[:, :, :39])  # [b, t_a, 128]
        egemaps_s = self.egemaps_squeeze(x_a[:, :, 39:127])  # [b, t_a, 128]
        covarep_s = self.covarep_squeeze(x_a[:, :, 127:])  # [b, t_a, 128]
        x_a = mfcc_s + egemaps_s + covarep_s  # [b, t_a, 128]

        bert_s = self.bert_squeeze(x_t)  # [b, t_t, 128]
        x_t = bert_s  # [b, t_t, 128]

        # encode the modality features with temporal info
        # and get the first dim as the new modality feature

        # modality homo
        x_v_homo = self.transformer_encoder_homo(x_v + x_a[:, 0:1, :] + x_t[:, 0:1, :]) + x_v  # [b, t_v, 128]
        x_a_homo = self.transformer_encoder_homo(x_a + x_v[:, 0:1, :] + x_t[:, 0:1, :]) + x_a  # [b, t_a, 128]
        x_t_homo = self.transformer_encoder_homo(x_t + x_a[:, 0:1, :] + x_v[:, 0:1, :]) + x_t  # [b, t_t, 128]

        # modality hetero
        x_v_hete = self.transformer_encoder_hetero_v(x_v + x_a[:, 0:1, :] + x_t[:, 0:1, :] + x_v_homo) + x_v  # [b, t_v, 128]
        x_a_hete = self.transformer_encoder_hetero_a(x_a + x_v[:, 0:1, :] + x_t[:, 0:1, :] + x_a_homo) + x_a  # [b, t_a, 128]
        x_t_hete = self.transformer_encoder_hetero_t(x_t + x_a[:, 0:1, :] + x_v[:, 0:1, :] + x_t_homo) + x_t  # [b, t_t, 128]

        # modality noise
        x_v_noise = self.transformer_noise_v(x_v)  # [b, t_v, 128]
        x_a_noise = self.transformer_noise_a(x_a)  # [b, t_a, 128]
        x_t_noise = self.transformer_noise_t(x_t)  # [b, t_t, 128]

        # make cross modal input
        x_v = x_v_homo[:, 0:1, :] + x_v_hete[:, 0:1, :] - x_v_noise[:, 0:1, :]
        x_a = x_a_homo[:, 0:1, :] + x_a_hete[:, 0:1, :] - x_a_noise[:, 0:1, :]
        x_t = x_t_homo[:, 0:1, :] + x_t_hete[:, 0:1, :] - x_t_noise[:, 0:1, :]

        # get vae latent
        mu = (self.transformer_intermodal(x_t, x_v)
              + self.transformer_intermodal(x_a, x_t)
              + self.transformer_intermodal(x_v, x_a))
        sigma = (self.transformer_intermodal(x_v, x_t)
                 + self.transformer_intermodal(x_t, x_a)
                 + self.transformer_intermodal(x_a, x_v))
        latents = reparameterization(mu, sigma)  # [b, 1, 128]

        s = self.score(latents).view(-1)  # [b]

        x_homo_dec = torch.matmul(self.vec_homo, latents)  # [b, 3, 128]
        x_hete_dec = torch.matmul(self.vec_hete, latents)
        x_n_dec = torch.matmul(self.vec_noise, latents)

        # decode hetero
        x_v_hete_dec = self.decoder_hete_v(x_hete_dec[:, 0, :]) + x_hete_dec[:, 0, :]  # [b, 128]
        x_a_hete_dec = self.decoder_hete_a(x_hete_dec[:, 1, :]) + x_hete_dec[:, 1, :]  # [b, 128]
        x_t_hete_dec = self.decoder_hete_t(x_hete_dec[:, 2, :]) + x_hete_dec[:, 2, :]  # [b, 128]

        # decode homo
        x_v_homo_dec = self.decoder_homo_v(x_homo_dec[:, 0, :]) + x_homo_dec[:, 0, :]  # [b, 128]
        x_a_homo_dec = self.decoder_homo_a(x_homo_dec[:, 1, :]) + x_homo_dec[:, 1, :]  # [b, 128]
        x_t_homo_dec = self.decoder_homo_t(x_homo_dec[:, 2, :]) + x_homo_dec[:, 2, :]  # [b, 128]

        # decode n
        x_v_n_dec = self.decoder_n_v(x_n_dec[:, 0, :]) + x_n_dec[:, 0, :]  # [b, 128]
        x_a_n_dec = self.decoder_n_a(x_n_dec[:, 1, :]) + x_n_dec[:, 1, :]  # [b, 128]
        x_t_n_dec = self.decoder_n_t(x_n_dec[:, 2, :]) + x_n_dec[:, 2, :]  # [b, 128]

        output = {
            'score': s,
            'x_v_homo': x_v_homo[:, 0, :],
            'x_a_homo': x_a_homo[:, 0, :],
            'x_t_homo': x_t_homo[:, 0, :],
            'x_v_hete': x_v_hete[:, 0, :],
            'x_a_hete': x_a_hete[:, 0, :],
            'x_t_hete': x_t_hete[:, 0, :],
            'x_v_noise': x_v_noise[:, 0, :],
            'x_a_noise': x_a_noise[:, 0, :],
            'x_t_noise': x_t_noise[:, 0, :],
            'x_v_hete_dec': x_v_hete_dec,
            'x_a_hete_dec': x_a_hete_dec,
            'x_t_hete_dec': x_t_hete_dec,
            'x_v_homo_dec': x_v_homo_dec,
            'x_a_homo_dec': x_a_homo_dec,
            'x_t_homo_dec': x_t_homo_dec,
            'x_v_n_dec': x_v_n_dec,
            'x_a_n_dec': x_a_n_dec,
            'x_t_n_dec': x_t_n_dec,
            'mu': mu.squeeze(1),
            'sigma': sigma.squeeze(1),
            'latents': latents.squeeze(1)
        }

        return output


if __name__ == '__main__':
    net = NetTransformer()
    a = torch.zeros((2, int(37*3.33), 208))
    v = torch.zeros((2, 37, 4647))
    t = torch.zeros((2, 15, 768))
    output = net(v, a, t)
    print(output['score'])
