import os
from typing import Sequence
from sklearn.covariance import log_likelihood
import torch
import torch.nn as nn
import joblib
from torch.nn import Module
from sklearn.mixture import GaussianMixture
from model.auxiliary import EncoderLSTM, PositionalEncoding, TransformerEncoder, TransformerDecoder
from model.BaseModel import GCN_encoder, GCN_decoder, GCN_encoder_nobn, GCN_decoder_nobn
from model.stgcn import STGCN_encoder
from model.aagcn import AAGCN_encoder
from model.msg3dutils.msg3d import msg3d_encoder
from model.shiftgcnutils.shift_gcn import shiftgcn_encoder
from model.gcnnasutils.agcn3 import gcnnas_encoder
from model.asgcnutils.as_gcn import asgcn_encoder
from utils.util import get_lr_scheduler, transform_torch
from utils.losses import reconstruction_loss_mpjpe
    

# [L, D] Linear
class SequenceEncoder(Module):
    def __init__(self, args, pool='cls', add_global_feature=False, transformer_dim=128, transformer_mlp_dim=128, transformer_depth=2, transformer_heads=4, d_model=20):
        super(SequenceEncoder, self).__init__()
        self.kernel_size = args.kernel_size
        self.d_model = d_model
        self.dct_n = args.dct_n
        self.add_global_feature = add_global_feature
        assert args.kernel_size == 10
        
        self.num_stage = args.num_stage
        self.node_n = args.node_n
        self.encoder_layer_num = args.encoder_n
        self.decoder_layer_num = args.decoder_n
        self.input_n = args.input_n
        self.output_n = args.output_n
        self.drop_out = args.drop_out
        self.pool = pool
        
        # Transformer Parameters
        self.dim = transformer_dim
        self.depth = transformer_depth
        self.heads = transformer_heads
        self.mlp_dim = transformer_mlp_dim
        
        self.gcn_encoder = nn.Linear(self.node_n*3, self.d_model)
        
        self.transformerencoder = TransformerEncoder(self.dim, self.depth, self.heads, self.mlp_dim)
        self.gcn_to_embedding = nn.Linear(self.d_model, self.dim)
        self.to_latent = nn.Identity()
        self.pos_embedding = PositionalEncoding(self.dim)
        self.transformerdecoder = TransformerDecoder(self.dim, self.depth, self.heads, self.mlp_dim)
        self.global_feature = nn.Parameter(torch.randn(1, 1, self.dim)).cuda()
    
    def forward(self, data, mask=None):
        bs = data.shape[0]
        seq_len = data.shape[1]
        data = data.reshape(bs, seq_len, -1) # (bs, seq_len, 24 * 3)
        gcn_encoded = self.gcn_encoder(data)
        gcn_encoded = self.gcn_to_embedding(gcn_encoded)
        
        if self.add_global_feature:
            pos_embedding = self.pos_embedding(torch.arange(seq_len+1))
        else:
            pos_embedding = self.pos_embedding(torch.arange(seq_len))
        pos_embedding = pos_embedding.repeat(bs, 1, 1) # (bs, seq_len, dim)
        
        if self.add_global_feature:
            global_feature = self.global_feature.expand(bs, -1, -1)
            gcn_encoded = torch.cat([global_feature, gcn_encoded], dim=1)
        
        gcn_encoded = gcn_encoded + pos_embedding
        gcn_encoded = gcn_encoded.permute(1, 0, 2).contiguous()
        transformer_encoded = self.transformerencoder(gcn_encoded, mask) # (bs, seq_len, dim)
        
        pos_embedding = pos_embedding.permute(1, 0, 2).contiguous()
        
        transformer_decoded = self.transformerdecoder(pos_embedding, transformer_encoded, mask=mask)
        transformer_decoded = transformer_decoded.permute(1, 0, 2).contiguous()
        if self.add_global_feature:
            transformer_decoded = transformer_decoded[:, 0]

        return transformer_decoded


# AutoEncoder for Encoder Training Linear
class AutoEncoder(Module):
    def __init__(self, args, input_feature=3):
        super(AutoEncoder, self).__init__()
        self.input_feature = input_feature
        self.kernel_size = args.kernel_size
        self.d_model = args.d_model
        self.dct_n = args.dct_n
        assert args.kernel_size == 10
        
        self.num_stage = args.num_stage
        self.node_n = args.node_n
        self.encoder_layer_num = args.encoder_n
        self.decoder_layer_num = args.decoder_n
        self.input_n = args.input_n
        self.output_n = args.output_n
        self.drop_out = args.drop_out
        self.dim = args.transformer_dim
        
        self.encoder = SequenceEncoder(args, transformer_dim=args.transformer_dim, transformer_depth=args.transformer_depth, transformer_heads=args.transformer_heads, transformer_mlp_dim=args.transformer_mlp_dim, d_model=args.d_model)
        
        self.embedding_to_gcn = nn.Linear(self.dim, self.d_model)
        self.gcn_decoder = nn.Linear(self.d_model, self.node_n * self.input_feature)
        
    def forward(self, data, mask=None):
        bs = data.shape[0]
        seq_len = data.shape[1]
        seq_embedding = self.encoder(data, mask) # (bs, seq_len, dim)
        seq_decoded = self.embedding_to_gcn(seq_embedding) # (bs, seq_len, d_model)
        seq_decoded = self.gcn_decoder(seq_decoded) # (bs*seq_len, 3, node_n, 1)
        seq_decoded = seq_decoded.reshape(bs, seq_len, self.node_n, self.input_feature)
        
        return seq_decoded


# AnchorRecognitionNetwork
class AnchorRecognitionNet(Module):
    def __init__(self, args, gmm):
        super(AnchorRecognitionNet, self).__init__()
        self.dim = args.transformer_dim
        self.middle_embedding_dim = args.middle_embedding_dim
        self.likelihood_threshold = args.likelihood_threshold
        
        # self.encoder = SequenceEncoder(args, transformer_dim=args.transformer_dim, transformer_depth=args.transformer_depth, transformer_heads=args.transformer_heads, transformer_mlp_dim=args.transformer_mlp_dim)
        self.encoder = SequenceEncoder(args, transformer_dim=args.arn_transformer_dim, transformer_depth=args.arn_transformer_depth, transformer_heads=args.arn_transformer_heads, transformer_mlp_dim=args.arn_transformer_mlp_dim)
        self.classifier = nn.Sequential(nn.Linear(self.dim, self.middle_embedding_dim),
                                        nn.Tanh(),
                                        nn.Linear(self.middle_embedding_dim, 2 * self.middle_embedding_dim),
                                        nn.Tanh(),
                                        nn.Linear(2 * self.middle_embedding_dim, self.middle_embedding_dim),
                                        nn.Tanh(),
                                        nn.Linear(self.middle_embedding_dim, 1))
        self.sigmoid = nn.Sigmoid()
        self.gmm = gmm
        
    def forward(self, data, mask=None, testing=False):
        with torch.no_grad():
            seq_encoded = self.encoder(data, mask)
        seq_encoded = torch.relu(seq_encoded)
        if not testing:
            output_label = self.classifier(seq_encoded)
            output_label = self.sigmoid(output_label)
            return output_label
        else:
            output_label = self.classifier(seq_encoded)
            output_label = self.sigmoid(output_label)
            binary_label = (output_label > 0.5).float()
            seq_encoded[binary_label == 0] = 0
            log_likelihood = self.gmm.score_samples(seq_encoded)
            # transformation
            output = torch.cat([output_label, log_likelihood], dim=-1)
            # print("output:", output.shape)
            return output


class AnchorRecognitionNetE2E(Module):
    def __init__(self, args, gmm=None):
        super(AnchorRecognitionNetE2E, self).__init__()
        self.dim = args.arn_transformer_dim
        # self.gmm_encoder = SequenceEncoder(args)
        # self.encoder = SequenceEncoder(args)
        self.encoder = SequenceEncoder(args, transformer_dim=args.arn_transformer_dim, transformer_depth=args.arn_transformer_depth, transformer_heads=args.arn_transformer_heads, transformer_mlp_dim=args.arn_transformer_mlp_dim, d_model=args.arn_d_model)
        self.top_layer = nn.Linear(self.dim, 1)
        self.sigmoid = nn.Sigmoid()
        self.gmm = gmm
        # self.gmm = joblib.load(os.path.join(args.save_dir_base, "GMM", "gmm.joblib"))
    
    def forward(self, data, mask=None, testing=False):
        bs = data.shape[0]
        output_label = self.encoder(data, mask)
        output_label = self.top_layer(output_label)
        output_label = self.sigmoid(output_label)
        
        if not testing:
            return output_label
        else:
            # with torch.no_grad():
            seq_encoded = self.encoder(data, mask)
            seq_encoded = seq_encoded.detach().cpu()
            output_preds = (output_label > 0.5).float()
            output_preds = output_preds.detach().cpu()
            input_encoded = seq_encoded.reshape(-1, seq_encoded.shape[2])
            log_likelihood = self.gmm.score_samples(input_encoded)
            log_likelihood = log_likelihood.reshape(bs, -1)
            log_likelihood = torch.from_numpy(log_likelihood).unsqueeze(-1)
            output = torch.cat([output_preds, log_likelihood], dim=-1)
            return output


class AnchorRefinementNet(Module):
    def __init__(self, args):
        super(AnchorRefinementNet, self).__init__()
        self.n_clusters = args.n_clusters
        self.dim = args.transformer_dim
        self.middle_embedding_dim = args.afn_middle_embedding_dim
        self.label_dim = args.label_dim
        self.node_n = args.node_n
        self.d_model = args.d_model
        self.decoder_dim = args.decoder_dim
        self.anchor_embedding_dim = args.anchor_embedding_dim
        self.anchor_encoder = SequenceEncoder(args)
        self.classifier_encoder = SequenceEncoder(args)
        self.classifier = nn.Sequential(nn.Linear(self.dim, self.middle_embedding_dim),
                                       nn.Tanh(),
                                       nn.Linear(self.middle_embedding_dim, self.middle_embedding_dim),
                                       nn.Tanh(),
                                       nn.Linear(self.middle_embedding_dim, self.middle_embedding_dim),
                                       nn.Tanh(),
                                    #    nn.Linear(self.middle_embedding_dim, self.middle_embedding_dim),
                                    #    nn.Tanh(),
                                       nn.Linear(self.middle_embedding_dim, self.n_clusters))
        
        self.label_fc = nn.Linear(self.n_clusters, self.label_dim)
        self.anchor_fc = nn.Linear(self.dim, self.anchor_embedding_dim)
        self.fuse = nn.Linear(self.label_dim + self.anchor_embedding_dim, self.decoder_dim)
        # self.decoder = nn.Sequential(nn.Linear(self.decoder_dim, self.middle_embedding_dim),
        #                              nn.Tanh(),
        #                              nn.Linear(self.middle_embedding_dim, self.middle_embedding_dim),
        #                              nn.Tanh(),
        #                              nn.Linear(self.middle_embedding_dim, self.node_n * 3))
        self.decoder = nn.Sequential(nn.Linear(self.decoder_dim, self.d_model),
                                     nn.Linear(self.d_model, self.node_n * 3))
        self.trans_generator = TransGenerator(args, no_joint0=False)
        
    def forward(self, data, anchor_pos, mask=None, anchor_class=None, training_classifier=False, training_together=False, testing=False):
        # with torch.no_grad():
        seq_encoded = self.classifier_encoder(data, mask)
        anchor_encoded = seq_encoded[torch.arange(seq_encoded.shape[0]), anchor_pos]
        if training_together or training_classifier:
            output_label_onehot = self.classifier(anchor_encoded)
            if training_classifier:
                return output_label_onehot
        else:
            with torch.no_grad():
                output_label_onehot = self.classifier(anchor_encoded)
            
        if testing:
            output_label = output_label_onehot
        else:
            output_label = anchor_class
        label_encoded = self.label_fc(output_label)
        label_encoded = torch.relu(label_encoded)
        seq_encoded = self.anchor_encoder(data, mask)
        anchor_encoded = seq_encoded[torch.arange(seq_encoded.shape[0]), anchor_pos]
        anchor_decoded = self.anchor_fc(anchor_encoded)
        anchor_decoded = torch.relu(anchor_decoded)
        anchor_decoded = torch.cat([anchor_decoded, label_encoded], dim=-1)
        anchor_decoded = self.fuse(anchor_decoded)
        anchor_decoded = torch.relu(anchor_decoded)
        anchor_decoded = anchor_encoded
        decoded_anchor = self.decoder(anchor_decoded) # [bs, node_n*3]
        decoded_anchor = decoded_anchor.reshape(decoded_anchor.shape[0], self.node_n, 3) # [bs, node_n, 3]
        
        return output_label_onehot, decoded_anchor
    

# TransGenerator CVAE
class TransGenerator(Module):
    def __init__(self, args, seq_len=60, no_joint0=True, add_global_feature=True):
        super(TransGenerator, self).__init__()
        self.kernel_size = args.kernel_size
        self.d_model = args.d_model
        self.dct_n = args.dct_n
        self.seq_len = seq_len
        assert args.kernel_size == 10

        self.num_stage = args.num_stage
        self.node_n = args.node_n
        if no_joint0:
            self.trans_node_n = args.node_n - 1
        else:
            self.trans_node_n = args.node_n
        self.encoder_layer_num = args.encoder_n
        self.decoder_layer_num = args.decoder_n
        self.input_n = args.input_n
        self.output_n = args.output_n
        self.drop_out = args.drop_out
        self.trans_embedding_dim = args.trans_embedding_dim
        self.first_anchor_embedding_dim = args.first_anchor_embedding_dim
        self.second_anchor_embedding_dim = args.second_anchor_embedding_dim
        self.duration_embedding_dim = args.duration_embedding_dim
        self.middle_embedding_dim = args.trans_middle_embedding_dim
        self.transformer_dim = args.transformer_dim
        self.global_feature_dim = args.global_feature_dim
        self.add_global_feature = add_global_feature
        
        self.z_dim = args.z_dim

        if args.gcn_type == "pgbig":
            self.gcn_encoder_trans_data = GCN_encoder(in_channal=9, out_channal=self.d_model,
                                            node_n=self.trans_node_n,
                                            seq_len=1,
                                            p_dropout=self.drop_out,
                                            num_stage=self.encoder_layer_num)
            
            self.gcn_encoder_first_anchor = GCN_encoder(in_channal=3, out_channal=self.d_model,
                                            node_n=self.node_n,
                                            seq_len=1,
                                            p_dropout=self.drop_out,
                                            num_stage=self.encoder_layer_num)
            
            self.gcn_encoder_second_anchor = GCN_encoder(in_channal=3, out_channal=self.d_model,
                                            node_n=self.node_n,
                                            seq_len=1,
                                            p_dropout=self.drop_out,
                                            num_stage=self.encoder_layer_num)
        elif args.gcn_type == "stgcn":
            self.gcn_encoder_trans_data = STGCN_encoder(in_channels=9,
                                            out_channels=self.d_model,
                                            graph_args={"layout": "humanact12", "strategy": "spatial"},
                                            edge_importance_weighting=True)
            
            self.gcn_encoder_first_anchor = STGCN_encoder(in_channels=3,
                                            out_channels=self.d_model,
                                            graph_args={"layout": "humanact12", "strategy": "spatial"},
                                            edge_importance_weighting=True)
            
            self.gcn_encoder_second_anchor = STGCN_encoder(in_channels=3,
                                            out_channels=self.d_model,
                                            graph_args={"layout": "humanact12", "strategy": "spatial"},
                                            edge_importance_weighting=True)
        
        elif args.gcn_type == "aagcn":
            self.gcn_encoder_trans_data = AAGCN_encoder(num_point=args.node_n, graph_args={'labeling_mode': 'spatial'}, in_channels=9, out_channels=self.d_model, drop_out=0, adaptive=True, attention=True)
            self.gcn_encoder_first_anchor = AAGCN_encoder(num_point=args.node_n, graph_args={'labeling_mode': 'spatial'}, in_channels=3, out_channels=self.d_model, drop_out=0, adaptive=True, attention=True)
            self.gcn_encoder_second_anchor = AAGCN_encoder(num_point=args.node_n, graph_args={'labeling_mode': 'spatial'}, in_channels=3, out_channels=self.d_model, drop_out=0, adaptive=True, attention=True)

        elif args.gcn_type == "msg3d":
            self.gcn_encoder_trans_data = msg3d_encoder(num_point=args.node_n, num_gcn_scales=13, num_g3d_scales=6, in_channels=9, out_channels=self.d_model)
            self.gcn_encoder_first_anchor = msg3d_encoder(num_point=args.node_n, num_gcn_scales=13, num_g3d_scales=6, in_channels=3, out_channels=self.d_model)
            self.gcn_encoder_second_anchor = msg3d_encoder(num_point=args.node_n, num_gcn_scales=13, num_g3d_scales=6, in_channels=3, out_channels=self.d_model)
        
        elif args.gcn_type == "shiftgcn":
            self.gcn_encoder_trans_data = shiftgcn_encoder(num_point=args.node_n, graph_args={'labeling_mode': 'spatial'}, in_channels=9, out_channels=self.d_model)
            self.gcn_encoder_first_anchor = shiftgcn_encoder(num_point=args.node_n, graph_args={'labeling_mode': 'spatial'}, in_channels=3, out_channels=self.d_model)
            self.gcn_encoder_second_anchor = shiftgcn_encoder(num_point=args.node_n, graph_args={'labeling_mode': 'spatial'}, in_channels=3, out_channels=self.d_model)
        
        elif args.gcn_type == "gcnnas":
            self.gcn_encoder_trans_data = gcnnas_encoder(num_point=args.node_n, graph_args={'labeling_mode': 'spatial'}, in_channels=9, out_channels=self.d_model)
            self.gcn_encoder_first_anchor = gcnnas_encoder(num_point=args.node_n, graph_args={'labeling_mode': 'spatial'}, in_channels=3, out_channels=self.d_model)
            self.gcn_encoder_second_anchor = gcnnas_encoder(num_point=args.node_n, graph_args={'labeling_mode': 'spatial'}, in_channels=3, out_channels=self.d_model)
        
        elif args.gcn_type == "asgcn":
            self.gcn_encoder_trans_data = asgcn_encoder(in_channels=9, out_channels=self.d_model, graph_args={'layout': 'humanact12', 'strategy': 'spatial', 'max_hop': 4}, edge_importance_weighting=True)
            self.gcn_encoder_first_anchor = asgcn_encoder(in_channels=3, out_channels=self.d_model, graph_args={'layout': 'humanact12', 'strategy': 'spatial', 'max_hop': 4}, edge_importance_weighting=True)
            self.gcn_encoder_second_anchor = asgcn_encoder(in_channels=3, out_channels=self.d_model, graph_args={'layout': 'humanact12', 'strategy': 'spatial', 'max_hop': 4}, edge_importance_weighting=True)
        
        else:
            raise NotImplementedError("gcn type not implemented")
        
        # self.duration_positional_embedding = PositionalEncoding(self.duration_embedding_dim)
        self.duration_encoder = nn.Linear(1, self.duration_embedding_dim)

        if args.gcn_type == "pgbig" or args.gcn_type == "stgcn" or args.gcn_type == "aagcn" or args.gcn_type == "msg3d" or args.gcn_type == "shiftgcn" or args.gcn_type == "gcnnas" or args.gcn_type == "asgcn":
            self.gcn_decoder = GCN_decoder(in_channal=self.d_model, out_channal=9,
                                            node_n=self.trans_node_n,
                                            seq_len=1,
                                            p_dropout=self.drop_out,
                                            num_stage=self.decoder_layer_num)
            self.encoder_fc_trans = nn.Linear(self.d_model*self.trans_node_n*1, self.trans_embedding_dim)
            self.encoder_fc_first_anchor = nn.Linear(self.d_model*self.node_n*1, self.first_anchor_embedding_dim)
            self.encoder_fc_second_anchor = nn.Linear(self.d_model*self.node_n*1, self.second_anchor_embedding_dim)

        else:
            raise NotImplementedError("gcn type not implemented")
        
        
        if self.add_global_feature:
            self.encoder_fc1 = nn.Linear(self.trans_embedding_dim+self.first_anchor_embedding_dim+self.second_anchor_embedding_dim+self.duration_embedding_dim+self.global_feature_dim, self.middle_embedding_dim)
        else:
            self.encoder_fc1 = nn.Linear(self.trans_embedding_dim+self.first_anchor_embedding_dim+self.second_anchor_embedding_dim+self.duration_embedding_dim, self.middle_embedding_dim)
        self.encoder_fc_mean = nn.Linear(self.middle_embedding_dim, self.z_dim)
        self.encoder_fc_log_var = nn.Linear(self.middle_embedding_dim, self.z_dim)
        if self.add_global_feature:
            self.decoder_fc1 = nn.Linear(self.z_dim+self.first_anchor_embedding_dim+self.second_anchor_embedding_dim+self.duration_embedding_dim+self.global_feature_dim, self.middle_embedding_dim)
        else:
            self.decoder_fc1 = nn.Linear(self.z_dim+self.first_anchor_embedding_dim+self.second_anchor_embedding_dim+self.duration_embedding_dim, self.middle_embedding_dim)
        self.decoder_fc2 = nn.Linear(self.middle_embedding_dim, self.d_model*self.trans_node_n*1)
        self.global_feature_fc = nn.Linear(self.seq_len * self.transformer_dim, self.global_feature_dim)
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)

        return eps.mul(std).add_(mu)
    
    def encode(self, x, y, t, global_feature=None):
        # print("x:", x.shape) # [bs, 9, num_joints, 1]
        # print("y:", y.shape) # [bs, 3, num_joints, 1]
        gcn_encoded_trans = self.gcn_encoder_trans_data(x) # [bs, d_model, node_n, 1] if not alexnet else [bs, d_model, 2, 1]
        gcn_encoded_trans = gcn_encoded_trans.view(gcn_encoded_trans.shape[0], -1)
        gcn_encoded_trans = torch.relu(self.encoder_fc_trans(gcn_encoded_trans))
        gcn_encoded_first_anchor = self.gcn_encoder_first_anchor(y[..., :1])
        gcn_encoded_second_anchor = self.gcn_encoder_second_anchor(y[..., 1:])
        gcn_encoded_first_anchor = gcn_encoded_first_anchor.view(gcn_encoded_first_anchor.shape[0], -1)
        gcn_encoded_second_anchor = gcn_encoded_second_anchor.view(gcn_encoded_second_anchor.shape[0], -1)
        gcn_encoded_first_anchor = torch.relu(self.encoder_fc_first_anchor(gcn_encoded_first_anchor))
        gcn_encoded_second_anchor = torch.relu(self.encoder_fc_second_anchor(gcn_encoded_second_anchor))
        # t = t.long()
        # encoded_duration = self.duration_positional_embedding(t)
        encoded_duration = self.duration_encoder(t)
        if self.add_global_feature:
            global_feature = global_feature.view(global_feature.shape[0], -1)
            global_feature = torch.relu(self.global_feature_fc(global_feature))
            gcn_encoded = torch.cat([gcn_encoded_trans, gcn_encoded_first_anchor, gcn_encoded_second_anchor, encoded_duration, global_feature], dim=1)
        else:
            gcn_encoded = torch.cat([gcn_encoded_trans, gcn_encoded_first_anchor, gcn_encoded_second_anchor, encoded_duration], dim=1)
        encoded = torch.relu(self.encoder_fc1(gcn_encoded))
        mean = self.encoder_fc_mean(encoded)
        log_var = self.encoder_fc_log_var(encoded)

        return mean, log_var
    
    def decode(self, z, y, t, global_feature=None):
        gcn_encoded_first_anchor = self.gcn_encoder_first_anchor(y[..., :1])
        gcn_encoded_second_anchor = self.gcn_encoder_second_anchor(y[..., 1:])
        gcn_encoded_first_anchor = gcn_encoded_first_anchor.view(gcn_encoded_first_anchor.shape[0], -1)
        gcn_encoded_second_anchor = gcn_encoded_second_anchor.view(gcn_encoded_second_anchor.shape[0], -1)
        gcn_encoded_first_anchor = torch.relu(self.encoder_fc_first_anchor(gcn_encoded_first_anchor))
        gcn_encoded_second_anchor = torch.relu(self.encoder_fc_second_anchor(gcn_encoded_second_anchor))
        # t = t.long()
        # encoded_duration = self.duration_positional_embedding(t)
        encoded_duration = self.duration_encoder(t)
        if self.add_global_feature:
            global_feature = global_feature.view(global_feature.shape[0], -1)
            global_feature = torch.relu(self.global_feature_fc(global_feature))
            x = torch.relu(self.decoder_fc1(torch.cat([z, gcn_encoded_first_anchor, gcn_encoded_second_anchor, encoded_duration, global_feature], dim=1)))
        else:
            x = torch.relu(self.decoder_fc1(torch.cat([z, gcn_encoded_first_anchor, gcn_encoded_second_anchor, encoded_duration], dim=1)))
        x = torch.sigmoid(self.decoder_fc2(x))
        x = x.view(-1, self.d_model, self.trans_node_n, 1)
        x = self.gcn_decoder(x)

        return x
    
    def forward(self, x, y, t, global_feature=None):
        x = x.permute(0, 2, 1, 3).contiguous()
        y = y.permute(0, 2, 1 ,3).contiguous()
        mean, logvar = self.encode(x, y, t, global_feature=global_feature)
        z = self.reparameterize(mean, logvar)
        reconstructed = self.decode(z, y, t, global_feature=global_feature)
        reconstructed = reconstructed.permute(0, 2, 1, 3).contiguous()

        return reconstructed, mean, logvar


class TransRefinementNet(Module):
    def __init__(self, args):
        super(TransRefinementNet, self).__init__()
        self.transformer_dim = args.transformer_dim
        self.transformer_depth = args.transformer_depth
        self.transformer_heads = args.transformer_heads
        self.transformer_mlp_dim = args.transformer_mlp_dim
        self.middle_embedding_dim = args.aprn_middle_embedding_dim
        self.node_n = args.node_n
        self.trans_feature_dim = 12
        self.input_feature = 3
        self.bias_encoder = SequenceEncoder(args, add_global_feature=True)
        self.core_encoder = SequenceEncoder(args, add_global_feature=True)
        self.core_decoder = nn.Sequential(nn.Linear(self.transformer_dim, self.middle_embedding_dim),
                                          nn.Tanh(),
                                          nn.Linear(self.middle_embedding_dim, self.middle_embedding_dim),
                                          nn.Tanh(),
                                          nn.Linear(self.middle_embedding_dim, self.node_n * self.trans_feature_dim))
        
        self.fuse = nn.Linear(self.transformer_dim * 2, self.transformer_dim)
        self.pos_embedding = PositionalEncoding(self.transformer_dim)
        self.transformer_decoder = TransformerDecoder(self.transformer_dim, self.transformer_depth, self.transformer_heads, self.transformer_mlp_dim)
        self.decoder = nn.Linear(self.transformer_dim, self.node_n * self.input_feature)

    def forward(self, full_keypoint_sequence, keypoint_sequence):
        bs = len(full_keypoint_sequence)
        seq_len = keypoint_sequence.shape[1]
        bias_feature = self.bias_encoder(full_keypoint_sequence)
        core_feature = self.core_encoder(keypoint_sequence)
        output_trans = self.core_decoder(core_feature)
        output_trans = output_trans.reshape(bs, self.node_n, self.trans_feature_dim)
        gather_feature = torch.cat([bias_feature, core_feature], dim=-1)
        gather_feature = self.fuse(gather_feature)
        gather_feature = gather_feature.unsqueeze(1)
        pos_embedding = self.pos_embedding(torch.arange(seq_len))
        pos_embedding = pos_embedding.repeat(bs, 1, 1)
        pos_embedding = pos_embedding.permute(1, 0, 2).contiguous()
        transformer_decoded = self.transformer_decoder(pos_embedding, gather_feature)
        transformer_decoded = transformer_decoded.permute(1, 0, 2).contiguous()
        transformer_decoded = transformer_decoded.reshape(bs*seq_len, -1)
        output = self.decoder(transformer_decoded)
        output = output.reshape(bs, seq_len, self.node_n, self.input_feature)
        
        return output_trans, output
    
    
class TransRefinementNetAR(Module):
    def __init__(self, args):
        super(TransRefinementNetAR, self).__init__()
        self.transformer_dim = args.transformer_dim
        self.transformer_depth = args.transformer_depth
        self.transformer_heads = args.transformer_heads
        self.transformer_mlp_dim = args.transformer_mlp_dim
        self.middle_embedding_dim = args.aprn_middle_embedding_dim
        self.node_n = args.node_n
        self.trans_feature_dim = 12
        self.input_feature = 3
        self.bias_encoder = SequenceEncoder(args, add_global_feature=True)
        self.core_encoder = SequenceEncoder(args, add_global_feature=True)
        self.core_decoder = nn.Sequential(nn.Linear(self.transformer_dim, self.middle_embedding_dim),
                                            nn.Tanh(),
                                            nn.Linear(self.middle_embedding_dim, self.middle_embedding_dim),
                                            nn.Tanh(),
                                            nn.Linear(self.middle_embedding_dim, self.node_n * self.trans_feature_dim))
        self.fuse = nn.Linear(self.transformer_dim * 2, self.transformer_dim)
        self.pos_embedding = PositionalEncoding(self.transformer_dim)
        self.transformer_decoder = TransformerDecoder(self.transformer_dim, self.transformer_depth, self.transformer_heads, self.transformer_mlp_dim)
        self.decoder = nn.Linear(self.transformer_dim, self.node_n * self.input_feature)
        self.slice_encoder = TransformerEncoder(self.transformer_dim, self.transformer_depth, self.transformer_heads, self.transformer_mlp_dim)
        
    def forward(self, full_keypoint_sequence, keypoint_sequence):
        num_slice = keypoint_sequence.shape[0]
        seq_len = full_keypoint_sequence.shape[1]
        bias_feature = self.bias_encoder(full_keypoint_sequence) # [bs, dim]
        bias_feature = bias_feature.repeat(num_slice, 1)
        core_feature = self.core_encoder(keypoint_sequence) # [num_slice, dim]
        output_trans = self.core_decoder(core_feature)
        output_trans = output_trans.reshape(num_slice, self.node_n, self.trans_feature_dim) # [num_slice, node_n, 12]
        gather_feature = torch.cat([bias_feature, core_feature], dim=-1)
        gather_feature = self.fuse(gather_feature) # [num_slice, dim]
        gather_feature = gather_feature.unsqueeze(0)
        pos_embedding_slice = self.pos_embedding(torch.arange(num_slice))
        pos_embedding_slice = pos_embedding_slice.unsqueeze(0)
        gather_feature = gather_feature + pos_embedding_slice
        gather_feature = gather_feature.permute(1, 0, 2).contiguous()
        slice_encoded = self.slice_encoder(gather_feature) # [num_slice, 1, dim]
        pos_embedding_keypoint = self.pos_embedding(torch.arange(seq_len))
        pos_embedding_keypoint = pos_embedding_keypoint.unsqueeze(1)
        transformer_decoded = self.transformer_decoder(pos_embedding_keypoint, slice_encoded)
        transformer_decoded = transformer_decoded.permute(1, 0, 2).contiguous() # [1, seq_len, dim]
        transformer_decoded = transformer_decoded.reshape(1 * seq_len, -1)
        output = self.decoder(transformer_decoded)
        output = output.reshape(1, seq_len, self.node_n, self.input_feature)
        
        return output_trans, output
        

class SequenceOptim(Module):
    def __init__(self, args):
        super(SequenceOptim, self).__init__()
        self.device = torch.device("cuda")
        self.args = args
        self.config_optimizers()
        
    def config_optimizers(self):
        bs = 1
        self.bs = bs
        self.num_iters = self.args.optim_iterations
        self.policy = {"name": "Poly", "power": 0.95}
        self.opt_params = {
            "trans_param_opt": torch.randn(bs, 9, self.args.node_n, 1, device=self.device, requires_grad=True)
        }
        self.optimizer = torch.optim.Adam([self.opt_params[k] for k in self.opt_params.keys()], lr=self.args.optim_lr)
        self.loss = nn.MSELoss(reduction="mean")
        
    def fitting(self, trans_param, anchor_pair, duration):
        optimizer = self.optimizer
        scheduler = get_lr_scheduler(self.policy, optimizer, max_iter=self.num_iters)
        self.opt_params["trans_param_opt"].data = trans_param.clone().detach()
        second_anchor = anchor_pair.clone().detach()
        second_anchor = second_anchor.permute(0, 3, 2, 1).contiguous()[0, 1:2]
        first_anchor = anchor_pair.clone().detach()
        first_anchor = first_anchor.permute(0, 3, 2, 1).contiguous()[0, 0:1]
        first_anchor = first_anchor.unsqueeze(-1) # [1, 25, 3, 1]
 
        for itr in range(1, self.num_iters+1):
            # print("itr:", itr)
            optimizer.zero_grad()
            generated_trans_last_keypoint = torch.zeros_like(second_anchor, device=self.device)
            generated_trans_param_tosave = self.opt_params["trans_param_opt"].clone().detach()
            generated_trans_param = self.opt_params["trans_param_opt"][..., 0]
            generated_trans_param = generated_trans_param.permute(0, 2, 1)
            generated_trans_param_xyz = generated_trans_param[..., :9].reshape(generated_trans_param.shape[0], generated_trans_param.shape[1], 3, 3)
            generated_trans_param_xyz = torch.cat((generated_trans_param_xyz, first_anchor), dim=-1)
            generated_trans_param_t = duration
            # print(itr, "generated_trans_param:", generated_trans_param[0, :, 0])

            for i in range(first_anchor.shape[1]):
                generated_trans_last_keypoint[0][i][0] = generated_trans_param_xyz[0][i][0][0] * generated_trans_param_t[0] ** 3 + generated_trans_param_xyz[0][i][0][1] * generated_trans_param_t[0] ** 2 + generated_trans_param_xyz[0][i][0][2] * generated_trans_param_t[0] + generated_trans_param_xyz[0][i][0][3]
                generated_trans_last_keypoint[0][i][1] = generated_trans_param_xyz[0][i][1][0] * generated_trans_param_t[0] ** 3 + generated_trans_param_xyz[0][i][1][1] * generated_trans_param_t[0] ** 2 + generated_trans_param_xyz[0][i][1][2] * generated_trans_param_t[0] + generated_trans_param_xyz[0][i][1][3]
                generated_trans_last_keypoint[0][i][2] = generated_trans_param_xyz[0][i][2][0] * generated_trans_param_t[0] ** 3 + generated_trans_param_xyz[0][i][2][1] * generated_trans_param_t[0] ** 2 + generated_trans_param_xyz[0][i][2][2] * generated_trans_param_t[0] + generated_trans_param_xyz[0][i][2][3]

            losses = self.loss(generated_trans_last_keypoint, second_anchor)
            losses.backward()
            optimizer.step()
            scheduler.step()
        opt_result = {}
        opt_result["trans_param_opt_xyz"] = generated_trans_param_xyz
        opt_result["trans_param_opt"] = generated_trans_param_tosave

        return opt_result
    
    
class TransOptim(Module):
    def __init__(self, args):
        super(TransOptim, self).__init__()
        self.device = torch.device("cuda")
        self.args = args
        self.config_optimizers()

    def config_optimizers(self):
        bs = 1
        self.bs = bs
        self.num_iters = self.args.optim_iterations
        self.policy = {"name": "Poly", "power": 0.95}
        self.opt_params = {
            "trans_param_opt": torch.randn(bs, 9, self.args.node_n, 1, device=self.device, requires_grad=True)
        }
        self.optimizer = torch.optim.Adam([self.opt_params[k] for k in ["trans_param_opt"]], lr=self.args.optim_lr)
        self.loss = nn.MSELoss(reduction="mean")
    
    def fitting(self, trans_param, anchor_pair, duration):
        optimizer = self.optimizer
        scheduler = get_lr_scheduler(self.policy, optimizer, max_iter=self.num_iters)
        self.opt_params["trans_param_opt"].data = trans_param.clone().detach()
        second_anchor = anchor_pair.clone().detach()
        second_anchor = second_anchor.permute(0, 3, 2, 1).contiguous()[0, 1:2]
        first_anchor = anchor_pair.clone().detach()
        first_anchor = first_anchor.permute(0, 3, 2, 1).contiguous()[0, 0:1]
        first_anchor = first_anchor.unsqueeze(-1) # [1, 25, 3, 1]
 
        for itr in range(1, self.num_iters+1):
            # print("itr:", itr)
            optimizer.zero_grad()
            generated_trans_last_keypoint = torch.zeros_like(second_anchor, device=self.device)
            generated_trans_param_tosave = self.opt_params["trans_param_opt"].clone().detach()
            generated_trans_param = self.opt_params["trans_param_opt"][..., 0]
            generated_trans_param = generated_trans_param.permute(0, 2, 1)
            generated_trans_param_xyz = generated_trans_param[..., :9].reshape(generated_trans_param.shape[0], generated_trans_param.shape[1], 3, 3)
            generated_trans_param_xyz = torch.cat((generated_trans_param_xyz, first_anchor), dim=-1)
            generated_trans_param_t = duration
            # print(itr, "generated_trans_param:", generated_trans_param[0, :, 0])

            for i in range(first_anchor.shape[1]):
                generated_trans_last_keypoint[0][i][0] = generated_trans_param_xyz[0][i][0][0] * generated_trans_param_t[0] ** 3 + generated_trans_param_xyz[0][i][0][1] * generated_trans_param_t[0] ** 2 + generated_trans_param_xyz[0][i][0][2] * generated_trans_param_t[0] + generated_trans_param_xyz[0][i][0][3]
                generated_trans_last_keypoint[0][i][1] = generated_trans_param_xyz[0][i][1][0] * generated_trans_param_t[0] ** 3 + generated_trans_param_xyz[0][i][1][1] * generated_trans_param_t[0] ** 2 + generated_trans_param_xyz[0][i][1][2] * generated_trans_param_t[0] + generated_trans_param_xyz[0][i][1][3]
                generated_trans_last_keypoint[0][i][2] = generated_trans_param_xyz[0][i][2][0] * generated_trans_param_t[0] ** 3 + generated_trans_param_xyz[0][i][2][1] * generated_trans_param_t[0] ** 2 + generated_trans_param_xyz[0][i][2][2] * generated_trans_param_t[0] + generated_trans_param_xyz[0][i][2][3]

            losses = self.loss(generated_trans_last_keypoint, second_anchor)
            # print("==> iter:", itr, "second_anchor_mpjpe:", reconstruction_loss_mpjpe(generated_trans_last_keypoint.unsqueeze(0), second_anchor.unsqueeze(0)).item())
            losses.backward()
            optimizer.step()
            scheduler.step()
        opt_result = {}
        opt_result["trans_param_opt_xyz"] = generated_trans_param_xyz
        opt_result["trans_param_opt"] = generated_trans_param_tosave

        return opt_result
