import copy
import math
import pickle
import numpy as np
import torch
import torch.nn as nn
from .utils.layer import BasicBlock
from .motion_encoder import *


class WavEncoder(nn.Module):
    def __init__(self, out_dim, audio_in=1):
        super().__init__()
        self.out_dim = out_dim
        self.feat_extractor = nn.Sequential(
            BasicBlock(audio_in, out_dim // 4, 15, 5, first_dilation=1600, downsample=True),
            BasicBlock(out_dim // 4, out_dim // 4, 15, 6, first_dilation=0, downsample=True),
            BasicBlock(out_dim // 4, out_dim // 4, 15, 1, first_dilation=7, ),
            BasicBlock(out_dim // 4, out_dim // 2, 15, 6, first_dilation=0, downsample=True),
            BasicBlock(out_dim // 2, out_dim // 2, 15, 1, first_dilation=7),
            BasicBlock(out_dim // 2, out_dim, 15, 3, first_dilation=0, downsample=True),
        )

    def forward(self, wav_data):
        if wav_data.dim() == 2:
            wav_data = wav_data.unsqueeze(1)
        else:
            wav_data = wav_data.transpose(1, 2)
        out = self.feat_extractor(wav_data)
        return out.transpose(1, 2)


class MLP(nn.Module):
    def __init__(self, in_dim, hidden_size, out_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_size),
            nn.LeakyReLU(0.2, True),
            #nn.Linear(hidden_size, hidden_size),
            #nn.LeakyReLU(0.2, True),
            nn.Linear(hidden_size, out_dim)
        )

    def forward(self, inputs):
        #residual = inputs
        out = self.mlp(inputs)
        #out = out+residual
        return out


class PeriodicPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=60):
        super(PeriodicPositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(period, d_model)
        position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, period, d_model)
        repeat_num = (max_seq_len // period) + 1
        pe = pe.repeat(1, repeat_num, 1)  # (1, repeat_num, period, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # print(self.pe.shape, x.shape)
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class MAGE_Transformer(nn.Module):
    def __init__(self, args):
        super(MAGE_Transformer, self).__init__()
        self.args = args
        with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f:
            self.lang_model = pickle.load(f)
            pre_trained_embedding = self.lang_model.word_embedding_weights
        self.text_pre_encoder_face = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),
                                                                  freeze=args.t_fix_pre)
        self.text_encoder_face = nn.Linear(300, args.audio_f)
        self.text_encoder_face = nn.Linear(300, args.audio_f)
        self.text_pre_encoder_body = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),
                                                                  freeze=args.t_fix_pre)
        self.text_encoder_body = nn.Linear(300, args.audio_f)
        self.text_encoder_body = nn.Linear(300, args.audio_f)

        self.audio_pre_encoder_face = WavEncoder(args.audio_f, audio_in=2)
        self.audio_pre_encoder_body = WavEncoder(args.audio_f, audio_in=2)

        self.at_attn_face = nn.Linear(args.audio_f * 2, args.audio_f * 2)
        self.at_attn_body = nn.Linear(args.audio_f * 2, args.audio_f * 2)

        args_top = copy.deepcopy(self.args)
        args_top.vae_layer = 3
        args_top.vae_length = args.motion_f
        args_top.vae_test_dim = args.pose_dims + 3 + 4
        self.motion_encoder = VQEncoderV9(args_top)  # masked motion to latent bs t 333 to bs t 256

        # face decoder
        self.feature2face = nn.Linear(args.audio_f * 2, args.hidden_size)
        self.face2latent = nn.Linear(args.hidden_size, args.vae_codebook_size)
        self.latent2rec = nn.Linear(args.vae_codebook_size, args.vae_codebook_size)
        self.face2latent2 = nn.Linear(args.vae_codebook_size, args.vae_codebook_size)
        self.upper2latent = nn.Linear(args.vae_codebook_size, args.vae_codebook_size)
        self.lower2latent = nn.Linear(args.vae_codebook_size, args.vae_codebook_size)
        self.hands2latent = nn.Linear(args.vae_codebook_size, args.vae_codebook_size)
        self.transformer_de_layer = nn.TransformerDecoderLayer(
            d_model=self.args.hidden_size,
            nhead=4,
            dim_feedforward=self.args.hidden_size * 2,
            batch_first=True
        )
        self.transformer_de_cdbk_layer = nn.TransformerDecoderLayer(
            d_model=args.vae_codebook_size,
            nhead=4,
            dim_feedforward=self.args.hidden_size * 2,
            batch_first=True
        )
        '''
        self.transformer_de_layer_style = nn.TransformerDecoderLayer(
            d_model=4*args.vae_codebook_size,
            nhead=4,
            dim_feedforward=args.vae_codebook_size * 8,
            batch_first=True
        )
        '''
        self.transformer_de_layer_token = nn.TransformerDecoderLayer(
            d_model=args.vae_codebook_size,
            nhead=4,
            dim_feedforward=args.vae_codebook_size * 4,
            dropout=0.5,
            batch_first=True,
            layer_norm_eps=1e-04,
        )

        self.face_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=8)
        self.position_embeddings = PeriodicPositionalEncoding(self.args.hidden_size, period=self.args.pose_length,
                                                              max_seq_len=self.args.pose_length)
        self.position_embeddings_out = PeriodicPositionalEncoding(self.args.vae_codebook_size, period=self.args.pose_length,
                                                              max_seq_len=self.args.pose_length)

        # motion decoder
        self.transformer_en_layer = nn.TransformerEncoderLayer(
            d_model=self.args.hidden_size,
            nhead=4,
            dim_feedforward=self.args.hidden_size * 2,
            batch_first=True
        )
        self.motion_self_encoder = nn.TransformerEncoder(self.transformer_en_layer, num_layers=1)
        self.audio_feature2motion = nn.Linear(args.audio_f, args.hidden_size)
        self.feature2motion = nn.Linear(args.motion_f, args.hidden_size)

        self.bodyhints_face = MLP(args.motion_f, args.hidden_size, args.motion_f)
        self.bodyhints_body = MLP(args.motion_f, args.hidden_size, args.motion_f)
        self.motion2latent_upper = MLP(args.hidden_size, args.hidden_size, self.args.hidden_size)
        self.motion2latent_hands = MLP(args.hidden_size, args.hidden_size, self.args.hidden_size)
        self.motion2latent_lower = MLP(args.hidden_size, args.hidden_size, self.args.hidden_size)
        self.wordhints_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=4)

        #self.token_decoder = nn.TransformerDecoder(self.transformer_de_layer_token, num_layers=4)
        #self.face_token_decoder = nn.TransformerDecoder(self.transformer_de_layer_token, num_layers=4)
        #self.hands_token_decoder = nn.TransformerDecoder(self.transformer_de_layer_token, num_layers=1)
        #self.upper_token_decoder = nn.TransformerDecoder(self.transformer_de_layer_token, num_layers=1)
        #self.lower_token_decoder = nn.TransformerDecoder(self.transformer_de_layer_token, num_layers=1)

        self.upper_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=1)
        self.hands_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=1)
        self.lower_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=1)

        #self.style_decoder = nn.TransformerDecoder(self.transformer_de_layer_style, num_layers=1)
        #self.style_ind_decoder = nn.TransformerDecoder(self.transformer_de_layer_style, num_layers=1)

        '''
        self.face_classifier = [MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda(),
                                MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda(),
                                MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda(),
                                MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda()]
        self.upper_classifier = [MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda(),
                                MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda(),
                                MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda(),
                                MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda()]
        self.hands_classifier = [MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda(),
                                MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda(),
                                MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda(),
                                MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda()]
        self.lower_classifier = [MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda(),
                                MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda(),
                                MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda(),
                                MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size).cuda()]
        '''

        self.face_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size)
        self.upper_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size)
        self.hands_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size)
        self.lower_classifier = MLP(self.args.vae_codebook_size, args.hidden_size, self.args.vae_codebook_size)

        self.mask_embeddings = nn.Parameter(torch.zeros(1, 1, self.args.pose_dims + 3 + 4))
        self.motion_down_upper = nn.Linear(args.hidden_size, self.args.vae_codebook_size)
        self.motion_down_hands = nn.Linear(args.hidden_size, self.args.vae_codebook_size)
        self.motion_down_lower = nn.Linear(args.hidden_size, self.args.vae_codebook_size)

        self.mstrans_face = [nn.TransformerDecoder(self.transformer_de_cdbk_layer, num_layers=2).cuda(), # 2
                             nn.TransformerDecoder(self.transformer_de_cdbk_layer, num_layers=2).cuda(),
                             nn.TransformerDecoder(self.transformer_de_cdbk_layer, num_layers=2).cuda(),
                             nn.TransformerDecoder(self.transformer_de_cdbk_layer, num_layers=2).cuda(),]

        self.mstrans_body = [nn.TransformerDecoder(self.transformer_de_cdbk_layer, num_layers=2).cuda(), # 2
                             nn.TransformerDecoder(self.transformer_de_cdbk_layer, num_layers=2).cuda(),
                             nn.TransformerDecoder(self.transformer_de_cdbk_layer, num_layers=2).cuda(),
                             nn.TransformerDecoder(self.transformer_de_cdbk_layer, num_layers=2).cuda(),]

        self.scaler_face = [nn.Linear(self.args.vae_codebook_size, self.args.vae_codebook_size).cuda(),
                            nn.Linear(self.args.vae_codebook_size*2, self.args.vae_codebook_size).cuda(),
                            nn.Linear(self.args.vae_codebook_size*4, self.args.vae_codebook_size).cuda(),
                            nn.Linear(self.args.vae_codebook_size*8, self.args.vae_codebook_size).cuda()]
        self.scaler_hands = [nn.Linear(self.args.vae_codebook_size, self.args.vae_codebook_size).cuda(),
                            nn.Linear(self.args.vae_codebook_size*2, self.args.vae_codebook_size).cuda(),
                            nn.Linear(self.args.vae_codebook_size*4, self.args.vae_codebook_size).cuda(),
                            nn.Linear(self.args.vae_codebook_size*8, self.args.vae_codebook_size).cuda()]
        self.scaler_upper = [nn.Linear(self.args.vae_codebook_size, self.args.vae_codebook_size).cuda(),
                            nn.Linear(self.args.vae_codebook_size*2, self.args.vae_codebook_size).cuda(),
                            nn.Linear(self.args.vae_codebook_size*4, self.args.vae_codebook_size).cuda(),
                            nn.Linear(self.args.vae_codebook_size*8, self.args.vae_codebook_size).cuda()]
        self.scaler_lower = [nn.Linear(self.args.vae_codebook_size, self.args.vae_codebook_size).cuda(),
                            nn.Linear(self.args.vae_codebook_size*2, self.args.vae_codebook_size).cuda(),
                            nn.Linear(self.args.vae_codebook_size*4, self.args.vae_codebook_size).cuda(),
                            nn.Linear(self.args.vae_codebook_size*8, self.args.vae_codebook_size).cuda()]


        self.caption_encoder_index = nn.Linear(args.hidden_size, self.args.vae_codebook_size)
        self.caption_encoder_style = nn.Linear(args.hidden_size, self.args.vae_codebook_size)
        self.caption_encoder_styleloss = nn.Linear(4096, 4*self.args.vae_codebook_size)
        self.caption_encoder_body = nn.Linear(4096, args.hidden_size)
        self.caption_encoder_face = nn.Linear(4096, args.hidden_size)
        self.caption_decoder_style = nn.Linear(4*self.args.vae_codebook_size, 4096)
        self.index_emb = nn.Embedding(self.args.vae_codebook_size, self.args.vae_codebook_size)
        self.index_emb_projector = nn.Linear(self.args.vae_codebook_size, self.args.vae_codebook_size)
        self.encode_extractor_face = VQEncoderV11()
        self.encode_extractor_lower = VQEncoderV11()
        self.encode_extractor_upper = VQEncoderV11()
        self.encode_extractor_hands = VQEncoderV11()
        self.leakyrelu = nn.LeakyReLU(0.2)
        self.ratio = nn.Parameter(torch.zeros(self.args.vae_codebook_size, 1))
        self._reset_parameters()

        self.spearker_encoder_body = nn.Embedding(25, args.hidden_size)
        self.spearker_encoder_face = nn.Embedding(25, args.hidden_size)


    def _reset_parameters(self):
        nn.init.normal_(self.mask_embeddings, 0, self.args.hidden_size ** -0.5)
        nn.init.normal_(self.ratio, 0, self.args.vae_codebook_size ** -0.5)

    def index_pred_mlp(self, cls, latent_list, caption_embedding, tar_index=None):
        index_list = []

        prior_latent = latent_list[-1]
        prior_latent_re = prior_latent.repeat_interleave(2, dim=1)
        for i in range(len(latent_list) - 1, -1, -1):
            if i == len(latent_list) - 1:
                input_latent = prior_latent #torch.cat((prior_latent, caption_embedding[:,:prior_latent.shape[1],:]), dim=2)
                index_list.append(cls(input_latent))
            else:
                prior_latent = latent_list[i] + prior_latent_re
                input_latent = prior_latent #torch.cat((prior_latent, caption_embedding[:, :prior_latent.shape[1], :]), dim=2)
                index_list.append(cls(input_latent))
            prior_latent_re = prior_latent.repeat_interleave(2, dim=1)
        index_list.reverse()

        return index_list#, prior_latent

    def index_pred_mlp_smalltail(self, cls, latent_list, caption_embedding, tar_index=None):
        index_list = []

        prior_latent = latent_list[0]
        #prior_latent_re = prior_latent.repeat_interleave(2, dim=1)
        for i in range(len(latent_list)):
            if i == 0:
                input_latent = prior_latent #torch.cat((prior_latent, caption_embedding[:,:prior_latent.shape[1],:]), dim=2)
                index_list.append(cls(input_latent))
            else:
                prior_latent = latent_list[i] + prior_latent_re
                input_latent = prior_latent #torch.cat((prior_latent, caption_embedding[:, :prior_latent.shape[1], :]), dim=2)
                index_list.append(cls(input_latent))
            prior_latent_re = torch.mean(prior_latent.reshape(prior_latent.shape[0], -1, 2, prior_latent.shape[-1]), dim=2)
        #index_list.reverse()

        return index_list#, prior_latent

    def index_pred(self, cls, latent_list, caption_embedding, tar_index=None):
        index_list = []
        cls_latent_list = []

        if tar_index != None:
            prior_latent = latent_list[-1]
            #prior_latent = torch.cat((prior_latent, caption_embedding[:, :len(prior_latent[0])]), dim=-1)
            #prior_latent_re = prior_latent.repeat_interleave(2, dim=1)
            for i in range(len(latent_list)-1,-1,-1):
                tgt_mask = (1 - torch.triu(torch.ones((len(latent_list[i][0]), len(latent_list[i][0]))), diagonal=0)).bool().cuda()
                #tgt_mask = tgt_mask.cuda()
                if i == len(latent_list)-1:
                    tgt = tar_index[i]#self.index_emb(tar_index[i])#nn.functional.one_hot(tar_index[i], 256).float()
                    style_cond = torch.mean(self.caption_encoder_style(caption_embedding), dim=1).unsqueeze(1)  # self.caption_encoder_style(caption_embedding)
                    sos = style_cond#nn.functional.one_hot(torch.LongTensor([0]).cuda(), 256).float().unsqueeze(1).repeat(tgt.shape[0], 1, 1)
                    tgt = torch.concat([sos,tgt], dim=1)

                    cond_prior_latent = prior_latent #+ style_cond[:, :len(prior_latent[0])]#torch.cat((prior_latent, self.caption_encoder_style(caption_embedding[:, :len(prior_latent[0])])), dim=-1)
                    out_latent = cls(tgt=tgt[:, :-1, :], tgt_mask=tgt_mask, memory=cond_prior_latent, memory_mask=tgt_mask)
                    out_index = self.index_emb_projector(out_latent)
                    #out_index = self.index_emb_projector(cls(tgt=tgt[:, :-1, :], tgt_mask=tgt_mask, memory=cond_prior_latent, memory_mask=tgt_mask))
                    index_list.append(out_index)
                    cls_latent_list.append(out_latent)
                    #index_list.append(cls(cond_prior_latent))
                else:
                    tgt = tar_index[i]#self.index_emb(tar_index[i]) #nn.functional.one_hot(tar_index[i], 256).float()
                    style_cond = torch.mean(self.caption_encoder_style(caption_embedding), dim=1).unsqueeze(1)  # self.caption_encoder_style(caption_embedding)
                    sos = style_cond
                    #sos = nn.functional.one_hot(torch.LongTensor([0]).cuda(), 256).float().unsqueeze(1).repeat(tgt.shape[0], 1, 1)
                    tgt = torch.concat([sos, tgt], dim=1)

                    prior_latent = latent_list[i]
                    cond_prior_latent = prior_latent #+ style_cond[:, :len(prior_latent[0])]#torch.cat((prior_latent, self.caption_encoder_style(caption_embedding[:, :len(prior_latent[0])])), dim=-1)
                    out_latent = cls(tgt=tgt[:, :-1, :], tgt_mask=tgt_mask, memory=cond_prior_latent, memory_mask=tgt_mask)
                    out_index = self.index_emb_projector(out_latent)
                    #out_index = self.index_emb_projector(cls(tgt=tgt[:,:-1,:], tgt_mask=tgt_mask, memory=cond_prior_latent, memory_mask=tgt_mask))
                    index_list.append(out_index)
                    cls_latent_list.append(out_latent)
                    #index_list.append(cls(cond_prior_latent))
                #prior_latent_re = prior_latent.repeat_interleave(2, dim=1)
        else:
            prior_latent = latent_list[-1]
            # prior_latent = torch.cat((prior_latent, caption_embedding[:, :len(prior_latent[0])]), dim=-1)
            #prior_latent_re = prior_latent.repeat_interleave(2, dim=1)
            for i in range(len(latent_list) - 1, -1, -1):
                #tgt_mask = (1 - torch.triu(torch.ones((len(latent_list[i][0]), len(latent_list[i][0]))),
                #                           diagonal=0)).bool().cuda()
                # tgt_mask = tgt_mask.cuda()
                if i == len(latent_list) - 1:
                    style_cond = torch.mean(self.caption_encoder_style(caption_embedding), dim=1).unsqueeze(1)  # self.caption_encoder_style(caption_embedding)
                    tgt = style_cond
                    #tgt = nn.functional.one_hot(torch.LongTensor([0]).cuda(), 256).float().unsqueeze(1)
                    cond_prior_latent = prior_latent #+ self.caption_encoder_style(caption_embedding[:, :len(prior_latent[0])])  # torch.cat((prior_latent, self.caption_encoder_style(caption_embedding[:, :len(prior_latent[0])])), dim=-1)
                    for j in range(len(prior_latent[0])):
                        out = cls(tgt=tgt, memory=cond_prior_latent)
                        tgt = torch.concat([tgt, out[:,-1:,:]], dim=1)
                    out_latent = tgt[:, 1:, :]
                    out_index = self.index_emb_projector(out_latent)
                    index_list.append(out_index)
                    cls_latent_list.append(out_latent)

                    # index_list.append(cls(cond_prior_latent))
                else:
                    prior_latent = latent_list[i]
                    style_cond = torch.mean(self.caption_encoder_style(caption_embedding), dim=1).unsqueeze(1)  # self.caption_encoder_style(caption_embedding)
                    tgt = style_cond#nn.functional.one_hot(torch.LongTensor([0]).cuda(), 256).float().unsqueeze(1)
                    cond_prior_latent = prior_latent #+ self.caption_encoder_style(caption_embedding[:, :len(prior_latent[0])])  # torch.cat((prior_latent, self.caption_encoder_style(caption_embedding[:, :len(prior_latent[0])])), dim=-1)
                    for j in range(len(prior_latent[0])):
                        out = cls(tgt=tgt, memory=cond_prior_latent)
                        tgt = torch.concat([tgt, out[:, -1:, :]], dim=1)
                    out_latent = tgt[:, 1:, :]
                    out_index = self.index_emb_projector(out_latent)
                    index_list.append(out_index)
                    cls_latent_list.append(out_latent)
                    #cond_prior_latent = prior_latent + self.caption_encoder_style(caption_embedding[:, :len(prior_latent[0])])  # torch.cat((prior_latent, self.caption_encoder_style(caption_embedding[:, :len(prior_latent[0])])), dim=-1)
                    #index_list.append(self.token_decoder(tgt=nn.functional.one_hot(torch.LongTensor([0]).cuda(), 256).float().unsqueeze(1), memory=cond_prior_latent))
                    # index_list.append(cls(cond_prior_latent))
                #prior_latent_re = prior_latent.repeat_interleave(2, dim=1)
        index_list.reverse()
        cls_latent_list.reverse()

        return index_list, cls_latent_list#, prior_latent

    def forward(self, in_audio=None, in_word=None, in_caption=None, mask=None, is_test=None, in_motion=None, use_attentions=True,
                use_word=True, in_id=None, tar_data=None):
        face_latent_list = []
        upper_latent_list = []
        lower_latent_list = []
        hands_latent_list = []
        cls_face = []
        cls_upper = []
        cls_lower = []
        cls_hands = []
        in_word_face = self.text_pre_encoder_face(in_word)
        in_word_face = self.text_encoder_face(in_word_face)
        in_word_body = self.text_pre_encoder_body(in_word)
        in_word_body = self.text_encoder_body(in_word_body)
        bs, t, c = in_word_face.shape
        in_audio_face = self.audio_pre_encoder_face(in_audio)
        in_audio_body = self.audio_pre_encoder_body(in_audio)
        if in_audio_face.shape[1] != in_motion.shape[1]:
            diff_length = in_motion.shape[1] - in_audio_face.shape[1]
            if diff_length < 0:
                in_audio_face = in_audio_face[:, :diff_length, :]
                in_audio_body = in_audio_body[:, :diff_length, :]
            else:
                in_audio_face = torch.cat((in_audio_face, in_audio_face[:, -diff_length:]), 1)
                in_audio_body = torch.cat((in_audio_body, in_audio_body[:, -diff_length:]), 1)

        if use_attentions:
            alpha_at_face = torch.cat([in_word_face, in_audio_face], dim=-1).reshape(bs, t, c * 2)
            alpha_at_face = self.at_attn_face(alpha_at_face).reshape(bs, t, c, 2)
            alpha_at_face = alpha_at_face.softmax(dim=-1)
            fusion_face = in_word_face * alpha_at_face[:, :, :, 1] + in_audio_face * alpha_at_face[:, :, :, 0]
            alpha_at_body = torch.cat([in_word_body, in_audio_body], dim=-1).reshape(bs, t, c * 2)
            alpha_at_body = self.at_attn_body(alpha_at_body).reshape(bs, t, c, 2)
            alpha_at_body = alpha_at_body.softmax(dim=-1)
            fusion_body = in_word_body * alpha_at_body[:, :, :, 1] + in_audio_body * alpha_at_body[:, :, :, 0]
        else:
            fusion_face = in_word_face + in_audio_face
            fusion_body = in_word_body + in_audio_body

        # use top 8 frames of in_motion ground truth as hint
        masked_embeddings = self.mask_embeddings.expand_as(in_motion)
        masked_motion = torch.where(mask == 1, masked_embeddings, in_motion)  # bs, t, 256
        body_hint = self.motion_encoder(masked_motion)  # bs t 256
        body_hint = [body_hint[i].repeat_interleave(2**i, dim=1) for i in range(len(body_hint))]
        body_hint = F.normalize(sum(body_hint))
        # speaker condition
        caption_style_embedding_face = self.spearker_encoder_face(in_id).squeeze(2)#self.caption_encoder_face(in_caption)#self.spearker_encoder_face(in_id).squeeze(2)
        caption_style_embedding_body = self.spearker_encoder_face(in_id).squeeze(2)#self.caption_encoder_body(in_caption) #self.spearker_encoder_body(in_id).squeeze(2)
        #caption_style = self.caption_encoder_styleloss(in_caption)
        # decode face
        use_body_hints = True
        if use_body_hints:
            body_hint_face = self.bodyhints_face(body_hint)
            fusion_face = torch.cat([fusion_face, body_hint_face], dim=2)
        a2g_face = self.feature2face(fusion_face)
        face_embeddings = caption_style_embedding_face
        #face_embeddings = a2g_face
        #face_embeddings = self.position_embeddings(face_embeddings)
        a2g_face = a2g_face + face_embeddings
        #decoded_face = self.face_decoder(tgt=face_embeddings, memory=a2g_face)

        # motion spatial encoder
        body_hint_body = self.bodyhints_body(body_hint)
        motion_embeddings = self.feature2motion(body_hint_body)
        motion_embeddings = motion_embeddings + caption_style_embedding_body
        motion_embeddings = self.position_embeddings(motion_embeddings)

        # bi-directional self-attention
        motion_refined_embeddings = self.motion_self_encoder(motion_embeddings)

        # audio to gesture cross-modal attention
        if use_word:
            a2g_motion = self.audio_feature2motion(fusion_body)
            motion_refined_embeddings_in = motion_refined_embeddings + caption_style_embedding_body
            motion_refined_embeddings_in = self.position_embeddings(motion_refined_embeddings_in)
            word_hints = self.wordhints_decoder(tgt=motion_refined_embeddings_in, memory=a2g_motion)
            motion_refined_embeddings = motion_refined_embeddings + word_hints

        # feedforward
        upper_latent = self.motion2latent_upper(motion_refined_embeddings)
        hands_latent = self.motion2latent_hands(motion_refined_embeddings)
        lower_latent = self.motion2latent_lower(motion_refined_embeddings)

        upper_latent_in = upper_latent + caption_style_embedding_body
        upper_latent_in = self.position_embeddings(upper_latent_in)
        hands_latent_in = hands_latent + caption_style_embedding_body
        hands_latent_in = self.position_embeddings(hands_latent_in)
        lower_latent_in = lower_latent + caption_style_embedding_body
        lower_latent_in = self.position_embeddings(lower_latent_in)

        # transformer decoder
        motion_upper = self.upper_decoder(tgt=upper_latent_in, memory=hands_latent + lower_latent)
        motion_hands = self.hands_decoder(tgt=hands_latent_in, memory=upper_latent + lower_latent)
        motion_lower = self.lower_decoder(tgt=lower_latent_in, memory=upper_latent + hands_latent)


        face_latent_list.append(self.face2latent(a2g_face))
        upper_latent_list.append(self.motion_down_upper(upper_latent + motion_upper)) # + motion_upper
        hands_latent_list.append(self.motion_down_hands(hands_latent + motion_hands)) # + motion_hands
        lower_latent_list.append(self.motion_down_lower(lower_latent + motion_lower)) # + motion_lower


        face_latent_list_origin =  [self.scaler_face[i](face_latent_list[0].reshape(face_latent_list[0].shape[0], -1, face_latent_list[0].shape[2]*(2**i))/2**i) for i in range(len(self.scaler_face))]
        upper_latent_list_origin = [self.scaler_upper[i](upper_latent_list[0].reshape(upper_latent_list[0].shape[0], -1, upper_latent_list[0].shape[2]*(2**i))) for i in range(len(self.scaler_upper))]
        hands_latent_list_origin = [self.scaler_hands[i](hands_latent_list[0].reshape(hands_latent_list[0].shape[0], -1, hands_latent_list[0].shape[2]*(2**i))) for i in range(len(self.scaler_hands))]
        lower_latent_list_origin = [self.scaler_lower[i](lower_latent_list[0].reshape(lower_latent_list[0].shape[0], -1, lower_latent_list[0].shape[2]*(2**i))) for i in range(len(self.scaler_lower))]


        face_latent_list_cnn = self.encode_extractor_face(face_latent_list[0])
        upper_latent_list_cnn = self.encode_extractor_upper(upper_latent_list[0])
        hands_latent_list_cnn = self.encode_extractor_hands(hands_latent_list[0])
        lower_latent_list_cnn = self.encode_extractor_lower(lower_latent_list[0])

        fullbody_latent = [(upper_latent_list_cnn[i] + hands_latent_list_cnn[i] + lower_latent_list_cnn[i])/3 for i in range(len(self.scaler_face))]

        face_latent_list_trans = [self.face2latent2(self.mstrans_face[i](tgt=face_latent_list_origin[i], memory=face_latent_list_cnn[i])) for i in range(len(self.scaler_face))]
        upper_latent_list_trans = [self.upper2latent(self.mstrans_body[i](tgt=upper_latent_list_origin[i], memory=fullbody_latent[i])) for i in range(len(self.scaler_face))]
        hands_latent_list_trans = [self.hands2latent(self.mstrans_body[i](tgt=hands_latent_list_origin[i], memory=fullbody_latent[i])) for i in range(len(self.scaler_face))]
        lower_latent_list_trans = [self.lower2latent(self.mstrans_body[i](tgt=lower_latent_list_origin[i], memory=fullbody_latent[i])) for i in range(len(self.scaler_face))]

        face_latent_list = face_latent_list_trans#[(face_latent_list[j]*self.ratio[0]+face_latent_list_trans[j]*(1 - self.ratio[0])) for j in range(len(self.scaler_face))]
        upper_latent_list = upper_latent_list_trans#[(upper_latent_list[j]*self.ratio[1]+upper_latent_list_origin[j]*(1 - self.ratio[1])) for j in range(len(self.scaler_face))]
        hands_latent_list = hands_latent_list_trans#[(hands_latent_list[j]*self.ratio[2]+hands_latent_list_origin[j]*(1 - self.ratio[2])) for j in range(len(self.scaler_face))]
        lower_latent_list = lower_latent_list_trans#[(lower_latent_list[j]*self.ratio[3]+lower_latent_list_origin[j]*(1 - self.ratio[3])) for j in range(len(self.scaler_face))]

        face_latent_list_rec = [self.latent2rec(face_latent_list[i]) for i in range(len(face_latent_list))]  # [(face_latent_list[j]*self.ratio[0]+face_latent_list_trans[j]*(1 - self.ratio[0])) for j in range(len(self.scaler_face))]
        upper_latent_list_rec = [self.latent2rec(upper_latent_list[i]) for i in range(len(face_latent_list))]  # [(upper_latent_list[j]*self.ratio[1]+upper_latent_list_origin[j]*(1 - self.ratio[1])) for j in range(len(self.scaler_face))]
        hands_latent_list_rec = [self.latent2rec(hands_latent_list[i]) for i in range(len(face_latent_list))]  # [(hands_latent_list[j]*self.ratio[2]+hands_latent_list_origin[j]*(1 - self.ratio[2])) for j in range(len(self.scaler_face))]
        lower_latent_list_rec = [self.latent2rec(lower_latent_list[i]) for i in range(len(face_latent_list))]

        cls_face = self.index_pred_mlp(self.face_classifier, face_latent_list, caption_style_embedding_body)  # [(self.face_classifier(face_latent_list[i])) for i in range(len(face_latent_list))] #.append(self.face_classifier(face_latent))
        cls_lower = self.index_pred_mlp(self.lower_classifier, lower_latent_list, caption_style_embedding_body)  # [(self.lower_classifier(lower_latent_list[i])) for i in range(len(face_latent_list))]  # .append(self.lower_classifier(lower_latent))
        cls_upper = self.index_pred_mlp(self.upper_classifier, upper_latent_list, caption_style_embedding_body)  # [(self.upper_classifier(upper_latent_list[i])) for i in range(len(face_latent_list))]#.append(self.upper_classifier(upper_latent))
        cls_hands = self.index_pred_mlp(self.hands_classifier, hands_latent_list, caption_style_embedding_body)  # [(self.hands_classifier(hands_latent_list[i])) for i in range(len(face_latent_list))]#.append(self.hands_classifier(hands_latent))

        '''
        if tar_data != None:
            #cls_face, latent_face =  self.index_pred(self.token_decoder, face_latent_list, caption_style_embedding_body, tar_data["latent_face_top"])# "tar_index_value_face_top" [(self.face_classifier(face_latent_list[i])) for i in range(len(face_latent_list))] #.append(self.face_classifier(face_latent))
            cls_lower, latent_lower = self.index_pred(self.token_decoder, lower_latent_list, caption_style_embedding_body, tar_data["latent_lower_top"])# "tar_index_value_lower_top" [(self.lower_classifier(lower_latent_list[i])) for i in range(len(face_latent_list))]  # .append(self.lower_classifier(lower_latent))
            cls_upper, latent_upper = self.index_pred(self.token_decoder, upper_latent_list, caption_style_embedding_body, tar_data["latent_upper_top"])# "tar_index_value_upper_top" [(self.upper_classifier(upper_latent_list[i])) for i in range(len(face_latent_list))]#.append(self.upper_classifier(upper_latent))
            cls_hands, latent_hands = self.index_pred(self.token_decoder, hands_latent_list, caption_style_embedding_body, tar_data["latent_hands_top"])# "tar_index_value_hands_top" [(self.hands_classifier(hands_latent_list[i])) for i in range(len(face_latent_list))]#.append(self.hands_classifier(hands_latent))
        else:
            #cls_face, latent_face = self.index_pred(self.token_decoder, face_latent_list, caption_style_embedding_body)  # [(self.face_classifier(face_latent_list[i])) for i in range(len(face_latent_list))] #.append(self.face_classifier(face_latent))
            cls_lower, latent_lower = self.index_pred(self.token_decoder, lower_latent_list, caption_style_embedding_body)  # [(self.lower_classifier(lower_latent_list[i])) for i in range(len(face_latent_list))]  # .append(self.lower_classifier(lower_latent))
            cls_upper, latent_upper = self.index_pred(self.token_decoder, upper_latent_list, caption_style_embedding_body)  # [(self.upper_classifier(upper_latent_list[i])) for i in range(len(face_latent_list))]#.append(self.upper_classifier(upper_latent))
            cls_hands, latent_hands = self.index_pred(self.token_decoder, hands_latent_list, caption_style_embedding_body)  # [(self.hands_classifier(hands_latent_list[i])) for i in range(len(face_latent_list))]#.append(self.hands_classifier(hands_latent))
        '''

        '''
        full_latent = torch.cat((face_latent, lower_latent, upper_latent, hands_latent), dim=2)
        style_latent = self.caption_encoder_styleloss(in_caption)
        full_style = self.style_decoder(tgt=style_latent[:,:1,:], memory=full_latent).repeat(1, style_latent.shape[1], 1)
        '''
        #full_style = self.caption_decoder_style(full_style).repeat(1, )
        #caption_style_embedding = self.caption_encoder_index(caption_style_embedding_body)
        #style_loss = nn.MSELoss(full_style, style_latent[:,:1,:])

        return {
            "rec_face": face_latent_list_rec,
            "rec_upper": upper_latent_list,
            "rec_lower": lower_latent_list,
            "rec_hands": hands_latent_list,
            "cls_face": cls_face,
            "cls_upper": cls_upper,
            "cls_lower": cls_lower,
            "cls_hands": cls_hands,
            #"full_style": full_style,
        }