import os
import pickle
import torch
import argparse
import random
import joblib
import time
import numpy as np

import tensorboardX
from torch import gt
from torch.utils.data import DataLoader
from sklearn.mixture import GaussianMixture
from model.model import AnchorRecognitionNetE2E, AnchorRefinementNet, AutoEncoder, AnchorRecognitionNet, TransGenerator, TransRefinementNet, TransRefinementNetAR
from model.auxiliary import ConceptDataset, collate_helper, collate_helper_trans, collate_helper_afn, collate_helper_aprn, collate_helper_refine, collate_helper_aprn_ar
from utils.util import get_lr_scheduler, visualize_single_video
from train import eval_sequence_refinement_metrics_humanact12, eval_sequence_refinement_metrics_humanact12_full, eval_sequence_refinement_metrics_uestc, eval_sequence_refinement_metrics_uestc_full, train_SequenceAE, eval_SequenceAE, train_arn, eval_arn, train_trans_generator, eval_trans_generator, train_afn, eval_afn, train_afn_classifier, eval_afn_classifier, train_afn_together, train_aprn, eval_aprn, eval_sequence_refinement, train_aprn_ar, eval_aprn_ar, eval_sequence_refinement_ar

parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data')
parser.add_argument('--generate_data_dir', type=str, default='data_generate')
parser.add_argument('--model_name', type=str, default='actor')
parser.add_argument('--data_name', type=str, default='saved_data.pkl')
parser.add_argument('--save_model_dir', type=str, default='saved_models')
parser.add_argument('--save_vis_dir', type=str, default='saved_models')
parser.add_argument('--save_refine_dir', type=str, default='saved_best_refine_results')
parser.add_argument('-d', '--dataset', type=str, default='humanact12', help='dataset name')
parser.add_argument('--save_anchor_label_name', type=str, default='euclidean_labels_manual.npy')
parser.add_argument('--concept_list', type=str, nargs='+')
parser.add_argument("-train", "--train", action="store_true")
parser.add_argument('-o', '--operations', type=int, nargs='+')
parser.add_argument('--seed', type=int, default=10)
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--base_lr', type=float, default=0.0005)
parser.add_argument('-bs', '--batch_size', type=int, default=16)

# transformer parameters
parser.add_argument('--transformer_dim', type=int, default=48)
parser.add_argument('--transformer_depth', type=int, default=2)
parser.add_argument('--transformer_heads', type=int, default=4)
parser.add_argument('--transformer_mlp_dim', type=int, default=256)

# gcn parameters
parser.add_argument('--gcn_type', type=str, default='pgbig', help='gcn or gat')
parser.add_argument('--num_stage', type=int, default=3, help='number of stages')
parser.add_argument('--d_model', type=int, default=20, help='past frame number')
parser.add_argument('--dct_n', type=int, default=35)
parser.add_argument('--node_n', type=int, default=25)
parser.add_argument('--kernel_size', type=int, default=10, help='past frame number')
parser.add_argument('--drop_out', type=float, default=0.1, help='drop out probability')
parser.add_argument('--encoder_n', type=int, default=1, help='encoder layer num')
parser.add_argument('--decoder_n', type=int, default=2, help='decoder layer num')
parser.add_argument('--input_n', type=int, default=10, help='past frame number')
parser.add_argument('--output_n', type=int, default=25, help='future frame number')

# gaussian mixture
parser.add_argument('--n_components', type=int, default=2)

# AnchorRecognitionNet
parser.add_argument('--arn_epochs', type=int, default=10000)
parser.add_argument('--arn_base_lr', type=float, default=0.0005)
parser.add_argument('--arn_transformer_dim', type=int, default=96)
parser.add_argument('--arn_transformer_depth', type=int, default=2)
parser.add_argument('--arn_transformer_heads', type=int, default=4)
parser.add_argument('--arn_transformer_mlp_dim', type=int, default=128)
parser.add_argument('--arn_d_model', type=int, default=48)
parser.add_argument('--arn_threshold', type=int, default=-50)

# trans_generator
parser.add_argument('--trans_epochs', type=int, default=10000)
parser.add_argument('--trans_base_lr', type=float, default=0.0002)
parser.add_argument('--trans_batch_size', type=int, default=32)
parser.add_argument('--trans_embedding_dim', type=int, default=256)
parser.add_argument('--first_anchor_embedding_dim', type=int, default=64)
parser.add_argument('--second_anchor_embedding_dim', type=int, default=64)
parser.add_argument('--duration_embedding_dim', type=int, default=24)
parser.add_argument('--trans_middle_embedding_dim', type=int, default=512)
parser.add_argument('--global_feature_dim', type=int, default=64)
parser.add_argument('--z_dim', type=int, default=20)
parser.add_argument('--load_trans_checkpoint', type=str, default='20000')
parser.add_argument('--trans_add_global_feature', action="store_true")
parser.add_argument('--trans_kl_factor', type=float, default=0.001)
parser.add_argument('--trans_loss_weight', type=float, default=1.0)

# AnchorRefinementNet
parser.add_argument('--afn_epochs', type=int, default=20000)
parser.add_argument('--afn_base_lr', type=float, default=0.0002)
parser.add_argument('--afn_batch_size', type=int, default=32)
parser.add_argument('--label_dim', type=int, default=8)
parser.add_argument('--decoder_dim', type=int, default=128)
parser.add_argument('--anchor_embedding_dim', type=int, default=128)
parser.add_argument('--afn_middle_embedding_dim', type=int, default=256)
parser.add_argument('--afn_loss_weight', type=float, default=1.0)

# TransRefinementNet
parser.add_argument('--aprn_epochs', type=int, default=20000)
parser.add_argument('--aprn_base_lr', type=float, default=0.0002)
parser.add_argument('--aprn_batch_size', type=int, default=32)
parser.add_argument('--aprn_middle_embedding_dim', type=int, default=256)
parser.add_argument('--aprn_trans_loss_weight', type=float, default=1.0)
parser.add_argument('--aprn_anchor_loss_weight', type=float, default=1.0)

# Transition Optimization
parser.add_argument('--optim_iterations', type=int, default=100)
parser.add_argument('--optim_lr', type=float, default=0.02)

# other parameters
parser.add_argument('--middle_embedding_dim', type=int, default=256)
parser.add_argument('--likelihood_threshold', type=float, default=0.5)
parser.add_argument('--z_factor', type=float, default=1.0)
parser.add_argument('--topK', type=int, default=5)
parser.add_argument('--save_freq', type=int, default=1000)
parser.add_argument('--num_samples', type=int, default=20)
parser.add_argument('--unsuitable_type', type=str, default="arn")
parser.add_argument('--experiment_name', type=str, default='test')

args = parser.parse_args()
policy = {"name": "Poly", "power": 0.95}
args.policy = policy

babel_limbs = [(0, 1), (1, 20), (2, 20), (3, 2), (4, 20), (5, 4), (6, 5), (7, 6), (8, 20), (9, 8), \
            (10, 9), (11, 10), (12, 0), (13, 12), (14, 13), (15, 14), (16, 0), (17, 16), (18, 17), (19, 18), (21, 22), (22, 7), (23, 24), (24, 11)]

humanact12_limbs = [(0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), \
                    (9, 12), (12, 15), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 14), (14, 17), (17, 19), (19, 21), (21, 23)]

uestc_limbs = [(0, 1), (0, 9), (9, 10), (10, 11), (11, 16), (0, 12), (12, 13), (13, 14), (14, 15), (1, 2), (2, 3), (3, 4), (1, 5), (5, 6), (6, 7), (1, 8), (8, 17)]

limbs_dict = {
    "babel": babel_limbs,
    "humanact12": humanact12_limbs,
    "uestc": uestc_limbs
}
humanact12_match = {
    "0": "warm_up",
    "1": "walk",
    "2": "run",
    "3": "jump",
    "4": "drink",
    "5": "lift_dumbbell",
    "6": "sit",
    "7": "eat",
    "8": "turn_steer_wheel",
    "9": "phone",
    "10": "boxing",
    "11": "throw"
}

uestc_list = ['alternate-knee-lifting', 'arm-circling', 'bent-over-twist', 'deltoid-muscle-stretching', 'dumbbell-one-arm-shoulder-pressing', 'dumbbell-shrugging', 'dumbbell-side-bend', 'elbow-circling', 'forward-lunging', 'front-raising', 'head-anticlockwise-circling', 'high-knees-running', 'jumping-jack', 'knee-circling', 'knee-to-chest', 'left-kicking', 'left-lunging', 'left-stretching', 'marking-time-and-knee-lifting', 'overhead-stretching', 'pinching-back', 'pulling-chest-expanders', 'punching', 'punching-and-knee-lifting', 'raising-hand-and-jumping', 'rope-skipping', 'rotation-clapping', 'shoulder-abduction', 'shoulder-raising', 'single-dumbbell-raising', 'single-leg-lateral-hopping', 'spinal-stretching', 'squatting', 'standing-gastrocnemius-calf', 'standing-opposite-elbow-to-knee-crunch', 'standing-rotation', 'standing-toe-touches', 'straight-forward-flexion', 'upper-back-stretching', 'wrist-circling']

def main(args):
    random.seed(args.seed)
    for concept in args.concept_list:
        args.concept = concept
        if args.dataset == "humanact12":
            action_name = humanact12_match[concept]
            args.action_idx = int(concept)
        elif args.dataset == "uestc":
            action_name = concept
        args.action_name = action_name
        limbs = limbs_dict[args.dataset]
        data_dir = os.path.join(args.data_dir, args.dataset, action_name)
        anchor_labels = np.load(os.path.join(data_dir, args.save_anchor_label_name), allow_pickle=True)    
        with open(os.path.join(data_dir, args.data_name), "rb") as f:
            saved_data = pickle.load(f)
        keypoint_sequences = saved_data["keypoint_sequence"]
        keypoint_sequences_norm = saved_data["keypoint_sequence_norm"]
        rotated_keypoint_sequences = saved_data["rotated_keypoint_sequence"]
        anchor_pos_labels = saved_data["anchor_pos_label"]
        transes = saved_data["prims"]
        num_cluster = len(np.unique(anchor_labels))
        args.n_clusters = num_cluster
        anchor_labels_reshape = []
        anchor_idx = 0
        for i in range(len(anchor_pos_labels)):
            anchor_num = len(np.where(anchor_pos_labels[i] == 1)[0])
            anchor_labels_reshape.append(anchor_labels[anchor_idx: anchor_idx + anchor_num])
            anchor_idx += anchor_num
        
        print("==> Load preprocessed data, num_cluster: ", num_cluster, "sequence number:", len(keypoint_sequences), "sequence norm number:", len(keypoint_sequences_norm), "anchor_pos_labels:", len(anchor_pos_labels), "anchor_labels:", len(anchor_labels))
            
        save_dir_base = os.path.join(args.save_model_dir, args.dataset, action_name)
        args.save_dir_base = save_dir_base
        if not os.path.exists(os.path.join(save_dir_base)):
            os.makedirs(save_dir_base)
        if not os.path.exists(os.path.join(save_dir_base, "train_idx.npy")):
            train_idx = random.sample(range(len(keypoint_sequences)), int(len(keypoint_sequences)*0.9))
            test_idx = list(set(range(len(keypoint_sequences))) - set(train_idx))
            np.save(os.path.join(save_dir_base, "train_idx.npy"), train_idx)
            np.save(os.path.join(save_dir_base, "test_idx.npy"), test_idx)
        else:
            print("Load idx...")
            train_idx = np.load(os.path.join(save_dir_base, "train_idx.npy"))
            test_idx = np.load(os.path.join(save_dir_base, "test_idx.npy"))
            
        perf_dict = {
            "AE": 1000.0,
            "ARN": [0.0, 0.0],
            "TransGenerator": [1000.0, 1000.0],
            "AFN": [0.0, 1000.0],
            "TRN": [1000.0, 1000.0]
        }
        
        for op in args.operations:
            if op == 0:
                # Prepare data
                train_data = []
                test_data = []
                for idx in train_idx:
                    anchor_pos = np.where(anchor_pos_labels[idx] == 1)[0]
                    if len(anchor_pos) == 1:
                        continue
                    tmp = {}
                    tmp_keypoint_sequence = keypoint_sequences_norm[idx].copy()
                    tmp["complete_keypoint_sequence"] = torch.from_numpy(tmp_keypoint_sequence).float()
                    tmp["keypoint_sequence"] = torch.from_numpy(tmp_keypoint_sequence).float()
                    tmp["anchor_pos_label"] = torch.from_numpy(anchor_pos_labels[idx]).float()
                    tmp["transition"] = torch.from_numpy(transes[idx]).float()
                    train_data.append(tmp)
                    for i in range(len(anchor_pos)):
                        tmp = {}
                        tmp_keypoint_sequence = keypoint_sequences_norm[idx].copy()
                        tmp["complete_keypoint_sequence"] = torch.from_numpy(tmp_keypoint_sequence).float()
                        if i == 0:
                            tmp_keypoint_sequence[:anchor_pos[i+1]-1] = np.zeros_like(tmp_keypoint_sequence[:anchor_pos[i+1]-1])
                        elif i == len(anchor_pos) - 1:
                            tmp_keypoint_sequence[anchor_pos[i-1]+1:] = np.zeros_like(tmp_keypoint_sequence[anchor_pos[i-1]+1:])
                        else:
                            tmp_keypoint_sequence[anchor_pos[i-1]+1:anchor_pos[i+1]-1] = np.zeros_like(tmp_keypoint_sequence[anchor_pos[i-1]+1:anchor_pos[i+1]-1])
                        tmp["keypoint_sequence"] = torch.from_numpy(tmp_keypoint_sequence).float()
                        tmp["anchor_pos_label"] = torch.from_numpy(anchor_pos_labels[idx]).float()
                        tmp["transition"] = torch.from_numpy(transes[idx]).float()
                        train_data.append(tmp)
                for idx in test_idx:
                    anchor_pos = np.where(anchor_pos_labels[idx] == 1)[0]
                    if len(anchor_pos) == 1:
                        continue
                    tmp = {}
                    tmp_keypoint_sequence = keypoint_sequences_norm[idx].copy()
                    tmp["complete_keypoint_sequence"] = torch.from_numpy(tmp_keypoint_sequence).float()
                    tmp["keypoint_sequence"] = torch.from_numpy(tmp_keypoint_sequence).float()
                    tmp["anchor_pos_label"] = torch.from_numpy(anchor_pos_labels[idx]).float()
                    tmp["transition"] = torch.from_numpy(transes[idx]).float()
                    test_data.append(tmp)
                    for i in range(len(anchor_pos)):
                        tmp = {}
                        tmp_keypoint_sequence = keypoint_sequences_norm[idx].copy()
                        tmp["complete_keypoint_sequence"] = torch.from_numpy(tmp_keypoint_sequence).float()
                        if i == 0:
                            tmp_keypoint_sequence[:anchor_pos[i+1]-1] = np.zeros_like(tmp_keypoint_sequence[:anchor_pos[i+1]-1])
                        elif i == len(anchor_pos) - 1:
                            tmp_keypoint_sequence[anchor_pos[i-1]+1:] = np.zeros_like(tmp_keypoint_sequence[anchor_pos[i-1]+1:])
                        else:
                            tmp_keypoint_sequence[anchor_pos[i-1]+1:anchor_pos[i+1]-1] = np.zeros_like(tmp_keypoint_sequence[anchor_pos[i-1]+1:anchor_pos[i+1]-1])
                        tmp["keypoint_sequence"] = torch.from_numpy(tmp_keypoint_sequence).float()
                        tmp["anchor_pos_label"] = torch.from_numpy(anchor_pos_labels[idx]).float()
                        tmp["transition"] = torch.from_numpy(transes[idx]).float()
                        test_data.append(tmp)
                train_dataset = ConceptDataset(train_data)
                test_dataset = ConceptDataset(test_data)
                print("==> Load data for AutoEncoder training and testing: train data number:", len(train_dataset), "test data number:", len(test_dataset))
                train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_helper)
                test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_helper)
                batch_num = len(train_loader)
                
                save_dir = os.path.join(save_dir_base, "AE")
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                autoencoder = AutoEncoder(args).cuda()
                if args.train:
                    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=args.base_lr)
                    scheduler = get_lr_scheduler(args.policy, optimizer, max_iter=args.epochs*batch_num)
                    print("==> Start training AE for concept ", action_name + "...")
                    exp_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
                    loss_writer = tensorboardX.SummaryWriter(os.path.join(save_dir, "loss", exp_time))
                    for epoch in range(args.epochs):
                        loss, perf_dict = train_SequenceAE(args, autoencoder, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict)
                        for key in loss.keys():
                            loss_writer.add_scalar(key, loss[key], epoch)
                    loss_writer.close()
                print("==> Start evaluating AE for concept", action_name + "...")
                checkpoint = torch.load(os.path.join(save_dir, "best.pth"))
                autoencoder.load_state_dict(checkpoint)
                eval_SequenceAE(args, autoencoder, test_loader, save_dir, limbs=limbs)
            
            elif op in [1, 2]:
                # Anchor Recognition Net
                train_data = []
                test_data = []
                for idx in train_idx:
                    tmp = {}
                    tmp["keypoint_sequence"] = torch.from_numpy(keypoint_sequences_norm[idx]).float()
                    tmp["anchor_pos_label"] = torch.from_numpy(anchor_pos_labels[idx]).float()
                    tmp["complete_keypoint_sequence"] = torch.from_numpy(keypoint_sequences_norm[idx]).float()
                    tmp["transition"] = torch.from_numpy(transes[idx]).float()
                    train_data.append(tmp)
                for idx in test_idx:
                    tmp = {}
                    tmp["keypoint_sequence"] = torch.from_numpy(keypoint_sequences_norm[idx]).float()
                    tmp["anchor_pos_label"] = torch.from_numpy(anchor_pos_labels[idx]).float()
                    tmp["complete_keypoint_sequence"] = torch.from_numpy(keypoint_sequences_norm[idx]).float()
                    tmp["transition"] = torch.from_numpy(transes[idx]).float()
                    test_data.append(tmp)
                
                train_dataset = ConceptDataset(train_data)
                test_dataset = ConceptDataset(test_data)
                train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_helper)
                test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_helper)
                batch_num = len(train_loader)
                if op == 1:
                    save_dir = os.path.join(save_dir_base, "ARN")
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)
                    arn = AnchorRecognitionNetE2E(args, gmm=None).cuda()
                    if args.train:
                        optimizer = torch.optim.Adam(arn.parameters(), lr=args.arn_base_lr)
                        scheduler = get_lr_scheduler(args.policy, optimizer, max_iter=args.arn_epochs*batch_num)
                        print("==> Start training ARN for concept ", action_name + "...")
                        exp_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
                        loss_writer = tensorboardX.SummaryWriter(os.path.join(save_dir, "loss", exp_time))
                        for epoch in range(args.arn_epochs):
                            loss, perf_dict = train_arn(args, arn, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict)
                            for key in loss.keys():
                                loss_writer.add_scalar(key, loss[key], epoch)
                        loss_writer.close()
                    print("==> Start evaluating ARN for concept", action_name + "...")
                    print("==> Load best model...")
                    checkpoint = torch.load(os.path.join(save_dir, "best_acc.pth"))
                    arn.load_state_dict(checkpoint)
                    eval_arn(arn, train_loader, test_loader, save_dir)
                    print("==> Load best model for anchor accuracy...")
                    checkpoint = torch.load(os.path.join(save_dir, "best_anchor_acc.pth"))
                    arn.load_state_dict(checkpoint)
                    eval_arn(arn, train_loader, test_loader, save_dir)
                    print("==> Load checkpoint 300 for anchor accuracy...")
                    checkpoint = torch.load(os.path.join(save_dir, "epoch_300.pth"))
                    arn.load_state_dict(checkpoint)
                    eval_arn(arn, train_loader, test_loader, save_dir)
                    
                elif op == 2:
                    # GMM of Anchor Recognition Net
                    save_dir = os.path.join(save_dir_base, "GMM")
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)
                    gmm = GaussianMixture(n_components=args.n_components)
                    arn = AnchorRecognitionNetE2E(args, gmm).cuda()
                    checkpoint = torch.load(os.path.join(save_dir_base, "ARN", "epoch_300.pth"))
                    arn.load_state_dict(checkpoint)
                    train_embedding, test_embedding = [], []
                    train_keypoint_embedding, test_keypoint_embedding = [], []
                    for data in train_loader:
                        sequence_data, mask_data, anchor_pos_label = data["keypoint_sequence"].cuda(), data["mask_data"], data["anchor_pos_label"]
                        output = arn.encoder(sequence_data, mask_data)
                        for i in range(output.shape[0]):
                            output_anchor = output[i][anchor_pos_label[i] > 0]
                            train_embedding.append(output_anchor.detach().cpu())
                            output_keypoint = output[i][anchor_pos_label[i] == 0]
                            train_keypoint_embedding.append(output_keypoint.detach().cpu())
                    for data in test_loader:
                        sequence_data, mask_data, anchor_pos_label = data["keypoint_sequence"].cuda(), data["mask_data"], data["anchor_pos_label"]
                        output = arn.encoder(sequence_data, mask_data)
                        for i in range(output.shape[0]):
                            output_anchor = output[i][anchor_pos_label[i] > 0]
                            test_embedding.append(output_anchor.detach().cpu())
                            output_keypoint = output[i][anchor_pos_label[i] == 0]
                            test_keypoint_embedding.append(output_keypoint.detach().cpu())
                    train_embedding = torch.cat(train_embedding, dim=0)
                    test_embedding = torch.cat(test_embedding, dim=0)
                    train_keypoint_embedding = torch.cat(train_keypoint_embedding, dim=0)
                    test_keypoint_embedding = torch.cat(test_keypoint_embedding, dim=0)
                    if args.train:
                        print("==> Start training GMM for concept ", action_name + "...")
                        gmm.fit(train_embedding)
                        joblib.dump(gmm, os.path.join(save_dir, "gmm.joblib"))
                    gmm = joblib.load(os.path.join(save_dir, "gmm.joblib"))
                    print("==> Start evaluating GMM for concept", action_name + "...")
                    log_likelihood = gmm.score_samples(train_embedding)
                    print("train_embedding log_likelihood:", np.mean(log_likelihood), np.std(log_likelihood))
                    log_likelihood = gmm.score_samples(test_embedding)
                    print("test_embedding log_likelihood:", np.mean(log_likelihood), np.std(log_likelihood))
                    log_likelihood = gmm.score_samples(train_keypoint_embedding)
                    print("train_keypoint_embedding log_likelihood:", np.mean(log_likelihood), np.std(log_likelihood))
                    log_likelihood = gmm.score_samples(test_keypoint_embedding)
                    print("test_keypoint_embedding log_likelihood:", np.mean(log_likelihood), np.std(log_likelihood))
            
            elif op == 3:
                # Transition Generator
                save_dir = os.path.join(save_dir_base, "TransGenerator")
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                train_data = []
                test_data = []
                for idx in train_idx:
                    anchor_pos = np.where(anchor_pos_labels[idx] == 1)[0]
                    for i in range(len(anchor_pos) - 1):
                        anchor_pair = np.concatenate([keypoint_sequences[idx][anchor_pos[i]:anchor_pos[i]+1], keypoint_sequences[idx][anchor_pos[i+1]:anchor_pos[i+1]+1]], axis=0)
                        anchor_pair_norm = anchor_pair.copy()
                        anchor_pair_norm -= anchor_pair_norm[0, 0]
                        tmp = {}
                        tmp["full_keypoint_sequence"] = torch.from_numpy(keypoint_sequences_norm[idx]).float()
                        tmp_keypoint_sequence = keypoint_sequences[idx][anchor_pos[i]:anchor_pos[i+1]+1].copy()
                        tmp_keypoint_sequence -= tmp_keypoint_sequence[0, 0]
                        tmp["keypoint_sequence"] = torch.from_numpy(tmp_keypoint_sequence).float()
                        tmp["anchor_pos_label"] = torch.from_numpy(anchor_pos_labels[idx][anchor_pos[i]:anchor_pos[i+1]]).float()
                        tmp["transition"] = torch.from_numpy(transes[idx][i][:, [0, 1, 2, 4, 5, 6, 8, 9, 10]]).float()
                        tmp["anchor_pair"] = torch.from_numpy(anchor_pair_norm).float()
                        tmp["duration"] = torch.tensor(anchor_pos[i+1] + 1 - anchor_pos[i]).float()
                        train_data.append(tmp)
                for idx in test_idx:
                    anchor_pos = np.where(anchor_pos_labels[idx] == 1)[0]
                    for i in range(len(anchor_pos) - 1):
                        anchor_pair = np.concatenate([keypoint_sequences[idx][anchor_pos[i]:anchor_pos[i]+1], keypoint_sequences[idx][anchor_pos[i+1]:anchor_pos[i+1]+1]], axis=0)
                        anchor_pair_norm = anchor_pair.copy()
                        anchor_pair_norm -= anchor_pair_norm[0, 0]
                        tmp = {}
                        tmp["full_keypoint_sequence"] = torch.from_numpy(keypoint_sequences_norm[idx]).float()
                        tmp_keypoint_sequence = keypoint_sequences[idx][anchor_pos[i]:anchor_pos[i+1]+1].copy()
                        tmp_keypoint_sequence -= tmp_keypoint_sequence[0, 0]
                        tmp["keypoint_sequence"] = torch.from_numpy(tmp_keypoint_sequence).float()
                        tmp["anchor_pos_label"] = torch.from_numpy(anchor_pos_labels[idx][anchor_pos[i]:anchor_pos[i+1]]+1).float()
                        tmp["transition"] = torch.from_numpy(transes[idx][i][:, [0, 1, 2, 4, 5, 6, 8, 9, 10]]).float()
                        tmp["anchor_pair"] = torch.from_numpy(anchor_pair_norm).float()
                        tmp["duration"] = torch.tensor(anchor_pos[i+1] + 1 - anchor_pos[i]).float()
                        test_data.append(tmp)
                train_dataset = ConceptDataset(train_data)
                test_dataset = ConceptDataset(test_data)
                print("==> Load data for TransGenerator training and testing: train data number:", len(train_dataset), "test data number:", len(test_dataset))
                train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_helper_trans)
                test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_helper_trans)
                batch_num = len(train_loader)
                autoencoder = AutoEncoder(args).cuda()
                checkpoint = torch.load(os.path.join(save_dir_base, "AE", "best.pth"))
                autoencoder.load_state_dict(checkpoint)
                trans_generator = TransGenerator(args, no_joint0=False, add_global_feature=args.trans_add_global_feature).cuda()
                if args.train:
                    optimizer = torch.optim.Adam(trans_generator.parameters(), lr=args.trans_base_lr)
                    scheduler = get_lr_scheduler(args.policy, optimizer, max_iter=args.trans_epochs*batch_num)
                    print("==> Start training TransGenerator for concept ", action_name + "...")
                    exp_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
                    loss_writer = tensorboardX.SummaryWriter(os.path.join(save_dir, "loss", exp_time))
                    for epoch in range(args.trans_epochs):
                        loss, perf_dict = train_trans_generator(args, autoencoder, trans_generator, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict)
                        for key in loss.keys():
                            loss_writer.add_scalar(key, loss[key], epoch)
                    loss_writer.close()
                print("==> Start evaluating TransGenerator for concept", action_name + "...")
                checkpoint = torch.load(os.path.join(save_dir, "best.pth"))
                trans_generator.load_state_dict(checkpoint)
                eval_trans_generator(args, autoencoder, trans_generator, train_loader, test_loader)
            
            elif op in [4, 5, 6]:
                # Anchor Refinement Net
                save_dir = os.path.join(save_dir_base, "AFN")
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                train_data = []
                test_data = []
                for idx in train_idx:
                    anchor_pos = np.where(anchor_pos_labels[idx] == 1)[0]
                    if len(anchor_pos) == 1:
                        continue
                    for i in range(len(anchor_pos)):
                        tmp = {}
                        tmp_keypoint_sequence = keypoint_sequences_norm[idx].copy()
                        tmp_gt_anchor = tmp_keypoint_sequence[anchor_pos[i]].copy()
                        tmp_global_translation = keypoint_sequences[idx][:, 0:1].copy()
                        if i == 0:
                            tmp_keypoint_sequence[:anchor_pos[i+1]-1] = np.zeros_like(tmp_keypoint_sequence[:anchor_pos[i+1]-1])
                        elif i == len(anchor_pos)-1:
                            tmp_keypoint_sequence[anchor_pos[i-1]+1:] = np.zeros_like(tmp_keypoint_sequence[anchor_pos[i-1]+1:])
                        else:
                            tmp_keypoint_sequence[anchor_pos[i-1]+1:anchor_pos[i+1]-1] = np.zeros_like(tmp_keypoint_sequence[anchor_pos[i-1]+1:anchor_pos[i+1]-1])
                        tmp["keypoint_sequence"] = torch.from_numpy(tmp_keypoint_sequence).float()
                        tmp["gt_anchor"] = torch.from_numpy(tmp_gt_anchor).float()
                        tmp["anchor_global_translation"] = tmp_global_translation
                        tmp["anchor_pos_label"] = torch.from_numpy(anchor_pos_labels[idx]).float()
                        tmp["anchor_class"] = torch.tensor(anchor_labels_reshape[idx][i])
                        anchor_class_onehot = torch.zeros(args.n_clusters)
                        anchor_class_onehot[anchor_labels_reshape[idx][i]] = 1
                        tmp["anchor_class_onehot"] = anchor_class_onehot
                        tmp["transition"] = torch.from_numpy(transes[idx]).float()
                        tmp["anchor_pos"] = torch.tensor(anchor_pos[i])
                        train_data.append(tmp)
                
                for idx in test_idx:
                    anchor_pos = np.where(anchor_pos_labels[idx] == 1)[0]
                    if len(anchor_pos) == 1:
                        continue
                    for i in range(len(anchor_pos)):
                        tmp = {}
                        tmp_keypoint_sequence = keypoint_sequences_norm[idx].copy()
                        tmp_gt_anchor = tmp_keypoint_sequence[anchor_pos[i]].copy()
                        tmp_global_translation = keypoint_sequences[idx][:, 0:1].copy()
                        if i == 0:
                            tmp_keypoint_sequence[:anchor_pos[i+1]-1] = np.zeros_like(tmp_keypoint_sequence[:anchor_pos[i+1]-1])
                        elif i == len(anchor_pos)-1:
                            tmp_keypoint_sequence[anchor_pos[i-1]+1:] = np.zeros_like(tmp_keypoint_sequence[anchor_pos[i-1]+1:])
                        else:
                            tmp_keypoint_sequence[anchor_pos[i-1]+1:anchor_pos[i+1]-1] = np.zeros_like(tmp_keypoint_sequence[anchor_pos[i-1]+1:anchor_pos[i+1]-1])
                        tmp["keypoint_sequence"] = torch.from_numpy(tmp_keypoint_sequence).float()
                        tmp["gt_anchor"] = torch.from_numpy(tmp_gt_anchor).float()
                        tmp["anchor_global_translation"] = tmp_global_translation
                        tmp["anchor_pos_label"] = torch.from_numpy(anchor_pos_labels[idx]).float()
                        tmp["anchor_class"] = torch.tensor(anchor_labels_reshape[idx][i])
                        anchor_class_onehot = torch.zeros(args.n_clusters)
                        anchor_class_onehot[anchor_labels_reshape[idx][i]] = 1
                        tmp["anchor_class_onehot"] = anchor_class_onehot
                        tmp["transition"] = torch.from_numpy(transes[idx]).float()
                        tmp["anchor_pos"] = torch.tensor(anchor_pos[i])
                        test_data.append(tmp)
                
                train_dataset = ConceptDataset(train_data)
                test_dataset = ConceptDataset(test_data)
                print("==> Load data for AnchorRefinementNet training and testing: train data number:", len(train_dataset), "test data number:", len(test_dataset))
                train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_helper_afn)
                test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_helper_afn)
                batch_num = len(train_loader)
                afn = AnchorRefinementNet(args).cuda()
                if op == 4:
                    # Train Anchor Refinement Net Classifier
                    save_dir_classifier = os.path.join(save_dir, "classifier")
                    if not os.path.exists(save_dir_classifier):
                        os.makedirs(save_dir_classifier)
                    if args.train:
                        optimizer = torch.optim.Adam(afn.classifier.parameters(), lr=args.afn_base_lr)
                        scheduler = get_lr_scheduler(args.policy, optimizer, max_iter=args.afn_epochs*batch_num)
                        print("==> Start training AFN classifier for concept ", action_name + "...")
                        exp_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
                        loss_writer = tensorboardX.SummaryWriter(os.path.join(save_dir, "classifier", "loss", exp_time))
                        for epoch in range(args.afn_epochs):
                            loss, perf_dict = train_afn_classifier(args, afn, train_loader, test_loader, optimizer, scheduler, epoch, save_dir_classifier, perf_dict)
                            for key in loss.keys():
                                loss_writer.add_scalar(key, loss[key], epoch)
                        loss_writer.close()
                    print("==> Start evaluating AFN classifier for concept", action_name, "...")
                    checkpoint = torch.load(os.path.join(save_dir_classifier, "best_encoder.pth"))
                    afn.classifier_encoder.load_state_dict(checkpoint)
                    checkpoint = torch.load(os.path.join(save_dir_classifier, "best_classifier.pth"))
                    afn.classifier.load_state_dict(checkpoint)
                    eval_afn_classifier(afn, train_loader, test_loader)
                
                elif op == 5:
                    # Train Anchor Refinement Net Decoder
                    checkpoint = torch.load(os.path.join(save_dir, "classifier", "best_encoder.pth"))
                    afn.classifier_encoder.load_state_dict(checkpoint)
                    checkpoint = torch.load(os.path.join(save_dir, "classifier", "best_classifier.pth"))
                    afn.classifier.load_state_dict(checkpoint)
                    if args.train:
                        optimizer = torch.optim.Adam(afn.parameters(), lr=args.afn_base_lr)
                        scheduler = get_lr_scheduler(args.policy, optimizer, max_iter=args.afn_epochs*batch_num)
                        print("==> Start training AFN decoder for concept ", action_name + "...")
                        exp_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
                        loss_writer = tensorboardX.SummaryWriter(os.path.join(save_dir, "loss", exp_time))
                        for epoch in range(args.afn_epochs):
                            loss, perf_dict = train_afn(args, afn, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict)
                            for key in loss.keys():
                                loss_writer.add_scalar(key, loss[key], epoch)
                        loss_writer.close()
                    print("==> Start evaluating AFN decoder for concept ", action_name, "...")
                    checkpoint = torch.load(os.path.join(save_dir, "best.pth"))
                    afn.load_state_dict(checkpoint)
                    checkpoint = torch.load(os.path.join(save_dir, "classifier", "best_encoder.pth"))
                    afn.classifier_encoder.load_state_dict(checkpoint)
                    checkpoint = torch.load(os.path.join(save_dir, "classifier", "best_classifier.pth"))
                    afn.classifier.load_state_dict(checkpoint)
                    eval_afn(afn, train_loader, test_loader)
                
                elif op == 6:
                    # Train Anchor Refinement Net Classifier and Decoder together
                    save_dir_together = os.path.join(save_dir, "together")
                    if not os.path.exists(save_dir_together):
                        os.makedirs(save_dir_together)
                    if args.train:
                        optimizer = torch.optim.Adam(afn.parameters(), lr=args.afn_base_lr)
                        scheduler = get_lr_scheduler(args.policy, optimizer, max_iter=args.afn_epochs*batch_num)
                        print("==> Start training AFN for concept ", action_name + "...")
                        exp_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
                        loss_writer = tensorboardX.SummaryWriter(os.path.join(save_dir, "loss", exp_time))
                        for epoch in range(args.afn_epochs):
                            loss, perf_dict = train_afn_together(args, afn, train_loader, test_loader, optimizer, scheduler, epoch, save_dir_together, perf_dict)
                            for key in loss.keys():
                                loss_writer.add_scalar(key, loss[key], epoch)
                        loss_writer.close()
                    print("==> Start evaluating AFN for concept ", action_name, "...")
                    checkpoint = torch.load(os.path.join(save_dir_together, "best_recon.pth"))
                    afn.load_state_dict(checkpoint)
                    eval_afn(afn, train_loader, test_loader)
                    
            elif op == 7:
                # Transition Refinement Net
                save_dir = os.path.join(save_dir_base, "TRN")
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                train_data = []
                test_data = []
                for idx in train_idx:
                    anchor_pos = np.where(anchor_pos_labels[idx] == 1)[0]
                    if len(anchor_pos) == 1:
                        continue
                    for i in range(len(anchor_pos) - 1):
                        tmp = {}
                        keypoint_sequence_slice = keypoint_sequences[idx][anchor_pos[i]:anchor_pos[i+1]+1].copy()
                        keypoint_sequence_slice -= keypoint_sequence_slice[0, 0]
                        tmp["full_keypoint_sequence"] = torch.from_numpy(keypoint_sequences[idx]).float()
                        tmp["keypoint_sequence"] = torch.from_numpy(keypoint_sequence_slice).float()
                        tmp["pos_record"] = torch.tensor([idx, i])
                        tmp["transition"] = torch.from_numpy(transes[idx][i][:, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]).float()
                        tmp["duration"] = torch.tensor(anchor_pos[i+1] - anchor_pos[i]).long()
                        train_data.append(tmp)
                for idx in test_idx:
                    anchor_pos = np.where(anchor_pos_labels[idx] == 1)[0]
                    if len(anchor_pos) == 1:
                        continue
                    for i in range(len(anchor_pos) - 1):
                        tmp = {}
                        keypoint_sequence_slice = keypoint_sequences[idx][anchor_pos[i]:anchor_pos[i+1]+1].copy()
                        keypoint_sequence_slice -= keypoint_sequence_slice[0, 0]
                        tmp["full_keypoint_sequence"] = torch.from_numpy(keypoint_sequences[idx]).float()
                        tmp["keypoint_sequence"] = torch.from_numpy(keypoint_sequence_slice).float()
                        tmp["pos_record"] = torch.tensor([idx, i])
                        tmp["transition"] = torch.from_numpy(transes[idx][i][:, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]).float()
                        tmp["duration"] = torch.tensor(anchor_pos[i+1] - anchor_pos[i]).long()
                        test_data.append(tmp)
                train_dataset = ConceptDataset(train_data)
                test_dataset = ConceptDataset(test_data)
                print("==> Load data for TransRefinementNet training and testing: train data number:", len(train_dataset), "test data number:", len(test_dataset))
                train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_helper_aprn)
                test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_helper_aprn)
                batch_num = len(train_loader)
                aprn = TransRefinementNet(args).cuda()
                if args.train:
                    optimizer = torch.optim.Adam(aprn.parameters(), lr=args.aprn_base_lr)
                    scheduler = get_lr_scheduler(args.policy, optimizer, max_iter=args.aprn_epochs*batch_num)
                    print("==> Start training TRN for concept ", action_name + "...")
                    exp_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
                    loss_writer = tensorboardX.SummaryWriter(os.path.join(save_dir, "loss", exp_time))
                    for epoch in range(args.aprn_epochs):
                        loss, perf_dict = train_aprn(args, aprn, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict)
                        for key in loss.keys():
                            loss_writer.add_scalar(key, loss[key], epoch)
                    loss_writer.close()
                print("==> Start evaluating TRN for concept", action_name + "...")
                checkpoint = torch.load(os.path.join(save_dir, "best_recon.pth"))
                aprn.load_state_dict(checkpoint)
                eval_aprn(args, aprn, train_loader, test_loader, save_dir, limbs=limbs)

            elif op == 8:
                # Transition Refinement Net AutoRegressive version, not used in the paper
                save_dir = os.path.join(save_dir_base, "TRN_AR")
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                train_data = []
                test_data = []
                for idx in train_idx:
                    tmp = {}
                    anchor_pos = np.where(anchor_pos_labels[idx] == 1)[0]
                    if len(anchor_pos) == 1:
                        continue
                    keypoint_sequence_slice_gather = []
                    keypoint_sequence_slice_len = []
                    transition_gather = []
                    for i in range(len(anchor_pos) - 1):
                        keypoint_sequence_slice = keypoint_sequences[idx][anchor_pos[i]:anchor_pos[i+1]+1].copy()
                        keypoint_sequence_slice -= keypoint_sequence_slice[0, 0]
                        keypoint_sequence_slice = torch.from_numpy(keypoint_sequence_slice).float()
                        keypoint_sequence_slice_gather.append(keypoint_sequence_slice)
                        keypoint_sequence_slice_len.append(torch.tensor(anchor_pos[i+1] - anchor_pos[i]).long())
                        transition_gather.append(torch.from_numpy(transes[idx][i][:, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]).float()) 
                    keypoint_sequence_slice_pad = [torch.cat([item, item[-1].unsqueeze(0).repeat(max(keypoint_sequence_slice_len)+1-item.shape[0], 1, 1)], dim=0) for item in keypoint_sequence_slice_gather]
                    keypoint_sequence_slice_pad = torch.stack(keypoint_sequence_slice_pad, dim=0)
                    keypoint_sequence_slice_len = torch.stack(keypoint_sequence_slice_len)
                    transition_gather = torch.stack(transition_gather)
                    tmp["full_keypoint_sequence"] = torch.from_numpy(keypoint_sequences[idx]).float()
                    tmp["keypoint_sequence"] = keypoint_sequence_slice_pad
                    tmp["transition"] = transition_gather
                    tmp["duration"] = keypoint_sequence_slice_len
                    train_data.append(tmp)
                for idx in test_idx:
                    tmp = {}
                    anchor_pos = np.where(anchor_pos_labels[idx] == 1)[0]
                    if len(anchor_pos) == 1:
                        continue
                    keypoint_sequence_slice_gather = []
                    keypoint_sequence_slice_len = []
                    transition_gather = []
                    for i in range(len(anchor_pos) - 1):
                        keypoint_sequence_slice = keypoint_sequences[idx][anchor_pos[i]:anchor_pos[i+1]+1].copy()
                        keypoint_sequence_slice -= keypoint_sequence_slice[0, 0]
                        keypoint_sequence_slice = torch.from_numpy(keypoint_sequence_slice).float()
                        keypoint_sequence_slice_gather.append(keypoint_sequence_slice)
                        keypoint_sequence_slice_len.append(torch.tensor(anchor_pos[i+1] - anchor_pos[i]).long())
                        transition_gather.append(torch.from_numpy(transes[idx][i][:, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]).float()) 
                    keypoint_sequence_slice_pad = [torch.cat([item, item[-1].unsqueeze(0).repeat(max(keypoint_sequence_slice_len)+1-item.shape[0], 1, 1)], dim=0) for item in keypoint_sequence_slice_gather]
                    keypoint_sequence_slice_pad = torch.stack(keypoint_sequence_slice_pad, dim=0)
                    keypoint_sequence_slice_len = torch.stack(keypoint_sequence_slice_len)
                    transition_gather = torch.stack(transition_gather)
                    tmp["full_keypoint_sequence"] = torch.from_numpy(keypoint_sequences[idx]).float()
                    tmp["keypoint_sequence"] = keypoint_sequence_slice_pad
                    tmp["transition"] = transition_gather
                    tmp["duration"] = keypoint_sequence_slice_len
                    test_data.append(tmp)
                train_dataset = ConceptDataset(train_data)
                test_dataset = ConceptDataset(test_data)
                print("==> Load data for TransRefinementNet training and testing: train data number:", len(train_dataset), "test data number:", len(test_dataset))
                train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_helper_aprn_ar)
                test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_helper_aprn_ar)
                batch_num = len(train_loader)
                aprn = TransRefinementNetAR(args).cuda()
                if args.train:
                    optimizer = torch.optim.Adam(aprn.parameters(), lr=args.aprn_base_lr)
                    scheduler = get_lr_scheduler(args.policy, optimizer, max_iter=args.aprn_epochs*batch_num)
                    print("==> Start training TRN_AR for concept ", action_name + "...")
                    exp_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
                    loss_writer = tensorboardX.SummaryWriter(os.path.join(save_dir, "loss", exp_time))
                    for epoch in range(args.aprn_epochs):
                        loss, perf_dict = train_aprn_ar(args, aprn, train_loader, test_loader, optimizer, scheduler, epoch, save_dir, perf_dict)
                        for key in loss.keys():
                            loss_writer.add_scalar(key, loss[key], epoch)
                    loss_writer.close()
                print("==> Start evaluating TRN_AR for concept", action_name + "...")
                checkpoint = torch.load(os.path.join(save_dir, "best_recon.pth"))
                aprn.load_state_dict(checkpoint)
                eval_aprn_ar(args, aprn, train_loader, test_loader, save_dir)
            
            elif op in [9, 10, 11, 12]:
                # HumanAct12 and UESTC actions evaluation
                arn = AnchorRecognitionNetE2E(args, gmm=None).cuda()
                checkpoint = torch.load(os.path.join(save_dir_base, "ARN", "epoch_300.pth"))
                arn.load_state_dict(checkpoint)
                gmm = joblib.load(os.path.join(save_dir_base, "GMM", "gmm.joblib"))
                arn.gmm = gmm
                autoencoder = AutoEncoder(args).cuda()
                checkpoint = torch.load(os.path.join(save_dir_base, "AE", "best.pth"))
                autoencoder.load_state_dict(checkpoint)
                trans_generator = TransGenerator(args, no_joint0=False, add_global_feature=args.trans_add_global_feature).cuda()
                checkpoint = torch.load(os.path.join(save_dir_base, "TransGenerator", "best.pth"))
                trans_generator.load_state_dict(checkpoint)
                afn = AnchorRefinementNet(args).cuda()
                checkpoint = torch.load(os.path.join(save_dir_base, "AFN", "best.pth"))
                afn.load_state_dict(checkpoint)
                checkpoint = torch.load(os.path.join(save_dir_base, "AFN", "classifier", "best_encoder.pth"))
                afn.classifier_encoder.load_state_dict(checkpoint)
                checkpoint = torch.load(os.path.join(save_dir_base, "AFN", "classifier", "best_classifier.pth"))
                afn.classifier.load_state_dict(checkpoint)
                aprn = TransRefinementNet(args).cuda()
                checkpoint = torch.load(os.path.join(save_dir_base, "TRN", "best_recon.pth"))
                aprn.load_state_dict(checkpoint)
                arn.eval()
                autoencoder.eval()
                trans_generator.eval()
                afn.eval()
                aprn.eval()
                if op in [9, 10]:
                    # Humanact12 actions evaluation
                    if args.dataset == "humanact12":
                        with open(os.path.join(args.generate_data_dir, args.model_name, args.dataset, "saved_processed_data.pkl"), "rb") as f:
                            generate_data = pickle.load(f)
                        generate_data_concept = generate_data[int(concept)]
                        if args.model_name == "actor":
                            gt_keypoint_sequence_for_eval = np.load(os.path.join("actor_keypoint_sequence", "gt_keypoint_sequence_data.npy"))
                            gt_keypoint_labels_for_eval = np.load(os.path.join("actor_keypoint_sequence", "gt_keypoint_labels.npy"))
                            if op in [9, 10]:
                                gt_keypoint_sequence_concept = []
                                for i in range(len(gt_keypoint_labels_for_eval)):
                                    if gt_keypoint_labels_for_eval[i] == int(concept):
                                        gt_keypoint_sequence_concept.append(gt_keypoint_sequence_for_eval[i])
                                gt_keypoint_sequence_concept = np.array(gt_keypoint_sequence_concept)
                            elif op == 11:
                                gt_keypoint_sequence_concept = gt_keypoint_sequence_for_eval
                            gt_keypoint_sequence_concept = torch.from_numpy(gt_keypoint_sequence_concept).float().cuda()
                        elif args.model_name == "motiongpt":
                            with open(os.path.join(args.generate_data_dir, args.model_name, args.dataset, "saved_gt_data.pkl"), "rb") as f:
                                gt_data = pickle.load(f)
                            gt_keypoint_sequence_concept = gt_data[int(concept)]
                            if not os.path.exists(os.path.join(save_dir_base, args.experiment_name, args.dataset, args.model_name+"_sample_idx.npy")):
                                if not os.path.exists(os.path.join(save_dir_base, args.experiment_name, args.dataset)):
                                    os.makedirs(os.path.join(save_dir_base, args.experiment_name, args.dataset))
                                sample_idx = np.random.choice(len(gt_keypoint_sequence_concept), int(len(gt_keypoint_sequence_concept) * 0.3), replace=False)
                                sample_idx = sorted(sample_idx)
                                print("== Generate sample_idx:", sample_idx)
                                np.save(os.path.join(save_dir_base, args.experiment_name, args.dataset, args.model_name+"_sample_idx.npy"), sample_idx)
                            else:
                                sample_idx = np.load(os.path.join(save_dir_base, args.experiment_name, args.dataset, args.model_name+"_sample_idx.npy"))
                                print("==> Load sample_idx:", sample_idx)
                            gt_keypoint_sequence_concept = [torch.from_numpy(gt_keypoint_sequence_concept[i]).float().cuda() for i in sample_idx]
                    for key in range(20):
                        print("==> Refine sequence_data for key:", key, "concept:", concept, "...")
                        if args.dataset == "humanact12":
                            if args.model_name == "action2motion":
                                gt_keypoint_sequence_for_eval = np.load(os.path.join(args.generate_data_dir, args.model_name, args.dataset, "ground_truth", f"motion_{key}", "gt_keypoint_sequence.npy"))
                                gt_labels_for_eval = np.load(os.path.join(args.generate_data_dir, args.model_name, args.dataset, "ground_truth", f"motion_{key}", "gt_labels.npy"))
                                gt_keypoint_sequence_concept = []
                                for i in range(len(gt_labels_for_eval)):
                                    if gt_labels_for_eval[i] == int(concept):
                                        gt_keypoint_sequence_concept.append(gt_keypoint_sequence_for_eval[i])
                                gt_keypoint_sequence_concept = np.array(gt_keypoint_sequence_concept)
                                gt_keypoint_sequence_concept = torch.from_numpy(gt_keypoint_sequence_concept).float().cuda()
                                print("gt_keypoint_sequence_concept:", gt_keypoint_sequence_concept.shape)
                            if args.model_name == "actor" or args.model_name == "action2motion":
                                sample_idx = np.arange(generate_data_concept[key].shape[0])
                        save_dir = os.path.join(save_dir_base, args.experiment_name, args.dataset, args.unsuitable_type, args.model_name, str(key))
                        if not os.path.exists(save_dir):
                            os.makedirs(save_dir)
                        refine_data = []
                        for idx in sample_idx:
                            data_to_refine = generate_data_concept[key][idx] * 100
                            global_position = data_to_refine[:, 0:1].copy()
                            data_to_refine_norm = data_to_refine.copy()
                            data_to_refine_norm = data_to_refine_norm - data_to_refine_norm[:, 0:1]
                            tmp = {}
                            tmp["keypoint_sequence"] = torch.from_numpy(data_to_refine).float()
                            tmp["keypoint_sequence_norm"] = torch.from_numpy(data_to_refine_norm).float()
                            tmp["global_position"] = torch.from_numpy(global_position).float()
                            refine_data.append(tmp)
                        refine_dataset = ConceptDataset(refine_data)
                        print("==> Load data for refinement:", len(refine_dataset))
                        refine_loader = DataLoader(refine_dataset, batch_size=1, shuffle=False, collate_fn=collate_helper_refine)
                        batch_num = len(refine_loader)
                        if op == 9:
                            if key == 0:
                                add_visualize = True
                            else:
                                add_visualize = False
                            eval_sequence_refinement(args, arn, autoencoder, trans_generator, afn, aprn, refine_loader, gt_keypoint_sequence_concept, save_dir, limbs=limbs, add_visualize=add_visualize)
                        elif op == 10:
                            eval_sequence_refinement_metrics_humanact12(args, gt_keypoint_sequence_concept, save_dir)
                        
                elif op in [11, 12]:
                    # UESTC actions evaluation
                    torch.manual_seed(0)
                    random.seed(0)
                    np.random.seed(0)
                    splits = ["train", "test"]
                    for split in splits:
                        for key in range(20):
                            print("==> Refine sequence_data for concept:", concept, "split:", split, "key:", key, "...")
                            if args.dataset == "uestc":
                                if args.model_name == "actor":
                                    with open(os.path.join(args.generate_data_dir, args.model_name, args.dataset, "uestc_xyz", "generated_motions_"+split+str(key)+".pkl"), "rb") as f:
                                        generate_data = pickle.load(f)
                                    gt_keypoint_sequence_for_eval = generate_data["gt"]
                                    gen_keypoint_sequence_for_eval = generate_data["output_xyz"]
                                    gt_labels_for_eval = generate_data["y"]
                                    gt_keypoint_sequence_concept = []
                                    gen_keypoint_sequence_concept = []
                                    for i in range(len(gt_labels_for_eval)):
                                        if uestc_list[gt_labels_for_eval[i]] == concept:
                                            gt_keypoint_sequence_concept.append(gt_keypoint_sequence_for_eval[i])
                                            gen_keypoint_sequence_concept.append(gen_keypoint_sequence_for_eval[i])
                                    gt_keypoint_sequence_concept = np.array(gt_keypoint_sequence_concept)
                                    gen_keypoint_sequence_concept = np.array(gen_keypoint_sequence_concept)
                                    gt_keypoint_sequence_concept = np.transpose(gt_keypoint_sequence_concept, (0, 3, 1, 2))
                                    gen_keypoint_sequence_concept = np.transpose(gen_keypoint_sequence_concept, (0, 3, 1, 2))
                                    if not os.path.exists(os.path.join(save_dir_base, args.experiment_name, args.dataset, args.unsuitable_type, args.model_name, split, str(key), "sample_idx.npy")):
                                        if not os.path.exists(os.path.join(save_dir_base, args.experiment_name, args.dataset, args.unsuitable_type, args.model_name, split, str(key))):
                                            os.makedirs(os.path.join(save_dir_base, args.experiment_name, args.dataset, args.unsuitable_type, args.model_name, split, str(key)))
                                        sample_idx = np.random.choice(len(gt_keypoint_sequence_concept), int(len(gt_keypoint_sequence_concept) * 0.3), replace=False)
                                        sample_idx = sorted(sample_idx)
                                        np.save(os.path.join(save_dir_base, args.experiment_name, args.dataset, args.unsuitable_type, args.model_name, split, str(key), "sample_idx.npy"), sample_idx)
                                    else:
                                        sample_idx = np.load(os.path.join(save_dir_base, args.experiment_name, args.dataset, args.unsuitable_type, args.model_name, split, str(key), "sample_idx.npy"))
                                        print("==> Load sample_idx for concept:", concept, "split:", split, "key:", key, "sample_idx:", sample_idx)
                                    gen_keypoint_sequence_concept = gen_keypoint_sequence_concept[sample_idx]
                                    gt_keypoint_sequence_concept = torch.from_numpy(gt_keypoint_sequence_concept[sample_idx]).float().cuda()
                                    print("gt_keypoint_sequence_concept:", gt_keypoint_sequence_concept.shape, "gen_keypoint_sequence_concept:", gen_keypoint_sequence_concept.shape)
                                elif args.model_name == "action2motion":
                                    with open(os.path.join(args.generate_data_dir, args.model_name, args.dataset, split, "train", "generated_motions_"+str(key)+".pkl"), "rb") as f:
                                        generate_data = pickle.load(f)
                                    gen_keypoint_sequence_for_eval = generate_data["motions"]
                                    gen_labels = generate_data["labels"]
                                    gt_keypoint_sequence_for_eval = generate_data["gt_motions"]
                                    gt_keypoint_sequence_concept = []
                                    gen_keypoint_sequence_concept = []
                                    for i in range(len(gen_labels)):
                                        if uestc_list[int(gen_labels[i])] == concept:
                                            gt_keypoint_sequence_concept.append(gt_keypoint_sequence_for_eval[i])
                                            gen_keypoint_sequence_concept.append(gen_keypoint_sequence_for_eval[i])
                                    gt_keypoint_sequence_concept = np.array(gt_keypoint_sequence_concept)
                                    gen_keypoint_sequence_concept = np.array(gen_keypoint_sequence_concept)
                                    gt_keypoint_sequence_concept = gt_keypoint_sequence_concept.reshape(gt_keypoint_sequence_concept.shape[0], gt_keypoint_sequence_concept.shape[1], -1, 3)
                                    gen_keypoint_sequence_concept = gen_keypoint_sequence_concept.reshape(gen_keypoint_sequence_concept.shape[0], gen_keypoint_sequence_concept.shape[1], -1, 3)
                                    print("gt_keypoint_sequence_concept:", gt_keypoint_sequence_concept.shape, "gen_keypoint_sequence_concept:", gen_keypoint_sequence_concept.shape)
                                    if not os.path.exists(os.path.join(save_dir_base, args.experiment_name, args.dataset, args.unsuitable_type, args.model_name, split, str(key), "sample_idx.npy")):
                                        if not os.path.exists(os.path.join(save_dir_base, args.experiment_name, args.dataset, args.unsuitable_type, args.model_name, split, str(key))):
                                            os.makedirs(os.path.join(save_dir_base, args.experiment_name, args.dataset, args.unsuitable_type, args.model_name, split, str(key)))
                                        sample_idx = np.random.choice(len(gt_keypoint_sequence_concept), int(len(gt_keypoint_sequence_concept) * 0.3), replace=False)
                                        sample_idx = sorted(sample_idx)
                                        np.save(os.path.join(save_dir_base, args.experiment_name, args.dataset, args.unsuitable_type, args.model_name, split, str(key), "sample_idx.npy"), sample_idx)
                                    else:
                                        sample_idx = np.load(os.path.join(save_dir_base, args.experiment_name, args.dataset, args.unsuitable_type, args.model_name, split, str(key), "sample_idx.npy"))
                                        print("==> Load sample_idx for concept:", concept, "split:", split, "key:", key, "sample_idx:", sample_idx)
                                    gen_keypoint_sequence_concept = gen_keypoint_sequence_concept[sample_idx]
                                    gt_keypoint_sequence_concept = torch.from_numpy(gt_keypoint_sequence_concept[sample_idx]).float().cuda()
                            else:
                                raise NotImplementedError
                            save_dir = os.path.join(save_dir_base, args.experiment_name, args.dataset, args.unsuitable_type, args.model_name, split, str(key))
                            if not os.path.exists(save_dir):
                                os.makedirs(save_dir)
                            data_to_refine = gen_keypoint_sequence_concept * 100
                            global_position = data_to_refine[:, :, 0:1].copy()
                            data_to_refine_norm = data_to_refine.copy()
                            data_to_refine_norm = data_to_refine_norm - data_to_refine_norm[:, :, 0:1]
                            refine_data = []
                            for idx in range(len(data_to_refine)):
                                tmp = {}
                                tmp["keypoint_sequence"] = torch.from_numpy(data_to_refine[idx]).float()
                                tmp["keypoint_sequence_norm"] = torch.from_numpy(data_to_refine_norm[idx]).float()
                                tmp["global_position"] = torch.from_numpy(global_position[idx]).float()
                                refine_data.append(tmp)
                            refine_dataset = ConceptDataset(refine_data)
                            print("==> Load data for refinement:", len(refine_dataset))
                            refine_loader = DataLoader(refine_dataset, batch_size=1, shuffle=False, collate_fn=collate_helper_refine)
                            batch_num = len(refine_loader)
                            if op == 11:
                                if key == 0:
                                    add_visualize = True
                                else:
                                    add_visualize = False
                                eval_sequence_refinement(args, arn, autoencoder, trans_generator, afn, aprn, refine_loader, gt_keypoint_sequence_concept, save_dir, limbs=limbs, add_visualize=add_visualize)
                            elif op == 12:
                                eval_sequence_refinement_metrics_uestc(args, gt_keypoint_sequence_concept, save_dir)
            elif op == 13:
                # HumanAct12 all actions evaluation
                save_dir = os.path.join(args.save_refine_dir, args.dataset, args.model_name)
                eval_sequence_refinement_metrics_humanact12_full(args, save_dir)
            elif op == 14:
                # UESTC all actions evaluation
                save_dir = os.path.join(args.save_refine_dir, args.dataset, args.model_name)
                eval_sequence_refinement_metrics_uestc_full(args, save_dir)
                
if __name__ == '__main__':
    main(args)
