import os
import torch
import time
import pickle
import random
import numpy as np
import torch.nn as nn
from model.model import TransOptim
from model.stgcn import STGCN
from evaluate.action2motion.models import load_classifier, load_classifier_for_fid
from evaluate.action2motion.fid import calculate_fid
from evaluate.action2motion.diversity import calculate_diversity_multimodality
from utils.losses import kl_divergence_loss, reconstruction_loss_mpjpe
from utils.util import calculate_activation_statistics, transform_torch, visualize_single_video, visualize_anchor, calculate_classifier_metrics, spline_fit


humanact12_limbs = [(0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), \
                    (9, 12), (12, 15), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 14), (14, 17), (17, 19), (19, 21), (21, 23)]

uestc_list = ['alternate-knee-lifting', 'arm-circling', 'bent-over-twist', 'deltoid-muscle-stretching', 'dumbbell-one-arm-shoulder-pressing', 'dumbbell-shrugging', 'dumbbell-side-bend', 'elbow-circling', 'forward-lunging', 'front-raising', 'head-anticlockwise-circling', 'high-knees-running', 'jumping-jack', 'knee-circling', 'knee-to-chest', 'left-kicking', 'left-lunging', 'left-stretching', 'marking-time-and-knee-lifting', 'overhead-stretching', 'pinching-back', 'pulling-chest-expanders', 'punching', 'punching-and-knee-lifting', 'raising-hand-and-jumping', 'rope-skipping', 'rotation-clapping', 'shoulder-abduction', 'shoulder-raising', 'single-dumbbell-raising', 'single-leg-lateral-hopping', 'spinal-stretching', 'squatting', 'standing-gastrocnemius-calf', 'standing-opposite-elbow-to-knee-crunch', 'standing-rotation', 'standing-toe-touches', 'straight-forward-flexion', 'upper-back-stretching', 'wrist-circling']


def train_SequenceAE(args, model, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict):
    start_time = time.time()
    train_loss, test_loss = 0.0, 0.0
    model.train()
    for data in train_loader:
        optimizer.zero_grad()
        sequence_data, complete_sequence_data, mask_data, seq_len = data["keypoint_sequence"].cuda(), data["complete_keypoint_sequence"].cuda(), data["mask_data"], data["seq_len"]
        output = model(sequence_data, mask=mask_data)
        bs = sequence_data.shape[0]
        loss = reconstruction_loss_mpjpe(output, complete_sequence_data)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()

    model.eval()
    with torch.no_grad():
        for data in test_loader:
            sequence_data, complete_sequence_data, mask_data, seq_len = data["keypoint_sequence"].cuda(), data["complete_keypoint_sequence"].cuda(), data["mask_data"], data["seq_len"]
            output = model(sequence_data, mask=mask_data)
            loss = reconstruction_loss_mpjpe(output, complete_sequence_data)
            test_loss += loss.item()
            
    end_time = time.time()
    train_loss /= len(train_loader)
    test_loss /= len(test_loader)
    print('Epoch: {}, Train Loss: {:.4f}, Test Loss: {:.4f}, Time: {:.2f}s'.format(epoch, train_loss, test_loss, end_time-start_time))

    loss = {}
    loss['train_loss'] = train_loss
    loss['test_loss'] = test_loss
    
    if test_loss < perf_dict['AE']:
        perf_dict['AE'] = test_loss
        save_path = os.path.join(save_dir, 'best.pth')
        torch.save(model.state_dict(), save_path)
        print('Model saved at {}'.format(save_path))
    
    if (epoch+1) % args.save_freq == 0:
        save_path = os.path.join(save_dir, 'epoch_{}.pth'.format(epoch+1))
        torch.save(model.state_dict(), save_path)
        print('Model saved at {}'.format(save_path))
    
    return loss, perf_dict


def eval_SequenceAE(args, model, test_loader, save_dir, limbs=humanact12_limbs):
    test_loss = 0.0
    test_gt_keypoint_sequence, test_reconstructed_keypoint_sequence = [], []
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            sequence_data, complete_sequence_data, mask_data, seq_len = data["keypoint_sequence"].cuda(), data["complete_keypoint_sequence"].cuda(), data["mask_data"], data["seq_len"]
            output = model(sequence_data, mask=mask_data)
            for seq in complete_sequence_data.cpu().numpy():
                test_gt_keypoint_sequence.append(seq)
            for seq in output.cpu().numpy():
                test_reconstructed_keypoint_sequence.append(seq)
            loss = reconstruction_loss_mpjpe(output, complete_sequence_data)
            test_loss += loss.item()
            
    test_loss /= len(test_loader)
    print('Test Loss: {:.4f}'.format(test_loss))


def train_arn(args, model, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict):
    start_time = time.time()
    train_loss, test_loss = 0.0, 0.0
    
    train_correct, test_correct, train_anchor_correct, test_anchor_correct = 0, 0, 0, 0
    train_count, test_count, train_anchor_count, test_anchor_count = 0, 0, 0, 0
    train_tp, train_fp, train_tn, train_fn, test_tp, test_fp, test_tn, test_fn = 0, 0, 0, 0, 0, 0, 0, 0
    criterion = nn.BCELoss()
    model.train()
    for data in train_loader:
        optimizer.zero_grad()
        sequence_data, anchor_pos_label, mask = data["keypoint_sequence"].cuda(), data["anchor_pos_label"].cuda(), data["mask_data"]
        bs = sequence_data.shape[0]
        output = model(sequence_data, mask=mask)

        loss = criterion(output, anchor_pos_label.unsqueeze(-1))
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()
        preds = (output > 0.5).float()
        train_correct += torch.sum(preds == (anchor_pos_label > 0.5).float().unsqueeze(-1)).item()
        train_count += anchor_pos_label.shape[0] * anchor_pos_label.shape[1]
        for i in range(bs):
            tp, tn, fp, fn = calculate_classifier_metrics(preds[i, :, 0], anchor_pos_label[i])
            for j in range(anchor_pos_label[i].shape[0]):
                if anchor_pos_label[i][j] == 1 or anchor_pos_label[i][j] == 0.8 or anchor_pos_label[i][j] == 0.6:
                    train_anchor_count += 1
                    if preds[i][j] == 1:
                        train_anchor_correct += 1
            train_tp += tp
            train_fp += fp
            train_tn += tn
            train_fn += fn
        
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            sequence_data, anchor_pos_label, mask = data["keypoint_sequence"].cuda(), data["anchor_pos_label"].cuda(), data["mask_data"]
            bs = sequence_data.shape[0]
            output = model(sequence_data, mask=mask)
            loss = criterion(output, anchor_pos_label.unsqueeze(-1))
            
            test_loss += loss.item()
            preds = (output > 0.5).float()
            test_correct += torch.sum(preds == (anchor_pos_label > 0.5).float().unsqueeze(-1)).item()
            test_count += anchor_pos_label.shape[0] * anchor_pos_label.shape[1]
            for i in range(bs):
                tp, tn, fp, fn = calculate_classifier_metrics(preds[i, :, 0], anchor_pos_label[i])
                for j in range(anchor_pos_label[i].shape[0]):
                    if anchor_pos_label[i][j] == 1 or anchor_pos_label[i][j] == 0.8 or anchor_pos_label[i][j] == 0.6:
                        test_anchor_count += 1
                        if preds[i][j] == 1:
                            test_anchor_correct += 1
                test_tp += tp
                test_tn += tn
                test_fp += fp
                test_fn += fn
            
    end_time = time.time()
    train_loss /= len(train_loader)
    test_loss /= len(test_loader)
    train_acc = train_correct / train_count
    test_acc = test_correct / test_count
    train_anchor_acc = train_anchor_correct / train_anchor_count
    test_anchor_acc = test_anchor_correct / test_anchor_count
    train_tp = train_tp / train_count
    test_tp = test_tp / test_count
    train_tn = train_tn / train_count
    test_tn = test_tn / test_count
    train_fp = train_fp / train_count
    test_fp = test_fp / test_count
    train_fn = train_fn / train_count
    test_fn = test_fn / test_count
    
    print('Epoch: {}, Train Loss: {:.4f}, Test Loss: {:.4f}, Train Acc: {:.4f}, Test Acc: {:.4f}, Train Anchor Acc: {:.4f}, Test Anchor Acc: {:.4f}, Train TP: {:.4f}, Test TP: {:.4f}, Train FP: {:.4f}, Test FP: {:.4f}, Train TN: {:.4f}, Test TN: {:.4f}, Train FN: {:.4f}, Test FN: {:.4f}, Time: {:.2f}s'.format(
        epoch, train_loss, test_loss, train_acc, test_acc, train_anchor_acc, test_anchor_acc, train_tp, test_tp, train_fp, test_fp, train_tn, test_tn, train_fn, test_fn, end_time - start_time))

    loss = {}
    loss['train_loss'] = train_loss
    loss['test_loss'] = test_loss
    loss['train_acc'] = train_acc
    loss['test_acc'] = test_acc
    loss['train_anchor_acc'] = train_anchor_acc
    loss['test_anchor_acc'] = test_anchor_acc
    loss['train_tp'] = train_tp
    loss['test_tp'] = test_tp
    loss['train_tn'] = train_tn
    loss['test_tn'] = test_tn
    loss['train_fp'] = train_fp
    loss['test_fp'] = test_fp
    loss['train_fn'] = train_fn
    loss['test_fn'] = test_fn
    
    if test_acc > perf_dict["ARN"][0]:
        perf_dict["ARN"][0] = test_acc
        save_path = os.path.join(save_dir, 'best_acc.pth')
        torch.save(model.state_dict(), save_path)
        print('Model saved at {}'.format(save_path))
    
    if test_anchor_acc > perf_dict["ARN"][1]:
        perf_dict["ARN"][1] = test_anchor_acc
        save_path = os.path.join(save_dir, 'best_anchor_acc.pth')
        torch.save(model.state_dict(), save_path)
        print('Model saved at {}'.format(save_path))
    
    if (epoch+1) % 100 == 0:
        save_path = os.path.join(save_dir, 'epoch_{}.pth'.format(epoch+1))
        torch.save(model.state_dict(), save_path)
        print('Model saved at {}'.format(save_path))
    
    return loss, perf_dict
    

def eval_arn(model, train_loader, test_loader, save_dir):
    start_time = time.time()
    train_loss, test_loss = 0.0, 0.0
    train_correct, test_correct, train_anchor_correct, test_anchor_correct = 0, 0, 0, 0
    train_count, test_count, train_anchor_count, test_anchor_count = 0, 0, 0, 0
    train_tp, train_fp, train_tn, train_fn, test_tp, test_fp, test_tn, test_fn = 0, 0, 0, 0, 0, 0, 0, 0
    criterion = nn.BCELoss()
    model.eval()
    train_gt_anchor_pos_label, test_gt_anchor_pos_label = [], []
    train_preds, test_preds = [], []
    with torch.no_grad():
        for data in train_loader:
            sequence_data, anchor_pos_label, mask = data["keypoint_sequence"].cuda(), data["anchor_pos_label"].cuda(), data["mask_data"]
            bs = sequence_data.shape[0]
            output = model(sequence_data, mask=mask)
            loss = criterion(output, anchor_pos_label.unsqueeze(-1))
            train_loss += loss.item()
            preds = (output > 0.5).float()
            train_correct += torch.sum(preds == (anchor_pos_label > 0.5).float().unsqueeze(-1)).item()
            train_count += anchor_pos_label.shape[0] * anchor_pos_label.shape[1]
            for i in range(bs):
                train_gt_anchor_pos_label.append(anchor_pos_label[i].cpu().numpy())
                train_preds.append(preds[i, :, 0].cpu().numpy())
                tp, tn, fp, fn = calculate_classifier_metrics(preds[i, :, 0], anchor_pos_label[i])
                for j in range(anchor_pos_label[i].shape[0]):
                    if anchor_pos_label[i][j] == 1 or anchor_pos_label[i][j] == 0.8 or anchor_pos_label[i][j] == 0.6:
                        train_anchor_count += 1
                        if preds[i][j] == 1:
                            train_anchor_correct += 1
                train_tp += tp
                train_fp += fp
                train_tn += tn
                train_fn += fn
        for data in test_loader:
            sequence_data, anchor_pos_label, mask = data["keypoint_sequence"].cuda(), data["anchor_pos_label"].cuda(), data["mask_data"]
            bs = sequence_data.shape[0]
            output = model(sequence_data, mask=mask)
            loss = criterion(output, anchor_pos_label.unsqueeze(-1))
            test_loss += loss.item()
            preds = (output > 0.5).float()
            test_correct += torch.sum(preds == (anchor_pos_label > 0.5).float().unsqueeze(-1)).item()
            test_count += anchor_pos_label.shape[0] * anchor_pos_label.shape[1]
            for i in range(bs):
                test_gt_anchor_pos_label.append(anchor_pos_label[i].cpu().numpy())
                test_preds.append(preds[i, :, 0].cpu().numpy())
                tp, tn, fp, fn = calculate_classifier_metrics(preds[i, :, 0], anchor_pos_label[i])
                for j in range(anchor_pos_label[i].shape[0]):
                    if anchor_pos_label[i][j] == 1 or anchor_pos_label[i][j] == 0.8 or anchor_pos_label[i][j] == 0.6:
                        test_anchor_count += 1
                        if preds[i][j] == 1:
                            test_anchor_correct += 1
                test_tp += tp
                test_fp += fp
                test_tn += tn
                test_fn += fn
                
    end_time = time.time()
    train_loss /= len(train_loader)
    test_loss /= len(test_loader)
    train_acc = train_correct / train_count
    test_acc = test_correct / test_count
    train_anchor_acc = train_anchor_correct / train_anchor_count
    test_anchor_acc = test_anchor_correct / test_anchor_count
    train_tp = train_tp / train_count
    test_tp = test_tp / test_count
    train_tn = train_tn / train_count
    test_tn = test_tn / test_count
    train_fp = train_fp / train_count
    test_fp = test_fp / test_count
    train_fn = train_fn / train_count
    test_fn = test_fn / train_count
    
    print('Train Loss: {:.4f}, Test Loss: {:.4f}, Train Acc: {:.4f}, Test Acc: {:.4f}, Train Anchor Acc: {:.4f}, Test Anchor Acc: {:.4f}, Train TP: {:.4f}, Test TP: {:.4f}, Train FP: {:.4f}, Test FP: {:.4f}, Train TN: {:.4f}, Test TN: {:.4f}, Train FN: {:.4f}, Test FN: {:.4f}, Time: {:.2f}s'.format(
        train_loss, test_loss, train_acc, test_acc, train_anchor_acc, test_anchor_acc, train_tp, test_tp, train_fp, test_fp, train_tn, test_tn, train_fn, test_fn, end_time - start_time))
    
    np.save(os.path.join(save_dir, "train_gt_anchor_pos_label.npy"), np.array(train_gt_anchor_pos_label))
    np.save(os.path.join(save_dir, "test_gt_anchor_pos_label.npy"), np.array(test_gt_anchor_pos_label))
    np.save(os.path.join(save_dir, "train_preds.npy"), np.array(train_preds))
    np.save(os.path.join(save_dir, "test_preds.npy"), np.array(test_preds))


def train_trans_generator(args, autoencoder, trans_generator, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict):
    start_time = time.time()
    train_loss, train_recon_loss, train_kl_loss, train_second_anchor_loss = 0.0, 0.0, 0.0, 0.0
    test_loss, test_recon_loss, test_kl_loss, test_second_anchor_loss = 0.0, 0.0, 0.0, 0.0
    
    autoencoder.eval()
    trans_generator.train()
    for batch in train_loader:
        full_keypoint_sequence, mask_data, anchor_pair, transition, duration = batch["full_keypoint_sequence"].cuda(), batch["mask_data"], batch["anchor_pair"].cuda(), batch["transition"].cuda(), batch["duration"].cuda()
        transition = transition.unsqueeze(-1) # [bs, num_joints, 9, 1]
        anchor_pair = anchor_pair.permute(0, 2, 3, 1).contiguous() # [bs, num_joints, 3, 2]
        duration = duration.unsqueeze(-1)
        first_anchor = anchor_pair[:, :, :, 0:1].clone().detach()
        second_anchor = anchor_pair[:, :, :, 1:2].clone().detach()
        second_anchor = second_anchor.permute(0, 3, 1, 2).contiguous()
        
        optimizer.zero_grad()
        with torch.no_grad():
            global_feature = autoencoder.encoder(full_keypoint_sequence, mask_data) # [bs, seq_len, dim]
        reconstructed, mu, logvar = trans_generator(transition, anchor_pair, duration, global_feature)
        reconstructed = reconstructed[..., 0]
        reconstructed_xyz = reconstructed[..., :9].reshape(reconstructed.shape[0], reconstructed.shape[1], 3, 3)
        reconstructed_xyz = torch.cat([reconstructed_xyz, first_anchor], dim=-1)
        reconstructed_keypoint = transform_torch(reconstructed_xyz, duration+1)
        transition_data = transition[..., 0]
        transition_xyz = transition_data[..., :9].reshape(transition_data.shape[0], transition_data.shape[1], 3, 3)
        transition_xyz = torch.cat([transition_xyz, first_anchor], dim=-1)
        gt_keypoint = transform_torch(transition_xyz, duration+1)
        second_anchor_positions = duration.long()
        reconstructed_second_anchor = reconstructed_keypoint[torch.arange(reconstructed_keypoint.shape[0]), second_anchor_positions.squeeze()]
        reconstructed_second_anchor = reconstructed_second_anchor.unsqueeze(1)
        
        train_recon_loss += reconstruction_loss_mpjpe(reconstructed_keypoint, gt_keypoint).item()
        train_kl_loss += kl_divergence_loss(mu, logvar).item()
        train_second_anchor_loss += reconstruction_loss_mpjpe(reconstructed_second_anchor, second_anchor).item()
        loss = args.trans_kl_factor * kl_divergence_loss(mu, logvar) + reconstruction_loss_mpjpe(reconstructed_keypoint, gt_keypoint) + args.trans_loss_weight * reconstruction_loss_mpjpe(reconstructed_second_anchor, second_anchor)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        scheduler.step()
    
    trans_generator.eval()
    with torch.no_grad():
        for batch in test_loader:
            full_keypoint_sequence, mask_data, anchor_pair, transition, duration = batch["full_keypoint_sequence"].cuda(), batch["mask_data"], batch["anchor_pair"].cuda(), batch["transition"].cuda(), batch["duration"].cuda()
            transition = transition.unsqueeze(-1) # [bs, num_joints, 9, 1]
            anchor_pair = anchor_pair.permute(0, 2, 3, 1).contiguous() # [bs, num_joints, 3, 2]
            duration = duration.unsqueeze(-1)
            first_anchor = anchor_pair[:, :, :, 0:1].clone().detach()
            second_anchor = anchor_pair[:, :, :, 1:2].clone().detach()
            second_anchor = second_anchor.permute(0, 3, 1, 2).contiguous()
            
            with torch.no_grad():
                global_feature = autoencoder.encoder(full_keypoint_sequence, mask_data) # [bs, seq_len, dim]
            reconstructed, mu, logvar = trans_generator(transition, anchor_pair, duration, global_feature)
            reconstructed = reconstructed[..., 0]
            reconstructed_xyz = reconstructed[..., :9].reshape(reconstructed.shape[0], reconstructed.shape[1], 3, 3)
            reconstructed_xyz = torch.cat([reconstructed_xyz, first_anchor], dim=-1)
            reconstructed_keypoint = transform_torch(reconstructed_xyz, duration+1)
            transition_data = transition[..., 0]
            transition_xyz = transition_data[..., :9].reshape(transition_data.shape[0], transition_data.shape[1], 3, 3)
            transition_xyz = torch.cat([transition_xyz, first_anchor], dim=-1)
            gt_keypoint = transform_torch(transition_xyz, duration+1)
            second_anchor_positions = duration.long()
            reconstructed_second_anchor = reconstructed_keypoint[torch.arange(reconstructed_keypoint.shape[0]), second_anchor_positions.squeeze()]
            reconstructed_second_anchor = reconstructed_second_anchor.unsqueeze(1)
            
            test_recon_loss += reconstruction_loss_mpjpe(reconstructed_keypoint, gt_keypoint).item()
            test_kl_loss += kl_divergence_loss(mu, logvar).item()
            test_second_anchor_loss += reconstruction_loss_mpjpe(reconstructed_second_anchor, second_anchor).item()
            loss = args.trans_kl_factor * kl_divergence_loss(mu, logvar) + reconstruction_loss_mpjpe(reconstructed_keypoint, gt_keypoint) + args.trans_loss_weight * reconstruction_loss_mpjpe(reconstructed_second_anchor, second_anchor)
            
            test_loss += loss.item()
    end_time = time.time()
    train_loss /= len(train_loader)
    train_recon_loss /= len(train_loader)
    train_kl_loss /= len(train_loader)
    train_second_anchor_loss /= len(train_loader)
    test_loss /= len(test_loader)
    test_recon_loss /= len(test_loader)
    test_kl_loss /= len(test_loader)
    test_second_anchor_loss /= len(test_loader)
    print("Epoch: {}, Train Loss: {:.4f}, Test Loss: {:.4f}, Train Recon Loss: {:.4f}, Test Recon Loss: {:.4f}, Train KL Loss: {:.4f}, Test KL Loss: {:.4f}, Train Second Anchor Loss: {:.4f}, Test Second Anchor Loss: {:.4f}, Time: {:.4f}".format(
        epoch, train_loss, test_loss, train_recon_loss, test_recon_loss, train_kl_loss, test_kl_loss, train_second_anchor_loss, test_second_anchor_loss, end_time - start_time
    ))
    
    loss = {}
    loss['train_loss'] = train_loss
    loss['train_recon_loss'] = train_recon_loss
    loss['train_kl_loss'] = train_kl_loss
    loss['train_second_anchor_loss'] = train_second_anchor_loss
    loss['test_loss'] = test_loss
    loss['test_recon_loss'] = test_recon_loss
    loss['test_kl_loss'] = test_kl_loss
    loss['test_second_anchor_loss'] = test_second_anchor_loss
    
    if test_recon_loss < perf_dict["TransGenerator"][0]:
        perf_dict["TransGenerator"][0] = test_recon_loss
        perf_dict["TransGenerator"][1] = test_second_anchor_loss
        save_path = os.path.join(save_dir, "best.pth")
        torch.save(trans_generator.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    if (epoch+1) % args.save_freq == 0:
        save_path = os.path.join(save_dir, "epoch_{}.pth".format(epoch+1))
        torch.save(trans_generator.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    return loss, perf_dict


def eval_trans_generator(args, autoencoder, trans_generator, train_loader, test_loader):
    start_time = time.time()
    train_loss, train_recon_loss, train_kl_loss, train_second_anchor_loss = 0.0, 0.0, 0.0, 0.0
    test_loss, test_recon_loss, test_kl_loss, test_second_anchor_loss = 0.0, 0.0, 0.0, 0.0
    
    autoencoder.eval()
    trans_generator.eval()
    with torch.no_grad():
        for batch in train_loader:
            full_keypoint_sequence, mask_data, anchor_pair, transition, duration = batch["full_keypoint_sequence"].cuda(), batch["mask_data"], batch["anchor_pair"].cuda(), batch["transition"].cuda(), batch["duration"].cuda()
            transition = transition.unsqueeze(-1)
            anchor_pair = anchor_pair.permute(0, 2, 3, 1).contiguous()
            duration = duration.unsqueeze(-1)
            first_anchor = anchor_pair[:, :, :, 0:1].clone().detach()
            second_anchor = anchor_pair[:, :, :, 1:2].clone().detach()
            second_anchor = second_anchor.permute(0, 3, 1, 2).contiguous()
            
            with torch.no_grad():
                global_feature = autoencoder.encoder(full_keypoint_sequence, mask_data)
            reconstructed, mu, logvar = trans_generator(transition, anchor_pair, duration, global_feature)
            reconstructed = reconstructed[..., 0]
            reconstructed_xyz = reconstructed[..., :9].reshape(reconstructed.shape[0], reconstructed.shape[1], 3, 3)
            reconstructed_xyz = torch.cat([reconstructed_xyz, first_anchor], dim=-1)
            reconstructed_keypoint = transform_torch(reconstructed_xyz, duration+1)
            transition_data = transition[..., 0]
            transition_xyz = transition_data[..., :9].reshape(transition_data.shape[0], transition_data.shape[1], 3, 3)
            transition_xyz = torch.cat([transition_xyz, first_anchor], dim=-1)
            gt_keypoint = transform_torch(transition_xyz, duration+1)
            second_anchor_positions = duration.long()
            reconstructed_second_anchor = reconstructed_keypoint[torch.arange(reconstructed_keypoint.shape[0]), second_anchor_positions.squeeze()]
            reconstructed_second_anchor = reconstructed_second_anchor.unsqueeze(1)
            
            train_recon_loss += reconstruction_loss_mpjpe(reconstructed_keypoint, gt_keypoint).item()
            train_kl_loss += kl_divergence_loss(mu, logvar).item()
            train_second_anchor_loss += reconstruction_loss_mpjpe(reconstructed_second_anchor, second_anchor).item()
            loss = args.trans_kl_factor * kl_divergence_loss(mu, logvar) + reconstruction_loss_mpjpe(reconstructed_keypoint, gt_keypoint) + args.trans_loss_weight * reconstruction_loss_mpjpe(reconstructed_second_anchor, second_anchor)
            train_loss += loss.item()
        
        for batch in test_loader:
            full_keypoint_sequence, mask_data, anchor_pair, transition, duration = batch["full_keypoint_sequence"].cuda(), batch["mask_data"], batch["anchor_pair"].cuda(), batch["transition"].cuda(), batch["duration"].cuda()
            transition = transition.unsqueeze(-1)
            anchor_pair = anchor_pair.permute(0, 2, 3, 1).contiguous()
            duration = duration.unsqueeze(-1)
            first_anchor = anchor_pair[:, :, :, 0:1].clone().detach()
            second_anchor = anchor_pair[:, :, :, 1:2].clone().detach()
            second_anchor = second_anchor.permute(0, 3, 1, 2).contiguous()
            
            with torch.no_grad():
                global_feature = autoencoder.encoder(full_keypoint_sequence, mask_data)
            reconstructed, mu, logvar = trans_generator(transition, anchor_pair, duration, global_feature)
            reconstructed = reconstructed[..., 0]
            reconstructed_xyz = reconstructed[..., :9].reshape(reconstructed.shape[0], reconstructed.shape[1], 3, 3)
            reconstructed_xyz = torch.cat([reconstructed_xyz, first_anchor], dim=-1)
            reconstructed_keypoint = transform_torch(reconstructed_xyz, duration+1)
            transition_data = transition[..., 0]
            transition_xyz = transition_data[..., :9].reshape(transition_data.shape[0], transition_data.shape[1], 3, 3)
            transition_xyz = torch.cat([transition_xyz, first_anchor], dim=-1)
            gt_keypoint = transform_torch(transition_xyz, duration+1)
            second_anchor_positions = duration.long()
            reconstructed_second_anchor = reconstructed_keypoint[torch.arange(reconstructed_keypoint.shape[0]), second_anchor_positions.squeeze()]
            reconstructed_second_anchor = reconstructed_second_anchor.unsqueeze(1)
            
            test_recon_loss += reconstruction_loss_mpjpe(reconstructed_keypoint, gt_keypoint).item()
            test_kl_loss += kl_divergence_loss(mu, logvar).item()
            test_second_anchor_loss += reconstruction_loss_mpjpe(reconstructed_second_anchor, second_anchor).item()
            loss = args.trans_kl_factor * kl_divergence_loss(mu, logvar) + reconstruction_loss_mpjpe(reconstructed_keypoint, gt_keypoint) + args.trans_loss_weight * reconstruction_loss_mpjpe(reconstructed_second_anchor, second_anchor)
            test_loss += loss.item()
    end_time = time.time()
    train_loss /= len(train_loader)
    train_recon_loss /= len(train_loader)
    train_kl_loss /= len(train_loader)
    train_second_anchor_loss /= len(train_loader)
    test_loss /= len(test_loader)
    test_recon_loss /= len(test_loader)
    test_kl_loss /= len(test_loader)
    test_second_anchor_loss /= len(test_loader)
    print("Train Loss: {:.4f}, Test Loss: {:.4f}, Train Recon Loss: {:.4f}, Test Recon Loss: {:.4f}, Train KL Loss: {:.4f}, Test KL Loss: {:.4f}, Train Second Anchor Loss: {:.4f}, Test Second Anchor Loss: {:.4f}, Time: {:.4f}".format(
        train_loss, test_loss, train_recon_loss, test_recon_loss, train_kl_loss, test_kl_loss, train_second_anchor_loss, test_second_anchor_loss, end_time - start_time
    ))
    
    
def train_afn_classifier(args, afn, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict):
    start_time = time.time()
    train_loss, train_classifier_loss, test_loss, test_classifier_loss = 0.0, 0.0, 0.0, 0.0
    train_correct, test_correct, train_count, test_count = 0, 0, 0, 0
    classifier_criterion = nn.CrossEntropyLoss()
    
    afn.train()
    for batch in train_loader:
        keypoint_sequence, anchor_class, anchor_pos, mask_data = batch["keypoint_sequence"].cuda(), batch["anchor_class"].cuda(), batch["anchor_pos"].cuda(), batch["mask_data"]
        optimizer.zero_grad()
        output_label = afn(keypoint_sequence, anchor_pos, mask_data, training_classifier=True)
        train_classifier_loss += classifier_criterion(output_label, anchor_class).item()
        train_count += anchor_class.shape[0]
        train_correct += torch.sum(torch.argmax(output_label, dim=-1) == anchor_class).item()
        loss = classifier_criterion(output_label, anchor_class)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
    afn.eval()
    with torch.no_grad():
        for batch in test_loader:
            keypoint_sequence, anchor_class, anchor_pos, mask_data = batch["keypoint_sequence"].cuda(), batch["anchor_class"].cuda(), batch["anchor_pos"].cuda(), batch["mask_data"]
            output_label = afn(keypoint_sequence, anchor_pos, mask_data, training_classifier=True)
            test_classifier_loss += classifier_criterion(output_label, anchor_class).item()
            test_count += anchor_class.shape[0]
            test_correct += torch.sum(torch.argmax(output_label, dim=-1) == anchor_class).item()
            loss = classifier_criterion(output_label, anchor_class)
            test_loss += loss.item()
            
    end_time = time.time()
    train_loss /= len(train_loader)
    train_classifier_loss /= len(train_loader)
    train_acc = train_correct / train_count
    test_loss /= len(test_loader)
    test_classifier_loss /= len(test_loader)
    test_acc = test_correct / test_count
    print("Epoch: {}, Train Loss: {:.4f}, Test Loss: {:.4f}, Train Classifier Loss: {:.4f}, Test Classifier Loss: {:.4f}, Train Classifier Acc: {:.4f}, Test Classifier Acc: {:.4f}, Time: {:.4f}".format(
        epoch, train_loss, test_loss, train_classifier_loss, test_classifier_loss, train_acc, test_acc, end_time - start_time
    ))
    
    loss = {}
    loss['train_loss'] = train_loss
    loss['train_classifier_loss'] = train_classifier_loss
    loss['train_acc'] = train_acc
    loss['test_loss'] = test_loss
    loss['test_classifier_loss'] = test_classifier_loss
    loss['test_acc'] = test_acc
    
    if test_acc > perf_dict["AFN"][0]:
        perf_dict["AFN"][0] = test_acc
        save_path = os.path.join(save_dir, "best_encoder.pth")
        torch.save(afn.classifier_encoder.state_dict(), save_path)
        save_path = os.path.join(save_dir, "best_classifier.pth")
        torch.save(afn.classifier.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    if (epoch+1) % args.save_freq == 0:
        save_path = os.path.join(save_dir, "epoch_encoder_{}.pth".format(epoch+1))
        torch.save(afn.classifier_encoder.state_dict(), save_path)
        save_path = os.path.join(save_dir, "epoch_classifier_{}.pth".format(epoch+1))
        torch.save(afn.classifier.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    return loss, perf_dict


def eval_afn_classifier(afn, train_loader, test_loader):
    start_time = time.time()
    train_loss, train_classifier_loss, test_loss, test_classifier_loss = 0.0, 0.0, 0.0, 0.0
    train_correct, test_correct, train_count, test_count = 0, 0, 0, 0
    classifier_criterion = nn.CrossEntropyLoss()
    
    afn.eval()
    with torch.no_grad():
        for batch in train_loader:
            keypoint_sequence, anchor_class, anchor_pos, mask_data = batch["keypoint_sequence"].cuda(), batch["anchor_class"].cuda(), batch["anchor_pos"].cuda(), batch["mask_data"]
            output_label = afn(keypoint_sequence, anchor_pos, mask_data, training_classifier=True)
            train_classifier_loss += classifier_criterion(output_label, anchor_class).item()
            train_count += anchor_class.shape[0]
            train_correct += torch.sum(torch.argmax(output_label, dim=-1) == anchor_class).item()
            loss = classifier_criterion(output_label, anchor_class)
            train_loss += loss.item()
        for batch in test_loader:
            keypoint_sequence, anchor_class, anchor_pos, mask_data = batch["keypoint_sequence"].cuda(), batch["anchor_class"].cuda(), batch["anchor_pos"].cuda(), batch["mask_data"]
            output_label = afn(keypoint_sequence, anchor_pos, mask_data, training_classifier=True)
            test_classifier_loss += classifier_criterion(output_label, anchor_class).item()
            test_count += anchor_class.shape[0]
            test_correct += torch.sum(torch.argmax(output_label, dim=-1) == anchor_class).item()
            loss = classifier_criterion(output_label, anchor_class)
            test_loss += loss.item()
            
    end_time = time.time()
    train_loss /= len(train_loader)
    train_classifier_loss /= len(train_loader)
    train_acc = train_correct / train_count
    test_loss /= len(test_loader)
    test_classifier_loss /= len(test_loader)
    test_acc = test_correct / test_count
    print("Train Loss: {:.4f}, Test Loss: {:.4f}, Train Classifier Loss: {:.4f}, Test Classifier Loss: {:.4f}, Train Classifier Acc: {:.4f}, Test Classifier Acc: {:.4f}, Time: {:.4f}".format(
        train_loss, test_loss, train_classifier_loss, test_classifier_loss, train_acc, test_acc, end_time - start_time
    ))
    

def train_afn_together(args, model, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict):
    start_time = time.time()
    train_loss, train_recon_loss, train_classifier_loss = 0.0, 0.0, 0.0
    test_loss, test_recon_loss, test_classifier_loss = 0.0, 0.0, 0.0
    train_correct, test_correct, train_count, test_count = 0, 0, 0, 0
    classifier_criterion = nn.CrossEntropyLoss()
    
    model.train()
    for batch in train_loader:
        keypoint_sequence, gt_anchor, anchor_class, anchor_class_onehot, anchor_pos, mask_data = batch["keypoint_sequence"].cuda(), batch["gt_anchor"].cuda(), batch["anchor_class"].cuda(), batch["anchor_class_onehot"].cuda(), batch["anchor_pos"].cuda(), batch["mask_data"]
        optimizer.zero_grad()
        output_label, output_anchor = model(keypoint_sequence, anchor_pos, mask_data, anchor_class=anchor_class_onehot, training_together=True)
        gt_anchor = gt_anchor.unsqueeze(1)
        output_anchor = output_anchor.unsqueeze(1)
        train_classifier_loss += classifier_criterion(output_label, anchor_class).item()
        train_recon_loss += reconstruction_loss_mpjpe(output_anchor, gt_anchor).item()
        train_count += anchor_class.shape[0]
        train_correct += torch.sum(torch.argmax(output_label, dim=-1) == anchor_class).item()
        loss = reconstruction_loss_mpjpe(output_anchor, gt_anchor) + args.afn_loss_weight * classifier_criterion(output_label, anchor_class).item()
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            keypoint_sequence, gt_anchor, anchor_class, anchor_class_onehot, anchor_pos, mask_data = batch["keypoint_sequence"].cuda(), batch["gt_anchor"].cuda(), batch["anchor_class"].cuda(), batch["anchor_class_onehot"].cuda(), batch["anchor_pos"].cuda(), batch["mask_data"]
            output_label, output_anchor = model(keypoint_sequence, anchor_pos, mask_data, anchor_class=anchor_class_onehot, training_together=True)
            gt_anchor = gt_anchor.unsqueeze(1)
            output_anchor = output_anchor.unsqueeze(1)
            test_classifier_loss += classifier_criterion(output_label, anchor_class).item()
            test_recon_loss += reconstruction_loss_mpjpe(output_anchor, gt_anchor).item()
            test_count += anchor_class.shape[0]
            test_correct += torch.sum(torch.argmax(output_label, dim=-1) == anchor_class).item()
            loss = reconstruction_loss_mpjpe(output_anchor, gt_anchor) + args.afn_loss_weight * classifier_criterion(output_label, anchor_class).item()
            test_loss += loss.item()
            
    end_time = time.time()
    train_loss /= len(train_loader)
    train_recon_loss /= len(train_loader)
    train_classifier_loss /= len(train_loader)
    train_acc = train_correct / train_count
    test_loss /= len(test_loader)
    test_recon_loss /= len(test_loader)
    test_classifier_loss /= len(test_loader)
    test_acc = test_correct / test_count
    print("Epoch: {}, Train Loss: {:.4f}, Test Loss: {:.4f}, Train Recon Loss: {:.4f}, Test Recon Loss: {:.4f}, Train Classifier Loss: {:.4f}, Test Classifier Loss: {:.4f}, Train Acc: {:.4f}, Test Acc: {:.4f}, Time: {:.4f}".format(
        epoch, train_loss, test_loss, train_recon_loss, test_recon_loss, train_classifier_loss, test_classifier_loss, train_acc, test_acc, end_time - start_time
    ))
    
    loss = {}
    loss['train_loss'] = train_loss
    loss['train_recon_loss'] = train_recon_loss
    loss['train_classifier_loss'] = train_classifier_loss
    loss['train_acc'] = train_acc
    loss['test_loss'] = test_loss
    loss['test_recon_loss'] = test_recon_loss
    loss['test_classifier_loss'] = test_classifier_loss
    loss['test_acc'] = test_acc
    
    if test_acc > perf_dict["AFN"][0]:
        perf_dict["AFN"][0] = test_acc
        save_path = os.path.join(save_dir, "best_acc.pth")
        torch.save(model.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    if test_recon_loss < perf_dict["AFN"][1]:
        perf_dict["AFN"][1] = test_recon_loss
        save_path = os.path.join(save_dir, "best_recon.pth")
        torch.save(model.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    if (epoch+1) % args.save_freq == 0:
        save_path = os.path.join(save_dir, "epoch_{}.pth".format(epoch+1))
        torch.save(model.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    return loss, perf_dict

    
def train_afn(args, afn, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict):
    start_time = time.time()
    train_loss, test_loss = 0.0, 0.0
    
    afn.train()
    for batch in train_loader:
        keypoint_sequence, gt_anchor, anchor_class, anchor_class_onehot, anchor_pos, mask_data = batch["keypoint_sequence"].cuda(), batch["gt_anchor"].cuda(), batch["anchor_class"].cuda(), batch["anchor_class_onehot"].cuda(), batch["anchor_pos"].cuda(), batch["mask_data"]
        optimizer.zero_grad()
        _, output_anchor = afn(keypoint_sequence, anchor_pos, mask_data, anchor_class=anchor_class_onehot)
        gt_anchor = gt_anchor.unsqueeze(1)
        output_anchor = output_anchor.unsqueeze(1)
        loss = reconstruction_loss_mpjpe(output_anchor, gt_anchor)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()

    afn.eval()
    with torch.no_grad():
        for batch in test_loader:
            keypoint_sequence, gt_anchor, anchor_class, anchor_class_onehot, anchor_pos, mask_data = batch["keypoint_sequence"].cuda(), batch["gt_anchor"].cuda(), batch["anchor_class"].cuda(), batch["anchor_class_onehot"].cuda(), batch["anchor_pos"].cuda(), batch["mask_data"]
            _, output_anchor = afn(keypoint_sequence, anchor_pos, mask_data, anchor_class=anchor_class_onehot)
            gt_anchor = gt_anchor.unsqueeze(1)
            output_anchor = output_anchor.unsqueeze(1)
            loss = reconstruction_loss_mpjpe(output_anchor, gt_anchor)
            test_loss += loss.item()
            
    end_time = time.time()
    train_loss /= len(train_loader)
    test_loss /= len(test_loader)
    print("Epoch: {}, Train Loss: {:.4f}, Test Loss: {:.4f}, Time: {:.4f}".format(
        epoch, train_loss, test_loss, end_time - start_time
    ))
    
    loss = {}
    loss['train_loss'] = train_loss
    loss['test_loss'] = test_loss
    
    if test_loss < perf_dict["AFN"][1]:
        perf_dict["AFN"][1] = test_loss
        save_path = os.path.join(save_dir, "best.pth")
        torch.save(afn.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    if (epoch+1) % args.save_freq == 0:
        save_path = os.path.join(save_dir, "epoch_{}.pth".format(epoch+1))
        torch.save(afn.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    return loss, perf_dict


def eval_afn(afn, train_loader, test_loader):
    start_time = time.time()
    train_loss, train_anchor_recon_loss, train_classifier_loss = 0.0, 0.0, 0.0
    test_loss, test_anchor_recon_loss, test_classifier_loss = 0.0, 0.0, 0.0
    train_correct, test_correct, train_count, test_count = 0, 0, 0, 0
    classifier_criterion = nn.CrossEntropyLoss()
    
    afn.eval()
    with torch.no_grad():
        for batch in train_loader:
            keypoint_sequence, gt_anchor, anchor_class, anchor_class_onehot, anchor_pos, mask_data = batch["keypoint_sequence"].cuda(), batch["gt_anchor"].cuda(), batch["anchor_class"].cuda(), batch["anchor_class_onehot"].cuda(), batch["anchor_pos"].cuda(), batch["mask_data"]
            output_label, output_anchor = afn(keypoint_sequence, anchor_pos, mask_data, anchor_class=anchor_class_onehot, testing=True)
            train_classifier_loss += classifier_criterion(output_label, anchor_class).item()
            gt_anchor = gt_anchor.unsqueeze(1)
            output_anchor = output_anchor.unsqueeze(1)
            train_anchor_recon_loss += reconstruction_loss_mpjpe(output_anchor, gt_anchor).item()
            train_count += anchor_class.shape[0]
            train_correct += torch.sum(torch.argmax(output_label, dim=-1) == anchor_class).item()
            
        for batch in test_loader:
            keypoint_sequence, gt_anchor, anchor_class, anchor_class_onehot, anchor_pos, mask_data = batch["keypoint_sequence"].cuda(), batch["gt_anchor"].cuda(), batch["anchor_class"].cuda(), batch["anchor_class_onehot"].cuda(), batch["anchor_pos"].cuda(), batch["mask_data"]
            output_label, output_anchor = afn(keypoint_sequence, anchor_pos, mask_data, anchor_class=anchor_class_onehot, testing=True)
            test_classifier_loss += classifier_criterion(output_label, anchor_class).item()
            gt_anchor = gt_anchor.unsqueeze(1)
            output_anchor = output_anchor.unsqueeze(1)
            test_anchor_recon_loss += reconstruction_loss_mpjpe(output_anchor, gt_anchor).item()
            test_count += anchor_class.shape[0]
            test_correct += torch.sum(torch.argmax(output_label, dim=-1) == anchor_class).item()
    
    end_time = time.time()
    train_loss /= len(train_loader)
    train_classifier_loss /= len(train_loader)
    train_anchor_recon_loss /= len(train_loader)
    train_acc = train_correct / train_count
    test_loss /= len(test_loader)
    test_classifier_loss /= len(test_loader)
    test_anchor_recon_loss /= len(test_loader)
    test_acc = test_correct / test_count
    print("Train Classifier Loss: {:.4f}, Test Classifier Loss: {:.4f}, Train Classifier Acc: {:.4f}, Test Classifier Acc: {:.4f}, Train Anchor Recon Loss: {:.4f}, Test Anchor Recon Loss: {:.4f}, Time: {:.4f}".format(
        train_classifier_loss, test_classifier_loss, train_acc, test_acc, train_anchor_recon_loss, test_anchor_recon_loss, end_time - start_time
    ))
    
    
def train_aprn(args, model, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict):
    start_time = time.time()
    train_loss, train_recon_loss, train_trans_loss = 0.0, 0.0, 0.0
    test_loss, test_recon_loss, test_trans_loss = 0.0, 0.0, 0.0
    
    model.train()
    for batch in train_loader:
        full_keypoint_sequence, keypoint_sequence, transition, duration = batch["full_keypoint_sequence"].cuda(), batch["keypoint_sequence"].cuda(), batch["transition"].cuda(), batch["duration"].cuda()
        duration = duration.unsqueeze(1)
        bs = keypoint_sequence.shape[0]
        optimizer.zero_grad()
        output_trans, output_keypoint_sequence = model(full_keypoint_sequence, keypoint_sequence)
        transition = transition.reshape(transition.shape[0], transition.shape[1], 3, 4)
        output_trans = output_trans.reshape(output_trans.shape[0], output_trans.shape[1], 3, 4)
        trans_gt_keypoint = transform_torch(transition, duration+1)
        trans_out_keypoint = transform_torch(output_trans, duration+1)
        train_trans_loss += reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint).item()
        loss = args.aprn_trans_loss_weight * reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint)
        for i in range(bs):
            train_recon_loss += reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, :duration[i]], keypoint_sequence[i:i+1, :duration[i]]).item()
            loss += reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, :duration[i]], keypoint_sequence[i:i+1, :duration[i]]) + args.aprn_anchor_loss_weight * (reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, 0:1], keypoint_sequence[i:i+1, 0:1]) + reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, duration[i]-1:duration[i]], keypoint_sequence[i:i+1, duration[i]-1:duration[i]]))
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            full_keypoint_sequence, keypoint_sequence, transition, duration = batch["full_keypoint_sequence"].cuda(), batch["keypoint_sequence"].cuda(), batch["transition"].cuda(), batch["duration"].cuda()
            duration = duration.unsqueeze(1)
            bs = keypoint_sequence.shape[0]
            output_trans, output_keypoint_sequence = model(full_keypoint_sequence, keypoint_sequence)
            transition = transition.reshape(transition.shape[0], transition.shape[1], 3, 4)
            output_trans = output_trans.reshape(output_trans.shape[0], output_trans.shape[1], 3, 4)
            trans_gt_keypoint = transform_torch(transition, duration+1)
            trans_out_keypoint = transform_torch(output_trans, duration+1)
            test_trans_loss += reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint).item()
            loss = args.aprn_trans_loss_weight * reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint)
            for i in range(bs):
                test_recon_loss += reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, :duration[i]], keypoint_sequence[i:i+1, :duration[i]]).item()
                loss += reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, :duration[i]], keypoint_sequence[i:i+1, :duration[i]]) + args.aprn_anchor_loss_weight * (reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, 0:1], keypoint_sequence[i:i+1, 0:1]) + reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, duration[i]-1:duration[i]], keypoint_sequence[i:i+1, duration[i]-1:duration[i]]))
            test_loss += loss.item()
            
    end_time = time.time()
    train_loss /= len(train_loader)
    train_recon_loss /= len(train_loader)
    train_trans_loss /= len(train_loader)
    test_loss /= len(test_loader)
    test_recon_loss /= len(test_loader)
    test_trans_loss /= len(test_loader)
    print("Epoch: {}, Train Loss: {:.4f}, Test Loss: {:.4f}, Train Recon Loss: {:.4f}, Test Recon Loss: {:.4f}, Train Trans Loss: {:.4f}, Test Trans Loss: {:.4f}, Time: {:.4f}".format(
        epoch, train_loss, test_loss, train_recon_loss, test_recon_loss, train_trans_loss, test_trans_loss, end_time - start_time
    ))
    
    loss = {}
    loss['train_loss'] = train_loss
    loss['train_recon_loss'] = train_recon_loss
    loss['train_trans_loss'] = train_trans_loss
    loss['test_loss'] = test_loss
    loss['test_recon_loss'] = test_recon_loss
    loss['test_trans_loss'] = test_trans_loss
    
    if test_recon_loss < perf_dict["TRN"][0]:
        perf_dict["TRN"][0] = test_recon_loss
        save_path = os.path.join(save_dir, "best_recon.pth")
        torch.save(model.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    if test_trans_loss < perf_dict["TRN"][1]:
        perf_dict["TRN"][1] = test_trans_loss
        save_path = os.path.join(save_dir, "best_trans.pth")
        torch.save(model.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    if (epoch+1) % args.save_freq == 0:
        save_path = os.path.join(save_dir, "epoch_{}.pth".format(epoch+1))
        torch.save(model.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    return loss, perf_dict


def eval_aprn(args, model, train_loader, test_loader, save_dir, limbs=humanact12_limbs):
    start_time = time.time()
    train_loss, train_recon_loss, train_trans_loss = 0.0, 0.0, 0.0
    test_loss, test_recon_loss, test_trans_loss = 0.0, 0.0, 0.0
    train_gt_keypoint_sequence, train_out_keypoint_sequence = [], []
    test_gt_keypoint_sequence, test_out_keypoint_sequence = [], []
    train_gt_trans_keypoint, train_out_trans_keypoint = [], []
    test_gt_trans_keypoint, test_out_trans_keypoint = [], []
    train_pos_record, test_pos_record = [], []
    criterion = nn.MSELoss()
    
    model.eval()
    with torch.no_grad():
        for batch in train_loader:
            full_keypoint_sequence, keypoint_sequence, transition, duration, pos_record = batch["full_keypoint_sequence"].cuda(), batch["keypoint_sequence"].cuda(), batch["transition"].cuda(), batch["duration"].cuda(), batch["pos_record"]
            duration = duration.unsqueeze(1)
            bs = keypoint_sequence.shape[0]
            output_trans, output_keypoint_sequence = model(full_keypoint_sequence, keypoint_sequence)
            transition = transition.reshape(transition.shape[0], transition.shape[1], 3, 4)
            output_trans = output_trans.reshape(output_trans.shape[0], output_trans.shape[1], 3, 4)
            trans_gt_keypoint = transform_torch(transition, duration+1)
            trans_out_keypoint = transform_torch(output_trans, duration+1)
            train_trans_loss += reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint).item()
            loss = args.aprn_trans_loss_weight * reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint)
            for i in range(bs):
                train_recon_loss += reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, :duration[i]], keypoint_sequence[i:i+1, :duration[i]]).item()
                loss += reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, :duration[i]], keypoint_sequence[i:i+1, :duration[i]]) + args.aprn_anchor_loss_weight * (reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, 0:1], keypoint_sequence[i:i+1, 0:1]) + reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, duration[i]-1:duration[i]], keypoint_sequence[i:i+1, duration[i]-1:duration[i]]))
            train_loss += loss.item()
            for i in range(bs):
                train_gt_keypoint_sequence.append(keypoint_sequence[i, :duration[i]].cpu().numpy())
                train_out_keypoint_sequence.append(output_keypoint_sequence[i, :duration[i]].cpu().numpy())
            for i in range(bs):
                trans_gt_keypoint = transform_torch(transition[i:i+1], duration[i:i+1])
                trans_out_keypoint = transform_torch(output_trans[i:i+1], duration[i:i+1])
                train_gt_trans_keypoint.append(trans_gt_keypoint[0].cpu().numpy())
                train_out_trans_keypoint.append(trans_out_keypoint[0].cpu().numpy())
                train_pos_record.append(pos_record[i].numpy())
        for batch in test_loader:
            full_keypoint_sequence, keypoint_sequence, transition, duration, pos_record = batch["full_keypoint_sequence"].cuda(), batch["keypoint_sequence"].cuda(), batch["transition"].cuda(), batch["duration"].cuda(), batch["pos_record"]
            duration = duration.unsqueeze(1)
            bs = keypoint_sequence.shape[0]
            output_trans, output_keypoint_sequence = model(full_keypoint_sequence, keypoint_sequence)
            transition = transition.reshape(transition.shape[0], transition.shape[1], 3, 4)
            output_trans = output_trans.reshape(output_trans.shape[0], output_trans.shape[1], 3, 4)
            trans_gt_keypoint = transform_torch(transition, duration+1)
            trans_out_keypoint = transform_torch(output_trans, duration+1)
            test_trans_loss += reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint).item()
            loss = args.aprn_trans_loss_weight * reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint)
            for i in range(bs):
                test_recon_loss += reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, :duration[i]], keypoint_sequence[i:i+1, :duration[i]]).item()
                loss += reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, :duration[i]], keypoint_sequence[i:i+1, :duration[i]]) + args.aprn_anchor_loss_weight * (reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, 0:1], keypoint_sequence[i:i+1, 0:1]) + reconstruction_loss_mpjpe(output_keypoint_sequence[i:i+1, duration[i]-1:duration[i]], keypoint_sequence[i:i+1, duration[i]-1:duration[i]]))
            test_loss += loss.item()
            for i in range(bs):
                test_gt_keypoint_sequence.append(keypoint_sequence[i, :duration[i]].cpu().numpy())
                test_out_keypoint_sequence.append(output_keypoint_sequence[i, :duration[i]].cpu().numpy())
            for i in range(bs):
                trans_gt_keypoint = transform_torch(transition[i:i+1], duration[i:i+1])
                trans_out_keypoint = transform_torch(output_trans[i:i+1], duration[i:i+1])
                test_gt_trans_keypoint.append(trans_gt_keypoint[0].cpu().numpy())
                test_out_trans_keypoint.append(trans_out_keypoint[0].cpu().numpy())
                test_pos_record.append(pos_record[i].numpy())
    end_time = time.time()
    train_loss /= len(train_loader)
    train_recon_loss /= len(train_loader)
    train_trans_loss /= len(train_loader)
    test_loss /= len(test_loader)
    test_recon_loss /= len(test_loader)
    test_trans_loss /= len(test_loader)
    print("Train Loss: {:.4f}, Test Loss: {:.4f}, Train Recon Loss: {:.4f}, Test Recon Loss: {:.4f}, Train Trans Loss: {:.4f}, Test Trans Loss: {:.4f}, Time: {:.4f}".format(
        train_loss, test_loss, train_recon_loss, test_recon_loss, train_trans_loss, test_trans_loss, end_time - start_time
    ))
    
    train_gt_keypoint_sequence_together, test_gt_keypoint_sequence_together = [], []
    train_out_keypoint_sequence_together, test_out_keypoint_sequence_together = [], []
    train_sequence_idx, train_slice_idx = [], []
    test_sequence_idx, test_slice_idx = [], []
    for i in range(len(train_gt_keypoint_sequence)):
        train_sequence_idx.append(train_pos_record[i][0])
        train_slice_idx.append(train_pos_record[i][1])
    train_sequence_idx = np.array(train_sequence_idx)
    train_slice_idx = np.array(train_slice_idx)
    num_sequence = np.unique(train_sequence_idx)
    for i in num_sequence:
        tmp_gt_keypoint_sequence, tmp_out_keypoint_sequence = [], []
        idx = np.where(train_sequence_idx == i)[0]
        idx = idx[np.argsort(train_slice_idx[idx])]
        for j in idx:
            tmp_gt_keypoint_sequence.append(train_gt_keypoint_sequence[j])
            tmp_out_keypoint_sequence.append(train_out_keypoint_sequence[j])
        tmp_gt_keypoint_sequence = np.concatenate(tmp_gt_keypoint_sequence, axis=0)
        tmp_out_keypoint_sequence = np.concatenate(tmp_out_keypoint_sequence, axis=0)
        train_gt_keypoint_sequence_together.append(tmp_gt_keypoint_sequence)
        train_out_keypoint_sequence_together.append(tmp_out_keypoint_sequence)
    
    for i in range(len(test_gt_keypoint_sequence)):
        test_sequence_idx.append(test_pos_record[i][0])
        test_slice_idx.append(test_pos_record[i][1])
    test_sequence_idx = np.array(test_sequence_idx)
    test_slice_idx = np.array(test_slice_idx)
    num_sequence = np.unique(test_sequence_idx)
    for i in num_sequence:
        tmp_gt_keypoint_sequence, tmp_out_keypoint_sequence = [], []
        idx = np.where(test_sequence_idx == i)[0]
        idx = idx[np.argsort(test_slice_idx[idx])]
        for j in idx:
            tmp_gt_keypoint_sequence.append(test_gt_keypoint_sequence[j])
            tmp_out_keypoint_sequence.append(test_out_keypoint_sequence[j])
        tmp_gt_keypoint_sequence = np.concatenate(tmp_gt_keypoint_sequence, axis=0)
        tmp_out_keypoint_sequence = np.concatenate(tmp_out_keypoint_sequence, axis=0)
        test_gt_keypoint_sequence_together.append(tmp_gt_keypoint_sequence)
        test_out_keypoint_sequence_together.append(tmp_out_keypoint_sequence)


def eval_sequence_refinement(args, arn, autoencoder, trans_generator, afn, aprn, refine_loader, gt_keypoint_sequence_for_eval, save_dir, limbs=humanact12_limbs, add_visualize=False):
    start_time = time.time()
    if args.dataset == "humanact12":
        input_size_raw = 72
        num_classes = 12
    elif args.dataset == "uestc":
        input_size_raw = 54
        num_classes = 40
    trans_optimizer = TransOptim(args=args).cuda()
    trans_optimizer.train()
    
    output_recognition_record, output_unsuitable_anchor_pos_record, output_anchor_label_record = [], [], []
    global_position_record, gt_keypoint_sequence, output_refined_anchor_keypoint_sequence_record, output_refined_keypoint_sequence_record = [], [], [], []
    
    for batch_idx, data in enumerate(refine_loader):
        sequence_data, sequence_data_norm, global_position = data["keypoint_sequence"].cuda(), data["keypoint_sequence_norm"].cuda(), data["global_position"].cuda()
        ### Anchor Recognition Net ###
        with torch.no_grad():
            output_recognition = arn(sequence_data_norm, mask=None, testing=True)
        output_recognition_record.append(output_recognition.clone().detach().cpu().numpy())
        if args.unsuitable_type == "arn":
            output_recognition[:, :, 1] = (output_recognition[:, :, 1] < args.arn_threshold).float()
        elif args.unsuitable_type == "all" or args.unsuitable_type == "one":
            output_recognition[:, :, 1] = 1
        else:
            raise ValueError("Invalid unsuitable type!")
    
        for i in range(output_recognition.shape[0]):
            anchor_pos_label = output_recognition[i, :, 0]
            processed_anchor_pos_label = anchor_pos_label.clone()
            start_idx = None
            for j in range(anchor_pos_label.shape[0]):
                if anchor_pos_label[j] == 1:
                    if start_idx is None:
                        start_idx = j
                    elif j == len(anchor_pos_label) - 1:
                        middle_idx = (start_idx + j) // 2
                        processed_anchor_pos_label[start_idx:j+1] = 0
                        processed_anchor_pos_label[middle_idx] = 1
                elif start_idx is not None:
                    if start_idx == 0:
                        middle_idx = start_idx
                    else:
                        middle_idx = (start_idx + j - 1) // 2
                    processed_anchor_pos_label[start_idx:j] = 0
                    processed_anchor_pos_label[middle_idx] = 1
                    start_idx = None
            processed_anchor_pos = torch.where(processed_anchor_pos_label == 1)[0]
            if len(processed_anchor_pos) == 0 or len(processed_anchor_pos) == 1:
                output_unsuitable_anchor_pos_record.append([])
                gt_keypoint_sequence.append(sequence_data[i].detach().cpu().numpy())
                global_position_record.append(global_position[i].detach().cpu().numpy())
                output_refined_anchor_keypoint_sequence_record.append(sequence_data[i].detach().cpu().numpy())
                output_anchor_label_record.append([])
                output_refined_keypoint_sequence_record.append(sequence_data[i].detach().cpu().numpy())
                print("No anchor recognized!")
            else:
                # fit transition
                processed_transition = []
                processed_duration = []
                for j in range(len(processed_anchor_pos) - 1):
                    start = processed_anchor_pos[j]
                    end = processed_anchor_pos[j+1] + 1
                    tmp_duration = end - start - 1
                    tmp_keypoint_sequence = sequence_data[i, start:end].clone().detach().cpu().numpy()
                    tmp_keypoint_sequence = tmp_keypoint_sequence - tmp_keypoint_sequence[0, 0]
                    tmp_trans = []
                    for k in range(tmp_keypoint_sequence.shape[1]):
                        tp, _, _ = spline_fit(tmp_keypoint_sequence[:, k])
                        tmp_trans.append(tp)
                    tmp_trans = np.array(tmp_trans)
                    tmp_trans = tmp_trans.reshape(tmp_keypoint_sequence.shape[1], -1)
                    processed_transition.append(tmp_trans)
                    processed_duration.append(tmp_duration)
                processed_transition = np.array(processed_transition)
                processed_transition = torch.from_numpy(processed_transition).float().cuda()
                processed_duration = np.array(processed_duration)
                processed_duration = torch.from_numpy(processed_duration).float().cuda()
                unsuitable_pos = torch.where(output_recognition[i, :, 1] == 1)[0]
                print("unsuitable_pos:", unsuitable_pos)
                combined = torch.cat([processed_anchor_pos, unsuitable_pos])
                combined_val, counts = combined.unique(return_counts=True)
                unsuitable_anchor_pos = combined_val[counts > 1]
                unsuitable_anchor_idx = []
                for j in range(len(unsuitable_anchor_pos)):
                    unsuitable_anchor_idx.append(torch.where(processed_anchor_pos == unsuitable_anchor_pos[j])[0][0])
                if args.unsuitable_type == "one":
                    selected_idx = random.randint(0, len(unsuitable_anchor_pos)-1)
                    unsuitable_anchor_pos = [unsuitable_anchor_pos[selected_idx]]
                    unsuitable_anchor_idx = [unsuitable_anchor_idx[selected_idx]]
                print("unsuitable_anchor_pos:", unsuitable_anchor_pos)
                if len(unsuitable_anchor_pos) == 0:
                    print("No anchor is unsuitable!")
                output_unsuitable_anchor_pos_record.append(unsuitable_anchor_pos)
                if add_visualize:
                    visualize_single_video(os.path.join(save_dir, "visualization", "gt_keypoint_sequence"), limbs, sequence_data[i].detach().cpu().numpy(), batch_idx)
                gt_keypoint_sequence.append(sequence_data[i].detach().cpu().numpy())
                global_position_record.append(global_position[i].detach().cpu().numpy())
                refined_anchor_keypoint_sequence = sequence_data[i].clone()
                tmp_output_anchor_label_record = []
                for j, anchor_pos in enumerate(unsuitable_anchor_pos):
                    anchor_idx = unsuitable_anchor_idx[j]
                    tmp_keypoint_sequence = sequence_data_norm[i].clone()
                    if anchor_idx == 0:
                        tmp_keypoint_sequence[:processed_anchor_pos[anchor_idx+1]-1] = torch.zeros_like(tmp_keypoint_sequence[:processed_anchor_pos[anchor_idx+1]-1])
                    elif anchor_idx == len(processed_anchor_pos) - 1:
                        tmp_keypoint_sequence[processed_anchor_pos[anchor_idx-1]+1:] = torch.zeros_like(tmp_keypoint_sequence[processed_anchor_pos[anchor_idx-1]+1:])
                    else:
                        tmp_keypoint_sequence[processed_anchor_pos[anchor_idx-1]+1:processed_anchor_pos[anchor_idx+1]-1] = torch.zeros_like(tmp_keypoint_sequence[processed_anchor_pos[anchor_idx-1]+1:processed_anchor_pos[anchor_idx+1]-1])
                    anchor_pos = torch.tensor([anchor_pos])
                    ### Anchor Refinement Net ###
                    with torch.no_grad():
                        output_label, output_anchor = afn(tmp_keypoint_sequence.unsqueeze(0), anchor_pos, testing=True)
                        # norm output_anchor
                        output_anchor_global = output_anchor.clone().detach()
                        output_anchor_global = output_anchor_global - output_anchor_global[0, 0] + global_position[i, processed_anchor_pos[anchor_idx], 0]
                        if tmp_keypoint_sequence.shape[0] < 60:
                            tmp_keypoint_sequence_pad = torch.cat((tmp_keypoint_sequence, tmp_keypoint_sequence[-1].unsqueeze(0).repeat(60-tmp_keypoint_sequence.shape[0], 1, 1)), dim=0)
                        else:
                            tmp_keypoint_sequence_pad = tmp_keypoint_sequence[:60]
                        global_feature = autoencoder.encoder(tmp_keypoint_sequence_pad.unsqueeze(0))
                    tmp_output_anchor_label_record.append(output_label.detach().cpu().numpy())
                    if anchor_idx == 0:
                        anchor_pair_right = torch.cat([output_anchor_global.unsqueeze(-1), sequence_data[i][processed_anchor_pos[anchor_idx+1]:processed_anchor_pos[anchor_idx+1]+1].unsqueeze(-1)], dim=-1)
                        anchor_pair_right_norm = anchor_pair_right.clone()
                        anchor_pair_right_norm = anchor_pair_right_norm - global_position[i, processed_anchor_pos[anchor_idx], 0].unsqueeze(-1)
                        transition_right = processed_transition[anchor_idx]
                        transition_t = processed_duration[anchor_idx].unsqueeze(-1).unsqueeze(-1)
                        transition_right = transition_right.unsqueeze(0)
                        z = torch.randn(1, args.z_dim).cuda()
                        anchor_pair_right_norm = anchor_pair_right_norm.permute(0, 2, 1, 3).contiguous() # [1, 3, 24, 2]
                        with torch.no_grad():
                            output_trans = trans_generator.decode(z, anchor_pair_right_norm, transition_t, global_feature)
                        output_trans_for_opt = output_trans.clone().detach()
                        output_trans = output_trans[..., 0]
                        output_trans = output_trans.permute(0, 2, 1).contiguous()
                        output_trans = output_trans.reshape(output_trans.shape[0], output_trans.shape[1], 3, 3)
                        opt_output = trans_optimizer.fitting(output_trans_for_opt, anchor_pair_right_norm, transition_t)
                        opt_output_trans = opt_output["trans_param_opt_xyz"]
                        anchor_pair_right_norm = anchor_pair_right_norm.permute(0, 2, 1, 3).contiguous()
                        output_trans = torch.cat([output_trans, anchor_pair_right_norm[..., 0:1]], dim=-1)
                        right_keypoint_sequence = transform_torch(opt_output_trans, transition_t)
                        right_keypoint_sequence += global_position[i, processed_anchor_pos[anchor_idx], 0]
                        refined_anchor_keypoint_sequence[processed_anchor_pos[anchor_idx]:processed_anchor_pos[anchor_idx+1]] = right_keypoint_sequence
                    elif anchor_idx == len(processed_anchor_pos) - 1:
                        anchor_pair_left = torch.cat([sequence_data[i][processed_anchor_pos[anchor_idx-1]:processed_anchor_pos[anchor_idx-1]+1].unsqueeze(-1), output_anchor_global.unsqueeze(-1)], dim=-1)
                        anchor_pair_left_norm = anchor_pair_left.clone()
                        anchor_pair_left_norm = anchor_pair_left_norm - global_position[i, processed_anchor_pos[anchor_idx-1], 0].unsqueeze(-1)
                        transition_left = processed_transition[anchor_idx-1]
                        transition_t = processed_duration[anchor_idx-1].unsqueeze(-1).unsqueeze(-1)
                        transition_left = transition_left.unsqueeze(0)
                        z = torch.randn(1, args.z_dim).cuda()
                        anchor_pair_left_norm = anchor_pair_left_norm.permute(0, 2, 1, 3).contiguous()
                        with torch.no_grad():
                            output_trans = trans_generator.decode(z, anchor_pair_left_norm, transition_t, global_feature)
                        output_trans_for_opt = output_trans.clone().detach()
                        output_trans = output_trans[..., 0]
                        output_trans = output_trans.permute(0, 2, 1).contiguous()
                        output_trans = output_trans.reshape(output_trans.shape[0], output_trans.shape[1], 3, 3)
                        opt_output = trans_optimizer.fitting(output_trans_for_opt, anchor_pair_left_norm, transition_t)
                        opt_output_trans = opt_output["trans_param_opt_xyz"]
                        anchor_pair_left_norm = anchor_pair_left_norm.permute(0, 2, 1, 3).contiguous()
                        output_trans = torch.cat([output_trans, anchor_pair_left_norm[..., 0:1]], dim=-1)
                        left_keypoint_sequence = transform_torch(opt_output_trans, transition_t)
                        left_keypoint_sequence += global_position[i, processed_anchor_pos[anchor_idx-1], 0]
                        refined_anchor_keypoint_sequence[processed_anchor_pos[anchor_idx-1]:processed_anchor_pos[anchor_idx]] = left_keypoint_sequence
                    else:
                        anchor_pair_left = torch.cat([sequence_data[i][processed_anchor_pos[anchor_idx-1]:processed_anchor_pos[anchor_idx-1]+1].unsqueeze(-1), output_anchor_global.unsqueeze(-1)], dim=-1)
                        anchor_pair_left_norm = anchor_pair_left.clone()
                        anchor_pair_left_norm = anchor_pair_left_norm - global_position[i, processed_anchor_pos[anchor_idx-1], 0].unsqueeze(-1)
                        transition_left = processed_transition[anchor_idx-1]
                        transition_t_left = processed_duration[anchor_idx-1].unsqueeze(-1).unsqueeze(-1)
                        transition_left = transition_left.unsqueeze(0)
                        z = torch.randn(1, args.z_dim).cuda()
                        anchor_pair_left_norm = anchor_pair_left_norm.permute(0, 2, 1, 3).contiguous()
                        with torch.no_grad():
                            output_trans_left = trans_generator.decode(z, anchor_pair_left_norm, transition_t_left, global_feature)
                        output_trans_left_for_opt = output_trans_left.clone().detach()
                        output_trans_left = output_trans_left[..., 0]
                        output_trans_left = output_trans_left.permute(0, 2, 1).contiguous()
                        output_trans_left = output_trans_left.reshape(output_trans_left.shape[0], output_trans_left.shape[1], 3, 3)
                        opt_output = trans_optimizer.fitting(output_trans_left_for_opt, anchor_pair_left_norm, transition_t_left)
                        opt_output_trans_left = opt_output["trans_param_opt_xyz"]
                        anchor_pair_left_norm = anchor_pair_left_norm.permute(0, 2, 1, 3).contiguous()
                        output_trans_left = torch.cat([output_trans_left, anchor_pair_left_norm[..., 0:1]], dim=-1)
                        left_keypoint_sequence = transform_torch(opt_output_trans_left, transition_t_left)
                        left_keypoint_sequence += global_position[i, processed_anchor_pos[anchor_idx-1], 0]
                        anchor_pair_right = torch.cat([output_anchor_global.unsqueeze(-1), sequence_data[i][processed_anchor_pos[anchor_idx+1]:processed_anchor_pos[anchor_idx+1]+1].unsqueeze(-1)], dim=-1)
                        anchor_pair_right_norm = anchor_pair_right.clone()
                        anchor_pair_right_norm = anchor_pair_right_norm - global_position[i, processed_anchor_pos[anchor_idx], 0].unsqueeze(-1)
                        transition_right = processed_transition[anchor_idx]
                        transition_t_right = processed_duration[anchor_idx].unsqueeze(-1).unsqueeze(-1)
                        transition_right = transition_right.unsqueeze(0)
                        anchor_pair_right_norm = anchor_pair_right_norm.permute(0, 2, 1, 3).contiguous()
                        with torch.no_grad():
                            output_trans_right = trans_generator.decode(z, anchor_pair_right_norm, transition_t_right, global_feature)
                        output_trans_right_for_opt = output_trans_right.clone().detach()
                        output_trans_right = output_trans_right[..., 0]
                        output_trans_right = output_trans_right.permute(0, 2, 1).contiguous()
                        output_trans_right = output_trans_right.reshape(output_trans_right.shape[0], output_trans_right.shape[1], 3, 3)
                        opt_output = trans_optimizer.fitting(output_trans_right_for_opt, anchor_pair_right_norm, transition_t_right)
                        opt_output_trans_right = opt_output["trans_param_opt_xyz"]
                        anchor_pair_right_norm = anchor_pair_right_norm.permute(0, 2, 1, 3).contiguous()
                        output_trans_right = torch.cat([output_trans_right, anchor_pair_right_norm[..., 0:1]], dim=-1)
                        right_keypoint_sequence = transform_torch(opt_output_trans_right, transition_t_right)
                        right_keypoint_sequence += global_position[i, processed_anchor_pos[anchor_idx], 0]
                        refined_anchor_keypoint_sequence[processed_anchor_pos[anchor_idx-1]:processed_anchor_pos[anchor_idx]] = left_keypoint_sequence
                        refined_anchor_keypoint_sequence[processed_anchor_pos[anchor_idx]:processed_anchor_pos[anchor_idx+1]] = right_keypoint_sequence
                if add_visualize:
                    visualize_single_video(os.path.join(save_dir, "visualization", "refined_anchor_keypoint_sequence"), limbs, refined_anchor_keypoint_sequence.detach().cpu().numpy(), batch_idx)
                tmp_output_anchor_label_record = np.array(tmp_output_anchor_label_record)
                output_anchor_label_record.append(tmp_output_anchor_label_record)
                output_refined_anchor_keypoint_sequence_record.append(refined_anchor_keypoint_sequence.detach().cpu().numpy())
                refined_keypoint_sequence = refined_anchor_keypoint_sequence.clone()
                for j in range(len(processed_anchor_pos) - 1):
                    keypoint_sequence_slice = refined_anchor_keypoint_sequence[processed_anchor_pos[j]:processed_anchor_pos[j+1]+1]
                    keypoint_sequence_slice_norm = keypoint_sequence_slice.clone()
                    keypoint_sequence_slice_norm = keypoint_sequence_slice_norm - keypoint_sequence_slice_norm[0, 0]
                    output_trans, output_keypoint_sequence = aprn(refined_anchor_keypoint_sequence.unsqueeze(0), keypoint_sequence_slice_norm.unsqueeze(0))
                    output_keypoint_sequence = output_keypoint_sequence - output_keypoint_sequence[0, 0, 0] + global_position[i, processed_anchor_pos[j], 0]
                    refined_keypoint_sequence[processed_anchor_pos[j]:processed_anchor_pos[j+1]] = output_keypoint_sequence[0, :-1]
                output_refined_keypoint_sequence_record.append(refined_keypoint_sequence.detach().cpu().numpy())
                if add_visualize:
                    visualize_single_video(os.path.join(save_dir, "visualization", "refined_keypoint_sequence"), limbs, refined_keypoint_sequence.detach().cpu().numpy(), batch_idx)
                print("==> Refinement Done!")
                # TBD: Optimization
    # Save output
    saved_output = {}
    saved_output["output_recognition_record"] = output_recognition_record
    saved_output["output_unsuitable_anchor_pos_record"] = output_unsuitable_anchor_pos_record
    saved_output["output_anchor_label_record"] = output_anchor_label_record
    saved_output["output_refined_anchor_keypoint_sequence_record"] = output_refined_anchor_keypoint_sequence_record
    saved_output["output_refined_keypoint_sequence_record"] = output_refined_keypoint_sequence_record
    saved_output["global_position_record"] = global_position_record
    saved_output["gt_keypoint_sequence"] = gt_keypoint_sequence
    saved_output["gt_keypoint_sequence_for_eval"] = [item.detach().cpu().numpy() for item in gt_keypoint_sequence_for_eval]
    with open(os.path.join(save_dir, "saved_output.pkl"), "wb") as f:
        pickle.dump(saved_output, f)
        
    anchor_pos, log_likelihood = [], []
    for i in range(len(output_recognition_record)):
        for j in range(len(output_recognition_record[i])):
            anchor_pos.append(output_recognition_record[i][j][0])
            log_likelihood.append(output_recognition_record[i][j][1])      

    # Evaluation
    if args.dataset == "humanact12":
        eval_sequence_refinement_metrics_humanact12(args, gt_keypoint_sequence_for_eval, save_dir)
    elif args.dataset == "uestc":
        eval_sequence_refinement_metrics_uestc(args, gt_keypoint_sequence_for_eval, save_dir)
        

def eval_sequence_refinement_metrics_humanact12(args, gt_keypoint_sequence_for_eval, save_dir, eval_full=False):
    start_time = time.time()
    if args.dataset == "humanact12":
        input_size_raw = 72
        num_classes = 12
    elif args.dataset == "uestc":
        input_size_raw = 54
        num_classes = 40
    else:
        raise ValueError("Invalid dataset!")
    # Load classifier for evaluation
    device = torch.cuda.current_device()
    gru_classifier_for_fid = load_classifier_for_fid(args.dataset, input_size_raw, num_classes, device="cuda:{}".format(str(device)))
    gru_classifier = load_classifier(args.dataset, input_size_raw, num_classes, device="cuda:{}".format(str(device)))
    
    with open(os.path.join(save_dir, "saved_output.pkl"), "rb") as f:
        saved_output = pickle.load(f)
    
    gt_keypoint_sequence_global = saved_output["gt_keypoint_sequence"]
    output_refined_anchor_keypoint_sequence_global = saved_output["output_refined_anchor_keypoint_sequence_record"]
    output_refined_keypoint_sequence_global = saved_output["output_refined_keypoint_sequence_record"]
    
    gt_activations, gen_activations, refined_anchor_activations, refined_activations = [], [], [], []
    gt_confusion, gen_confusion, refined_anchor_confusion, refined_confusion = torch.zeros(num_classes, num_classes, dtype=torch.long), torch.zeros(num_classes, num_classes, dtype=torch.long), torch.zeros(num_classes, num_classes, dtype=torch.long), torch.zeros(num_classes, num_classes, dtype=torch.long)
    for i in range(len(gt_keypoint_sequence_for_eval)):
        input_gt_keypoint_sequence = gt_keypoint_sequence_for_eval[i].unsqueeze(0).clone()
        input_gt_keypoint_sequence = input_gt_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
        input_gt_keypoint_sequence -= input_gt_keypoint_sequence[:, 0:1, :, 0:1]
        lengths = torch.tensor([input_gt_keypoint_sequence.shape[-1]]).cuda()
        batch_labels = torch.ones(input_gt_keypoint_sequence.shape[0]) * int(args.concept)
        batch_labels = batch_labels.long()
        with torch.no_grad():
            activations = gru_classifier_for_fid(input_gt_keypoint_sequence, lengths=lengths)
            batch_prob = gru_classifier(input_gt_keypoint_sequence, lengths=lengths)
            batch_pred = batch_prob.max(dim=1).indices
            for label, pred in zip(batch_labels, batch_pred):
                gt_confusion[label, pred] += 1
            gt_activations.append(activations)
    gt_keypoint_sequence_global = [torch.from_numpy(item / 100).float().cuda() for item in gt_keypoint_sequence_global]
    for i in range(len(gt_keypoint_sequence_global)):
        input_keypoint_sequence = gt_keypoint_sequence_global[i].unsqueeze(0).clone()
        input_keypoint_sequence = input_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
        input_keypoint_sequence -= input_keypoint_sequence[:, 0:1, :, 0:1]
        lengths = torch.tensor([input_keypoint_sequence.shape[-1]]).cuda()
        batch_labels = torch.ones(input_keypoint_sequence.shape[0]) * int(args.concept)
        batch_labels = batch_labels.long()
        with torch.no_grad():
            activations = gru_classifier_for_fid(input_keypoint_sequence, lengths=lengths)
            batch_prob = gru_classifier(input_keypoint_sequence, lengths=lengths)
            batch_pred = batch_prob.max(dim=1).indices
            for label, pred in zip(batch_labels, batch_pred):
                gen_confusion[label, pred] += 1
            gen_activations.append(activations)
    output_refined_anchor_keypoint_sequence_global = [torch.from_numpy(item / 100).float().cuda() for item in output_refined_anchor_keypoint_sequence_global]
    for i in range(len(output_refined_anchor_keypoint_sequence_global)):
        input_keypoint_sequence = output_refined_anchor_keypoint_sequence_global[i].unsqueeze(0).clone()
        input_keypoint_sequence = input_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
        input_keypoint_sequence -= input_keypoint_sequence[:, 0:1, :, 0:1]
        lengths = torch.tensor([input_keypoint_sequence.shape[-1]]).cuda()
        batch_labels = torch.ones(input_keypoint_sequence.shape[0]) * int(args.concept)
        batch_labels = batch_labels.long()
        with torch.no_grad():
            activations = gru_classifier_for_fid(input_keypoint_sequence, lengths=lengths)
            batch_prob = gru_classifier(input_keypoint_sequence, lengths=lengths)
            batch_pred = batch_prob.max(dim=1).indices
            for label, pred in zip(batch_labels, batch_pred):
                refined_anchor_confusion[label, pred] += 1
            refined_anchor_activations.append(activations)
    output_refined_keypoint_sequence_global = [torch.from_numpy(item / 100).float().cuda() for item in output_refined_keypoint_sequence_global]
    for i in range(len(output_refined_keypoint_sequence_global)):
        input_keypoint_sequence = output_refined_keypoint_sequence_global[i].unsqueeze(0).clone()
        input_keypoint_sequence = input_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
        input_keypoint_sequence -= input_keypoint_sequence[:, 0:1, :, 0:1]
        lengths = torch.tensor([input_keypoint_sequence.shape[-1]]).cuda()
        batch_labels = torch.ones(input_keypoint_sequence.shape[0]) * int(args.concept)
        batch_labels = batch_labels.long()
        with torch.no_grad():
            activations = gru_classifier_for_fid(input_keypoint_sequence, lengths=lengths)
            batch_prob = gru_classifier(input_keypoint_sequence, lengths=lengths)
            batch_pred = batch_prob.max(dim=1).indices
            for label, pred in zip(batch_labels, batch_pred):
                refined_confusion[label, pred] += 1
            refined_activations.append(activations)
        
    gt_activations = torch.cat(gt_activations, dim=0).cuda()
    gen_activations = torch.cat(gen_activations, dim=0).cuda()
    refined_anchor_activations = torch.cat(refined_anchor_activations, dim=0).cuda()
    refined_activations = torch.cat(refined_activations, dim=0).cuda()
    
    gt_accuracy = torch.trace(gt_confusion) / torch.sum(gt_confusion)
    gen_accuracy = torch.trace(gen_confusion) / torch.sum(gen_confusion)
    refined_anchor_accuracy = torch.trace(refined_anchor_confusion) / torch.sum(refined_anchor_confusion)
    refined_accuracy = torch.trace(refined_confusion) / torch.sum(refined_confusion)
    
    if not eval_full:
        gt_labels = torch.ones(gt_activations.shape[0]) * int(args.concept)
        gt_labels = gt_labels.cuda()
        generated_labels = torch.ones(gen_activations.shape[0]) * int(args.concept)
        generated_labels = generated_labels.cuda()
        refined_anchor_labels = torch.ones(refined_anchor_activations.shape[0]) * int(args.concept)
        refined_anchor_labels = refined_anchor_labels.cuda()
        refined_labels = torch.ones(refined_activations.shape[0]) * int(args.concept)
        refined_labels = refined_labels.cuda()
    else:
        gt_labels = np.load(os.path.join(save_dir, "gt_labels.npy"))
        generated_labels = np.load(os.path.join(save_dir, "generated_labels.npy"))
        refined_anchor_labels = np.load(os.path.join(save_dir, "refined_anchor_labels.npy"))
        refined_labels = np.load(os.path.join(save_dir, "refined_labels.npy"))
        gt_labels = torch.from_numpy(gt_labels).cuda()
        generated_labels = torch.from_numpy(generated_labels).cuda()
        refined_anchor_labels = torch.from_numpy(refined_anchor_labels).cuda()
        refined_labels = torch.from_numpy(refined_labels).cuda()
    
    gt_diversity, gt_multimodality = calculate_diversity_multimodality(gt_activations, gt_labels, int(args.concept), num_classes)
    generated_diversity, generated_multimodality = calculate_diversity_multimodality(gen_activations, generated_labels, int(args.concept), num_classes)
    refined_anchor_diversity, refined_anchor_multimodality = calculate_diversity_multimodality(refined_anchor_activations, refined_anchor_labels, int(args.concept), num_classes)
    refined_diversity, refined_multimodality = calculate_diversity_multimodality(refined_activations, refined_labels, int(args.concept), num_classes)
    
    gt_stats = calculate_activation_statistics(gt_activations)
    generated_stats = calculate_activation_statistics(gen_activations)
    refined_anchor_stats = calculate_activation_statistics(refined_anchor_activations)
    refined_stats = calculate_activation_statistics(refined_activations)
    
    fid_generated = float(calculate_fid(gt_stats, generated_stats))
    fid_refined_anchor = float(calculate_fid(gt_stats, refined_anchor_stats))
    fid_refined = float(calculate_fid(gt_stats, refined_stats))
    end_time = time.time()
    
    print("Evaluation Results: GT Accuracy: {:.4f}, GT Diversity: {:.4f}, GT Multimodality: {:.4f}, Generated Accuracy: {:.4f}, Generated Diversity: {:.4f}, Generated Multimodality: {:.4f}, Refined Anchor Accuracy: {:.4f}, Refined Anchor Diversity: {:.4f}, Refined Anchor Multimodality: {:.4f}, Refined Accuracy: {:.4f}, Refined Diversity: {:.4f}, Refined Multimodality: {:.4f}, FID Generated: {:.4f}, FID Refined Anchor: {:.4f}, FID Refined: {:.4f}, Time: {:.4f}".format(gt_accuracy, gt_diversity, gt_multimodality, gen_accuracy, generated_diversity, generated_multimodality, refined_anchor_accuracy, refined_anchor_diversity, refined_anchor_multimodality, refined_accuracy, refined_diversity, refined_multimodality, fid_generated, fid_refined_anchor, fid_refined, end_time - start_time))
        
    evaluation_results = {}
    evaluation_results["gt_accuracy"] = gt_accuracy
    evaluation_results["gt_diversity"] = gt_diversity
    evaluation_results["gt_multimodality"] = gt_multimodality
    evaluation_results["generated_accuracy"] = gen_accuracy
    evaluation_results["generated_diversity"] = generated_diversity
    evaluation_results["generated_multimodality"] = generated_multimodality
    evaluation_results["refined_anchor_accuracy"] = refined_anchor_accuracy
    evaluation_results["refined_anchor_diversity"] = refined_anchor_diversity
    evaluation_results["refined_anchor_multimodality"] = refined_anchor_multimodality
    evaluation_results["refined_accuracy"] = refined_accuracy
    evaluation_results["refined_diversity"] = refined_diversity
    evaluation_results["refined_multimodality"] = refined_multimodality
    evaluation_results["fid_generated"] = fid_generated
    evaluation_results["fid_refined_anchor"] = fid_refined_anchor
    evaluation_results["fid_refined"] = fid_refined
    with open(os.path.join(save_dir, "evaluation_results.pkl"), "wb") as f:
        pickle.dump(evaluation_results, f)
        

def eval_sequence_refinement_metrics_humanact12_full(args, save_dir):
    start_time = time.time()
    if args.dataset == "humanact12":
        input_size_raw = 72
        num_classes = 12
    elif args.dataset == "uestc":
        input_size_raw = 54
        num_classes = 40
    else:
        raise ValueError("Invalid dataset!")
    # Load classifier for evaluation
    device = torch.cuda.current_device()
    gru_classifier_for_fid = load_classifier_for_fid(args.dataset, input_size_raw, num_classes, device="cuda:{}".format(str(device)))
    gru_classifier = load_classifier(args.dataset, input_size_raw, num_classes, device="cuda:{}".format(str(device)))
    gru_classifier_for_fid.eval()
    gru_classifier.eval()
    
    for s in range(args.num_samples):
        gt_keypoint_sequence = np.load(os.path.join(save_dir, "gt_keypoint_sequence_{}.npy".format(s)))
        gen_keypoint_sequence = np.load(os.path.join(save_dir, "gen_keypoint_sequence_{}.npy".format(s)))
        refined_keypoint_sequence = np.load(os.path.join(save_dir, "refined_keypoint_sequence_{}.npy".format(s)))
        labels = np.load(os.path.join(save_dir, "labels_{}.npy".format(s)))
        gt_keypoint_sequence = torch.from_numpy(gt_keypoint_sequence).float().cuda()
        gen_keypoint_sequence = torch.from_numpy(gen_keypoint_sequence).float().cuda()
        refined_keypoint_sequence = torch.from_numpy(refined_keypoint_sequence).float().cuda()
        gen_keypoint_sequence /= 100
        refined_keypoint_sequence /= 100

        gt_activations, gen_activations, refined_activations = [], [], []
        gt_confusion, gen_confusion, refined_confusion = torch.zeros(num_classes, num_classes, dtype=torch.long), torch.zeros(num_classes, num_classes, dtype=torch.long), torch.zeros(num_classes, num_classes, dtype=torch.long)
        for i in range(gt_keypoint_sequence.shape[0]):
            input_gt_keypoint_sequence = gt_keypoint_sequence[i:i+1].clone()
            input_gt_keypoint_sequence = input_gt_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
            input_gt_keypoint_sequence -= input_gt_keypoint_sequence[:, 0:1, :, 0:1]
            lengths = torch.tensor([input_gt_keypoint_sequence.shape[-1]]).cuda()
            batch_labels = torch.from_numpy(labels[i:i+1]).long()
            with torch.no_grad():
                activations = gru_classifier_for_fid(input_gt_keypoint_sequence, lengths=lengths)
                batch_prob = gru_classifier(input_gt_keypoint_sequence, lengths=lengths)
                batch_pred = batch_prob.max(dim=1).indices
                for label, pred in zip(batch_labels, batch_pred):
                    gt_confusion[label, pred] += 1
                gt_activations.append(activations)
        for i in range(gen_keypoint_sequence.shape[0]):
            input_keypoint_sequence = gen_keypoint_sequence[i:i+1].clone()
            input_keypoint_sequence = input_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
            input_keypoint_sequence -= input_keypoint_sequence[:, 0:1, :, 0:1]
            lengths = torch.tensor([input_keypoint_sequence.shape[-1]]).cuda()
            batch_labels = torch.from_numpy(labels[i:i+1]).long()
            with torch.no_grad():
                activations = gru_classifier_for_fid(input_keypoint_sequence, lengths=lengths)
                batch_prob = gru_classifier(input_keypoint_sequence, lengths=lengths)
                batch_pred = batch_prob.max(dim=1).indices
                for label, pred in zip(batch_labels, batch_pred):
                    gen_confusion[label, pred] += 1
                gen_activations.append(activations)
        for i in range(refined_keypoint_sequence.shape[0]):
            input_keypoint_sequence = refined_keypoint_sequence[i:i+1].clone()
            input_keypoint_sequence = input_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
            input_keypoint_sequence -= input_keypoint_sequence[:, 0:1, :, 0:1]
            lengths = torch.tensor([input_keypoint_sequence.shape[-1]]).cuda()
            batch_labels = torch.from_numpy(labels[i:i+1]).long()
            with torch.no_grad():
                activations = gru_classifier_for_fid(input_keypoint_sequence, lengths=lengths)
                batch_prob = gru_classifier(input_keypoint_sequence, lengths=lengths)
                batch_pred = batch_prob.max(dim=1).indices
                for label, pred in zip(batch_labels, batch_pred):
                    refined_confusion[label, pred] += 1
                refined_activations.append(activations)
            
        gt_activations = torch.cat(gt_activations, dim=0).cuda()
        gen_activations = torch.cat(gen_activations, dim=0).cuda()
        refined_activations = torch.cat(refined_activations, dim=0).cuda()

        labels = torch.from_numpy(labels).cuda()
        
        gt_accuracy = torch.trace(gt_confusion) / torch.sum(gt_confusion)
        gen_accuracy = torch.trace(gen_confusion) / torch.sum(gen_confusion)
        refined_accuracy = torch.trace(refined_confusion) / torch.sum(refined_confusion)
    
        gt_diversity, gt_multimodality = calculate_diversity_multimodality(gt_activations, labels, int(args.concept), num_classes)
        generated_diversity, generated_multimodality = calculate_diversity_multimodality(gen_activations, labels, int(args.concept), num_classes)
        refined_diversity, refined_multimodality = calculate_diversity_multimodality(refined_activations, labels, int(args.concept), num_classes)
        
        gt_stats = calculate_activation_statistics(gt_activations)
        generated_stats = calculate_activation_statistics(gen_activations)
        refined_stats = calculate_activation_statistics(refined_activations)
        
        fid_generated = float(calculate_fid(gt_stats, generated_stats))
        fid_refined = float(calculate_fid(gt_stats, refined_stats))
        end_time = time.time()
        
        print("Evaluation Results: GT Accuracy: {:.4f}, GT Diversity: {:.4f}, GT Multimodality: {:.4f}, Generated Accuracy: {:.4f}, Generated Diversity: {:.4f}, Generated Multimodality: {:.4f}, Refined Accuracy: {:.4f}, Refined Diversity: {:.4f}, Refined Multimodality: {:.4f}, FID Generated: {:.4f}, FID Refined: {:.4f}, Time: {:.4f}".format(
            gt_accuracy, gt_diversity, gt_multimodality, gen_accuracy, generated_diversity, generated_multimodality, refined_accuracy, refined_diversity, refined_multimodality, fid_generated, fid_refined, end_time - start_time))
            
        evaluation_results = {}
        evaluation_results["gt_accuracy"] = gt_accuracy
        evaluation_results["gt_diversity"] = gt_diversity
        evaluation_results["gt_multimodality"] = gt_multimodality
        evaluation_results["generated_accuracy"] = gen_accuracy
        evaluation_results["generated_diversity"] = generated_diversity
        evaluation_results["generated_multimodality"] = generated_multimodality
        evaluation_results["refined_accuracy"] = refined_accuracy
        evaluation_results["refined_diversity"] = refined_diversity
        evaluation_results["refined_multimodality"] = refined_multimodality
        evaluation_results["fid_generated"] = fid_generated
        evaluation_results["fid_refined"] = fid_refined
        with open(os.path.join(save_dir, "evaluation_results_"+str(s)+".pkl"), "wb") as f:
            pickle.dump(evaluation_results, f)


def eval_sequence_refinement_metrics_uestc(args, gt_keypoint_sequence_for_eval, save_dir, eval_full=False):
    start_time = time.time()
    if args.dataset == "humanact12":
        input_size_raw = 72
        num_classes = 12
    elif args.dataset == "uestc":
        input_size_raw = 54
        num_classes = 40
    else:
        raise ValueError("Invalid dataset!")
    # fix seed
    
    layout = "uestc"
    device = torch.cuda.current_device()
    model = STGCN(in_channels=3,
                      num_class=num_classes,
                      graph_args={"layout": layout, "strategy": "spatial"},
                      edge_importance_weighting=True)
    model = model.to(device)
    state_dict = torch.load(os.path.join("models/actionrecognition/uestc_xyz_stgcn.tar"))
    model.load_state_dict(state_dict)
    model.eval()
    
    with open(os.path.join(save_dir, "saved_output.pkl"), "rb") as f:
        saved_output = pickle.load(f)
        
    gt_keypoint_sequence_global = saved_output["gt_keypoint_sequence"]
    output_refined_anchor_keypoint_sequence_global = saved_output["output_refined_anchor_keypoint_sequence_record"]
    output_refined_keypoint_sequence_global = saved_output["output_refined_keypoint_sequence_record"]
    
    gt_activations, gen_activations, refined_anchor_activations, refined_activations = [], [], [], []
    gt_confusion, gen_confusion, refined_anchor_confusion, refined_confusion = torch.zeros(num_classes, num_classes, dtype=torch.long), torch.zeros(num_classes, num_classes, dtype=torch.long), torch.zeros(num_classes, num_classes, dtype=torch.long), torch.zeros(num_classes, num_classes, dtype=torch.long)
    for i in range(len(gt_keypoint_sequence_for_eval)):
        input_gt_keypoint_sequence = gt_keypoint_sequence_for_eval[i].unsqueeze(0).clone()
        input_gt_keypoint_sequence = input_gt_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
        input_gt_keypoint_sequence -= input_gt_keypoint_sequence[:, 0:1, :, 0:1]
        lengths = torch.tensor([input_gt_keypoint_sequence.shape[-1]]).cuda()
        batch = {}
        batch["x"] = input_gt_keypoint_sequence
        batch_labels = torch.ones(input_gt_keypoint_sequence.shape[0]) * int(uestc_list.index(args.concept))
        batch_labels = batch_labels.long()
        with torch.no_grad():
            activations = model(batch)["features"]
            batch_prob = model(batch)["yhat"]
            batch_pred = batch_prob.max(dim=1).indices
            for label, pred in zip(batch_labels, batch_pred):
                gt_confusion[label, pred] += 1
            gt_activations.append(activations)
    gt_keypoint_sequence_global = [torch.from_numpy(item / 100).float().cuda() for item in gt_keypoint_sequence_global]
    for i in range(len(gt_keypoint_sequence_global)):
        input_keypoint_sequence = gt_keypoint_sequence_global[i].unsqueeze(0).clone()
        input_keypoint_sequence = input_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
        input_keypoint_sequence -= input_keypoint_sequence[:, 0:1, :, 0:1]
        lengths = torch.tensor([input_keypoint_sequence.shape[-1]]).cuda()
        batch = {}
        batch["x"] = input_keypoint_sequence
        batch_labels = torch.ones(input_keypoint_sequence.shape[0]) * int(uestc_list.index(args.concept))
        batch_labels = batch_labels.long()
        with torch.no_grad():
            activations = model(batch)["features"]
            batch_prob = model(batch)["yhat"]
            batch_pred = batch_prob.max(dim=1).indices
            for label, pred in zip(batch_labels, batch_pred):
                gen_confusion[label, pred] += 1
            gen_activations.append(activations)
    output_refined_anchor_keypoint_sequence_global = [torch.from_numpy(item / 100).float().cuda() for item in output_refined_anchor_keypoint_sequence_global]
    for i in range(len(output_refined_anchor_keypoint_sequence_global)):
        input_keypoint_sequence = output_refined_anchor_keypoint_sequence_global[i].unsqueeze(0).clone()
        input_keypoint_sequence = input_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
        input_keypoint_sequence -= input_keypoint_sequence[:, 0:1, :, 0:1]
        lengths = torch.tensor([input_keypoint_sequence.shape[-1]]).cuda()
        batch = {}
        batch["x"] = input_keypoint_sequence
        batch_labels = torch.ones(input_keypoint_sequence.shape[0]) * int(uestc_list.index(args.concept))
        batch_labels = batch_labels.long()
        with torch.no_grad():
            activations = model(batch)["features"]
            batch_prob = model(batch)["yhat"]
            batch_pred = batch_prob.max(dim=1).indices
            for label, pred in zip(batch_labels, batch_pred):
                refined_anchor_confusion[label, pred] += 1
            refined_anchor_activations.append(activations)
    output_refined_keypoint_sequence_global = [torch.from_numpy(item / 100).float().cuda() for item in output_refined_keypoint_sequence_global]
    for i in range(len(output_refined_keypoint_sequence_global)):
        input_keypoint_sequence = output_refined_keypoint_sequence_global[i].unsqueeze(0).clone()
        input_keypoint_sequence = input_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
        input_keypoint_sequence -= input_keypoint_sequence[:, 0:1, :, 0:1]
        lengths = torch.tensor([input_keypoint_sequence.shape[-1]]).cuda()
        batch = {}
        batch["x"] = input_keypoint_sequence
        batch_labels = torch.ones(input_keypoint_sequence.shape[0]) * int(uestc_list.index(args.concept))
        batch_labels = batch_labels.long()
        with torch.no_grad():
            activations = model(batch)["features"]
            batch_prob = model(batch)["yhat"]
            batch_pred = batch_prob.max(dim=1).indices
            for label, pred in zip(batch_labels, batch_pred):
                refined_confusion[label, pred] += 1
            refined_activations.append(activations)
        
    gt_activations = torch.stack(gt_activations, dim=0).cuda()
    gen_activations = torch.stack(gen_activations, dim=0).cuda()
    refined_anchor_activations = torch.stack(refined_anchor_activations, dim=0).cuda()
    refined_activations = torch.stack(refined_activations, dim=0).cuda()
    
    gt_accuracy = torch.trace(gt_confusion) / torch.sum(gt_confusion)
    gen_accuracy = torch.trace(gen_confusion) / torch.sum(gen_confusion)
    refined_anchor_accuracy = torch.trace(refined_anchor_confusion) / torch.sum(refined_anchor_confusion)
    refined_accuracy = torch.trace(refined_confusion) / torch.sum(refined_confusion)
    
    if not eval_full:
        gt_labels = torch.ones(gt_activations.shape[0]) * int(uestc_list.index(args.concept))
        gt_labels = gt_labels.cuda()
        generated_labels = torch.ones(gen_activations.shape[0]) * int(uestc_list.index(args.concept))
        generated_labels = generated_labels.cuda()
        refined_anchor_labels = torch.ones(refined_anchor_activations.shape[0]) * int(uestc_list.index(args.concept))
        refined_anchor_labels = refined_anchor_labels.cuda()
        refined_labels = torch.ones(refined_activations.shape[0]) * int(uestc_list.index(args.concept))
        refined_labels = refined_labels.cuda()
    else:
        gt_labels = np.load(os.path.join(save_dir, "gt_labels.npy"))
        generated_labels = np.load(os.path.join(save_dir, "generated_labels.npy"))
        refined_anchor_labels = np.load(os.path.join(save_dir, "refined_anchor_labels.npy"))
        refined_labels = np.load(os.path.join(save_dir, "refined_labels.npy"))
        gt_labels = torch.from_numpy(gt_labels).cuda()
        generated_labels = torch.from_numpy(generated_labels).cuda()
        refined_anchor_labels = torch.from_numpy(refined_anchor_labels).cuda()
        refined_labels = torch.from_numpy(refined_labels).cuda()
    
    gt_diversity, gt_multimodality = calculate_diversity_multimodality(gt_activations, gt_labels, int(uestc_list.index(args.concept)), num_classes)
    generated_diversity, generated_multimodality = calculate_diversity_multimodality(gen_activations, generated_labels, int(uestc_list.index(args.concept)), num_classes)
    refined_anchor_diversity, refined_anchor_multimodality = calculate_diversity_multimodality(refined_anchor_activations, refined_anchor_labels, int(uestc_list.index(args.concept)), num_classes)
    refined_diversity, refined_multimodality = calculate_diversity_multimodality(refined_activations, refined_labels, int(uestc_list.index(args.concept)), num_classes)
    
    gt_stats = calculate_activation_statistics(gt_activations)
    generated_stats = calculate_activation_statistics(gen_activations)
    refined_anchor_stats = calculate_activation_statistics(refined_anchor_activations)
    refined_stats = calculate_activation_statistics(refined_activations)
    
    fid_generated = float(calculate_fid(gt_stats, generated_stats))
    fid_refined_anchor = float(calculate_fid(gt_stats, refined_anchor_stats))
    fid_refined = float(calculate_fid(gt_stats, refined_stats))
    end_time = time.time()
    
    print("Evaluation Results: GT Accuracy: {:.4f}, GT Diversity: {:.4f}, GT Multimodality: {:.4f}, Generated Accuracy: {:.4f}, Generated Diversity: {:.4f}, Generated Multimodality: {:.4f}, Refined Anchor Accuracy: {:.4f}, Refined Anchor Diversity: {:.4f}, Refined Anchor Multimodality: {:.4f}, Refined Accuracy: {:.4f}, Refined Diversity: {:.4f}, Refined Multimodality: {:.4f}, FID Generated: {:.4f}, FID Refined Anchor: {:.4f}, FID Refined: {:.4f}, Time: {:.4f}".format(gt_accuracy, gt_diversity, gt_multimodality, gen_accuracy, generated_diversity, generated_multimodality, refined_anchor_accuracy, refined_anchor_diversity, refined_anchor_multimodality, refined_accuracy, refined_diversity, refined_multimodality, fid_generated, fid_refined_anchor, fid_refined, end_time - start_time))
        
    evaluation_results = {}
    evaluation_results["gt_accuracy"] = gt_accuracy
    evaluation_results["gt_diversity"] = gt_diversity
    evaluation_results["gt_multimodality"] = gt_multimodality
    evaluation_results["generated_accuracy"] = gen_accuracy
    evaluation_results["generated_diversity"] = generated_diversity
    evaluation_results["generated_multimodality"] = generated_multimodality
    evaluation_results["refined_anchor_accuracy"] = refined_anchor_accuracy
    evaluation_results["refined_anchor_diversity"] = refined_anchor_diversity
    evaluation_results["refined_anchor_multimodality"] = refined_anchor_multimodality
    evaluation_results["refined_accuracy"] = refined_accuracy
    evaluation_results["refined_diversity"] = refined_diversity
    evaluation_results["refined_multimodality"] = refined_multimodality
    evaluation_results["fid_generated"] = fid_generated
    evaluation_results["fid_refined_anchor"] = fid_refined_anchor
    evaluation_results["fid_refined"] = fid_refined
    with open(os.path.join(save_dir, "evaluation_results.pkl"), "wb") as f:
        pickle.dump(evaluation_results, f)


def eval_sequence_refinement_metrics_uestc_full(args, save_dir):
    start_time = time.time()
    num_classes = 40
    layout = "uestc"
    device = torch.cuda.current_device()
    model = STGCN(in_channels=3,
                      num_class=num_classes,
                      graph_args={"layout": layout, "strategy": "spatial"},
                      edge_importance_weighting=True)
    model = model.to(device)
    state_dict = torch.load(os.path.join("models/actionrecognition/uestc_xyz_stgcn.tar"))
    model.load_state_dict(state_dict)
    model.eval()
    
    splits = ["train", "test"]
    for split in splits:
        for s in range(args.num_samples):
            gt_keypoint_sequence = np.load(os.path.join(save_dir, split, "gt_keypoint_sequence_{}.npy".format(s)))
            gen_keypoint_sequence = np.load(os.path.join(save_dir, split, "gen_keypoint_sequence_{}.npy".format(s)))
            refined_keypoint_sequence = np.load(os.path.join(save_dir, split, "refined_keypoint_sequence_{}.npy".format(s)))
            labels = np.load(os.path.join(save_dir, split, "labels_{}.npy".format(s)))
            gt_keypoint_sequence = torch.from_numpy(gt_keypoint_sequence).float().cuda()
            gen_keypoint_sequence = torch.from_numpy(gen_keypoint_sequence).float().cuda()
            refined_keypoint_sequence = torch.from_numpy(refined_keypoint_sequence).float().cuda()
            gen_keypoint_sequence /= 100
            refined_keypoint_sequence /= 100
            
            gt_activations, gen_activations, refined_activations = [], [], []
            for i in range(gt_keypoint_sequence.shape[0]):
                input_gt_keypoint_sequence = gt_keypoint_sequence[i:i+1].clone()
                input_gt_keypoint_sequence = input_gt_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
                input_gt_keypoint_sequence -= input_gt_keypoint_sequence[:, 0:1, :, 0:1]
                batch = {}
                batch["x"] = input_gt_keypoint_sequence
                with torch.no_grad():
                    activations = model(batch)["features"]
                    gt_activations.append(activations)
            for i in range(gen_keypoint_sequence.shape[0]):
                input_keypoint_sequence = gen_keypoint_sequence[i:i+1].clone()
                input_keypoint_sequence = input_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
                input_keypoint_sequence -= input_keypoint_sequence[:, 0:1, :, 0:1]
                batch = {}
                batch["x"] = input_keypoint_sequence
                with torch.no_grad():
                    activations = model(batch)["features"]
                    gen_activations.append(activations)
            for i in range(refined_keypoint_sequence.shape[0]):
                input_keypoint_sequence = refined_keypoint_sequence[i:i+1].clone()
                input_keypoint_sequence = input_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
                input_keypoint_sequence -= input_keypoint_sequence[:, 0:1, :, 0:1]
                batch = {}
                batch["x"] = input_keypoint_sequence
                with torch.no_grad():
                    activations = model(batch)["features"]
                    refined_activations.append(activations)
                    
            gt_activations = torch.stack(gt_activations, dim=0).cuda()
            gen_activations = torch.stack(gen_activations, dim=0).cuda()
            refined_activations = torch.stack(refined_activations, dim=0).cuda()
            
            labels = torch.from_numpy(labels).cuda()
            
            gt_diversity, gt_multimodality = calculate_diversity_multimodality(gt_activations, labels, int(uestc_list.index(args.concept)), num_classes)
            generated_diversity, generated_multimodality = calculate_diversity_multimodality(gen_activations, labels, int(uestc_list.index(args.concept)), num_classes)
            refined_diversity, refined_multimodality = calculate_diversity_multimodality(refined_activations, labels, int(uestc_list.index(args.concept)), num_classes)
            
            gt_stats = calculate_activation_statistics(gt_activations)
            generated_stats = calculate_activation_statistics(gen_activations)
            refined_stats = calculate_activation_statistics(refined_activations)
            
            fid_generated = float(calculate_fid(gt_stats, generated_stats))
            fid_refined = float(calculate_fid(gt_stats, refined_stats))
            end_time = time.time()
            
            print("Evaluation Results: Split: {} GT Diversity: {:.4f}, GT Multimodality: {:.4f}, Generated Diversity: {:.4f}, Generated Multimodality: {:.4f}, Refined Diversity: {:.4f}, Refined Multimodality: {:.4f}, FID Generated: {:.4f}, FID Refined: {:.4f}, Time: {:.4f}".format(
                split, gt_diversity, gt_multimodality, generated_diversity, generated_multimodality, refined_diversity, refined_multimodality, fid_generated, fid_refined, end_time - start_time))
            
            evaluation_results = {}
            evaluation_results["gt_diversity"] = gt_diversity
            evaluation_results["gt_multimodality"] = gt_multimodality
            evaluation_results["generated_diversity"] = generated_diversity
            evaluation_results["generated_multimodality"] = generated_multimodality
            evaluation_results["refined_diversity"] = refined_diversity
            evaluation_results["refined_multimodality"] = refined_multimodality
            evaluation_results["fid_generated"] = fid_generated
            evaluation_results["fid_refined"] = fid_refined
            with open(os.path.join(save_dir, split, "evaluation_results_"+str(s)+".pkl"), "wb") as f:
                pickle.dump(evaluation_results, f)


def eval_sequence_refinement_ar(args, arn, autoencoder, trans_generator, afn, aprn, refine_loader, gt_keypoint_sequence_for_eval, save_dir, limbs=humanact12_limbs):
    start_time = time.time()
    input_size_raw = 72
    num_classes = 12
    # Load classifier for evaluation
    device = torch.cuda.current_device()
    gru_classifier_for_fid = load_classifier_for_fid("humanact12", input_size_raw, num_classes, device="cuda:{}".format(str(device)))
    gru_classifier = load_classifier("humanact12", input_size_raw, num_classes, device="cuda:{}".format(str(device)))
    
    output_recognition_record, output_unsuitable_anchor_pos_record, output_anchor_label_record = [], [], []
    global_position_record, gt_keypoint_sequence, output_refined_anchor_keypoint_sequence_record, output_refined_keypoint_sequence_record = [], [], [], []
    
    for batch_idx, data in enumerate(refine_loader):
        sequence_data, global_position = data["keypoint_sequence"].cuda(), data["global_position"]
        ### Anchor Recognition Net ###
        with torch.no_grad():
            output_recognition = arn(sequence_data, mask=None, testing=True)
        output_recognition_record.append(output_recognition.clone().detach().cpu().numpy())
        if args.unsuitable_type == "arn":
            output_recognition[:, :, 1] = (output_recognition[:, :, 1] < args.arn_threshold).float()
        elif args.unsuitable_type == "all" or args.unsuitable_type == "one":
            output_recognition[:, :, 1] = 1
        else:
            raise ValueError("Invalid unsuitable type!")
    
        for i in range(output_recognition.shape[0]):
            anchor_pos_label = output_recognition[i, :, 0]
            processed_anchor_pos_label = anchor_pos_label.clone()
            start_idx = None
            for j in range(anchor_pos_label.shape[0]):
                if anchor_pos_label[j] == 1:
                    if start_idx is None:
                        start_idx = j
                    elif j == len(anchor_pos_label) - 1:
                        middle_idx = (start_idx + j) // 2
                        processed_anchor_pos_label[start_idx:j+1] = 0
                        processed_anchor_pos_label[middle_idx] = 1
                elif start_idx is not None:
                    if start_idx == 0:
                        middle_idx = start_idx
                    else:
                        middle_idx = (start_idx + j - 1) // 2
                    processed_anchor_pos_label[start_idx:j] = 0
                    processed_anchor_pos_label[middle_idx] = 1
                    start_idx = None
            # print("all_anchor_pos:", all_anchor_pos)
            processed_anchor_pos = torch.where(processed_anchor_pos_label == 1)[0]
            print(batch_idx, "processed_anchor_pos:", processed_anchor_pos)
            if len(processed_anchor_pos) == 0:
                print("No anchor recognized!")
            else:
                # fit transition
                processed_transition = []
                processed_duration = []
                for j in range(len(processed_anchor_pos) - 1):
                    start = processed_anchor_pos[j]
                    end = processed_anchor_pos[j+1] + 1
                    tmp_duration = end - start - 1
                    tmp_keypoint_sequence = sequence_data[i, start:end].clone().detach().cpu().numpy()
                    tmp_trans = []
                    for k in range(tmp_keypoint_sequence.shape[1]):
                        tp, _, _ = spline_fit(tmp_keypoint_sequence[:, k])
                        tmp_trans.append(tp)
                    tmp_trans = np.array(tmp_trans)
                    tmp_trans = tmp_trans.reshape(tmp_keypoint_sequence.shape[1], -1)
                    processed_transition.append(tmp_trans)
                    processed_duration.append(tmp_duration)
                processed_transition = np.array(processed_transition)
                processed_transition = torch.from_numpy(processed_transition).float().cuda()
                processed_duration = np.array(processed_duration)
                processed_duration = torch.from_numpy(processed_duration).float().cuda()
                # get unsuitable anchor pos
                unsuitable_pos = torch.where(output_recognition[i, :, 1] == 1)[0]
                print("unsuitable_pos:", unsuitable_pos)
                combined = torch.cat([processed_anchor_pos, unsuitable_pos])
                combined_val, counts = combined.unique(return_counts=True)
                unsuitable_anchor_pos = combined_val[counts > 1]
                unsuitable_anchor_idx = []
                for j in range(len(unsuitable_anchor_pos)):
                    unsuitable_anchor_idx.append(torch.where(processed_anchor_pos == unsuitable_anchor_pos[j])[0][0])
                if args.unsuitable_type == "one":
                    selected_idx = random.randint(0, len(unsuitable_anchor_pos)-1)
                    unsuitable_anchor_pos = [unsuitable_anchor_pos[selected_idx]]
                    unsuitable_anchor_idx = [unsuitable_anchor_idx[selected_idx]]
                print("unsuitable_anchor_pos:", unsuitable_anchor_pos)
                if len(unsuitable_anchor_pos) == 0:
                    print("No anchor is unsuitable!")
            output_unsuitable_anchor_pos_record.append(unsuitable_anchor_pos)
            visualize_single_video(os.path.join(save_dir, "visualization", "gt_keypoint_sequence"), limbs, sequence_data[i].detach().cpu().numpy(), batch_idx)
            gt_keypoint_sequence.append(sequence_data[i].detach().cpu().numpy())
            global_position_record.append(global_position[i].numpy())
            refined_anchor_keypoint_sequence = sequence_data[i].clone()
            trans_optimizer = TransOptim(args=args).cuda()
            trans_optimizer.train()
            tmp_output_anchor_label_record = []
            for j, anchor_pos in enumerate(unsuitable_anchor_pos):
                anchor_idx = unsuitable_anchor_idx[j]
                tmp_keypoint_sequence = sequence_data[i].clone()
                if anchor_idx == 0:
                    tmp_keypoint_sequence[:processed_anchor_pos[anchor_idx+1]-1] = torch.zeros_like(tmp_keypoint_sequence[:processed_anchor_pos[anchor_idx+1]-1])
                elif anchor_idx == len(processed_anchor_pos) - 1:
                    tmp_keypoint_sequence[processed_anchor_pos[anchor_idx-1]+1:] = torch.zeros_like(tmp_keypoint_sequence[processed_anchor_pos[anchor_idx-1]+1:])
                else:
                    tmp_keypoint_sequence[processed_anchor_pos[anchor_idx-1]+1:processed_anchor_pos[anchor_idx+1]-1] = torch.zeros_like(tmp_keypoint_sequence[processed_anchor_pos[anchor_idx-1]+1:processed_anchor_pos[anchor_idx+1]-1])
                anchor_pos = torch.tensor([anchor_pos])
                ### Anchor Refinement Net ###
                with torch.no_grad():
                    output_label, output_anchor = afn(tmp_keypoint_sequence.unsqueeze(0), anchor_pos, testing=True)
                    global_feature = autoencoder.encoder(tmp_keypoint_sequence.unsqueeze(0))
                tmp_output_anchor_label_record.append(output_label.detach().cpu().numpy())
                if anchor_idx == 0:
                    anchor_pair_right = torch.cat([output_anchor.unsqueeze(-1), tmp_keypoint_sequence[processed_anchor_pos[anchor_idx+1]:processed_anchor_pos[anchor_idx+1]+1].unsqueeze(-1)], dim=-1)
                    transition_right = processed_transition[anchor_idx]
                    transition_t = processed_duration[anchor_idx].unsqueeze(-1).unsqueeze(-1)
                    transition_right = transition_right.unsqueeze(0)
                    z = torch.randn(1, args.z_dim).cuda()
                    anchor_pair_right = anchor_pair_right.permute(0, 2, 1, 3).contiguous() # [1, 3, 24, 2]
                    with torch.no_grad():
                        output_trans = trans_generator.decode(z, anchor_pair_right, transition_t, global_feature)
                    output_trans_for_opt = output_trans.clone().detach()
                    output_trans = output_trans[..., 0]
                    output_trans = torch.cat([torch.zeros_like(output_trans[:, :, :1]), output_trans], dim=2)
                    output_trans = output_trans.permute(0, 2, 1).contiguous()
                    output_trans = output_trans.reshape(output_trans.shape[0], output_trans.shape[1], 3, 3)
                    opt_output = trans_optimizer.fitting(output_trans_for_opt, anchor_pair_right, transition_t)
                    opt_output_trans = opt_output["trans_param_opt_xyz"]
                    anchor_pair_right = anchor_pair_right.permute(0, 2, 1, 3).contiguous()
                    output_trans = torch.cat([output_trans, anchor_pair_right[..., 0:1]], dim=-1)
                    right_keypoint_sequence = transform_torch(opt_output_trans, transition_t)
                    refined_anchor_keypoint_sequence[processed_anchor_pos[anchor_idx]:processed_anchor_pos[anchor_idx+1]] = right_keypoint_sequence
                elif anchor_idx == len(processed_anchor_pos) - 1:
                    anchor_pair_left = torch.cat([tmp_keypoint_sequence[processed_anchor_pos[anchor_idx-1]:processed_anchor_pos[anchor_idx-1]+1].unsqueeze(-1), output_anchor.unsqueeze(-1)], dim=-1)
                    transition_left = processed_transition[anchor_idx-1]
                    transition_t = processed_duration[anchor_idx-1].unsqueeze(-1).unsqueeze(-1)
                    transition_left = transition_left.unsqueeze(0)
                    z = torch.randn(1, args.z_dim).cuda()
                    anchor_pair_left = anchor_pair_left.permute(0, 2, 1, 3).contiguous()
                    with torch.no_grad():
                        output_trans = trans_generator.decode(z, anchor_pair_left, transition_t, global_feature)
                    output_trans_for_opt = output_trans.clone().detach()
                    output_trans = output_trans[..., 0]
                    output_trans = torch.cat([torch.zeros_like(output_trans[:, :, :1]), output_trans], dim=2)
                    output_trans = output_trans.permute(0, 2, 1).contiguous()
                    output_trans = output_trans.reshape(output_trans.shape[0], output_trans.shape[1], 3, 3)
                    opt_output = trans_optimizer.fitting(output_trans_for_opt, anchor_pair_left, transition_t)
                    opt_output_trans = opt_output["trans_param_opt_xyz"]
                    anchor_pair_left = anchor_pair_left.permute(0, 2, 1, 3).contiguous()
                    output_trans = torch.cat([output_trans, anchor_pair_left[..., 0:1]], dim=-1)
                    left_keypoint_sequence = transform_torch(opt_output_trans, transition_t)
                    refined_anchor_keypoint_sequence[processed_anchor_pos[anchor_idx-1]:processed_anchor_pos[anchor_idx]] = left_keypoint_sequence
                else:
                    anchor_pair_left = torch.cat([tmp_keypoint_sequence[processed_anchor_pos[anchor_idx-1]:processed_anchor_pos[anchor_idx-1]+1].unsqueeze(-1), output_anchor.unsqueeze(-1)], dim=-1)
                    transition_left = processed_transition[anchor_idx-1]
                    transition_t_left = processed_duration[anchor_idx-1].unsqueeze(-1).unsqueeze(-1)
                    transition_left = transition_left.unsqueeze(0)
                    z = torch.randn(1, args.z_dim).cuda()
                    anchor_pair_left = anchor_pair_left.permute(0, 2, 1, 3).contiguous()
                    with torch.no_grad():
                        output_trans_left = trans_generator.decode(z, anchor_pair_left, transition_t_left, global_feature)
                    output_trans_left_for_opt = output_trans_left.clone().detach()
                    output_trans_left = output_trans_left[..., 0]
                    output_trans_left = torch.cat([torch.zeros_like(output_trans_left[:, :, :1]), output_trans_left], dim=2)
                    output_trans_left = output_trans_left.permute(0, 2, 1).contiguous()
                    output_trans_left = output_trans_left.reshape(output_trans_left.shape[0], output_trans_left.shape[1], 3, 3)
                    opt_output = trans_optimizer.fitting(output_trans_left_for_opt, anchor_pair_left, transition_t_left)
                    opt_output_trans_left = opt_output["trans_param_opt_xyz"]
                    anchor_pair_left = anchor_pair_left.permute(0, 2, 1, 3).contiguous()
                    output_trans_left = torch.cat([output_trans_left, anchor_pair_left[..., 0:1]], dim=-1)
                    left_keypoint_sequence = transform_torch(opt_output_trans_left, transition_t_left)
                    anchor_pair_right = torch.cat([output_anchor.unsqueeze(-1), tmp_keypoint_sequence[processed_anchor_pos[anchor_idx+1]:processed_anchor_pos[anchor_idx+1]+1].unsqueeze(-1)], dim=-1)
                    transition_right = processed_transition[anchor_idx]
                    transition_t_right = processed_duration[anchor_idx].unsqueeze(-1).unsqueeze(-1)
                    transition_right = transition_right.unsqueeze(0)
                    anchor_pair_right = anchor_pair_right.permute(0, 2, 1, 3).contiguous()
                    with torch.no_grad():
                        output_trans_right = trans_generator.decode(z, anchor_pair_right, transition_t_right, global_feature)
                    output_trans_right_for_opt = output_trans_right.clone().detach()
                    output_trans_right = output_trans_right[..., 0]
                    output_trans_right = torch.cat([torch.zeros_like(output_trans_right[:, :, :1]), output_trans_right], dim=2)
                    output_trans_right = output_trans_right.permute(0, 2, 1).contiguous()
                    output_trans_right = output_trans_right.reshape(output_trans_right.shape[0], output_trans_right.shape[1], 3, 3)
                    opt_output = trans_optimizer.fitting(output_trans_right_for_opt, anchor_pair_right, transition_t_right)
                    opt_output_trans_right = opt_output["trans_param_opt_xyz"]
                    anchor_pair_right = anchor_pair_right.permute(0, 2, 1, 3).contiguous()
                    output_trans_right = torch.cat([output_trans_right, anchor_pair_right[..., 0:1]], dim=-1)
                    right_keypoint_sequence = transform_torch(opt_output_trans_right, transition_t_right)
                    refined_anchor_keypoint_sequence[processed_anchor_pos[anchor_idx-1]:processed_anchor_pos[anchor_idx]] = left_keypoint_sequence
                    refined_anchor_keypoint_sequence[processed_anchor_pos[anchor_idx]:processed_anchor_pos[anchor_idx+1]] = right_keypoint_sequence
            visualize_single_video(os.path.join(save_dir, "visualization", "refined_anchor_keypoint_sequence"), limbs, refined_anchor_keypoint_sequence.detach().cpu().numpy(), batch_idx)
            tmp_output_anchor_label_record = np.array(tmp_output_anchor_label_record)
            output_anchor_label_record.append(tmp_output_anchor_label_record)
            output_refined_anchor_keypoint_sequence_record.append(refined_anchor_keypoint_sequence.detach().cpu().numpy())
            keypoint_sequence_slices = []
            keypoint_sequence_slices_len = []
            for j in range(len(processed_anchor_pos) - 1):
                tmp_keypoint_sequence_slice = refined_anchor_keypoint_sequence[processed_anchor_pos[j]:processed_anchor_pos[j+1]+1]
                keypoint_sequence_slices.append(tmp_keypoint_sequence_slice)
                keypoint_sequence_slices_len.append(tmp_keypoint_sequence_slice.shape[0])
            keypoint_sequence_slices_pad = [torch.cat([item, item[-1].unsqueeze(0).repeat(max(keypoint_sequence_slices_len) - item.shape[0], 1, 1)], dim=0) for item in keypoint_sequence_slices]
            keypoint_sequence_slices_pad = torch.stack(keypoint_sequence_slices_pad, dim=0)
            print("keypoint_sequence_slices_pad:", keypoint_sequence_slices_pad.shape)

            with torch.no_grad():
                output_trans, refined_keypoint_sequence = aprn(refined_anchor_keypoint_sequence.unsqueeze(0), keypoint_sequence_slices_pad)
            refined_keypoint_sequence = refined_keypoint_sequence[0]
            output_refined_keypoint_sequence_record.append(refined_keypoint_sequence.detach().cpu().numpy())
            visualize_single_video(os.path.join(save_dir, "visualization", "refined_keypoint_sequence"), limbs, refined_keypoint_sequence.detach().cpu().numpy(), batch_idx)
            print("==> Refinement Done!")
            # TBD: Optimization
    # Save output
    output_recognition_record = np.concatenate(output_recognition_record, axis=0)
    output_refined_anchor_keypoint_sequence_record = np.array(output_refined_anchor_keypoint_sequence_record)
    output_refined_keypoint_sequence_record = np.array(output_refined_keypoint_sequence_record)
    gt_keypoint_sequence = np.array(gt_keypoint_sequence)
    global_position_record = np.array(global_position_record)
    gt_keypoint_sequence_global = gt_keypoint_sequence.copy()
    gt_keypoint_sequence_global += global_position_record
    output_refined_anchor_keypoint_sequence_global = output_refined_anchor_keypoint_sequence_record.copy()
    output_refined_anchor_keypoint_sequence_global += global_position_record
    output_refined_keypoint_sequence_global = output_refined_keypoint_sequence_record.copy()
    output_refined_keypoint_sequence_global += global_position_record
    saved_output = {}
    saved_output["output_recognition_record"] = output_recognition_record
    saved_output["output_unsuitable_anchor_pos_record"] = output_unsuitable_anchor_pos_record
    saved_output["output_anchor_label_record"] = output_anchor_label_record
    saved_output["output_refined_anchor_keypoint_sequence_record"] = output_refined_anchor_keypoint_sequence_record
    saved_output["output_refined_keypoint_sequence_record"] = output_refined_keypoint_sequence_record
    saved_output["global_position_record"] = global_position_record
    saved_output["gt_keypoint_sequence"] = gt_keypoint_sequence
    saved_output["gt_keypoint_sequence_global"] = gt_keypoint_sequence_global
    saved_output["output_refined_anchor_keypoint_sequence_global"] = output_refined_anchor_keypoint_sequence_global
    saved_output["output_refined_keypoint_sequence_global"] = output_refined_keypoint_sequence_global
    saved_output["gt_keypoint_sequence_for_eval"] = gt_keypoint_sequence_for_eval.detach().cpu().numpy()
    with open(os.path.join(save_dir, "saved_output.pkl"), "wb") as f:
        pickle.dump(saved_output, f)
        
    output_recognition_record = output_recognition_record.reshape(-1, 2)
    anchor_pos = output_recognition_record[:, 0]
    log_likelihood = output_recognition_record[:, 1]
    print("log_likelihood:", np.mean(log_likelihood), np.std(log_likelihood))
    print("anchor_log_likelihood:", np.mean(log_likelihood[anchor_pos == 1]), np.std(log_likelihood[anchor_pos == 1]))
    print(log_likelihood[anchor_pos == 1][:20])
    print("keypoint_log_likelihood:", np.mean(log_likelihood[anchor_pos == 0]), np.std(log_likelihood[anchor_pos == 0]))
    print(log_likelihood[anchor_pos == 0][:20])
    # Evaluation
    gt_activations = []
    generated_activations = []
    refined_anchor_activations = []
    refined_activations = []
    for i in range(gt_keypoint_sequence_for_eval.shape[0]):
        input_gt_keypoint_sequence = gt_keypoint_sequence_for_eval[i:i+1]
        input_gt_keypoint_sequence = input_gt_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
        input_gt_keypoint_sequence -= input_gt_keypoint_sequence[:, 0:1, :, 0:1]
        lengths = torch.tensor([input_gt_keypoint_sequence.shape[-1]]).cuda()
        gt_activations.append(gru_classifier_for_fid(input_gt_keypoint_sequence, lengths=lengths))
    gt_keypoint_sequence_global = torch.from_numpy(gt_keypoint_sequence_global).float().cuda()
    gt_keypoint_sequence_global /= 100
    for i in range(gt_keypoint_sequence.shape[0]):
        input_keypoint_sequence = gt_keypoint_sequence_global[i:i+1]
        input_keypoint_sequence = input_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
        input_keypoint_sequence -= input_keypoint_sequence[:, 0:1, :, 0:1]
        lengths = torch.tensor([input_keypoint_sequence.shape[-1]]).cuda()
        generated_activations.append(gru_classifier_for_fid(input_keypoint_sequence, lengths=lengths))
    output_refined_anchor_keypoint_sequence_global = torch.from_numpy(output_refined_anchor_keypoint_sequence_global).float().cuda()
    output_refined_anchor_keypoint_sequence_global /= 100
    for i in range(output_refined_anchor_keypoint_sequence_global.shape[0]):
        input_keypoint_sequence = output_refined_anchor_keypoint_sequence_global[i:i+1]
        input_keypoint_sequence = input_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
        input_keypoint_sequence -= input_keypoint_sequence[:, 0:1, :, 0:1]
        lengths = torch.tensor([input_keypoint_sequence.shape[-1]]).cuda()
        refined_anchor_activations.append(gru_classifier_for_fid(input_keypoint_sequence, lengths=lengths))
    output_refined_keypoint_sequence_global = torch.from_numpy(output_refined_keypoint_sequence_global).float().cuda()
    output_refined_keypoint_sequence_global /= 100
    for i in range(output_refined_keypoint_sequence_global.shape[0]):
        input_keypoint_sequence = output_refined_keypoint_sequence_global[i:i+1]
        input_keypoint_sequence = input_keypoint_sequence.permute(0, 2, 3, 1).contiguous()
        input_keypoint_sequence -= input_keypoint_sequence[:, 0:1, :, 0:1]
        lengths = torch.tensor([input_keypoint_sequence.shape[-1]]).cuda()
        refined_activations.append(gru_classifier_for_fid(input_keypoint_sequence, lengths=lengths))
    
    gt_activations = torch.cat(gt_activations, dim=0).cuda()
    gt_labels = torch.ones(gt_activations.shape[0]) * int(args.concept)
    gt_labels = gt_labels.cuda()
    generated_activations = torch.cat(generated_activations, dim=0).cuda()
    generated_labels = torch.ones(generated_activations.shape[0]) * int(args.concept)
    generated_labels = generated_labels.cuda()
    refined_anchor_activations = torch.cat(refined_anchor_activations, dim=0).cuda()
    refined_anchor_labels = torch.ones(refined_anchor_activations.shape[0]) * int(args.concept)
    refined_anchor_labels = refined_anchor_labels.cuda()
    refined_activations = torch.cat(refined_activations, dim=0).cuda()
    refined_labels = torch.ones(refined_activations.shape[0]) * int(args.concept)
    refined_labels = refined_labels.cuda()
    
    gt_diversity, gt_multimodality = calculate_diversity_multimodality(gt_activations, gt_labels, int(args.concept), num_classes)
    generated_diversity, generated_multimodality = calculate_diversity_multimodality(generated_activations, generated_labels, int(args.concept), num_classes)
    refined_anchor_diversity, refined_anchor_multimodality = calculate_diversity_multimodality(refined_anchor_activations, refined_anchor_labels, int(args.concept), num_classes)
    refined_diversity, refined_multimodality = calculate_diversity_multimodality(refined_activations, refined_labels, int(args.concept), num_classes)
    
    gt_stats = calculate_activation_statistics(gt_activations)
    generated_stats = calculate_activation_statistics(generated_activations)
    refined_anchor_stats = calculate_activation_statistics(refined_anchor_activations)
    refined_stats = calculate_activation_statistics(refined_activations)
    
    fid_generated = float(calculate_fid(gt_stats, generated_stats))
    fid_refined_anchor = float(calculate_fid(gt_stats, refined_anchor_stats))
    fid_refined = float(calculate_fid(gt_stats, refined_stats))
    end_time = time.time()
    
    print("Evaluation Results: GT Diversity: {:.4f}, GT Multimodality: {:.4f}, Generated Diversity: {:.4f}, Generated Multimodality: {:.4f}, Refined Anchor Diversity: {:.4f}, Refined Anchor Multimodality: {:.4f}, Refined Diversity: {:.4f}, Refined Multimodality: {:.4f}, FID Generated: {:.4f}, FID Refined Anchor: {:.4f}, FID Refined: {:.4f}, Time: {:.4f}".format(
        gt_diversity, gt_multimodality, generated_diversity, generated_multimodality, refined_anchor_diversity, refined_anchor_multimodality, refined_diversity, refined_multimodality, fid_generated, fid_refined_anchor, fid_refined, end_time - start_time
    ))
        
    evaluation_results = {}
    evaluation_results["gt_diversity"] = gt_diversity
    evaluation_results["gt_multimodality"] = gt_multimodality
    evaluation_results["generated_diversity"] = generated_diversity
    evaluation_results["generated_multimodality"] = generated_multimodality
    evaluation_results["refined_anchor_diversity"] = refined_anchor_diversity
    evaluation_results["refined_anchor_multimodality"] = refined_anchor_multimodality
    evaluation_results["refined_diversity"] = refined_diversity
    evaluation_results["refined_multimodality"] = refined_multimodality
    evaluation_results["fid_generated"] = fid_generated
    evaluation_results["fid_refined_anchor"] = fid_refined_anchor
    evaluation_results["fid_refined"] = fid_refined
    with open(os.path.join(save_dir, "evaluation_results.pkl"), "wb") as f:
        pickle.dump(evaluation_results, f)

    
def train_aprn_ar(args, model, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict):
    start_time = time.time()
    train_loss, train_recon_loss, train_trans_loss = 0.0, 0.0, 0.0
    test_loss, test_recon_loss, test_trans_loss = 0.0, 0.0, 0.0
    
    model.train()
    for batch in train_loader:
        full_keypoint_sequence, keypoint_sequence, transition, duration = batch["full_keypoint_sequence"].cuda(), batch["keypoint_sequence"].cuda(), batch["transition"].cuda(), batch["duration"].cuda()
        keypoint_sequence = keypoint_sequence[0]
        transition = transition[0]
        duration = duration[0].unsqueeze(-1)
        optimizer.zero_grad()
        output_trans, output_keypoint_sequence = model(full_keypoint_sequence, keypoint_sequence)
        transition = transition.reshape(transition.shape[0], transition.shape[1], 3, 4)
        output_trans = output_trans.reshape(output_trans.shape[0], output_trans.shape[1], 3, 4)
        loss = reconstruction_loss_mpjpe(output_keypoint_sequence, full_keypoint_sequence) #  + args.aprn_anchor_loss_weight * (reconstruction_loss_mpjpe(output_keypoint_sequence[:, 0:1], keypoint_sequence[:, 0:1]) + reconstruction_loss_mpjpe(output_keypoint_sequence[:, -1:], keypoint_sequence[:, -1:]))
        train_recon_loss += reconstruction_loss_mpjpe(output_keypoint_sequence, full_keypoint_sequence).item()
        for i in range(transition.shape[0]):
            trans_gt_keypoint = transform_torch(transition[i:i+1], duration[i:i+1])
            trans_out_keypoint = transform_torch(output_trans[i:i+1], duration[i:i+1])
            train_trans_loss += reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint).item()
            loss += args.aprn_trans_loss_weight * reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()
    
    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            full_keypoint_sequence, keypoint_sequence, transition, duration = batch["full_keypoint_sequence"].cuda(), batch["keypoint_sequence"].cuda(), batch["transition"].cuda(), batch["duration"].cuda()
            keypoint_sequence = keypoint_sequence[0]
            transition = transition[0]
            duration = duration[0].unsqueeze(-1)
            output_trans, output_keypoint_sequence = model(full_keypoint_sequence, keypoint_sequence)
            transition = transition.reshape(transition.shape[0], transition.shape[1], 3, 4)
            output_trans = output_trans.reshape(output_trans.shape[0], output_trans.shape[1], 3, 4)
            loss = reconstruction_loss_mpjpe(output_keypoint_sequence, full_keypoint_sequence) # + args.aprn_anchor_loss_weight * (reconstruction_loss_mpjpe(output_keypoint_sequence[:, 0:1], keypoint_sequence[:, 0:1]) + reconstruction_loss_mpjpe(output_keypoint_sequence[:, -1:], keypoint_sequence[:, -1:]))
            test_recon_loss += reconstruction_loss_mpjpe(output_keypoint_sequence, full_keypoint_sequence).item()
            for i in range(transition.shape[0]):
                trans_gt_keypoint = transform_torch(transition[i:i+1], duration[i:i+1])
                trans_out_keypoint = transform_torch(output_trans[i:i+1], duration[i:i+1])
                test_trans_loss += reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint).item()
                loss += args.aprn_trans_loss_weight * reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint)
            test_loss += loss.item()
            
    end_time = time.time()
    train_loss /= len(train_loader)
    train_recon_loss /= len(train_loader)
    train_trans_loss /= len(train_loader)
    test_loss /= len(test_loader)
    test_recon_loss /= len(test_loader)
    test_trans_loss /= len(test_loader)
    print("Epoch: {}, Train Loss: {:.4f}, Train Recon Loss: {:.4f}, Train Trans Loss: {:.4f}, Test Loss: {:.4f}, Test Recon Loss: {:.4f}, Test Trans Loss: {:.4f}, Time: {:.4f}".format(
        epoch, train_loss, train_recon_loss, train_trans_loss, test_loss, test_recon_loss, test_trans_loss, end_time - start_time
    ))
    
    loss = {}
    loss['train_loss'] = train_loss
    loss['train_recon_loss'] = train_recon_loss
    loss['train_trans_loss'] = train_trans_loss
    loss['test_loss'] = test_loss
    loss['test_recon_loss'] = test_recon_loss
    loss['test_trans_loss'] = test_trans_loss
    
    if test_recon_loss < perf_dict["TRN"][0]:
        perf_dict["TRN"][0] = test_recon_loss
        save_path = os.path.join(save_dir, "best_recon.pth")
        torch.save(model.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    if test_trans_loss < perf_dict["TRN"][1]:
        perf_dict["TRN"][1] = test_trans_loss
        save_path = os.path.join(save_dir, "best_trans.pth")
        torch.save(model.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    if (epoch+1) % args.save_freq == 0:
        save_path = os.path.join(save_dir, "epoch_{}.pth".format(epoch+1))
        torch.save(model.state_dict(), save_path)
        print("Model saved at {}".format(save_path))
        
    return loss, perf_dict
    
    
def eval_aprn_ar(args, model, train_loader, test_loader, save_dir):
    start_time = time.time()
    train_loss, train_recon_loss, train_trans_loss = 0.0, 0.0, 0.0
    test_loss, test_recon_loss, test_trans_loss = 0.0, 0.0, 0.0
    
    model.eval()
    with torch.no_grad():
        for batch in train_loader:
            full_keypoint_sequence, keypoint_sequence, transition, duration = batch["full_keypoint_sequence"].cuda(), batch["keypoint_sequence"].cuda(), batch["transition"].cuda(), batch["duration"].cuda()
            keypoint_sequence = keypoint_sequence[0]
            transition = transition[0]
            duration = duration[0].unsqueeze(-1)
            output_trans, output_keypoint_sequence = model(full_keypoint_sequence, keypoint_sequence)
            transition = transition.reshape(transition.shape[0], transition.shape[1], 3, 4)
            output_trans = output_trans.reshape(output_trans.shape[0], output_trans.shape[1], 3, 4)
            loss = reconstruction_loss_mpjpe(output_keypoint_sequence, full_keypoint_sequence) # + args.aprn_anchor_loss_weight * (reconstruction_loss_mpjpe(output_keypoint_sequence[:, 0:1], keypoint_sequence[:, 0:1]) + reconstruction_loss_mpjpe(output_keypoint_sequence[:, -1:], keypoint_sequence[:, -1:]))
            train_recon_loss += reconstruction_loss_mpjpe(output_keypoint_sequence, full_keypoint_sequence).item()
            for i in range(transition.shape[0]):
                trans_gt_keypoint = transform_torch(transition[i:i+1], duration[i:i+1])
                trans_out_keypoint = transform_torch(output_trans[i:i+1], duration[i:i+1])
                train_trans_loss += reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint).item()
                loss += args.aprn_trans_loss_weight * reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint)
            train_loss += loss.item()
            
        for batch in test_loader:
            full_keypoint_sequence, keypoint_sequence, transition, duration = batch["full_keypoint_sequence"].cuda(), batch["keypoint_sequence"].cuda(), batch["transition"].cuda(), batch["duration"].cuda()
            keypoint_sequence = keypoint_sequence[0]
            transition = transition[0]
            duration = duration[0].unsqueeze(-1)
            output_trans, output_keypoint_sequence = model(full_keypoint_sequence, keypoint_sequence)
            transition = transition.reshape(transition.shape[0], transition.shape[1], 3, 4)
            output_trans = output_trans.reshape(output_trans.shape[0], output_trans.shape[1], 3, 4)
            loss = reconstruction_loss_mpjpe(output_keypoint_sequence, full_keypoint_sequence)
            test_recon_loss += reconstruction_loss_mpjpe(output_keypoint_sequence, full_keypoint_sequence).item()
            for i in range(transition.shape[0]):
                trans_gt_keypoint = transform_torch(transition[i:i+1], duration[i:i+1])
                trans_out_keypoint = transform_torch(output_trans[i:i+1], duration[i:i+1])
                test_trans_loss += reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint).item()
                loss += args.aprn_trans_loss_weight * reconstruction_loss_mpjpe(trans_out_keypoint, trans_gt_keypoint)
            test_loss += loss.item()
            
    end_time = time.time()
    train_loss /= len(train_loader)
    train_recon_loss /= len(train_loader)
    train_trans_loss /= len(train_loader)
    test_loss /= len(test_loader)
    test_recon_loss /= len(test_loader)
    test_trans_loss /= len(test_loader)
    print("Evaluation Results: Train Loss: {:.4f}, Train Recon Loss: {:.4f}, Train Trans Loss: {:.4f}, Test Loss: {:.4f}, Test Recon Loss: {:.4f}, Test Trans Loss: {:.4f}, Time: {:.4f}".format(
        train_loss, train_recon_loss, train_trans_loss, test_loss, test_recon_loss, test_trans_loss, end_time - start_time
    ))
