import train
import os
import time
import csv
import sys
import warnings
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
import time
import pprint
from loguru import logger
from utils import rotation_conversions as rc
import smplx
from utils import config, logger_tools, other_tools, metric, data_transfer
from dataloaders import data_tools
from optimizers.optim_factory import create_optimizer
from optimizers.scheduler_factory import create_scheduler
from optimizers.loss_factory import get_loss_func
from dataloaders.data_tools import joints_list
import librosa

class CustomTrainer(train.BaseTrainer):
    def __init__(self, args):
        super().__init__(args)
        self.args = args
        self.joints = self.train_data.joints
        self.ori_joint_list = joints_list[self.args.ori_joints]
        self.tar_joint_list_face = joints_list["beat_smplx_face"]
        self.tar_joint_list_upper = joints_list["beat_smplx_upper"]
        self.tar_joint_list_hands = joints_list["beat_smplx_hands"]
        self.tar_joint_list_lower = joints_list["beat_smplx_lower"]

        self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys())) * 3)
        self.joints = 55
        for joint_name in self.tar_joint_list_face:
            self.joint_mask_face[
            self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][
                1]] = 1
        self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys())) * 3)
        for joint_name in self.tar_joint_list_upper:
            self.joint_mask_upper[
            self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][
                1]] = 1
        self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys())) * 3)
        for joint_name in self.tar_joint_list_hands:
            self.joint_mask_hands[
            self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][
                1]] = 1
        self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys())) * 3)
        for joint_name in self.tar_joint_list_lower:
            self.joint_mask_lower[
            self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][
                1]] = 1

        self.tracker = other_tools.EpochTracker(
            ["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'exp', 'lvd', 'mse',
             "cls", "rec_face", "latent", "cls_full", "cls_self", "cls_word", "latent_word", "latent_self", "style_loss", "rec_style_loss", "tar_style_loss", "tar_rec_loss", "ind_style_loss"],
            [False, True, True, False, False, False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False, False, False])

        vq_model_module = __import__(f"models.motion_representation", fromlist=["something"])
        self.args.vae_layer = 2
        self.args.vae_length = 256
        self.args.vae_test_dim = 106
        self.vq_model_face = getattr(vq_model_module, "VQVAEConvGlobal")(self.args).to(self.rank)
        # print(self.vq_model_face)
        #crossid
        #other_tools.load_checkpoints(self.vq_model_face, self.args.data_path_1 + "pretrained_vq_global_crossid/face_last_300.bin", args.e_name)
        # id2
        #other_tools.load_checkpoints(self.vq_model_face, self.args.data_path_1 + "pretrained_vq_global_id2/face_last_160.bin", args.e_name)
        # id2 resnorm
        # other_tools.load_checkpoints(self.vq_model_face, self.args.data_path_1 + "pretrained_vq_id2_res2norm/face_last_440.bin", args.e_name)
        # AMG
        other_tools.load_checkpoints(self.vq_model_face,self.args.data_path_1 + "pretrained/face.bin",args.e_name)

        self.args.vae_test_dim = 78
        self.vq_model_upper = getattr(vq_model_module, "VQVAEConvGlobal")(self.args).to(self.rank)
        #crossid
        #other_tools.load_checkpoints(self.vq_model_upper, self.args.data_path_1 + "pretrained_vq_global_crossid/upper_last_140.bin", args.e_name) #"pretrained_vq_global_id2/upper_last_300.bin"
        # id2
        #other_tools.load_checkpoints(self.vq_model_upper, self.args.data_path_1 + "pretrained_vq_global_id2/upper_last_300.bin", args.e_name)  #
        # id2 resnorm
        #other_tools.load_checkpoints(self.vq_model_upper, self.args.data_path_1 + "pretrained_vq_id2_res2norm/upper_last_320.bin", args.e_name)
        # AMG
        other_tools.load_checkpoints(self.vq_model_upper,self.args.data_path_1 + "pretrained/upper.bin",args.e_name)

        self.args.vae_test_dim = 180
        self.vq_model_hands = getattr(vq_model_module, "VQVAEConvGlobal")(self.args).to(self.rank)
        #crossid
        #other_tools.load_checkpoints(self.vq_model_hands, self.args.data_path_1 + "pretrained_vq_global_crossid/hands_last_120.bin", args.e_name) #"pretrained_vq_global_id2/hands_last_60.bin"
        # id2
        #other_tools.load_checkpoints(self.vq_model_hands, self.args.data_path_1 + "pretrained_vq_global_id2/hands_last_60.bin", args.e_name)  #
        # id2 resnorm
        #other_tools.load_checkpoints(self.vq_model_hands, self.args.data_path_1 + "pretrained_vq_id2_res2norm/hands_last_260.bin", args.e_name)
        # AMG
        other_tools.load_checkpoints(self.vq_model_hands,self.args.data_path_1 + "pretrained/hands.bin",args.e_name) # 253

        self.args.vae_test_dim = 61
        self.args.vae_layer = 4
        self.vq_model_lower = getattr(vq_model_module, "VQVAEConvGlobal_noatt")(self.args).to(self.rank)
        # crossid
        #other_tools.load_checkpoints(self.vq_model_lower, self.args.data_path_1 + "pretrained_vq_global_crossid/lowerfoot_last_220.bin", args.e_name)
        # id2
        #other_tools.load_checkpoints(self.vq_model_lower, self.args.data_path_1 + "pretrained_vq_global_id2/lowerfoot_last_300.bin", args.e_name)
        # id2 resnorm
        other_tools.load_checkpoints(self.vq_model_lower, self.args.data_path_1 + "pretrained/lowerfoot.bin", args.e_name)
        # AMG
        #other_tools.load_checkpoints(self.vq_model_lower,self.args.data_path_1 + "amg_qres_256/lowerfoot_395.bin",args.e_name) #395

        self.args.vae_test_dim = 61
        self.args.vae_layer = 2
        self.global_motion = getattr(vq_model_module, "VAEConvGlobal")(self.args).to(self.rank)
        # crossid
        #other_tools.load_checkpoints(self.global_motion, self.args.data_path_1 + "pretrained_vq_global_crossid/lower_last_400.bin", args.e_name)
        # id2
        #other_tools.load_checkpoints(self.global_motion, self.args.data_path_1 + "pretrained_vq_global_id2/lower_last_400.bin", args.e_name)
        # id2 resnorm
        other_tools.load_checkpoints(self.global_motion, self.args.data_path_1 + "pretrained/lowerlast.bin", args.e_name)
        # AMG
        #other_tools.load_checkpoints(self.global_motion,self.args.data_path_1 + "amg_qres_256/lower_335.bin",args.e_name)

        self.args.vae_test_dim = 330
        self.args.vae_layer = 4
        self.args.vae_length = 240

        self.vq_model_face.eval()
        self.vq_model_upper.eval()
        self.vq_model_hands.eval()
        self.vq_model_lower.eval()
        self.global_motion.eval()

        self.transformer_de_layer_style = nn.TransformerDecoderLayer(
            d_model=args.vae_codebook_size,
            nhead=4,
            dim_feedforward=args.vae_codebook_size * 8,
            batch_first=True,
            activation='gelu',
        )
        self.style_decoder = nn.TransformerDecoder(self.transformer_de_layer_style, num_layers=1).to(self.rank)
        self.caption_encoder_style = nn.Sequential(nn.Linear(4096, 768, bias=True),
                                                    nn.ReLU(),
                                                    nn.Linear(768, self.args.vae_codebook_size, bias=True)).to(self.rank)
        #self.style_encoder_caption = nn.Linear(4 * self.args.vae_codebook_size, 768).to(self.rank)
        '''
        self.encoder = [
            nn.Conv1d(4 * self.args.vae_codebook_size, 4 * self.args.vae_codebook_size, 5, 4, 2),
            nn.Sigmoid(),
            nn.Conv1d(4 * self.args.vae_codebook_size, 4 * self.args.vae_codebook_size, 5, 4, 2),
            nn.Sigmoid(),
            nn.Conv1d(4 * self.args.vae_codebook_size, 4 * self.args.vae_codebook_size, 5, 4, 2),
            nn.Sigmoid(),
            nn.Conv1d(4 * self.args.vae_codebook_size, 4 * self.args.vae_codebook_size, 5, 4, 2),
            nn.Sigmoid(),
        ]
        '''
        #self.style_quantizer = nn.Embedding(self.args.vae_codebook_size, 768).to(self.rank)
        #self.seq_style_projector = ContrastiveLoss(args.batch_size, self.rank, )

        #self.style_encoder = nn.Sequential(*self.encoder).to(self.rank)
        self.cls_loss = nn.NLLLoss().to(self.rank)
        self.reclatent_loss = nn.MSELoss().to(self.rank)
        self.vel_loss = torch.nn.L1Loss(reduction='mean').to(self.rank)
        self.rec_loss = get_loss_func("GeodesicLoss").to(self.rank)
        self.style_loss = torch.nn.MSELoss().to(self.rank)
        self.log_softmax = nn.LogSoftmax(dim=2).to(self.rank)
        self.style_index_emb = nn.Embedding(self.args.vae_codebook_size, 256).to(self.rank)

        '''
        for its, batch_data in enumerate(self.train_loader):
            loaded_data = self._load_data(batch_data)
            style_emb = loaded_data["in_caption"][:, 0, :]
            d = torch.cosine_similarity(style_emb, style_emb, dim=1)
        '''

    def inverse_selection(self, filtered_t, selection_array, n):
        original_shape_t = np.zeros((n, selection_array.size))
        selected_indices = np.where(selection_array == 1)[0]
        for i in range(n):
            original_shape_t[i, selected_indices] = filtered_t[i]
        return original_shape_t

    def inverse_selection_tensor(self, filtered_t, selection_array, n):
        selection_array = torch.from_numpy(selection_array).cuda()
        original_shape_t = torch.zeros((n, 165)).cuda()
        selected_indices = torch.where(selection_array == 1)[0]
        for i in range(n):
            original_shape_t[i, selected_indices] = filtered_t[i]
        return original_shape_t

    def _load_data(self, dict_data):
        tar_pose_raw = dict_data["pose"]
        tar_pose = tar_pose_raw[:, :, :165].to(self.rank)
        tar_contact = tar_pose_raw[:, :, 165:169].to(self.rank)
        tar_trans = dict_data["trans"].to(self.rank)
        tar_exps = dict_data["facial"].to(self.rank)
        in_caption = dict_data["caption"].to(self.rank)
        in_audio = dict_data["audio"].to(self.rank)
        in_word = dict_data["word"].to(self.rank)
        tar_beta = dict_data["beta"].to(self.rank)
        tar_id = dict_data["id"].to(self.rank).long()
        bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints

        tar_pose_jaw = tar_pose[:, :, 66:69]
        tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
        tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1 * 6)
        tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)

        tar_pose_hands = tar_pose[:, :, 25 * 3:55 * 3]
        tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
        tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30 * 6)

        tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)]
        tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
        tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13 * 6)

        tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)]
        tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
        tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9 * 6)
        tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2)

        # tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3))
        # tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
        tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2)

        tar_index_value_face_top = self.vq_model_face.map2index(tar_pose_face)  # bs*n/4
        tar_index_value_upper_top = self.vq_model_upper.map2index(tar_pose_upper)  # bs*n/4
        tar_index_value_hands_top = self.vq_model_hands.map2index(tar_pose_hands)  # bs*n/4
        tar_index_value_lower_top = self.vq_model_lower.map2index(tar_pose_lower)  # bs*n/4

        latent_face_top = self.vq_model_face.map2latent(tar_pose_face)  # bs*n/4
        latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper)  # bs*n/4
        latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands)  # bs*n/4
        latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower)  # bs*n/4


        latent_in = [torch.cat([latent_upper_top[i], latent_hands_top[i], latent_lower_top[i]], dim=2) for i in range(len(latent_hands_top))]

        index_in = [torch.stack([tar_index_value_upper_top[i], tar_index_value_hands_top[i], tar_index_value_lower_top[i]], dim=-1).long() for i in range(len(latent_hands_top))]

        tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
        tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55 * 6)
        latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
        # print(tar_index_value_upper_top.shape, index_in.shape)
        return {
            "tar_pose_jaw": tar_pose_jaw,
            "tar_pose_face": tar_pose_face,
            "tar_pose_upper": tar_pose_upper,
            "tar_pose_lower": tar_pose_lower,
            "tar_pose_hands": tar_pose_hands,
            'tar_pose_leg': tar_pose_leg,
            "in_audio": in_audio,
            "in_caption": in_caption,
            "in_word": in_word,
            "tar_trans": tar_trans,
            "tar_exps": tar_exps,
            "tar_beta": tar_beta,
            "tar_pose": tar_pose,
            "tar4dis": tar4dis,
            "tar_index_value_face_top": tar_index_value_face_top,
            "tar_index_value_upper_top": tar_index_value_upper_top,
            "tar_index_value_hands_top": tar_index_value_hands_top,
            "tar_index_value_lower_top": tar_index_value_lower_top,
            "latent_face_top": latent_face_top,
            "latent_upper_top": latent_upper_top,
            "latent_hands_top": latent_hands_top,
            "latent_lower_top": latent_lower_top,
            "latent_in": latent_in,
            "index_in": index_in,
            "tar_id": tar_id,
            "latent_all": latent_all,
            "tar_pose_6d": tar_pose_6d,
            "tar_contact": tar_contact,
        }

    def fulllatent2pose(self, latent_list):
        prior_latent = latent_list[0]
        for i in range(1, len(latent_list)):
            prior_latent += latent_list[i].repeat_interleave(2**i, dim=1)
        return prior_latent/len(latent_list)

    def style_latent_constrasitive(self, rec_data, tar_data):
        # contrastive latent space
        #loss_func = self.ContrastiveLoss(batch_size=rec_data["rec_face"][0].shape[0])
        rec_latent_list = [torch.cat((rec_data["rec_face"][i], rec_data["rec_lower"][i],
                                      rec_data["rec_hands"][i], rec_data["rec_upper"][i]), dim=-1) for i in
                           range(len(rec_data["rec_face"]))]
        rec_full_latent = self.fulllatent2pose(rec_latent_list)
        enc_style = self.caption_encoder_style(tar_data['in_caption'])
        #neg_idx = torch.randperm(enc_style.shape[0])
        #neg_style = enc_style[neg_idx, :, :].repeat(1, tar_data['in_caption'].shape[1], 1)
        # rec_style_latent = self.style_encoder(rec_full_latent.permute(0, 2, 1)).permute(0, 2, 1)
        rec_style_latent = self.style_decoder(tgt=enc_style[:, :1, :], memory=rec_full_latent).repeat(1, tar_data['in_caption'].shape[1], 1)
        sim_score = torch.cosine_similarity(tar_data['in_caption'], tar_data['in_caption'].permute(1, 0, 2), dim=2)
        rec_style_latent = torch.mean(rec_style_latent, dim=1)

        enc_style_loss = self.encode_contra_loss(rec_style_latent, torch.mean(enc_style, dim=1), sim_score)#self.style_loss(rec_style_latent, enc_style)#self.style_loss(rec_style_latent, enc_style)
        rec_style_loss = self.style_contra_loss(rec_style_latent, torch.mean(enc_style, dim=1))
        #rec_style_com = torch.cat([rec_style_latent, torch.mean(enc_style, dim=1)], dim=0)
        #rec_style_sim = F.cosine_similarity(rec_style_com.unsqueeze(1), rec_style_com.unsqueeze(0), dim=2)


        tar_latent_list = [torch.cat((tar_data["latent_face_top"][i], tar_data["latent_lower_top"][i],
                                      tar_data["latent_hands_top"][i], tar_data["latent_upper_top"][i]),
                                     dim=-1) for i in range(len(tar_data["latent_face_top"]))]
        tar_full_latent = self.fulllatent2pose(tar_latent_list)
        #tar_style_latent = self.style_encoder(tar_full_latent.permute(0, 2, 1)).permute(0, 2, 1)
        tar_style_latent = self.style_decoder(tgt=enc_style[:, :1, :], memory=tar_full_latent).repeat(1, tar_data['in_caption'].shape[1], 1)
        tar_style_latent = torch.mean(tar_style_latent, dim=1)
        tar_style_loss = self.style_contra_loss(tar_style_latent, torch.mean(enc_style, dim=1))
        #tar_style_loss = self.style_loss(tar_style_latent, enc_style)
        tar_rec_loss = self.style_contra_loss(tar_style_latent, rec_style_latent)#self.style_loss(tar_style_latent, rec_style_latent)
        # style_loss = rec_style_loss + tar_style_loss + tar_rec_loss
        return enc_style_loss, rec_style_loss, tar_rec_loss

    def style_multi_latents_constrasitive(self, rec_data, masked_rec_data, masked_word_data, tar_data):
        # contrastive latent space
        #loss_func = self.ContrastiveLoss(batch_size=rec_data["rec_face"][0].shape[0])
        rec_latent_list = rec_data["rec_face"]#[torch.cat((rec_data["rec_face"][i], rec_data["rec_lower"][i], rec_data["rec_hands"][i], rec_data["rec_upper"][i]), dim=-1) for i in range(len(rec_data["rec_face"]))]
        rec_mask_latent_list = masked_rec_data["rec_face"]#[torch.cat((masked_rec_data["rec_face"][i], masked_rec_data["rec_lower"][i], masked_rec_data["rec_hands"][i], masked_rec_data["rec_upper"][i]), dim=-1) for i in range(len(masked_rec_data["rec_face"]))]
        word_mask_latent_list = masked_word_data["rec_face"]#[torch.cat((masked_word_data["rec_face"][i], masked_word_data["rec_lower"][i], masked_word_data["rec_hands"][i], masked_word_data["rec_upper"][i]), dim=-1) for i in range(len(masked_word_data["rec_face"]))]
        tar_latent_list = tar_data["latent_face_top"]#[torch.cat((tar_data["latent_face_top"][i], tar_data["latent_lower_top"][i], tar_data["latent_hands_top"][i], tar_data["latent_upper_top"][i]), dim=-1) for i in range(len(tar_data["latent_face_top"]))]

        rec_full_latent = torch.cat(rec_latent_list, dim=1)#self.fulllatent2pose(rec_latent_list)
        rec_mask_full_latent = torch.cat(rec_mask_latent_list, dim=1)#self.fulllatent2pose(rec_mask_latent_list)
        rec_word_full_latent = torch.cat(word_mask_latent_list, dim=1)#self.fulllatent2pose(word_mask_latent_list)
        tar_full_latent = torch.cat(tar_latent_list, dim=1)

        enc_style = self.caption_encoder_style(tar_data['in_caption'])

        #neg_idx = torch.randperm(enc_style.shape[0])
        #neg_style = enc_style[neg_idx, :, :].repeat(1, tar_data['in_caption'].shape[1], 1)
        # rec_style_latent = self.style_encoder(rec_full_latent.permute(0, 2, 1)).permute(0, 2, 1)
        tar_style = self.style_decoder(tgt=enc_style[:, :1, :], memory=tar_full_latent).repeat(1, tar_data['in_caption'].shape[1], 1)
        style_rec = self.style_decoder(tgt=enc_style[:, :1, :], memory=rec_full_latent).repeat(1, tar_data['in_caption'].shape[1], 1)
        style_rec_self = self.style_decoder(tgt=enc_style[:, :1, :], memory=rec_mask_full_latent).repeat(1, tar_data['in_caption'].shape[1], 1)
        style_rec_word = self.style_decoder(tgt=enc_style[:, :1, :], memory=rec_word_full_latent).repeat(1, tar_data['in_caption'].shape[1], 1)

        style_rec_loss = self.style_loss(tar_style, enc_style[:, :1, :]) + self.style_loss(style_rec, enc_style[:, :1, :])+self.style_loss(style_rec_self, enc_style[:, :1, :])+self.style_loss(style_rec_word, enc_style[:, :1, :])

        sim_score = torch.cosine_similarity(tar_data['in_caption'], tar_data['in_caption'].permute(1, 0, 2), dim=2)
        #rec_style_latent = torch.mean(rec_style_latent, dim=1)
        encoder_cont_loss = self.encode_contra_loss(enc_style[:, 0, :], enc_style[:, 0, :], sim_score)
        style_rec_loss += encoder_cont_loss

        rec_self_cont_loss = self.latent_contra_loss(rec_full_latent.reshape(rec_full_latent.shape[0], -1), rec_mask_full_latent.reshape(rec_full_latent.shape[0], -1), sim_score)#self.style_loss(rec_style_latent, enc_style)#self.style_loss(rec_style_latent, enc_style)
        rec_word_cont_loss = self.latent_contra_loss(rec_full_latent.reshape(rec_full_latent.shape[0], -1), rec_word_full_latent.reshape(rec_full_latent.shape[0], -1), sim_score)#self.style_loss(rec_style_latent, enc_style)#self.style_loss(rec_style_latent, enc_style)
        self_self_cont_loss = self.latent_contra_loss(rec_word_full_latent.reshape(rec_full_latent.shape[0], -1), rec_mask_full_latent.reshape(rec_full_latent.shape[0], -1), sim_score)#self.style_loss(rec_style_latent, enc_style)#self.style_loss(rec_style_latent, enc_style)

        latent_cont_loss =  rec_self_cont_loss + rec_word_cont_loss + self_self_cont_loss
        #rec_style_loss = self.style_contra_loss(rec_style_latent, torch.mean(enc_style, dim=1))
        #rec_style_com = torch.cat([rec_style_latent, torch.mean(enc_style, dim=1)], dim=0)
        #rec_style_sim = F.cosine_similarity(rec_style_com.unsqueeze(1), rec_style_com.unsqueeze(0), dim=2)

        #tar_style_latent = torch.mean(tar_style_latent, dim=1)
        #tar_style_loss = self.style_contra_loss(tar_style_latent, torch.mean(enc_style, dim=1))
        #tar_style_loss = self.style_loss(tar_style_latent, enc_style)
        #tar_rec_loss = self.style_contra_loss(tar_style_latent, rec_style_latent)#self.style_loss(tar_style_latent, rec_style_latent)
        # style_loss = rec_style_loss + tar_style_loss + tar_rec_loss
        return latent_cont_loss

    def style_index_constrasitive(self, rec_data, masked_rec_data, masked_word_data, tar_data):
        # contrastive latent space

        index_tensor = [torch.LongTensor(list(range(256))).repeat(64, int(64/(2**i)), 1).cuda() for i in range(len(rec_data["cls_lower"]))]
        rec_upper_prob_list = [torch.softmax(rec_data["cls_upper"][i], dim=-1) for i in range(len(rec_data["cls_lower"]))]
        rec_upper_quant_list = self.vq_model_upper.softquantizer(rec_upper_prob_list)
        tar_upper_quant_list = self.vq_model_upper.quantizer.get_codebook_entry(tar_data["tar_index_value_upper_top"])
        rec_upper_ind_list = torch.cat([torch.max(rec_upper_prob_list[i], dim=2).indices for i in range(len(rec_data["cls_lower"]))], dim=1)
        tar_upper_ind_list = torch.cat(tar_data["tar_index_value_upper_top"], dim=1)
        rec_upper_ind_latent = self.vq_model_upper.quantizer.get_codebook_entry([torch.max(rec_upper_prob_list[i], dim=2).indices for i in range(len(rec_data["cls_lower"]))])#self.style_index_emb(rec_upper_ind_list)
        tar_upper_ind_latent = self.vq_model_upper.quantizer.get_codebook_entry(tar_data["tar_index_value_upper_top"])#self.style_index_emb(tar_upper_ind_list)


        rec_lower_prob_list = [torch.softmax(rec_data["cls_lower"][i], dim=-1) for i in range(len(rec_data["cls_lower"]))]
        rec_lower_quant_list = self.vq_model_upper.softquantizer(rec_lower_prob_list)
        tar_lower_quant_list = self.vq_model_upper.quantizer.get_codebook_entry(tar_data["tar_index_value_lower_top"])
        rec_lower_ind_list = torch.cat([torch.max(rec_upper_prob_list[i], dim=2).indices for i in range(len(rec_data["cls_lower"]))], dim=1)
        tar_lower_ind_list = torch.cat(tar_data["tar_index_value_lower_top"], dim=1)
        rec_lower_ind_latent = self.vq_model_upper.quantizer.get_codebook_entry([torch.max(rec_lower_prob_list[i], dim=2).indices for i in range(len(rec_data["cls_lower"]))])#self.style_index_emb(rec_lower_ind_list)
        tar_lower_ind_latent = self.vq_model_upper.quantizer.get_codebook_entry(tar_data["tar_index_value_lower_top"])#self.style_index_emb(tar_lower_ind_list)

        rec_hands_prob_list = [torch.softmax(rec_data["cls_hands"][i], dim=-1) for i in range(len(rec_data["cls_lower"]))]
        rec_hands_quant_list = self.vq_model_upper.softquantizer(rec_hands_prob_list)
        tar_hands_quant_list = self.vq_model_upper.quantizer.get_codebook_entry(tar_data["tar_index_value_hands_top"])
        rec_hands_ind_list = torch.cat([torch.max(rec_upper_prob_list[i], dim=2).indices for i in range(len(rec_data["cls_lower"]))], dim=1)
        tar_hands_ind_list = torch.cat(tar_data["tar_index_value_hands_top"], dim=1)
        rec_hands_ind_latent = self.vq_model_upper.quantizer.get_codebook_entry([torch.max(rec_hands_prob_list[i], dim=2).indices for i in range(len(rec_data["cls_lower"]))])#self.style_index_emb(rec_hands_ind_list)
        tar_hands_ind_latent = self.vq_model_upper.quantizer.get_codebook_entry(tar_data["tar_index_value_hands_top"])#self.style_index_emb(tar_hands_ind_list)

        #hard
        style_rec_loss = self.style_loss(torch.cat(rec_upper_ind_latent, dim=1), torch.cat(tar_upper_ind_latent, dim=1)) + \
                         self.style_loss(torch.cat(rec_lower_ind_latent, dim=1), torch.cat(tar_lower_ind_latent, dim=1)) + \
                         self.style_loss(torch.cat(rec_hands_ind_latent, dim=1), torch.cat(tar_hands_ind_latent, dim=1))

        return style_rec_loss
    def _g_training(self, loaded_data, use_adv, mode="train", epoch=0):
        bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints


        # ------ full generatation task ------ #
        mask_val = torch.ones(bs, n, self.args.pose_dims + 3 + 4).float().cuda()
        mask_val[:, :self.args.pre_frames, :] = 0.0

        net_out_val = self.model(
            loaded_data['in_audio'], loaded_data['in_word'], loaded_data['in_caption'], mask=mask_val,
            in_id=loaded_data['tar_id'], in_motion=loaded_data['latent_all'], tar_data=loaded_data,
            use_attentions=True)
        g_loss_final = 0

        #loss_latent_face = self.reclatent_loss(self.vq_model_face.decoder(self.vq_model_face.quantizer(net_out_val["rec_face"])[1]), loaded_data["tar_pose_face"])
        loss_latent_face = sum([self.reclatent_loss(net_out_val["rec_face"][i], loaded_data["latent_face_top"][i]) for i in range(len(net_out_val["rec_face"]))])
        loss_latent_lower = sum([self.reclatent_loss(net_out_val["rec_lower"][i], loaded_data["latent_lower_top"][i]) for i in range(len(net_out_val["rec_face"]))])#self.reclatent_loss(net_out_val["rec_lower"], loaded_data["latent_lower_top"])
        loss_latent_hands = sum([self.reclatent_loss(net_out_val["rec_hands"][i], loaded_data["latent_hands_top"][i]) for i in range(len(net_out_val["rec_face"]))])#self.reclatent_loss(net_out_val["rec_hands"], loaded_data["latent_hands_top"])
        loss_latent_upper = sum([self.reclatent_loss(net_out_val["rec_upper"][i], loaded_data["latent_upper_top"][i]) for i in range(len(net_out_val["rec_face"]))])#self.reclatent_loss(net_out_val["rec_upper"], loaded_data["latent_upper_top"])
        loss_latent = self.args.lf * loss_latent_face + self.args.ll * loss_latent_lower + self.args.lh * loss_latent_hands + self.args.lu * loss_latent_upper
        self.tracker.update_meter("latent", "train", loss_latent.item())
        g_loss_final += loss_latent

        rec_index_face_val = [self.log_softmax(net_out_val["cls_face"][i]).reshape(-1, self.args.vae_codebook_size) for i in range(len(net_out_val["cls_face"]))]
        rec_index_upper_val =  [self.log_softmax(net_out_val["cls_upper"][i]).reshape(-1, self.args.vae_codebook_size) for i in range(len(net_out_val["cls_face"]))]#self.log_softmax(net_out_val["cls_upper"]).reshape(-1, self.args.vae_codebook_size)
        rec_index_lower_val =  [self.log_softmax(net_out_val["cls_lower"][i]).reshape(-1, self.args.vae_codebook_size) for i in range(len(net_out_val["cls_face"]))]#self.log_softmax(net_out_val["cls_lower"]).reshape(-1, self.args.vae_codebook_size)
        rec_index_hands_val =  [self.log_softmax(net_out_val["cls_hands"][i]).reshape(-1, self.args.vae_codebook_size) for i in range(len(net_out_val["cls_face"]))]#self.log_softmax(net_out_val["cls_hands"]).reshape(-1, self.args.vae_codebook_size)
        tar_index_value_face_top = [loaded_data["tar_index_value_face_top"][i].reshape(-1) for i in range(len(loaded_data["tar_index_value_face_top"]))]
        tar_index_value_upper_top = [loaded_data["tar_index_value_upper_top"][i].reshape(-1) for i in range(len(loaded_data["tar_index_value_face_top"]))]#loaded_data["tar_index_value_upper_top"].reshape(-1)
        tar_index_value_lower_top =  [loaded_data["tar_index_value_lower_top"][i].reshape(-1) for i in range(len(loaded_data["tar_index_value_face_top"]))]#loaded_data["tar_index_value_lower_top"].reshape(-1)
        tar_index_value_hands_top =  [loaded_data["tar_index_value_hands_top"][i].reshape(-1) for i in range(len(loaded_data["tar_index_value_face_top"]))]#loaded_data["tar_index_value_hands_top"].reshape(-1)

        face_cls_loss = sum([self.cls_loss(rec_index_face_val[i], tar_index_value_face_top[i]) for i in range(len(rec_index_face_val))])/len(rec_index_face_val)
        upper_cls_loss = sum([self.cls_loss(rec_index_upper_val[i], tar_index_value_upper_top[i]) for i in range(len(rec_index_face_val))])/len(rec_index_face_val)
        lower_cls_loss = sum([self.cls_loss(rec_index_lower_val[i], tar_index_value_lower_top[i]) for i in range(len(rec_index_face_val))])/len(rec_index_face_val)
        hands_cls_loss = sum([self.cls_loss(rec_index_hands_val[i], tar_index_value_hands_top[i]) for i in range(len(rec_index_face_val))])/len(rec_index_face_val)

        #index_contra_loss = self.style_index_constrasitive(net_out_val["cls_face"], tar_index_value_face_top, loaded_data['in_caption'])

        loss_cls = self.args.cf * face_cls_loss + self.args.cu * upper_cls_loss + self.args.cl * lower_cls_loss + self.args.ch * hands_cls_loss
        self.tracker.update_meter("cls_full", "train", loss_cls.item())
        g_loss_final += loss_cls

        if mode == 'train':
                 # ------ masked gesture moderling------ #
            mask_ratio = (epoch / self.args.epochs) * 0.95 + 0.05
            mask = torch.rand(bs, n, self.args.pose_dims + 3 + 4) < mask_ratio
            mask = mask.float().cuda()
            net_out_self = self.model(
                loaded_data['in_audio'], loaded_data['in_word'], loaded_data['in_caption'], mask=mask,
                in_id=loaded_data['tar_id'], in_motion=loaded_data['latent_all'], tar_data=loaded_data,
                use_attentions=True, use_word=False)

            # contrastive latent space
            '''
            rec_style_loss, tar_style_loss, tar_rec_loss = self.style_latent_constrasitive(net_out_val, loaded_data)

            g_loss_final += rec_style_loss + tar_style_loss + tar_rec_loss
            self.tracker.update_meter("rec_style_loss", "train", rec_style_loss.item())
            self.tracker.update_meter("tar_style_loss", "train", tar_style_loss.item())
            self.tracker.update_meter("tar_rec_loss", "train", tar_rec_loss.item())
            '''

            loss_latent_face_self = sum(
                [self.reclatent_loss(net_out_self["rec_face"][i], loaded_data["latent_face_top"][i]) for i in
                 range(len(net_out_self["rec_face"]))])
            loss_latent_lower_self = sum(
                [self.reclatent_loss(net_out_self["rec_lower"][i], loaded_data["latent_lower_top"][i]) for i in range(
                    len(net_out_self[
                            "rec_face"]))])  # self.reclatent_loss(net_out_val["rec_lower"], loaded_data["latent_lower_top"])
            loss_latent_hands_self = sum(
                [self.reclatent_loss(net_out_self["rec_hands"][i], loaded_data["latent_hands_top"][i]) for i in range(
                    len(net_out_self[
                            "rec_face"]))])  # self.reclatent_loss(net_out_val["rec_hands"], loaded_data["latent_hands_top"])
            loss_latent_upper_self = sum(
                [self.reclatent_loss(net_out_self["rec_upper"][i], loaded_data["latent_upper_top"][i]) for i in range(
                    len(net_out_self[
                            "rec_face"]))])  # self.reclatent_loss(net_out_val["rec_upper"], loaded_data["latent_upper_top"])
            loss_latent_self = self.args.lf * loss_latent_face_self + self.args.ll * loss_latent_lower_self + self.args.lh * loss_latent_hands_self + self.args.lu * loss_latent_upper_self


            self.tracker.update_meter("latent_self", "train", loss_latent_self.item())
            #g_loss_final += loss_latent_self
            rec_index_face_self = [self.log_softmax(net_out_self["cls_face"][i]).reshape(-1, self.args.vae_codebook_size)
                                  for i in range(len(net_out_self["cls_face"]))]
            rec_index_upper_self = [
                self.log_softmax(net_out_self["cls_upper"][i]).reshape(-1, self.args.vae_codebook_size) for i in range(
                    len(net_out_self[
                            "cls_face"]))]  # self.log_softmax(net_out_val["cls_upper"]).reshape(-1, self.args.vae_codebook_size)
            rec_index_lower_self = [
                self.log_softmax(net_out_self["cls_lower"][i]).reshape(-1, self.args.vae_codebook_size) for i in range(
                    len(net_out_self[
                            "cls_face"]))]  # self.log_softmax(net_out_val["cls_lower"]).reshape(-1, self.args.vae_codebook_size)
            rec_index_hands_self = [
                self.log_softmax(net_out_self["cls_hands"][i]).reshape(-1, self.args.vae_codebook_size) for i in range(
                    len(net_out_self["cls_face"]))]  # self.log_softmax(net_out_val["cls_hands"]).reshape(-1, self.args.vae_codebook_size)


            face_cls_loss_self = sum([self.cls_loss(rec_index_face_self[i], tar_index_value_face_top[i]) for i in
                                 range(len(rec_index_face_self))])/len(rec_index_face_val)
            upper_cls_loss_self = sum([self.cls_loss(rec_index_upper_self[i], tar_index_value_upper_top[i]) for i in
                                  range(len(rec_index_face_self))])/len(rec_index_face_val)
            lower_cls_loss_self = sum([self.cls_loss(rec_index_lower_self[i], tar_index_value_lower_top[i]) for i in
                                  range(len(rec_index_face_self))])/len(rec_index_face_val)
            hands_cls_loss_self = sum([self.cls_loss(rec_index_hands_self[i], tar_index_value_hands_top[i]) for i in
                                  range(len(rec_index_face_self))])/len(rec_index_face_val)

            index_loss_top_self = face_cls_loss_self + upper_cls_loss_self + lower_cls_loss_self + hands_cls_loss_self

            #index_loss_top_self = self.cls_loss(rec_index_face_self, tar_index_value_face_top) + self.cls_loss(rec_index_upper_self, tar_index_value_upper_top) + self.cls_loss(rec_index_lower_self,tar_index_value_lower_top) + self.cls_loss(rec_index_hands_self, tar_index_value_hands_top)
            self.tracker.update_meter("cls_self", "train", index_loss_top_self.item())
            #g_loss_final += index_loss_top_self

            # ------ masked audio gesture moderling ------ #
            net_out_word = self.model(
                loaded_data['in_audio'], loaded_data['in_word'], loaded_data['in_caption'], mask=mask,
                in_id=loaded_data['tar_id'], in_motion=loaded_data['latent_all'], tar_data=loaded_data,
                use_attentions=True, use_word=True)

            # contrastive latent space
            '''
            rec_style_loss, tar_style_loss, tar_rec_loss = self.style_latent_constrasitive(net_out_val, loaded_data)
            #g_loss_final += rec_style_loss + tar_style_loss + tar_rec_loss
            self.tracker.update_meter("rec_style_loss", "train", rec_style_loss.item())
            self.tracker.update_meter("tar_style_loss", "train", tar_style_loss.item())
            self.tracker.update_meter("tar_rec_loss", "train", tar_rec_loss.item())
            '''

            loss_latent_face_word = sum(
                [self.reclatent_loss(net_out_word["rec_face"][i], loaded_data["latent_face_top"][i]) for i in
                 range(len(net_out_word["rec_face"]))])
            loss_latent_lower_word = sum(
                [self.reclatent_loss(net_out_word["rec_lower"][i], loaded_data["latent_lower_top"][i]) for i in range(
                    len(net_out_word[
                            "rec_face"]))])  # self.reclatent_loss(net_out_val["rec_lower"], loaded_data["latent_lower_top"])
            loss_latent_hands_word = sum(
                [self.reclatent_loss(net_out_word["rec_hands"][i], loaded_data["latent_hands_top"][i]) for i in range(
                    len(net_out_word[
                            "rec_face"]))])  # self.reclatent_loss(net_out_val["rec_hands"], loaded_data["latent_hands_top"])
            loss_latent_upper_word = sum(
                [self.reclatent_loss(net_out_word["rec_upper"][i], loaded_data["latent_upper_top"][i]) for i in range(
                    len(net_out_word[
                            "rec_face"]))])  # self.reclatent_loss(net_out_val["rec_upper"], loaded_data["latent_upper_top"])
            loss_latent_word = self.args.lf * loss_latent_face_word + self.args.ll * loss_latent_lower_word + self.args.lh * loss_latent_hands_word + self.args.lu * loss_latent_upper_word
            '''
            loss_latent_face_word = self.reclatent_loss(net_out_word["rec_face"], loaded_data["latent_face_top"])
            loss_latent_lower_word = self.reclatent_loss(net_out_word["rec_lower"], loaded_data["latent_lower_top"])
            loss_latent_hands_word = self.reclatent_loss(net_out_word["rec_hands"], loaded_data["latent_hands_top"])
            loss_latent_upper_word = self.reclatent_loss(net_out_word["rec_upper"], loaded_data["latent_upper_top"])
            loss_latent_word = self.args.lf * loss_latent_face_word + self.args.ll * loss_latent_lower_word + self.args.lh * loss_latent_hands_word + self.args.lu * loss_latent_upper_word
            '''
            self.tracker.update_meter("latent_word", "train", loss_latent_word.item())
            #g_loss_final += loss_latent_word

            rec_index_face_word = ([self.log_softmax(net_out_word["cls_face"][i]).reshape(-1, self.args.vae_codebook_size)
                                   for i in range(len(net_out_word["cls_face"]))])
            rec_index_upper_word = ([
                self.log_softmax(net_out_word["cls_upper"][i]).reshape(-1, self.args.vae_codebook_size) for i in range(
                    len(net_out_word[
                            "cls_face"]))])  # self.log_softmax(net_out_val["cls_upper"]).reshape(-1, self.args.vae_codebook_size)
            rec_index_lower_word = ([
                self.log_softmax(net_out_word["cls_lower"][i]).reshape(-1, self.args.vae_codebook_size) for i in range(
                    len(net_out_word[
                            "cls_face"]))])  # self.log_softmax(net_out_val["cls_lower"]).reshape(-1, self.args.vae_codebook_size)
            rec_index_hands_word = ([
                self.log_softmax(net_out_word["cls_hands"][i]).reshape(-1, self.args.vae_codebook_size) for i in range(
                    len(net_out_word[
                            "cls_face"]))])  # self.log_softmax(net_out_val["cls_hands"]).reshape(-1, self.args.vae_codebook_size)

            face_cls_loss_word = sum([self.cls_loss(rec_index_face_word[i], tar_index_value_face_top[i]) for i in
                                      range(len(rec_index_face_word))])/len(rec_index_face_val)
            upper_cls_loss_word = sum([self.cls_loss(rec_index_upper_word[i], tar_index_value_upper_top[i]) for i in
                                       range(len(rec_index_face_word))])/len(rec_index_face_val)
            lower_cls_loss_word = sum([self.cls_loss(rec_index_lower_word[i], tar_index_value_lower_top[i]) for i in
                                       range(len(rec_index_face_word))])/len(rec_index_face_val)
            hands_cls_loss_word = sum([self.cls_loss(rec_index_hands_word[i], tar_index_value_hands_top[i]) for i in
                                       range(len(rec_index_face_word))])/len(rec_index_face_val)

            index_loss_top_word = face_cls_loss_word + upper_cls_loss_word + lower_cls_loss_word + hands_cls_loss_word
            '''
            rec_index_face_word = self.log_softmax(net_out_word["cls_face"]).reshape(-1, self.args.vae_codebook_size)
            rec_index_upper_word = self.log_softmax(net_out_word["cls_upper"]).reshape(-1, self.args.vae_codebook_size)
            rec_index_lower_word = self.log_softmax(net_out_word["cls_lower"]).reshape(-1, self.args.vae_codebook_size)
            rec_index_hands_word = self.log_softmax(net_out_word["cls_hands"]).reshape(-1, self.args.vae_codebook_size)
            index_loss_top_word = self.cls_loss(rec_index_face_word, tar_index_value_face_top) + self.cls_loss(rec_index_upper_word, tar_index_value_upper_top) + self.cls_loss(rec_index_lower_word,tar_index_value_lower_top) + self.cls_loss(rec_index_hands_word, tar_index_value_hands_top)
            '''
            self.tracker.update_meter("cls_word", "train", index_loss_top_word.item())
            #g_loss_final += index_loss_top_word




        #if mode != 'train':
        if self.args.cu != 0:
            rec_index_upper = [torch.max(
                rec_index_upper_val[i].reshape(self.args.batch_size, -1, self.args.vae_codebook_size), dim=2)[1] for i in range(len(rec_index_upper_val))]
            rec_upper = self.vq_model_upper.decode(rec_index_upper)
        else:
            _, rec_index_upper, _, _ = self.vq_model_upper.quantizer(net_out_val["rec_upper"])
            rec_upper = self.vq_model_upper.decoder(rec_index_upper)
        if self.args.cl != 0:
            rec_index_lower = [torch.max(
                rec_index_lower_val[i].reshape(self.args.batch_size, -1, self.args.vae_codebook_size), dim=2)[1] for i in range(len(rec_index_upper_val))]
            rec_lower = self.vq_model_lower.decode(rec_index_lower)
        else:
            _, rec_index_lower, _, _ = self.vq_model_lower.quantizer(net_out_val["rec_lower"])
            rec_lower = self.vq_model_lower.decoder(rec_index_lower)
        if self.args.ch != 0:
            rec_index_hands = [torch.max(
                rec_index_hands_val[i].reshape(self.args.batch_size, -1, self.args.vae_codebook_size), dim=2)[1] for i in range(len(rec_index_upper_val))]
            rec_hands = self.vq_model_hands.decode(rec_index_hands)
        else:
            _, rec_index_hands, _, _ = self.vq_model_hands.quantizer(net_out_val["rec_hands"])
            rec_hands = self.vq_model_hands.decoder(rec_index_hands)
        if self.args.cf != 0:
            rec_index_face = [torch.max(
                rec_index_face_val[i].reshape(self.args.batch_size, -1, self.args.vae_codebook_size), dim=2)[1] for i in range(len(rec_index_upper_val))]
            rec_face = self.vq_model_face.decode(rec_index_face)
        else:
            _, rec_index_face, _, _ = self.vq_model_face.quantizer(net_out_val["rec_face"])
            rec_face = self.vq_model_face.decoder(rec_index_face)

        rec_pose_jaw = rec_face[:, :, :6]
        rec_pose_legs = rec_lower[:, :, :54]

        rec_pose_upper = rec_upper.reshape(bs, n, 13, 6)
        rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)  #
        rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs * n, 13 * 3)
        rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs * n)
        rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6)
        rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower)
        rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs * n, 9 * 3)
        rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs * n)
        rec_pose_hands = rec_hands.reshape(bs, n, 30, 6)
        rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands)
        rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs * n, 30 * 3)
        rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs * n)
        rec_pose_jaw = rec_pose_jaw.reshape(bs * n, 6)
        rec_pose_jaw = rc.rotation_6d_to_matrix(rec_pose_jaw)
        rec_pose_jaw = rc.matrix_to_axis_angle(rec_pose_jaw).reshape(bs * n, 1 * 3)
        rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover
        rec_pose[:, 66:69] = rec_pose_jaw
        # print(rec_pose.shape, tar_pose.shape)

        # tar_trans = loaded_data["tar_trans"]
        # rec_trans_v_s = rec_lower[:, :, 54:57]
        # rec_x_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 0:1], 1/self.args.pose_fps, tar_trans[:, 0, 0:1])
        # rec_z_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 2:3], 1/self.args.pose_fps, tar_trans[:, 0, 2:3])
        # rec_y_trans = rec_trans_v_s[:,:,1:2]
        # rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1)
        # tar_pose = loaded_data["tar_pose"]
        # tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3))
        # tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
        rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, j, 3))
        rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j * 6)

        if mode == 'train':
            return {
                'rec_pose': rec_pose,
                # rec_trans': rec_pose_trans,
                'tar_pose': loaded_data["tar_pose_6d"],
                'g_loss_final': g_loss_final,
            }
        elif mode == 'val':
            return {
                'rec_pose': rec_pose,
                # rec_trans': rec_pose_trans,
                'tar_pose': loaded_data["tar_pose_6d"],
            }
        else:
            return {
                'rec_pose': rec_pose,
                # 'rec_trans': rec_trans,
                'tar_pose': loaded_data["tar_pose"],
                'tar_exps': loaded_data["tar_exps"],
                'tar_beta': loaded_data["tar_beta"],
                'tar_trans': loaded_data["tar_trans"],
                # 'rec_exps': rec_exps,
            }

    def _g_test(self, loaded_data):
        mode = 'test'

        bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints
        tar_pose = loaded_data["tar_pose"]
        tar_beta = loaded_data["tar_beta"]
        in_word = loaded_data["in_word"]
        tar_exps = loaded_data["tar_exps"]
        tar_contact = loaded_data["tar_contact"]
        in_audio = loaded_data["in_audio"]
        in_caption = loaded_data['in_caption']
        tar_trans = loaded_data["tar_trans"]
        tar_index_value_face_top = [loaded_data["tar_index_value_face_top"][i] for i in
                                    range(len(loaded_data["tar_index_value_face_top"]))]
        tar_index_value_upper_top = [loaded_data["tar_index_value_upper_top"][i] for i in range(
            len(loaded_data["tar_index_value_face_top"]))]  # loaded_data["tar_index_value_upper_top"].reshape(-1)
        tar_index_value_lower_top = [loaded_data["tar_index_value_lower_top"][i] for i in range(
            len(loaded_data["tar_index_value_face_top"]))]  # loaded_data["tar_index_value_lower_top"].reshape(-1)
        tar_index_value_hands_top = [loaded_data["tar_index_value_hands_top"][i] for i in range(
            len(loaded_data["tar_index_value_face_top"]))]  # loaded_data["tar_index_value_hands_top"].reshape(-1)
        rec_index_face_tar = tar_index_value_face_top
        rec_index_upper_tar = tar_index_value_upper_top
        rec_index_lower_tar = tar_index_value_lower_top
        rec_index_hands_tar = tar_index_value_hands_top


        remain = n % 8
        if remain != 0:
            tar_pose = tar_pose[:, :-remain, :]
            tar_beta = tar_beta[:, :-remain, :]
            tar_trans = tar_trans[:, :-remain, :]
            in_word = in_word[:, :-remain]
            in_caption = in_caption[:, :-remain]
            tar_exps = tar_exps[:, :-remain, :]
            tar_contact = tar_contact[:, :-remain, :]

            rec_index_face_tar = [rec_index_face_tar[k][:, :int((len(tar_index_value_face_top[0][0])-remain)/(2**k))] for k in range(len(rec_index_face_tar))]
            rec_index_upper_tar = [rec_index_upper_tar[k][:, :int((len(tar_index_value_face_top[0][0])-remain)/(2**k))] for k in range(len(rec_index_face_tar))]
            rec_index_lower_tar = [rec_index_lower_tar[k][:, :int((len(tar_index_value_face_top[0][0])-remain)/(2**k))] for k in range(len(rec_index_face_tar))]
            rec_index_hands_tar = [rec_index_hands_tar[k][:, :int((len(tar_index_value_face_top[0][0])-remain)/(2**k))] for k in range(len(rec_index_face_tar))]
            n = n - remain

        #for i in range(len(rec_index_face_tar)):
        #if len(rec_index_face_tar[0][0]) != len(in_caption[0]):
        #    print("Error")

        rec_index_face_tarlist = []
        rec_index_upper_tarlist = []
        rec_index_lower_tarlist = []
        rec_index_hands_tarlist = []

        tar_pose_jaw = tar_pose[:, :, 66:69]
        tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
        tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1 * 6)
        tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)

        tar_pose_hands = tar_pose[:, :, 25 * 3:55 * 3]
        tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
        tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30 * 6)

        tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)]
        tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
        tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13 * 6)

        tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)]
        tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
        tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9 * 6)
        tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2)

        tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
        tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55 * 6)
        latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)

        rec_index_all_face = []
        rec_index_all_upper = []
        rec_index_all_lower = []
        rec_index_all_hands = []

        rec_softindex_all_face = []
        rec_softindex_all_upper = []
        rec_softindex_all_lower = []
        rec_softindex_all_hands = []

        quantized_loss = []
        quantized_face_loss = []
        decode_loss = []
        decode_face_loss = []

        # rec_index_all_face_bot = []
        # rec_index_all_upper_bot = []
        # rec_index_all_lower_bot = []
        # rec_index_all_hands_bot = []

        roundt = (n - self.args.pre_frames) // (self.args.pose_length - self.args.pre_frames)
        remain = (n - self.args.pre_frames) % (self.args.pose_length - self.args.pre_frames)
        round_l = self.args.pose_length - self.args.pre_frames

        # pad latent_all_9 to the same length with latent_all
        # if n - latent_all_9.shape[1] >= 0:
        #     latent_all = torch.cat([latent_all_9, torch.zeros(bs, n - latent_all_9.shape[1], latent_all_9.shape[2]).cuda()], dim=1)
        # else:
        #     latent_all = latent_all_9[:, :n, :]

        for i in range(0, roundt):
            in_word_tmp = in_word[:, i * (round_l):(i + 1) * (round_l) + self.args.pre_frames]
            # audio fps is 16000 and pose fps is 30
            in_audio_tmp = in_audio[:, i * (16000 // 30 * round_l):(i + 1) * (
                        16000 // 30 * round_l) + 16000 // 30 * self.args.pre_frames]
            in_caption_tmp = in_caption[:, i * (round_l):(i + 1) * (round_l) + self.args.pre_frames]
            in_id_tmp = loaded_data['tar_id'][:, i * (round_l):(i + 1) * (round_l) + self.args.pre_frames]
            rec_index_face_tmp = [rec_index_face_tar[k][:, i * (int(round_l/(2**k))):(i + 1) * (int(round_l/(2**k))) + int(self.args.pre_frames/(2**k))] for k in range(len(rec_index_face_tar)) ]
            rec_index_upper_tmp = [rec_index_upper_tar[k][:, i * (int(round_l/(2**k))):(i + 1) * (int(round_l/(2**k))) + int(self.args.pre_frames/(2**k))] for k in range(len(rec_index_face_tar)) ]
            rec_index_lower_tmp = [rec_index_lower_tar[k][:, i * (int(round_l/(2**k))):(i + 1) * (int(round_l/(2**k))) + int(self.args.pre_frames/(2**k))] for k in range(len(rec_index_face_tar)) ]
            rec_index_hands_tmp = [rec_index_hands_tar[k][:, i * (int(round_l/(2**k))):(i + 1) * (int(round_l/(2**k))) + int(self.args.pre_frames/(2**k))] for k in range(len(rec_index_face_tar)) ]
            mask_val = torch.ones(bs, self.args.pose_length, self.args.pose_dims + 3 + 4).float().cuda()
            mask_val[:, :self.args.pre_frames, :] = 0.0
            if i == 0:
                latent_all_tmp = latent_all[:, i * (round_l):(i + 1) * (round_l) + self.args.pre_frames, :]
            else:
                latent_all_tmp = latent_all[:, i * (round_l):(i + 1) * (round_l) + self.args.pre_frames, :]
                # print(latent_all_tmp.shape, latent_last.shape)
                latent_all_tmp[:, :self.args.pre_frames, :] = latent_last[:, -self.args.pre_frames:, :]

            net_out_val = self.model(
                in_audio=in_audio_tmp,
                in_word=in_word_tmp,
                in_caption=in_caption_tmp,
                mask=mask_val,
                in_motion=latent_all_tmp,
                in_id=in_id_tmp,
                use_attentions=True, )

            if self.args.cu != 0:
                rec_index_upper = [self.log_softmax(net_out_val["cls_upper"][k]).reshape(-1, self.args.vae_codebook_size) for k in range(len(net_out_val["cls_face"]))]
                rec_index_upper = [torch.max(
                    rec_index_upper[k].reshape(-1, rec_index_upper[k].shape[0], self.args.vae_codebook_size), dim=2)[1] for k in range(len(rec_index_upper))]
                rec_softindex_upper = [torch.softmax(net_out_val["cls_upper"][k], dim=1) for k in range(len(net_out_val["cls_face"]))]
                #rec_upper = self.vq_model_upper.decode(rec_index_upper)
            else:
                _, rec_index_upper, _, _ = self.vq_model_upper.quantizer(net_out_val["rec_upper"])
                #rec_upper = self.vq_model_upper.decoder(rec_index_upper)
            if self.args.cl != 0:
                rec_index_lower = [self.log_softmax(net_out_val["cls_lower"][k]).reshape(-1, self.args.vae_codebook_size) for k in range(len(net_out_val["cls_face"]))]
                rec_index_lower = [torch.max(
                    rec_index_lower[k].reshape(-1, rec_index_lower[k].shape[0], self.args.vae_codebook_size), dim=2)[1] for k in range(len(rec_index_upper))]
                rec_softindex_lower = [torch.softmax(net_out_val["cls_lower"][k], dim=1) for k in range(len(net_out_val["cls_face"]))]
                #rec_lower = self.vq_model_lower.decode(rec_index_lower)
            else:
                _, rec_index_lower, _, _ = self.vq_model_lower.quantizer(net_out_val["rec_lower"])
                #rec_lower = self.vq_model_lower.decoder(rec_index_lower)
            if self.args.ch != 0:
                rec_index_hands = [self.log_softmax(net_out_val["cls_hands"][k]).reshape(-1, self.args.vae_codebook_size) for k in range(len(net_out_val["cls_face"]))]
                rec_index_hands = [torch.max(rec_index_hands[k].reshape(-1, rec_index_hands[k].shape[0], self.args.vae_codebook_size), dim=2)[1] for k in range(len(rec_index_upper))]
                rec_softindex_hands = [torch.softmax(net_out_val["cls_hands"][k], dim=1) for k in range(len(net_out_val["cls_face"]))]

                #rec_hands = self.vq_model_hands.decode(rec_index_hands)
            else:
                _, rec_index_hands, _, _ = self.vq_model_hands.quantizer(net_out_val["rec_hands"])
                #rec_hands = self.vq_model_hands.decoder(rec_index_hands)
            if self.args.cf != 0:
                rec_index_face = [self.log_softmax(net_out_val["cls_face"][k]).reshape(-1, self.args.vae_codebook_size) for k in range(len(net_out_val["cls_face"]))]
                rec_index_face = [torch.max(rec_index_face[k].reshape(-1, rec_index_face[k].shape[0], self.args.vae_codebook_size), dim=2)[1] for k in range(len(rec_index_upper))]
                rec_softindex_face = [torch.softmax(net_out_val["cls_face"][k], dim=1) for k in range(len(net_out_val["cls_face"]))]

                #rec_face = self.vq_model_face.decode(rec_index_face)
            else:
                _, rec_index_face, _, _ = self.vq_model_face.quantizer(net_out_val["rec_face"])
                #rec_face = self.vq_model_face.decoder(rec_index_face)

            if i == 0:
                rec_index_face_tarlist = [[rec_index_face_tmp[k]] for k in range(len(rec_index_face_tar))]
                rec_index_upper_tarlist = [[rec_index_upper_tmp[k]] for k in range(len(rec_index_face_tar))]
                rec_index_lower_tarlist = [[rec_index_lower_tmp[k]] for k in range(len(rec_index_face_tar))]
                rec_index_hands_tarlist = [[rec_index_hands_tmp[k]] for k in range(len(rec_index_face_tar))]
                rec_index_all_face = [[rec_index_face[k]] for k in range(len(rec_index_face))]
                rec_index_all_upper = [[rec_index_upper[k]] for k in range(len(rec_index_upper))]
                rec_index_all_lower = [[rec_index_lower[k]] for k in range(len(rec_index_lower))]
                rec_index_all_hands = [[rec_index_hands[k]] for k in range(len(rec_index_hands))]
                #rec_softindex_all_face = [[rec_softindex_face[k]] for k in range(len(rec_softindex_face))]  # .append(rec_index_face[:, self.args.pre_frames:])
                #rec_softindex_all_upper = [[rec_softindex_upper[k]] for k in range(len(rec_softindex_upper))]  # .append(rec_index_upper[:, self.args.pre_frames:])
                #rec_softindex_all_lower = [[rec_softindex_lower[k]] for k in range(len(rec_softindex_lower))]  # .append(rec_index_lower[:, self.args.pre_frames:])
                #rec_softindex_all_hands = [[rec_softindex_hands[k]] for k in range(len(rec_softindex_hands))]  # .append(rec_index_hands[:, self.args.pre_frames:])

            else:
                rec_index_face_tarlist = [rec_index_face_tarlist[k] + [rec_index_face_tmp[k][:, int(self.args.pre_frames/(2**k)):]] for k in range(len(rec_index_face_tar))]
                rec_index_upper_tarlist = [rec_index_upper_tarlist[k] + [rec_index_upper_tmp[k][:, int(self.args.pre_frames/(2**k)):]] for k in range(len(rec_index_face_tar))]
                rec_index_lower_tarlist = [rec_index_lower_tarlist[k] + [rec_index_lower_tmp[k][:, int(self.args.pre_frames/(2**k)):]] for k in range(len(rec_index_face_tar))]
                rec_index_hands_tarlist = [rec_index_hands_tarlist[k] + [rec_index_hands_tmp[k][:, int(self.args.pre_frames/(2**k)):]] for k in range(len(rec_index_face_tar))]
                rec_index_all_face = [rec_index_all_face[k]+[rec_index_face[k][:, int(self.args.pre_frames/(2**k)):]] for k in range(len(rec_index_face))]#.append(rec_index_face[:, self.args.pre_frames:])
                rec_index_all_upper = [rec_index_all_upper[k]+[rec_index_upper[k][:, int(self.args.pre_frames/(2**k)):]] for k in range(len(rec_index_upper))]#.append(rec_index_upper[:, self.args.pre_frames:])
                rec_index_all_lower = [rec_index_all_lower[k]+[rec_index_lower[k][:, int(self.args.pre_frames/(2**k)):]] for k in range(len(rec_index_lower))]#.append(rec_index_lower[:, self.args.pre_frames:])
                rec_index_all_hands = [rec_index_all_hands[k]+[rec_index_hands[k][:, int(self.args.pre_frames/(2**k)):]] for k in range(len(rec_index_hands))]#.append(rec_index_hands[:, self.args.pre_frames:])
                #rec_softindex_all_face = [rec_softindex_all_face[k] + [rec_softindex_face[k][:, int(self.args.pre_frames / (2 ** k)):,:]] for k in range(len(rec_softindex_face))]  # .append(rec_index_face[:, self.args.pre_frames:])
                #rec_softindex_all_upper = [
                #    rec_softindex_all_upper[k] + [rec_softindex_upper[k][:, int(self.args.pre_frames / (2 ** k)):,:]] for k in
                #    range(len(rec_softindex_upper))]  # .append(rec_index_upper[:, self.args.pre_frames:])
                #rec_softindex_all_lower = [
                #    rec_softindex_all_lower[k] + [rec_softindex_lower[k][:, int(self.args.pre_frames / (2 ** k)):,:]] for k in
                #    range(len(rec_softindex_lower))]  # .append(rec_index_lower[:, self.args.pre_frames:])
                #rec_softindex_all_hands = [
                #    rec_softindex_all_hands[k] + [rec_softindex_hands[k][:, int(self.args.pre_frames / (2 ** k)):,:]] for k in
                #    range(len(rec_softindex_hands))]  # .append(rec_index_hands[:, self.args.pre_frames:])

            if self.args.cf != 0:
                face_cls_loss = sum([self.style_loss(self.vq_model_upper.quantizer.get_codebook_entry(rec_index_face)[i],
                                                      self.vq_model_upper.quantizer.get_codebook_entry(rec_index_face_tmp)[
                                                          i]) for i in range(len(rec_index_upper))]) / len(rec_index_upper)
                quantized_face_loss.append(face_cls_loss)
            else:
                face_cls_loss = sum(
                    [self.style_loss(rec_index_face[i],
                                     self.vq_model_upper.quantizer.get_codebook_entry(rec_index_face_tmp)[
                                         i]) for i in range(len(rec_index_upper))]) / len(rec_index_upper)
                quantized_face_loss.append(face_cls_loss)
            upper_cls_loss = sum([self.style_loss(self.vq_model_upper.quantizer.get_codebook_entry(rec_index_upper)[i], self.vq_model_upper.quantizer.get_codebook_entry(rec_index_upper_tmp)[i]) for i in
                                  range(len(rec_index_upper))]) / len(rec_index_upper)
            lower_cls_loss = sum([self.style_loss(self.vq_model_lower.quantizer.get_codebook_entry(rec_index_lower)[i], self.vq_model_lower.quantizer.get_codebook_entry(rec_index_lower_tmp)[i]) for i in
                                  range(len(rec_index_upper))]) / len(rec_index_upper)
            hands_cls_loss = sum([self.style_loss(self.vq_model_hands.quantizer.get_codebook_entry(rec_index_hands)[i], self.vq_model_hands.quantizer.get_codebook_entry(rec_index_hands_tmp)[i]) for i in
                                  range(len(rec_index_upper))]) / len(rec_index_upper)
            quantized_loss.append(upper_cls_loss+lower_cls_loss+hands_cls_loss)

            if self.args.cf != 0:
                face_decode_loss = self.style_loss(self.vq_model_upper.decode(rec_index_face), self.vq_model_upper.decode(rec_index_face_tmp))
                decode_face_loss.append(face_decode_loss)
            else:
                face_decode_loss = self.style_loss(self.vq_model_upper.decoder(rec_index_face),
                                                   self.vq_model_upper.decode(rec_index_face_tmp))
                decode_face_loss.append(face_decode_loss)
            upper_decode_loss = self.style_loss(self.vq_model_upper.decode(rec_index_upper), self.vq_model_upper.decode(rec_index_upper_tmp))
            lower_decode_loss = self.style_loss(self.vq_model_upper.decode(rec_index_lower), self.vq_model_upper.decode(rec_index_lower_tmp))
            hands_decode_loss = self.style_loss(self.vq_model_upper.decode(rec_index_hands), self.vq_model_upper.decode(rec_index_hands_tmp))
            decode_loss.append(upper_decode_loss+lower_decode_loss+hands_decode_loss)

            if self.args.cu != 0:
                rec_upper_last = self.vq_model_upper.decode(rec_index_upper)
                #rec_upper_last = self.vq_model_upper.decoder(rec_softindex_upper)
            else:
                rec_upper_last = self.vq_model_upper.decoder(rec_index_upper)
            if self.args.cl != 0:
                rec_lower_last = self.vq_model_lower.decode(rec_index_lower)
                #rec_lower_last = self.vq_model_lower.decoder(rec_softindex_lower)
            else:
                rec_lower_last = self.vq_model_lower.decoder(rec_index_lower)
            if self.args.ch != 0:
                rec_hands_last = self.vq_model_hands.decode(rec_index_hands)
                #rec_hands_last = self.vq_model_hands.decoder(rec_softindex_hands)
            else:
                rec_hands_last = self.vq_model_hands.decoder(rec_index_hands)
            #if self.args.cf != 0:
            #    rec_face_last = self.vq_model_face.decode(rec_index_face)
            #else:
            #    rec_face_last = self.vq_model_face.decoder(rec_index_face)

            rec_pose_legs = rec_lower_last[:, :, :54]
            bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1]
            rec_pose_upper = rec_upper_last.reshape(bs, n, 13, 6)
            rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)  #
            rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs * n, 13 * 3)
            rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs * n)
            rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6)
            rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower)
            rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs * n, 9 * 3)
            rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs * n)
            rec_pose_hands = rec_hands_last.reshape(bs, n, 30, 6)
            rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands)
            rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs * n, 30 * 3)
            rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs * n)
            rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover
            rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, j, 3))
            rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j * 6)
            rec_trans_v_s = rec_lower_last[:, :, 54:57]
            rec_x_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 0:1], 1 / self.args.pose_fps,
                                                        tar_trans[:, 0, 0:1])
            rec_z_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 2:3], 1 / self.args.pose_fps,
                                                        tar_trans[:, 0, 2:3])
            rec_y_trans = rec_trans_v_s[:, :, 1:2]
            rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1)
            latent_last = torch.cat([rec_pose, rec_trans, rec_lower_last[:, :, 57:61]], dim=-1)

        rec_index_face_tar_full = [torch.cat(rec_index_face_tarlist[i], dim=1) for i in range(len(rec_index_face_tarlist))]
        rec_index_upper_tar_full = [torch.cat(rec_index_upper_tarlist[i], dim=1) for i in range(len(rec_index_face_tarlist))]  # torch.cat(rec_index_all_upper, dim=1)
        rec_index_lower_tar_full = [torch.cat(rec_index_lower_tarlist[i], dim=1) for i in range(len(rec_index_face_tarlist))]  # torch.cat(rec_index_all_lower, dim=1)
        rec_index_hands_tar_full = [torch.cat(rec_index_hands_tarlist[i], dim=1) for i in range(len(rec_index_face_tarlist))]

        rec_index_face = [torch.cat(rec_index_all_face[i], dim=1) for i in range(len(rec_index_all_face))]
        rec_index_upper = [torch.cat(rec_index_all_upper[i], dim=1) for i in range(len(rec_index_all_upper))]#torch.cat(rec_index_all_upper, dim=1)
        rec_index_lower = [torch.cat(rec_index_all_lower[i], dim=1) for i in range(len(rec_index_all_lower))]#torch.cat(rec_index_all_lower, dim=1)
        rec_index_hands = [torch.cat(rec_index_all_hands[i], dim=1) for i in range(len(rec_index_all_hands))]#torch.cat(rec_index_all_hands, dim=1)

        for i in range(len(rec_index_upper_tar_full)):
            if len(rec_index_face[i][0]) != len(rec_index_upper_tar_full[i][0]):
                print("Error")
        # Hard Decode
        if self.args.cu != 0:
            rec_upper = self.vq_model_upper.decode(rec_index_upper)
        else:
            rec_upper = self.vq_model_upper.decoder(rec_index_upper)
        if self.args.cl != 0:
            rec_lower = self.vq_model_lower.decode(rec_index_lower)
        else:
            rec_lower = self.vq_model_lower.decoder(rec_index_lower)
        if self.args.ch != 0:
            rec_hands = self.vq_model_hands.decode(rec_index_hands)
        else:
            rec_hands = self.vq_model_hands.decoder(rec_index_hands)
        if self.args.cf != 0:
            rec_face = self.vq_model_face.decode(rec_index_face)
        else:
            rec_face = self.vq_model_face.decoder(rec_index_face)

        rec_exps = rec_face[:, :, 6:]
        rec_pose_jaw = rec_face[:, :, :6]
        rec_pose_legs = rec_lower[:, :, :54]
        bs, n = rec_pose_jaw.shape[0], rec_pose_jaw.shape[1]
        rec_pose_upper = rec_upper.reshape(bs, n, 13, 6)
        rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)  #
        rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs * n, 13 * 3)
        rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs * n)
        rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6)
        rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower)
        rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9 * 6)
        rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs * n, 9 * 3)
        rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs * n)
        rec_pose_hands = rec_hands.reshape(bs, n, 30, 6)
        rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands)
        rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs * n, 30 * 3)
        rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs * n)
        rec_pose_jaw = rec_pose_jaw.reshape(bs * n, 6)
        rec_pose_jaw = rc.rotation_6d_to_matrix(rec_pose_jaw)
        rec_pose_jaw = rc.matrix_to_axis_angle(rec_pose_jaw).reshape(bs * n, 1 * 3)
        rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover
        rec_pose[:, 66:69] = rec_pose_jaw

        to_global = rec_lower
        to_global[:, :, 54:57] = 0.0
        to_global[:, :, :54] = rec_lower2global
        rec_global = self.global_motion(to_global)

        rec_trans_v_s = rec_global["rec_pose"][:, :, 54:57]
        rec_x_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 0:1], 1 / self.args.pose_fps,
                                                    tar_trans[:, 0, 0:1])
        rec_z_trans = other_tools.velocity2position(rec_trans_v_s[:, :, 2:3], 1 / self.args.pose_fps,
                                                    tar_trans[:, 0, 2:3])
        rec_y_trans = rec_trans_v_s[:, :, 1:2]
        rec_trans = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1)
        tar_pose = tar_pose[:, :n, :]
        tar_exps = tar_exps[:, :n, :]
        tar_trans = tar_trans[:, :n, :]
        tar_beta = tar_beta[:, :n, :]

        rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs * n, j, 3))
        rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j * 6)
        tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs * n, j, 3))
        tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j * 6)

        return {
            'rec_pose': rec_pose,
            'rec_trans': rec_trans,
            'tar_pose': tar_pose,
            'tar_exps': tar_exps,
            'tar_beta': tar_beta,
            'tar_trans': tar_trans,
            'rec_exps': rec_exps,
        }

    def train(self, epoch):
        # torch.autograd.set_detect_anomaly(True)
        use_adv = bool(epoch >= self.args.no_adv_epoch)
        self.model.train()
        # self.d_model.train()
        t_start = time.time()
        self.tracker.reset()
        full_rec_pose = []
        full_tar_pose = []
        full_caption = []

        for its, batch_data in enumerate(self.train_loader):
            loaded_data = self._load_data(batch_data)
            t_data = time.time() - t_start

            self.opt.zero_grad()
            g_loss_final = 0
            net_out = self._g_training(loaded_data, use_adv, 'train', epoch)
            g_loss_final += net_out["g_loss_final"]
            #full_rec_pose.append(net_out['tar_pose'])
            #full_tar_pose.append(net_out['rec_pose'])
            #full_caption.append(loaded_data['in_caption'])
            # with torch.autograd.detect_anomaly():
            g_loss_final.backward()
            if self.args.grad_norm != 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm)
            self.opt.step()

            mem_cost = torch.cuda.memory_cached() / 1E9
            lr_g = self.opt.param_groups[0]['lr']
            # lr_d = self.opt_d.param_groups[0]['lr']
            t_train = time.time() - t_start - t_data
            t_start = time.time()
            if its % self.args.log_period == 0:
                self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g)
            if self.args.debug:
                if its == 1: break
        self.opt_s.step(epoch)
        # self.opt_d_s.step(epoch)

    def val(self, epoch):
        self.model.eval()
        # self.d_model.eval()
        with torch.no_grad():
            for its, batch_data in enumerate(self.train_loader):
                loaded_data = self._load_data(batch_data)
                net_out = self._g_training(loaded_data, False, 'val', epoch)
                tar_pose = net_out['tar_pose']
                rec_pose = net_out['rec_pose']
                if (30 / self.args.pose_fps) != 1:
                    assert 30 % self.args.pose_fps == 0
                    n *= int(30 / self.args.pose_fps)
                    tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1),
                                                               scale_factor=30 / self.args.pose_fps,
                                                               mode='linear').permute(0, 2, 1)
                    rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1),
                                                               scale_factor=30 / self.args.pose_fps,
                                                               mode='linear').permute(0, 2, 1)
                n = tar_pose.shape[1]
                remain = n % self.args.vae_test_len
                tar_pose = tar_pose[:, :n - remain, :]
                rec_pose = rec_pose[:, :n - remain, :]
                latent_out = self.eval_copy.map2latent(rec_pose).reshape(-1, self.args.vae_length).cpu().numpy()
                latent_ori = self.eval_copy.map2latent(tar_pose).reshape(-1, self.args.vae_length).cpu().numpy()
                if its == 0:
                    latent_out_motion_all = latent_out
                    latent_ori_all = latent_ori
                else:
                    latent_out_motion_all = np.concatenate([latent_out_motion_all, latent_out], axis=0)
                    latent_ori_all = np.concatenate([latent_ori_all, latent_ori], axis=0)
                if self.args.debug:
                    if its == 1: break
        fid_motion = data_tools.FIDCalculator.frechet_distance(latent_out_motion_all, latent_ori_all)
        self.tracker.update_meter("fid", "val", fid_motion)
        self.val_recording(epoch)

    def test(self, epoch):

        results_save_path = self.checkpoint_path + f"/{epoch}/"
        if os.path.exists(results_save_path):
            return 0
        os.makedirs(results_save_path)
        start_time = time.time()
        total_length = 0
        test_seq_list = self.test_data.selected_file
        align = 0
        latent_out = []
        latent_ori = []
        l2_all = 0
        lvel = 0
        self.model.eval()
        self.smplx.eval()
        self.eval_copy.eval()
        with torch.no_grad():
            for its, batch_data in enumerate(self.test_loader):
                loaded_data = self._load_data(batch_data)
                net_out = self._g_test(loaded_data)
                tar_pose = net_out['tar_pose']
                rec_pose = net_out['rec_pose']
                tar_exps = net_out['tar_exps']
                tar_beta = net_out['tar_beta']
                rec_trans = net_out['rec_trans']
                tar_trans = net_out['tar_trans']
                rec_exps = net_out['rec_exps']
                # print(rec_pose.shape, tar_pose.shape)
                bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
                if (30 / self.args.pose_fps) != 1:
                    assert 30 % self.args.pose_fps == 0
                    n *= int(30 / self.args.pose_fps)
                    tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1),
                                                               scale_factor=30 / self.args.pose_fps,
                                                               mode='linear').permute(0, 2, 1)
                    rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1),
                                                               scale_factor=30 / self.args.pose_fps,
                                                               mode='linear').permute(0, 2, 1)



                # print(rec_pose.shape, tar_pose.shape)
                rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs * n, j, 6))
                rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j * 6)
                tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs * n, j, 6))
                tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j * 6)
                remain = n % self.args.vae_test_len
                latent_out.append(self.eval_copy.map2latent(rec_pose[:, :n - remain]).reshape(-1,
                                                                                              self.args.vae_length).detach().cpu().numpy())  # bs * n/8 * 240
                latent_ori.append(self.eval_copy.map2latent(tar_pose[:, :n - remain]).reshape(-1,
                                                                                              self.args.vae_length).detach().cpu().numpy())

                rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs * n, j, 6))
                rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs * n, j * 3)
                tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs * n, j, 6))
                tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs * n, j * 3)

                vertices_rec = self.smplx(
                    betas=tar_beta.reshape(bs * n, 300),
                    transl=rec_trans.reshape(bs * n, 3) - rec_trans.reshape(bs * n, 3),
                    expression=tar_exps.reshape(bs * n, 100) - tar_exps.reshape(bs * n, 100),
                    jaw_pose=rec_pose[:, 66:69],
                    global_orient=tar_pose[:, :3],
                    body_pose=rec_pose[:, 3:21 * 3 + 3],
                    left_hand_pose=rec_pose[:, 25 * 3:40 * 3],
                    right_hand_pose=rec_pose[:, 40 * 3:55 * 3],
                    return_joints=True,
                    leye_pose=rec_pose[:, 69:72],
                    reye_pose=rec_pose[:, 72:75],
                )
                # vertices_tar = self.smplx(
                #         betas=tar_beta.reshape(bs*n, 300),
                #         transl=rec_trans.reshape(bs*n, 3)-rec_trans.reshape(bs*n, 3),
                #         expression=tar_exps.reshape(bs*n, 100)-tar_exps.reshape(bs*n, 100),
                #         jaw_pose=tar_pose[:, 66:69],
                #         global_orient=tar_pose[:,:3],
                #         body_pose=tar_pose[:,3:21*3+3],
                #         left_hand_pose=tar_pose[:,25*3:40*3],
                #         right_hand_pose=tar_pose[:,40*3:55*3],
                #         return_joints=True,
                #         leye_pose=tar_pose[:, 69:72],
                #         reye_pose=tar_pose[:, 72:75],
                #     )
                vertices_rec_face = self.smplx(
                    betas=tar_beta.reshape(bs * n, 300),
                    transl=rec_trans.reshape(bs * n, 3) - rec_trans.reshape(bs * n, 3),
                    expression=rec_exps.reshape(bs * n, 100),
                    jaw_pose=rec_pose[:, 66:69],
                    global_orient=rec_pose[:, :3] - rec_pose[:, :3],
                    body_pose=rec_pose[:, 3:21 * 3 + 3] - rec_pose[:, 3:21 * 3 + 3],
                    left_hand_pose=rec_pose[:, 25 * 3:40 * 3] - rec_pose[:, 25 * 3:40 * 3],
                    right_hand_pose=rec_pose[:, 40 * 3:55 * 3] - rec_pose[:, 40 * 3:55 * 3],
                    return_verts=True,
                    return_joints=True,
                    leye_pose=rec_pose[:, 69:72] - rec_pose[:, 69:72],
                    reye_pose=rec_pose[:, 72:75] - rec_pose[:, 72:75],
                )
                vertices_tar_face = self.smplx(
                    betas=tar_beta.reshape(bs * n, 300),
                    transl=tar_trans.reshape(bs * n, 3) - tar_trans.reshape(bs * n, 3),
                    expression=tar_exps.reshape(bs * n, 100),
                    jaw_pose=tar_pose[:, 66:69],
                    global_orient=tar_pose[:, :3] - tar_pose[:, :3],
                    body_pose=tar_pose[:, 3:21 * 3 + 3] - tar_pose[:, 3:21 * 3 + 3],
                    left_hand_pose=tar_pose[:, 25 * 3:40 * 3] - tar_pose[:, 25 * 3:40 * 3],
                    right_hand_pose=tar_pose[:, 40 * 3:55 * 3] - tar_pose[:, 40 * 3:55 * 3],
                    return_verts=True,
                    return_joints=True,
                    leye_pose=tar_pose[:, 69:72] - tar_pose[:, 69:72],
                    reye_pose=tar_pose[:, 72:75] - tar_pose[:, 72:75],
                )
                joints_rec = vertices_rec["joints"].detach().cpu().numpy().reshape(1, n, 127 * 3)[0, :n, :55 * 3]
                # joints_tar = vertices_tar["joints"].detach().cpu().numpy().reshape(1, n, 127*3)[0, :n, :55*3]
                facial_rec = vertices_rec_face['vertices'].reshape(1, n, -1)[0, :n]
                facial_tar = vertices_tar_face['vertices'].reshape(1, n, -1)[0, :n]
                face_vel_loss = self.vel_loss(facial_rec[1:, :] - facial_tar[:-1, :],
                                              facial_tar[1:, :] - facial_tar[:-1, :])
                l2 = self.reclatent_loss(facial_rec, facial_tar)
                l2_all += l2.item() * n
                lvel += face_vel_loss.item() * n

                _ = self.l1_calculator.run(joints_rec)
                if self.alignmenter is not None:
                    in_audio_eval, sr = librosa.load(
                        self.args.data_path + "wave16k/" + test_seq_list.iloc[its]['id'] + ".wav")
                    in_audio_eval = librosa.resample(in_audio_eval, orig_sr=sr, target_sr=self.args.audio_sr)
                    a_offset = int(self.align_mask * (self.args.audio_sr / self.args.pose_fps))
                    onset_bt = self.alignmenter.load_audio(
                        in_audio_eval[:int(self.args.audio_sr / self.args.pose_fps * n)], a_offset,
                        len(in_audio_eval) - a_offset, True)
                    beat_vel = self.alignmenter.load_pose(joints_rec, self.align_mask, n - self.align_mask, 30, True)
                    # print(beat_vel)
                    align += (self.alignmenter.calculate_align(onset_bt, beat_vel, 30) * (n - 2 * self.align_mask))

                tar_pose_np = tar_pose.detach().cpu().numpy()
                rec_pose_np = rec_pose.detach().cpu().numpy()
                rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs * n, 3)
                rec_exp_np = rec_exps.detach().cpu().numpy().reshape(bs * n, 100)
                tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs * n, 100)
                tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs * n, 3)
                gt_npz = np.load(
                    self.args.data_path + self.args.pose_rep + "/" + test_seq_list.iloc[its]['id'] + ".npz",
                    allow_pickle=True)
                '''
                np.savez(results_save_path + "gt_" + test_seq_list.iloc[its]['id'] + '.npz',
                         betas=gt_npz["betas"],
                         poses=tar_pose_np,
                         expressions=tar_exp_np,
                         trans=tar_trans_np,
                         model='smplx2020',
                         gender='neutral',
                         mocap_frame_rate=30,
                         )
                np.savez(results_save_path + "res_" + test_seq_list.iloc[its]['id'] + '.npz',
                         betas=gt_npz["betas"],
                         poses=rec_pose_np,
                         expressions=rec_exp_np,
                         trans=rec_trans_np,
                         model='smplx2020',
                         gender='neutral',
                         mocap_frame_rate=30,
                         )
                '''
                total_length += n

        logger.info(f"l2 loss: {l2_all / total_length}")
        logger.info(f"lvel loss: {lvel / total_length}")

        latent_out_all = np.concatenate(latent_out, axis=0)
        latent_ori_all = np.concatenate(latent_ori, axis=0)
        fid = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all)
        logger.info(f"fid score: {fid}")
        self.test_recording("fid", fid, epoch)

        align_avg = align / (total_length - 2 * len(self.test_loader) * self.align_mask)
        logger.info(f"align score: {align_avg}")
        self.test_recording("bc", align_avg, epoch)

        l1div = self.l1_calculator.avg()
        logger.info(f"l1div score: {l1div}")
        self.test_recording("l1div", l1div, epoch)

        # data_tools.result2target_vis(self.args.pose_version, results_save_path, results_save_path, self.test_demo, False)
        end_time = time.time() - start_time
        logger.info(f"total inference time: {int(end_time)} s for {int(total_length / self.args.pose_fps)} s motion")

    def test_demo(self, epoch):
        '''
        input audio and text, output motion
        do not calculate loss and metric
        save video
        '''
        results_save_path = self.checkpoint_path + f"/{epoch}/"
        if os.path.exists(results_save_path):
            return 0
        os.makedirs(results_save_path)
        start_time = time.time()
        total_length = 0
        test_seq_list = self.test_data.selected_file
        align = 0
        latent_out = []
        latent_ori = []
        l2_all = 0
        lvel = 0
        self.model.eval()
        self.smplx.eval()
        # self.eval_copy.eval()
        with torch.no_grad():
            for its, batch_data in enumerate(self.test_loader):
                loaded_data = self._load_data(batch_data)
                net_out = self._g_test(loaded_data)
                tar_pose = net_out['tar_pose']
                rec_pose = net_out['rec_pose']
                tar_exps = net_out['tar_exps']
                tar_beta = net_out['tar_beta']
                rec_trans = net_out['rec_trans']
                tar_trans = net_out['tar_trans']
                rec_exps = net_out['rec_exps']
                # print(rec_pose.shape, tar_pose.shape)
                bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints

                # interpolate to 30fps
                if (30 / self.args.pose_fps) != 1:
                    assert 30 % self.args.pose_fps == 0
                    n *= int(30 / self.args.pose_fps)
                    tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1),
                                                               scale_factor=30 / self.args.pose_fps,
                                                               mode='linear').permute(0, 2, 1)
                    rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1),
                                                               scale_factor=30 / self.args.pose_fps,
                                                               mode='linear').permute(0, 2, 1)

                # print(rec_pose.shape, tar_pose.shape)
                rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs * n, j, 6))
                rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs * n, j * 3)

                tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs * n, j, 6))
                tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs * n, j * 3)

                tar_pose_np = tar_pose.detach().cpu().numpy()
                rec_pose_np = rec_pose.detach().cpu().numpy()
                rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs * n, 3)
                rec_exp_np = rec_exps.detach().cpu().numpy().reshape(bs * n, 100)
                tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs * n, 100)
                tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs * n, 3)

                gt_npz = np.load(
                    self.args.data_path + self.args.pose_rep + "/" + test_seq_list.iloc[its]['id'] + ".npz",
                    allow_pickle=True)
                np.savez(results_save_path + "gt_" + test_seq_list.iloc[its]['id'] + '.npz',
                         betas=gt_npz["betas"],
                         poses=tar_pose_np,
                         expressions=tar_exp_np,
                         trans=tar_trans_np,
                         model='smplx2020',
                         gender='neutral',
                         mocap_frame_rate=30,
                         )
                np.savez(results_save_path + "res_" + test_seq_list.iloc[its]['id'] + '.npz',
                         betas=gt_npz["betas"],
                         poses=rec_pose_np,
                         expressions=rec_exp_np,
                         trans=rec_trans_np,
                         model='smplx2020',
                         gender='neutral',
                         mocap_frame_rate=30,
                         )
                total_length += n

        data_tools.result2target_vis(self.args.pose_version, results_save_path, results_save_path, self.test_demo,
                                     False)
        end_time = time.time() - start_time
        logger.info(f"total inference time: {int(end_time)} s for {int(total_length / self.args.pose_fps)} s motion")
