# mainly borrowed from https://github.com/ChenFengYe/motion-latent-diffusion/blob/main/mld/models/architectures/mld_vae.py

from functools import reduce
from typing import List, Optional, Union

import torch
import torch.nn as nn
from torch import Tensor, nn
from torch.distributions.distribution import Distribution

from operator.cross_attention import (
    SkipTransformerEncoder,
    SkipTransformerDecoder,
    TransformerDecoderLayer,
    TransformerEncoderLayer,
)
from operator.position_encoding import build_position_encoding
from utils.temos_utils import lengths_to_mask


class MldVae(nn.Module):

    def __init__(self,
                 nfeats: int,
                 latent_dim: list = [1, 256],
                 ff_size: int = 1024,
                 num_layers: int = 9,
                 num_heads: int = 4,
                 dropout: float = 0.1,
                 arch: str = "all_encoder",
                 normalize_before: bool = False,
                 activation: str = "gelu",
                 position_embedding: str = "learned",
                 mean_std_inv=None,
                 **kwargs) -> None:

        super().__init__()

        self.latent_size, self.latent_dim = latent_dim
        input_feats = nfeats
        output_feats = nfeats
        self.arch = arch
        self.mean_std_inv = mean_std_inv
    
        self.query_pos_encoder = build_position_encoding(
            self.latent_dim, position_embedding=position_embedding)
        self.query_pos_decoder = build_position_encoding(
            self.latent_dim, position_embedding=position_embedding)

        encoder_layer = TransformerEncoderLayer(
            self.latent_dim,
            num_heads,
            ff_size,
            dropout,
            activation,
            normalize_before,
        )
        encoder_norm = nn.LayerNorm(self.latent_dim)
        self.encoder = SkipTransformerEncoder(encoder_layer, num_layers,
                                              encoder_norm)

        decoder_layer = TransformerDecoderLayer(
            self.latent_dim,
            num_heads,
            ff_size,
            dropout,
            activation,
            normalize_before,
        )
        decoder_norm = nn.LayerNorm(self.latent_dim)
        self.decoder = SkipTransformerDecoder(decoder_layer, num_layers,
                                                  decoder_norm)

        self.global_motion_token = nn.Parameter(
            torch.randn(self.latent_size * 2, self.latent_dim))

        self.skel_embedding = nn.Linear(input_feats, self.latent_dim)
        self.final_layer = nn.Linear(self.latent_dim, output_feats)

    def encode_dist(self, features: Tensor, lengths: Optional[List[int]] = None):
        if lengths is None:
            lengths = [len(feature) for feature in features]

        device = features.device
        bs, nframes, nfeats = features.shape
        mask = lengths_to_mask(lengths, device,max_len=nframes)

        x = features
        x = self.skel_embedding(x)

        dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1))
        dist_masks = torch.ones(bs, dist.shape[0])
        aug_mask = torch.cat((dist_masks, mask), 1)

        xseq = torch.cat((dist, x), 0)
        if xseq.shape[0]>500:
            print(xseq.shape, dist.shape, x.shape)

        xseq = self.query_pos_encoder(xseq)
        dist = self.encoder(xseq,
                            src_key_padding_mask=~aug_mask)[:dist.shape[0]]
        return dist
    
    def forward(self, features: Tensor, lengths: Optional[List[int]] = None):
        z, dist = self.encode(features, lengths)
        feats_rst = self.decode(z, lengths)
        return feats_rst, z, dist

    def encode_dist2z(self, dist):
        mu = dist[0:self.latent_size, ...]
        logvar = dist[self.latent_size:, ...] 

        std = logvar.exp().pow(0.5)
        dist = torch.distributions.Normal(mu, std)
        latent = dist.rsample()
        return latent, dist

    def encode(
            self,
            features: Tensor,
            lengths: Optional[List[int]] = None
    ) -> Union[Tensor, Distribution]:
        dist = self.encode_dist(features, lengths)
        latent, dist = self.encode_dist2z(dist)
        return latent, dist


    def decode(self, z: Tensor, lengths: List[int]):
        mask = lengths_to_mask(lengths, z.device)
        bs, nframes = mask.shape

        queries = torch.zeros(nframes, bs, self.latent_dim, device=z.device)
        queries = self.query_pos_decoder(queries)
        output = self.decoder(
            tgt=queries,
            memory=z,
            tgt_key_padding_mask=~mask,
        ).squeeze(0)

        output = self.final_layer(output)
        output[~mask.T] = 0
        feats = output.permute(1, 0, 2)
        return feats
