import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import smplx
import copy
from .motion_encoder_att import *

# -------- Style Classifier ------- #

class StyleCls(nn.Module):
    def __init__(self, args):
        super(StyleCls, self).__init__()
        self.motion_encoder = VQEncoderV10(args)
        self.caption_encoder = nn.Linear(4096, args.vae_codebook_size)
        self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        #self.motion_decoder = VQDecoderV9(args)
        self.style_loss = nn.MSELoss()
    def forward(self, inputs, cond_emb):
        motion_latent = self.motion_encoder(inputs)
        style_latent = self.caption_encoder(cond_emb)
        embedding_loss = self.style_loss(motion_latent, style_latent)
        # print(pre_latent.shape)
        #embedding_loss = self.quantizer(pre_latent)
        #rec_pose = self.motion_decoder(pre_latent, cond_emb)
        return {
            "embedding_loss": embedding_loss,
        }

# ----------- AE, VAE ------------- #
class VAEConvCondition(nn.Module):
    def __init__(self, args):
        super(VAEConvCondition, self).__init__()
        self.encoder = VQEncoderV7(args)
        # self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        self.decoder = VQDecoderV8(args)

    def forward(self, inputs, cond_emb):
        pre_latent = self.encoder(inputs, cond_emb)
        # print(pre_latent.shape)
        # embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
        rec_pose = self.decoder(pre_latent, cond_emb)
        return {
            # "poses_feat":vq_latent,
            # "embedding_loss":embedding_loss,
            # "perplexity":perplexity,
            "rec_pose": rec_pose
        }

class VAEConvGlobal(nn.Module):
    def __init__(self, args):
        super(VAEConvGlobal, self).__init__()
        self.encoder = VQEncoderV19(args)
        # self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        self.decoder = VQDecoderV19(args)

    def forward(self, inputs, cap_emb=None):
        pre_latent = self.encoder(inputs)
        # print(pre_latent.shape)
        # embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
        rec_pose = self.decoder(pre_latent)
        return {
            # "poses_feat":vq_latent,
            # "embedding_loss":embedding_loss,
            # "perplexity":perplexity,
            "rec_pose": rec_pose
        }
# ----------- AE, VAE ------------- #
'''
class VAEConvZero(nn.Module):
    def __init__(self, args):
        super(VAEConvZero, self).__init__()
        self.encoder = VQEncoderV5(args)
        # self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        self.decoder = VQDecoderV5(args)
        
    def forward(self, inputs, cond_emb):
        pre_latent = self.encoder(inputs)
        # print(pre_latent.shape)
        # embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
        rec_pose = self.decoder(pre_latent)
        return {
            # "poses_feat":vq_latent,
            # "embedding_loss":embedding_loss,
            # "perplexity":perplexity,
            "rec_pose": rec_pose
            }
'''
class VAEConv(nn.Module):
    def __init__(self, args):
        super(VAEConv, self).__init__()
        self.encoder = VQEncoderV3(args)
        self.decoder = VQDecoderV3(args)
        self.fc_mu = nn.Linear(args.vae_length, args.vae_length)
        self.fc_logvar = nn.Linear(args.vae_length, args.vae_length)
        self.variational = args.variational
        
    def forward(self, inputs):
        pre_latent = self.encoder(inputs)
        mu, logvar = None, None
        if self.variational:
            mu = self.fc_mu(pre_latent)
            logvar = self.fc_logvar(pre_latent)
            pre_latent = reparameterize(mu, logvar)
        rec_pose = self.decoder(pre_latent)
        return {
            "poses_feat":pre_latent,
            "rec_pose": rec_pose,
            "pose_mu": mu,
            "pose_logvar": logvar,
            }
    
    def map2latent(self, inputs):
        pre_latent = self.encoder(inputs)
        if self.variational:
            mu = self.fc_mu(pre_latent)
            logvar = self.fc_logvar(pre_latent)
            pre_latent = reparameterize(mu, logvar)
        return pre_latent
    
    def decode(self, pre_latent):
        rec_pose = self.decoder(pre_latent)
        return rec_pose

class VAESKConv(VAEConv):
    def __init__(self, args):
        super(VAESKConv, self).__init__(args)
        smpl_fname = args.data_path_1+'smplx_models/smplx/SMPLX_NEUTRAL_2020.npz'
        smpl_data = np.load(smpl_fname, encoding='latin1')
        parents = smpl_data['kintree_table'][0].astype(np.int32)
        edges = build_edge_topology(parents)
        self.encoder = LocalEncoder(args, edges)
        self.decoder = VQDecoderV3(args)
        
class VAEConvMLP(VAEConv):
    def __init__(self, args):
        super(VAEConvMLP, self).__init__(args)
        self.encoder = PoseEncoderConv(args.vae_test_len, args.vae_test_dim, feature_length=args.vae_length)
        self.decoder = PoseDecoderConv(args.vae_test_len, args.vae_test_dim, feature_length=args.vae_length)
 
class VAELSTM(VAEConv):
    def __init__(self, args):
        super(VAELSTM, self).__init__(args)
        pose_dim = args.vae_test_dim
        feature_length = args.vae_length
        self.encoder = PoseEncoderLSTM_Resnet(pose_dim, feature_length=feature_length)
        self.decoder = PoseDecoderLSTM(pose_dim, feature_length=feature_length)

class VAETransformer(VAEConv):
    def __init__(self, args):
        super(VAETransformer, self).__init__(args)
        self.encoder = Encoder_TRANSFORMER(args)
        self.decoder = Decoder_TRANSFORMER(args)

# ----------- VQVAE --------------- #

class VQVAEConvGlobal(nn.Module):
    def __init__(self, args):
        super(VQVAEConvGlobal, self).__init__()

        self.encoder = VQEncoderV9(args)
        self.quantizer = MSQuantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        self.decoder = VQDecoderV9(args)

    def forward(self, inputs, cap_emb=None):

        input_len = inputs.shape[1]
        if input_len%16 != 0:
            padding = torch.zeros((inputs.shape[0], 16-input_len%16, inputs.shape[-1]))
            inputs = torch.cat((inputs, padding), dim=1)
        pre_latent = self.encoder(inputs)
        #pre_latent_noise = self.encoder(inputs, torch.rand_like(cond_emb))
        # print(pre_latent.shape)
        embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
        #embedding_loss_noise, vq_latent_noise, _, perplexity_noise = self.quantizer(pre_latent_noise)
        rec_pose = self.decoder(vq_latent)
        #rec_pose_noise = self.decoder(vq_latent, torch.rand_like(cond_emb))
        rec_pose = rec_pose[:,:input_len,:]
        return {
            "poses_feat": vq_latent,
            "embedding_loss": embedding_loss,
            "perplexity": perplexity,
            "rec_pose": rec_pose
        }

    def map2index(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        return index

    def map2latent(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        z_q = self.quantizer.get_codebook_entry(index)
        return z_q

    def decode(self, index):
        z_q = self.quantizer.get_codebook_entry(index)
        rec_pose = self.decoder(z_q)
        return rec_pose

    def softquantizer(self, ind_dist):
        index_tensor = [torch.LongTensor(list(range(256))).repeat(64, int(64 / (2 ** i)), 1).cuda() for i in
                        range(len(ind_dist))]
        rec_pose_list = [torch.zeros_like(ind_dist[i]).cuda() for i in range(len(ind_dist))]
        for i in range(index_tensor[0].shape[2]):
            ind_tmp = [index_tensor[j][:,:,0] for j in range(len(index_tensor))]
            weight_tmp = [ind_dist[j][:,:,0] for j in range(len(index_tensor))]
            z_q_tmp = self.quantizer.get_codebook_entry(ind_tmp)
            #rec_pose_tmp = self.decoder(z_q_tmp)
            weighted_rec_pose_tmp = [z_q_tmp[j]*weight_tmp[j].unsqueeze(2) for j in range(len(index_tensor))]
            rec_pose_list = [rec_pose_list[j]+weighted_rec_pose_tmp[j] for j in range(len(index_tensor))]
        #rec_pose = torch.sum(rec_pose_list)
        return rec_pose_list

class VQVAEConvGlobal_noatt(nn.Module):
    def __init__(self, args):
        super(VQVAEConvGlobal_noatt, self).__init__()

        self.encoder = VQEncoderV19(args)
        self.quantizer = MSQuantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        self.decoder = VQDecoderV19(args)

    def forward(self, inputs, cap_emb=None):

        input_len = inputs.shape[1]
        if input_len%16 != 0:
            padding = torch.zeros((inputs.shape[0], 16-input_len%16, inputs.shape[-1]))
            inputs = torch.cat((inputs, padding), dim=1)
        pre_latent = self.encoder(inputs)
        #pre_latent_noise = self.encoder(inputs, torch.rand_like(cond_emb))
        # print(pre_latent.shape)
        embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
        #embedding_loss_noise, vq_latent_noise, _, perplexity_noise = self.quantizer(pre_latent_noise)
        rec_pose = self.decoder(vq_latent)
        #rec_pose_noise = self.decoder(vq_latent, torch.rand_like(cond_emb))
        rec_pose = rec_pose[:,:input_len,:]
        return {
            "poses_feat": vq_latent,
            "embedding_loss": embedding_loss,
            "perplexity": perplexity,
            "rec_pose": rec_pose
        }

    def map2index(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        return index

    def map2latent(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        z_q = self.quantizer.get_codebook_entry(index)
        return z_q

    def decode(self, index):
        z_q = self.quantizer.get_codebook_entry(index)
        rec_pose = self.decoder(z_q)
        return rec_pose

    def softquantizer(self, ind_dist):
        index_tensor = [torch.LongTensor(list(range(256))).repeat(64, int(64 / (2 ** i)), 1).cuda() for i in
                        range(len(ind_dist))]
        rec_pose_list = [torch.zeros_like(ind_dist[i]).cuda() for i in range(len(ind_dist))]
        for i in range(index_tensor[0].shape[2]):
            ind_tmp = [index_tensor[j][:,:,0] for j in range(len(index_tensor))]
            weight_tmp = [ind_dist[j][:,:,0] for j in range(len(index_tensor))]
            z_q_tmp = self.quantizer.get_codebook_entry(ind_tmp)
            #rec_pose_tmp = self.decoder(z_q_tmp)
            weighted_rec_pose_tmp = [z_q_tmp[j]*weight_tmp[j].unsqueeze(2) for j in range(len(index_tensor))]
            rec_pose_list = [rec_pose_list[j]+weighted_rec_pose_tmp[j] for j in range(len(index_tensor))]
        #rec_pose = torch.sum(rec_pose_list)
        return rec_pose_list

class VQVAEConvCondition(nn.Module):
    def __init__(self, args):
        super(VQVAEConvCondition, self).__init__()
        self.encoder = VQEncoderV7(args)
        self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        self.decoder = VQDecoderV8(args)

    def forward(self, inputs, cond_emb):
        pre_latent = self.encoder(inputs, cond_emb)
        #pre_latent_noise = self.encoder(inputs, torch.rand_like(cond_emb))
        # print(pre_latent.shape)
        embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
        #embedding_loss_noise, vq_latent_noise, _, perplexity_noise = self.quantizer(pre_latent_noise)
        rec_pose = self.decoder(vq_latent, cond_emb)
        #rec_pose_noise = self.decoder(vq_latent, torch.rand_like(cond_emb))
        return {
            "poses_feat": vq_latent,
            "embedding_loss": embedding_loss,
            "perplexity": perplexity,
            "rec_pose": rec_pose
        }

    def map2index(self, inputs, cond_emb):
        pre_latent = self.encoder(inputs, cond_emb)
        index = self.quantizer.map2index(pre_latent)
        return index

    def map2latent(self, inputs, cond_emb):
        pre_latent = self.encoder(inputs, cond_emb)
        index = self.quantizer.map2index(pre_latent)
        z_q = self.quantizer.get_codebook_entry(index)
        return z_q

    def decode(self, index, cond_emb):
        z_q = self.quantizer.get_codebook_entry(index)
        rec_pose = self.decoder(z_q, cond_emb)
        return rec_pose

class VQVAEConv(nn.Module):
    def __init__(self, args):
        super(VQVAEConv, self).__init__()
        self.encoder = VQEncoderV3(args)
        self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        self.decoder = VQDecoderV3(args)
        
    def forward(self, inputs):
        pre_latent = self.encoder(inputs)
        # print(pre_latent.shape)
        embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
        rec_pose = self.decoder(vq_latent)
        return {
            "poses_feat":vq_latent,
            "embedding_loss":embedding_loss,
            "perplexity":perplexity,
            "rec_pose": rec_pose
            }
    
    def map2index(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        return index
    
    def map2latent(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        z_q = self.quantizer.get_codebook_entry(index)
        return z_q
    
    def decode(self, index):
        z_q = self.quantizer.get_codebook_entry(index)
        rec_pose = self.decoder(z_q)
        return rec_pose

class VQVAESKConv(VQVAEConv):
    def __init__(self, args):
        super(VQVAESKConv, self).__init__(args)
        smpl_fname = args.data_path_1+'smplx_models/smplx/SMPLX_NEUTRAL_2020.npz'
        smpl_data = np.load(smpl_fname, encoding='latin1')
        parents = smpl_data['kintree_table'][0].astype(np.int32)
        edges = build_edge_topology(parents)
        self.encoder = LocalEncoder(args, edges)


class VQVAEConvStride(nn.Module):
    def __init__(self, args):
        super(VQVAEConvStride, self).__init__()
        self.encoder = VQEncoderV4(args)
        self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        self.decoder = VQDecoderV4(args)
        
    def forward(self, inputs):
        pre_latent = self.encoder(inputs)
        # print(pre_latent.shape)
        embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
        rec_pose = self.decoder(vq_latent)
        return {
            "poses_feat":vq_latent,
            "embedding_loss":embedding_loss,
            "perplexity":perplexity,
            "rec_pose": rec_pose
            }
    
    def map2index(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        return index
    
    def map2latent(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        z_q = self.quantizer.get_codebook_entry(index)
        return z_q
    
    def decode(self, index):
        z_q = self.quantizer.get_codebook_entry(index)
        rec_pose = self.decoder(z_q)
        return rec_pose

class VQVAEConvZero(nn.Module):
    def __init__(self, args):
        super(VQVAEConvZero, self).__init__()
        self.encoder = VQEncoderV5(args)
        self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        self.decoder = VQDecoderV5(args)
        
    def forward(self, inputs, cond_emb=None):
        pre_latent = self.encoder(inputs)
        # print(pre_latent.shape)
        embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
        rec_pose = self.decoder(vq_latent)
        return {
            "poses_feat":vq_latent,
            "embedding_loss":embedding_loss,
            "perplexity":perplexity,
            "rec_pose": rec_pose
            }
    
    def map2index(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        return index
    
    def map2latent(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        z_q = self.quantizer.get_codebook_entry(index)
        return z_q
    
    def decode(self, index):
        z_q = self.quantizer.get_codebook_entry(index)
        rec_pose = self.decoder(z_q)
        return rec_pose
    

class VAEConvZero(nn.Module):
    def __init__(self, args):
        super(VAEConvZero, self).__init__()
        self.encoder = VQEncoderV5(args)
        # self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        self.decoder = VQDecoderV5(args)
        
    def forward(self, inputs, emb_cap = None):
        pre_latent = self.encoder(inputs)
        # print(pre_latent.shape)
        # embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
        rec_pose = self.decoder(pre_latent)
        return {
            # "poses_feat":vq_latent,
            # "embedding_loss":embedding_loss,
            # "perplexity":perplexity,
            "rec_pose": rec_pose
            }
    
    # def map2index(self, inputs):
    #     pre_latent = self.encoder(inputs)
    #     index = self.quantizer.map2index(pre_latent)
    #     return index
    
    # def map2latent(self, inputs):
    #     pre_latent = self.encoder(inputs)
    #     index = self.quantizer.map2index(pre_latent)
    #     z_q = self.quantizer.get_codebook_entry(index)
    #     return z_q
    
    # def decode(self, index):
    #     z_q = self.quantizer.get_codebook_entry(index)
    #     rec_pose = self.decoder(z_q)
    #     return rec_pose


class VQVAEConvZero3(nn.Module):
    def __init__(self, args):
        super(VQVAEConvZero3, self).__init__()
        self.encoder = VQEncoderV5(args)
        self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        self.decoder = VQDecoderV5(args)
        
    def forward(self, inputs):
        pre_latent = self.encoder(inputs)
        # print(pre_latent.shape)
        embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
        rec_pose = self.decoder(vq_latent)
        return {
            "poses_feat":vq_latent,
            "embedding_loss":embedding_loss,
            "perplexity":perplexity,
            "rec_pose": rec_pose
            }
    
    def map2index(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        return index
    
    def map2latent(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        z_q = self.quantizer.get_codebook_entry(index)
        return z_q
    
    def decode(self, index):
        z_q = self.quantizer.get_codebook_entry(index)
        rec_pose = self.decoder(z_q)
        return rec_pose

class VQVAEConvZero2(nn.Module):
    def __init__(self, args):
        super(VQVAEConvZero2, self).__init__()
        self.encoder = VQEncoderV5(args)
        self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        self.decoder = VQDecoderV7(args)
        
    def forward(self, inputs):
        pre_latent = self.encoder(inputs)
        # print(pre_latent.shape)
        embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
        rec_pose = self.decoder(vq_latent)
        return {
            "poses_feat":vq_latent,
            "embedding_loss":embedding_loss,
            "perplexity":perplexity,
            "rec_pose": rec_pose
            }
    
    def map2index(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        return index
    
    def map2latent(self, inputs):
        pre_latent = self.encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        z_q = self.quantizer.get_codebook_entry(index)
        return z_q
    
    def decode(self, index):
        z_q = self.quantizer.get_codebook_entry(index)
        rec_pose = self.decoder(z_q)
        return rec_pose

class VQVAE2(nn.Module):
    def __init__(self, args):
        super(VQVAE2, self).__init__()
        # Bottom-level encoder and decoder
        args_bottom = copy.deepcopy(args)
        args_bottom.vae_layer = 2
        self.bottom_encoder = VQEncoderV6(args_bottom)
        self.bottom_quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        args_bottom.vae_test_dim = args.vae_test_dim
        self.bottom_decoder = VQDecoderV6(args_bottom)
        
        # Top-level encoder and decoder
        args_top = copy.deepcopy(args)
        args_top.vae_layer = 3
        args_top.vae_test_dim = args.vae_length
        self.top_encoder = VQEncoderV3(args_top)  # Adjust according to the top level's design
        self.quantize_conv_t = nn.Conv1d(args.vae_length+args.vae_length, args.vae_length, 1)
        self.top_quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        # self.upsample_t_up = nn.Upsample(scale_factor=2, mode='nearest')
        layers = [
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            ]
        self.upsample_t= nn.Sequential(*layers)
        self.top_decoder = VQDecoderV3(args_top)  # Adjust to handle top level features appropriately

    def forward(self, inputs):
        # Bottom-level processing
        enc_b = self.bottom_encoder(inputs)
        enc_t = self.top_encoder(enc_b)
        #print(enc_b.shape, enc_t.shape)
        top_embedding_loss, quant_t, _, top_perplexity = self.top_quantizer(enc_t)
        #print(quant_t.shape)
        dec_t = self.top_decoder(quant_t)
        #print(dec_t.shape)
        enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1)
        #print(enc_b.shape)
        quant_b = self.quantize_conv_t(enc_b).permute(0,2,1)
        #print("5",quant_b.shape)
        bottom_embedding_loss, quant_b, _, bottom_perplexity = self.bottom_quantizer(quant_b)
        #print("6",quant_b.shape)
        upsample_t = self.upsample_t(quant_t.permute(0,2,1)).permute(0,2,1)
        #print("7",upsample_t.shape)
        quant = torch.cat([upsample_t, quant_b], 2)
        rec_pose = self.bottom_decoder(quant)
        # print(quant_t.shape, quant_b.shape, rec_pose.shape)
        return {
            "poses_feat_top": quant_t,
            "pose_feat_bottom": quant_b,
            "embedding_loss":top_embedding_loss+bottom_embedding_loss,
            #"perplexity":perplexity,
            "rec_pose": rec_pose
            }
    
    def map2index(self, inputs):
        enc_b = self.bottom_encoder(inputs)
        enc_t = self.top_encoder(enc_b)
        
        _, quant_t, _, _ = self.top_quantizer(enc_t)
        top_index = self.top_quantizer.map2index(enc_t)
        dec_t = self.top_decoder(quant_t)

        enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1)
        #print(enc_b.shape)
        quant_b = self.quantize_conv_t(enc_b).permute(0,2,1)
        # quant_b = self.quantize_conv_t(enc_b)
        bottom_index = self.bottom_quantizer.map2index(quant_b)
        return top_index, bottom_index
    
    def get_top_laent(self, top_index):
        z_q_top = self.top_quantizer.get_codebook_entry(top_index)
        return z_q_top
    
    def map2latent(self, inputs):
        enc_b = self.bottom_encoder(inputs)
        enc_t = self.top_encoder(enc_b)
        
        _, quant_t, _, _ = self.top_quantizer(enc_t)
        top_index = self.top_quantizer.map2index(enc_t)
        dec_t = self.top_decoder(quant_t)

        enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1)
        #print(enc_b.shape)
        quant_b = self.quantize_conv_t(enc_b).permute(0,2,1)
        # quant_b = self.quantize_conv_t(enc_b)
        bottom_index = self.bottom_quantizer.map2index(quant_b)
        z_q_top = self.top_quantizer.get_codebook_entry(top_index)
        z_q_bottom = self.bottom_quantizer.get_codebook_entry(bottom_index)
        return z_q_top, z_q_bottom
    
    def map2latent_top(self, inputs):
        enc_b = self.bottom_encoder(inputs)
        enc_t = self.top_encoder(enc_b)
        top_index = self.top_quantizer.map2index(enc_t)
        z_q_top = self.top_quantizer.get_codebook_entry(top_index)
        return z_q_top
    
    def decode(self, top_index, bottom_index):
        quant_t = self.top_quantizer.get_codebook_entry(top_index)
        quant_b = self.bottom_quantizer.get_codebook_entry(bottom_index)
        upsample_t = self.upsample_t(quant_t.permute(0,2,1)).permute(0,2,1)
        #print("7",upsample_t.shape)
        quant = torch.cat([upsample_t, quant_b], 2)
        rec_pose = self.bottom_decoder(quant)      
        return rec_pose