# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


from common.arguments import parse_args
import os
args = parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

import numpy as np
import random

import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import sys
import errno
import math

from einops import rearrange, repeat
from copy import deepcopy

from common.camera import *
import collections

from common.finepose_3dhp import *

from common.loss import *
from common.generators_3dhp import ChunkedGenerator_Seq, UnchunkedGenerator_Seq
from time import time
from common.utils import *
from common.logging import Logger
# from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

import scipy.io as scio

#cudnn.benchmark = True       
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print("CUDA Device Count: ", torch.cuda.device_count())
import clip
print("CUDA Device Count: ", torch.cuda.device_count())

args = parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu


if args.evaluate != '':
    description = "Evaluate!"
elif args.evaluate == '':
    
    description = "Train!"

# initial setting
TIMESTAMP = "{0:%Y%m%dT%H-%M-%S/}".format(datetime.now())
# tensorboard
if not args.nolog:
    writer = SummaryWriter(args.log+'_'+TIMESTAMP)
    writer.add_text('description', description)
    writer.add_text('command', 'python ' + ' '.join(sys.argv))
    # logging setting
    logfile = os.path.join(args.log+'_'+TIMESTAMP, 'logging.log')
    sys.stdout = Logger(logfile)
print(description)
print('python ' + ' '.join(sys.argv))
print("CUDA Device Count: ", torch.cuda.device_count())
print(args)

manualSeed = 1
random.seed(manualSeed)
torch.manual_seed(manualSeed)
np.random.seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)

# if not assign checkpoint path, Save checkpoint file into log folder
if args.checkpoint=='':
    args.checkpoint = args.log+'_'+TIMESTAMP
try:
    # Create checkpoint directory if it does not exist
    os.makedirs(args.checkpoint)
except OSError as e:
    if e.errno != errno.EEXIST:
        raise RuntimeError('Unable to create checkpoint directory:', args.checkpoint)

# dataset loading
print('Loading dataset...')

print('Preparing data...')

out_poses_3d_train = {}
out_poses_2d_train = {}
out_poses_3d_test = {}
out_poses_2d_test = {}
valid_frame = {}

kps_left, kps_right = [5, 6, 7, 11, 12, 13], [2, 3, 4, 8, 9, 10]
joints_left, joints_right = [5, 6, 7, 11, 12, 13], [2, 3, 4, 8, 9, 10]



data_train = np.load("./data/data_train_3dhp_ori.npz", allow_pickle=True)['data'].item()
for seq in data_train.keys():
    for cam in data_train[seq][0].keys():
        anim = data_train[seq][0][cam]

        subject_name, seq_name = seq.split(" ")

        data_3d = anim['data_3d']
        data_3d[:, :14] -= data_3d[:, 14:15]
        data_3d[:, 15:] -= data_3d[:, 14:15]
        out_poses_3d_train[(subject_name, seq_name, cam)] = data_3d

        data_2d = anim['data_2d']

        data_2d[..., :2] = normalize_screen_coordinates(data_2d[..., :2], w=2048, h=2048)
        out_poses_2d_train[(subject_name, seq_name, cam)] = data_2d


data_test = np.load("./data/data_test_3dhp_ori.npz", allow_pickle=True)['data'].item()
for seq in data_test.keys():

    anim = data_test[seq]

    valid_frame[seq] = anim["valid"]

    data_3d = anim['data_3d']
    data_3d[:, :14] -= data_3d[:, 14:15]
    data_3d[:, 15:] -= data_3d[:, 14:15]
    out_poses_3d_test[seq] = data_3d

    data_2d = anim['data_2d']

    if seq == "TS5" or seq == "TS6":
        width = 1920
        height = 1080
    else:
        width = 2048
        height = 2048
    data_2d[..., :2] = normalize_screen_coordinates(data_2d[..., :2], w=width, h=height)
    out_poses_2d_test[seq] = data_2d


subjects_train = args.subjects_train.split(',')
subjects_semi = [] if not args.subjects_unlabeled else args.subjects_unlabeled.split(',')
if not args.render:
    subjects_test = args.subjects_test.split(',')
else:
    subjects_test = [args.viz_subject]


action_filter = None if args.actions == '*' else args.actions.split(',')
if action_filter is not None:
    print('Selected actions:', action_filter)


# set receptive_field as number assigned
receptive_field = args.number_of_frames
print('INFO: Receptive field: {} frames'.format(receptive_field))
if not args.nolog:
    writer.add_text(args.log+'_'+TIMESTAMP + '/Receptive field', str(receptive_field))
pad = (receptive_field -1) // 2 # Padding on each side
min_loss = args.min_loss

model_pos_train = FinePOSE(args, joints_left, joints_right, is_train=True)
model_pos_test_temp = FinePOSE(args,joints_left, joints_right, is_train=False)
model_pos = FinePOSE(args,joints_left, joints_right,  is_train=False, num_proposals=args.num_proposals, sampling_timesteps=args.sampling_timesteps)


def set_requires_grad(nets, requires_grad=False):
    """Set requies_grad for all the networks.

    Args:
        nets (nn.Module | list[nn.Module]): A list of networks or a single
            network.
        requires_grad (bool): Whether the networks require gradients or not
    """
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad


def encode_text(text):
    with torch.no_grad():
        text = clip.tokenize(text, truncate=True).cuda()
        return text

#################
causal_shift = 0
model_params = 0
for parameter in model_pos.parameters():
    model_params += parameter.numel()
print('INFO: Trainable parameter count:', model_params/1000000, 'Million')
if not args.nolog:
    writer.add_text(args.log+'_'+TIMESTAMP + '/Trainable parameter count', str(model_params/1000000) + ' Million')

# make model parallel
if torch.cuda.is_available():
    model_pos = nn.DataParallel(model_pos)
    model_pos = model_pos.cuda()
    model_pos_train = nn.DataParallel(model_pos_train)
    model_pos_train = model_pos_train.cuda()
    model_pos_test_temp = nn.DataParallel(model_pos_test_temp)
    model_pos_test_temp = model_pos_test_temp.cuda()

if args.resume or args.evaluate:
    chk_filename = os.path.join(args.checkpoint, args.resume if args.resume else args.evaluate)
    print('Loading checkpoint', chk_filename)
    checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
    print('This model was trained for {} epochs'.format(checkpoint['epoch']))
    model_pos_train.load_state_dict(checkpoint['model_pos'], strict=False)
    model_pos.load_state_dict(checkpoint['model_pos'], strict=False)


test_generator = UnchunkedGenerator_Seq(None, out_poses_3d_test, out_poses_2d_test,
                                    pad=pad, causal_shift=causal_shift, augment=False,
                                    kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right, valid_frame=valid_frame)
print('INFO: Testing on {} frames'.format(test_generator.num_frames()))
if not args.nolog:
    writer.add_text(args.log+'_'+TIMESTAMP + '/Testing Frames', str(test_generator.num_frames()))

def eval_data_prepare(receptive_field, inputs_2d, inputs_3d, valid_frame):

    assert inputs_2d.shape[:-1] == inputs_3d.shape[:-1], "2d and 3d inputs shape must be same! "+str(inputs_2d.shape)+str(inputs_3d.shape)
    inputs_2d_p = torch.squeeze(inputs_2d)
    inputs_3d_p = torch.squeeze(inputs_3d)
    valid_frame = valid_frame.unsqueeze(1)

    if inputs_2d_p.shape[0] / receptive_field > inputs_2d_p.shape[0] // receptive_field: 
        out_num = inputs_2d_p.shape[0] // receptive_field+1
    elif inputs_2d_p.shape[0] / receptive_field == inputs_2d_p.shape[0] // receptive_field:
        out_num = inputs_2d_p.shape[0] // receptive_field

    eval_input_2d = torch.empty(out_num, receptive_field, inputs_2d_p.shape[1], inputs_2d_p.shape[2])
    eval_input_3d = torch.empty(out_num, receptive_field, inputs_3d_p.shape[1], inputs_3d_p.shape[2])
    eval_valid_frame = torch.empty(out_num, receptive_field, 1)

    for i in range(out_num-1):
        eval_input_2d[i,:,:,:] = inputs_2d_p[i*receptive_field:i*receptive_field+receptive_field,:,:]
        eval_input_3d[i,:,:,:] = inputs_3d_p[i*receptive_field:i*receptive_field+receptive_field,:,:]
        eval_valid_frame[i, :, :] = valid_frame[i * receptive_field:i * receptive_field + receptive_field, :]
    if inputs_2d_p.shape[0] < receptive_field:
        from torch.nn import functional as F
        pad_right = receptive_field-inputs_2d_p.shape[0]
        inputs_2d_p = rearrange(inputs_2d_p, 'b f c -> f c b')
        inputs_2d_p = F.pad(inputs_2d_p, (0,pad_right), mode='replicate')
        inputs_2d_p = rearrange(inputs_2d_p, 'f c b -> b f c')
    if inputs_3d_p.shape[0] < receptive_field:
        pad_right = receptive_field-inputs_3d_p.shape[0]
        inputs_3d_p = rearrange(inputs_3d_p, 'b f c -> f c b')
        inputs_3d_p = F.pad(inputs_3d_p, (0,pad_right), mode='replicate')
        inputs_3d_p = rearrange(inputs_3d_p, 'f c b -> b f c')
    if valid_frame.shape[0] < receptive_field:
        pad_right = receptive_field-valid_frame.shape[0]
        valid_frame = rearrange(valid_frame, 'f c -> c f')
        valid_frame = F.pad(valid_frame, (0,pad_right), mode='replicate')
        valid_frame = rearrange(valid_frame, 'c f -> f c')
    eval_input_2d[-1,:,:,:] = inputs_2d_p[-receptive_field:,:,:]
    eval_input_3d[-1,:,:,:] = inputs_3d_p[-receptive_field:,:,:]
    eval_valid_frame[-1, :, :] = valid_frame[-receptive_field:, :]

    return eval_input_2d, eval_input_3d, eval_valid_frame


def pose_post_process(pose_pred, data_list, keys, receptive_field):
    for ii in range(pose_pred.shape[0] - 1):
        data_list[keys][:, ii * receptive_field:(ii + 1) * receptive_field] = pose_pred[ii]
    data_list[keys][:, -receptive_field:] = pose_pred[-1]
    data_list[keys] = data_list[keys].transpose(3, 2, 1, 0)
    return data_list

def cam_mm_to_pix(cam, cam_data):
    # w, h, ss_x, ss_y
    mx = cam_data[0] / cam_data[2]
    my = cam_data[1] / cam_data[3]
    cam[0] = cam[0] * mx
    cam[1] = cam[1] * my
    cam[2] = cam[2] * mx + cam_data[0]/2
    cam[3] = cam[3] * my + cam_data[1]/2

    return cam
###################


pre_text_information = [
    "A person",
    "speed",
    "head",
    "body",
    "arm",
    "leg",
]

pre_text_tensor = []
for i in pre_text_information:
    tmp_text = encode_text(i)
    pre_text_tensor.append(tmp_text)

pre_text_tensor = torch.cat(pre_text_tensor, dim=0)



# Training start
if not args.evaluate:
    flag_best_20_10 = False

    lr = args.learning_rate
    optimizer = optim.AdamW(model_pos_train.parameters(), lr=lr, weight_decay=0.1)

    lr_decay = args.lr_decay
    losses_3d_train = []
    losses_3d_pos_train = []
    losses_3d_diff_train = []
    losses_3d_train_eval = []
    losses_3d_valid = []
    losses_3d_depth_valid = []

    epoch = 0
    best_epoch = 0
    initial_momentum = 0.1
    final_momentum = 0.001

    # get training data

    train_generator = ChunkedGenerator_Seq(args.batch_size//args.stride, None, out_poses_3d_train, out_poses_2d_train, args.number_of_frames,
                                       pad=pad, causal_shift=causal_shift, shuffle=True, augment=args.data_augmentation,
                                       kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right)
    train_generator_eval = UnchunkedGenerator_Seq(None, out_poses_3d_train, out_poses_2d_train,
                                              pad=pad, causal_shift=causal_shift, augment=False)
    print('INFO: Training on {} frames'.format(train_generator_eval.num_frames()))
    if not args.nolog:
        writer.add_text(args.log+'_'+TIMESTAMP + '/Training Frames', str(train_generator_eval.num_frames()))

    if args.resume:
        epoch = checkpoint['epoch']
        if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None:
            optimizer.load_state_dict(checkpoint['optimizer'])
            train_generator.set_random_state(checkpoint['random_state'])
        else:
            print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.')
        if not args.coverlr:
            lr = checkpoint['lr']

    print('** Note: reported losses are averaged over all frames.')
    print('** The final evaluation will be carried out after the last training epoch.')

    # Pos model only
    while epoch < args.epochs:
        start_time = time()
        epoch_loss_3d_train = 0
        epoch_loss_3d_pos_train = 0
        epoch_loss_3d_diff_train = 0
        epoch_loss_traj_train = 0
        epoch_loss_2d_train_unlabeled = 0
        N = 0
        N_semi = 0
        model_pos_train.train()
        iteration = 0

        num_batches = train_generator.batch_num()

        # Just train 1 time, for quick debug
        quickdebug=args.debug
        for cameras_train, batch_3d, batch_2d in train_generator.next_epoch():

            

            if iteration % 1000 == 0:
                print("%d/%d"% (iteration, num_batches))

            if cameras_train is not None:
                cameras_train = torch.from_numpy(cameras_train.astype('float32'))
            inputs_3d = torch.from_numpy(batch_3d.astype('float32'))
            inputs_2d = torch.from_numpy(batch_2d.astype('float32'))

            pre_text_tensor_train = pre_text_tensor.unsqueeze(dim=0)
            pre_text_tensor_train = pre_text_tensor_train.repeat(inputs_3d.shape[0], 1, 1)

            if torch.cuda.is_available():
                inputs_3d = inputs_3d.cuda()
                inputs_2d = inputs_2d.cuda()
                pre_text_tensor_train = pre_text_tensor_train.cuda()
                if cameras_train is not None:
                    cameras_train = cameras_train.cuda()
            inputs_traj = inputs_3d[:, :, 14:15].clone()
            inputs_3d[:, :, 14] = 0

            optimizer.zero_grad()

            # Predict 3D poses
            predicted_3d_pos = model_pos_train(inputs_2d, inputs_3d, pre_text_tensor_train)

            loss_3d_pos = mpjpe(predicted_3d_pos, inputs_3d)

            loss_total = loss_3d_pos
            
            loss_total.backward(loss_total.clone().detach())

            loss_total = torch.mean(loss_total)

            epoch_loss_3d_train += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_total.item()
            epoch_loss_3d_pos_train += inputs_3d.shape[0] * inputs_3d.shape[1] * loss_3d_pos.item()
            N += inputs_3d.shape[0] * inputs_3d.shape[1]

            optimizer.step()


            iteration += 1

            if quickdebug:
                if N==inputs_3d.shape[0] * inputs_3d.shape[1]:
                    break

        losses_3d_train.append(epoch_loss_3d_train / N)
        losses_3d_pos_train.append(epoch_loss_3d_pos_train / N)

        # End-of-epoch evaluation
        with torch.no_grad():
            model_pos_test_temp.load_state_dict(model_pos_train.state_dict(), strict=False)
            model_pos_test_temp.eval()

            epoch_loss_3d_valid = None
            epoch_loss_3d_depth_valid = 0
            epoch_loss_traj_valid = 0
            epoch_loss_2d_valid = 0
            epoch_loss_3d_vel = 0
            N = 0
            iteration = 0
            if not args.no_eval:
                # Evaluate on test set
                for cam, batch, batch_2d, batch_valid, _ in test_generator.next_epoch():
                    inputs_3d = torch.from_numpy(batch.astype('float32'))
                    inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
                    inputs_valid = torch.from_numpy(batch_valid.astype('float32'))

                    ##### apply test-time-augmentation (following D3DP)
                    inputs_2d_flip = inputs_2d.clone()
                    inputs_2d_flip[:, :, :, 0] *= -1
                    inputs_2d_flip[:, :, kps_left + kps_right, :] = inputs_2d_flip[:, :, kps_right + kps_left, :]

                    ##### convert size
                    inputs_3d_p = inputs_3d
                    inputs_2d, inputs_3d, valid_frame = eval_data_prepare(receptive_field, inputs_2d, inputs_3d_p, inputs_valid)
                    inputs_2d_flip, _, _ = eval_data_prepare(receptive_field, inputs_2d_flip, inputs_3d_p, inputs_valid)
                    pre_text_tensor_valid = pre_text_tensor.unsqueeze(dim=0)
                    pre_text_tensor_valid = pre_text_tensor_valid.repeat(inputs_2d.shape[0], 1, 1)

                    if torch.cuda.is_available():
                        inputs_3d = inputs_3d.cuda()
                        inputs_2d = inputs_2d.cuda()
                        inputs_2d_flip = inputs_2d_flip.cuda()
                        pre_text_tensor_valid = pre_text_tensor_valid.cuda()
                    inputs_3d[:, :, 14] = 0

                    bs = 4
                    total_batch = (inputs_3d.shape[0] + bs - 1) // bs

                    for batch_cnt in range(total_batch):

                        if (batch_cnt + 1) * bs > inputs_3d.shape[0]:
                            inputs_2d_single = inputs_2d[batch_cnt * bs:]
                            inputs_2d_flip_single = inputs_2d_flip[batch_cnt * bs:]
                            inputs_3d_single = inputs_3d[batch_cnt * bs:]
                            pre_text_tensor_valid_single = pre_text_tensor_valid[batch_cnt * bs:]
                            valid_frame_single = valid_frame[batch_cnt * bs:]
                        else:
                            inputs_2d_single = inputs_2d[batch_cnt * bs:(batch_cnt + 1) * bs]
                            inputs_2d_flip_single = inputs_2d_flip[batch_cnt * bs:(batch_cnt + 1) * bs]
                            inputs_3d_single = inputs_3d[batch_cnt * bs:(batch_cnt + 1) * bs]
                            pre_text_tensor_valid_single = pre_text_tensor_valid[batch_cnt * bs:(batch_cnt + 1) * bs]
                            valid_frame_single = valid_frame[batch_cnt * bs:(batch_cnt + 1) * bs]


                        predicted_3d_pos_single = model_pos_test_temp(inputs_2d_single, inputs_3d_single, pre_text_tensor_valid_single,
                                                      input_2d_flip=inputs_2d_flip_single)

                        predicted_3d_pos_single[:, :, :, :, 14] = 0

                        error = mpjpe_diffusion_3dhp(predicted_3d_pos_single, inputs_3d_single, valid_frame_single.type(torch.bool))

                        if iteration == 0:
                            epoch_loss_3d_valid = inputs_3d_single.shape[0] * inputs_3d_single.shape[1] * error.clone()
                        else:
                            epoch_loss_3d_valid += inputs_3d_single.shape[0] * inputs_3d_single.shape[1] * error.clone()

                        N += inputs_3d_single.shape[0] * inputs_3d_single.shape[1]


                        iteration += 1

                        if quickdebug:
                            if N == inputs_3d_single.shape[0] * inputs_3d_single.shape[1]:
                                break

                    if quickdebug:
                        if N == inputs_3d_single.shape[0] * inputs_3d_single.shape[1]:
                            break


                losses_3d_valid.append(epoch_loss_3d_valid / N)


        elapsed = (time() - start_time) / 60

        if args.no_eval:
            print('[%d] time %.2f lr %f 3d_train %f 3d_pos_train %f 3d_diff_train %f' % (
                epoch + 1,
                elapsed,
                lr,
                losses_3d_train[-1] * 1000,
                losses_3d_pos_train[-1] * 1000,
                losses_3d_diff_train[-1] * 1000
            ))

            log_path = os.path.join(args.checkpoint, 'training_log.txt')
            f = open(log_path, mode='a')
            f.write('[%d] time %.2f lr %f 3d_train %f 3d_pos_train %f 3d_diff_train %f\n' % (
                epoch + 1,
                elapsed,
                lr,
                losses_3d_train[-1] * 1000,
                losses_3d_pos_train[-1] * 1000,
                losses_3d_diff_train[-1] * 1000
            ))
            f.close()

        else:
            print('[%d] time %.2f lr %f 3d_train %f 3d_pos_train %f 3d_pos_valid %f' % (
                epoch + 1,
                elapsed,
                lr,
                losses_3d_train[-1],
                losses_3d_pos_train[-1],
                losses_3d_valid[-1][0]
            ))

            log_path = os.path.join(args.checkpoint, 'training_log.txt')
            f = open(log_path, mode='a')
            f.write('[%d] time %.2f lr %f 3d_train %f 3d_pos_train %f 3d_pos_valid %f\n' % (
                epoch + 1,
                elapsed,
                lr,
                losses_3d_train[-1],
                losses_3d_pos_train[-1],
                losses_3d_valid[-1][0]
            ))
            f.close()

            if not args.nolog:
                writer.add_scalar("Loss/3d validation loss", losses_3d_valid[-1] * 1000, epoch+1)
        if not args.nolog:
            writer.add_scalar("Loss/3d training loss", losses_3d_train[-1] * 1000, epoch+1)
            writer.add_scalar("Parameters/learing rate", lr, epoch+1)
            writer.add_scalar('Parameters/training time per epoch', elapsed, epoch+1)
        # Decay learning rate exponentially
        lr *= lr_decay
        for param_group in optimizer.param_groups:
            param_group['lr'] *= lr_decay
        epoch += 1

        # Save checkpoint if necessary
        if epoch % args.checkpoint_frequency == 0 and epoch >= 80:
            chk_path = os.path.join(args.checkpoint, 'epoch_{}.bin'.format(epoch))
            print('Saving checkpoint to', chk_path)

            torch.save({
                'epoch': epoch,
                'lr': lr,
                'random_state': train_generator.random_state(),
                'optimizer': optimizer.state_dict(),
                'model_pos': model_pos_train.state_dict(),
            }, chk_path)

        #### save best checkpoint
        best_chk_path = os.path.join(args.checkpoint, 'best_epoch_1_1.bin')
        best_chk_path_epoch = os.path.join(args.checkpoint, 'best_epoch_20_10.bin')
        if losses_3d_valid[-1][0] < min_loss:
            min_loss = losses_3d_valid[-1]
            best_epoch = epoch
            print("save best checkpoint")
            torch.save({
                'epoch': epoch,
                'lr': lr,
                'random_state': train_generator.random_state(),
                'optimizer': optimizer.state_dict(),
                'model_pos': model_pos_train.state_dict(),
            }, best_chk_path)
            if epoch >= 80:
                torch.save({
                    'epoch': epoch,
                    'lr': lr,
                    'random_state': train_generator.random_state(),
                    'optimizer': optimizer.state_dict(),
                    'model_pos': model_pos_train.state_dict(),
                }, best_chk_path_epoch)

            f = open(log_path, mode='a')
            f.write('best epoch\n')
            f.close()

        # Save training curves after every epoch, as .png images (if requested)
        if args.export_training_curves and epoch > 3:
            if 'matplotlib' not in sys.modules:
                import matplotlib
                matplotlib.use('Agg')
                import matplotlib.pyplot as plt

            plt.figure()
            epoch_x = np.arange(3, len(losses_3d_train)) + 1
            plt.plot(epoch_x, losses_3d_train[3:], '--', color='C0')
            plt.plot(epoch_x, losses_3d_train_eval[3:], color='C0')
            plt.plot(epoch_x, losses_3d_valid[3:], color='C1')
            plt.legend(['3d train', '3d train (eval)', '3d valid (eval)'])
            plt.ylabel('MPJPE (m)')
            plt.xlabel('Epoch')
            plt.xlim((3, epoch))
            plt.savefig(os.path.join(args.checkpoint, 'loss_3d.png'))

            plt.close('all')
# Training end

# Evaluate
def evaluate(test_generator, action=None, return_predictions=False, use_trajectory_model=False, newmodel=None):
    epoch_loss_3d_pos = torch.zeros(args.sampling_timesteps).cuda()
    epoch_loss_3d_pos_mean = torch.zeros(args.sampling_timesteps).cuda()

    with torch.no_grad():
        if newmodel is not None:
            print('Loading comparison model')
            model_eval = newmodel
            chk_file_path = 'checkpoint/train_pf_00/epoch_60.bin'
            print('Loading evaluate checkpoint of comparison model', chk_file_path)
            checkpoint = torch.load(chk_file_path, map_location=lambda storage, loc: storage)
            model_eval.load_state_dict(checkpoint['model_pos'], strict=False)
            model_eval.eval()
        else:
            model_eval = model_pos
            if not use_trajectory_model:
                # load best checkpoint
                if args.evaluate == '':
                    chk_file_path = os.path.join(args.checkpoint, 'best_epoch_%d_%.2f.bin' % (best_epoch, min_loss))
                    print('Loading best checkpoint', chk_file_path)
                elif args.evaluate != '':
                    chk_file_path = os.path.join(args.checkpoint, args.evaluate)
                    print('Loading evaluate checkpoint', chk_file_path)
                checkpoint = torch.load(chk_file_path, map_location=lambda storage, loc: storage)
                print('This model was trained for {} epochs'.format(checkpoint['epoch']))
                model_eval.load_state_dict(checkpoint['model_pos'])
                model_eval.eval()
        # else:
            # model_traj.eval()
        N = 0
        iteration = 0
        data_inference_all = {}
        data_inference_mean = {}
        data_inference_h_min = {}
        data_inference_joint_min = {}
        data_inference_reproj_min = {}

        cam_1 = torch.tensor([7.32506, 7.32506, -0.0322884, 0.0929296, 0, 0, 0, 0, 0])
        cam_data_1 = [2048, 2048, 10, 10] #width, height, sensorSize_x, sensorSize_y
        cam_2 = torch.tensor([8.770747185, 8.770747185, -0.104908645, 0.104899704, 0, 0, 0, 0, 0])
        cam_data_2 = [1920, 1080, 10, 5.625]  # width, height, sensorSize_x, sensorSize_y
        cam_1 = cam_mm_to_pix(cam_1, cam_data_1)
        cam_2 = cam_mm_to_pix(cam_2, cam_data_2)

        quickdebug=args.debug
        for _, batch, batch_2d, batch_valid, keys in test_generator.next_epoch():

            inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
            inputs_3d = torch.from_numpy(batch.astype('float32'))
            inputs_valid = torch.from_numpy(batch_valid.astype('float32'))

            _, f_sz, j_sz, c_sz = inputs_3d.shape
            data_inference_all[keys] = np.zeros((args.sampling_timesteps, args.num_proposals, f_sz, j_sz, c_sz))
            data_inference_mean[keys] = np.zeros((args.sampling_timesteps, f_sz, j_sz, c_sz))
            data_inference_h_min[keys] = np.zeros((args.sampling_timesteps, f_sz, j_sz, c_sz))
            data_inference_joint_min[keys] = np.zeros((args.sampling_timesteps, f_sz, j_sz, c_sz))
            data_inference_reproj_min[keys] = np.zeros((args.sampling_timesteps, f_sz, j_sz, c_sz))

            print(keys)

            ##### apply test-time-augmentation (following D3DP)
            inputs_2d_flip = inputs_2d.clone()
            inputs_2d_flip [:, :, :, 0] *= -1
            inputs_2d_flip[:, :, kps_left + kps_right,:] = inputs_2d_flip[:, :, kps_right + kps_left,:]

            ##### convert size
            inputs_3d_p = inputs_3d
            if newmodel is not None:
                def eval_data_prepare_pf(receptive_field, inputs_2d, inputs_3d):
                    inputs_2d_p = torch.squeeze(inputs_2d)
                    inputs_3d_p = inputs_3d.permute(1,0,2,3)
                    padding = int(receptive_field//2)
                    inputs_2d_p = rearrange(inputs_2d_p, 'b f c -> f c b')
                    inputs_2d_p = F.pad(inputs_2d_p, (padding,padding), mode='replicate')
                    inputs_2d_p = rearrange(inputs_2d_p, 'f c b -> b f c')
                    out_num = inputs_2d_p.shape[0] - receptive_field + 1
                    eval_input_2d = torch.empty(out_num, receptive_field, inputs_2d_p.shape[1], inputs_2d_p.shape[2])
                    for i in range(out_num):
                        eval_input_2d[i,:,:,:] = inputs_2d_p[i:i+receptive_field, :, :]
                    return eval_input_2d, inputs_3d_p
                
                inputs_2d, inputs_3d = eval_data_prepare_pf(81, inputs_2d, inputs_3d_p)
                inputs_2d_flip, _ = eval_data_prepare_pf(81, inputs_2d_flip, inputs_3d_p)
            else:
                inputs_2d, inputs_3d, valid_frame = eval_data_prepare(receptive_field, inputs_2d, inputs_3d_p, inputs_valid)
                inputs_2d_flip, _, _ = eval_data_prepare(receptive_field, inputs_2d_flip, inputs_3d_p, inputs_valid)

            
            pre_text_tensor_valid = pre_text_tensor.unsqueeze(dim=0)
            pre_text_tensor_valid = pre_text_tensor_valid.repeat(inputs_2d.shape[0], 1, 1)
            
            if torch.cuda.is_available():
                inputs_2d = inputs_2d.cuda()
                inputs_2d_flip = inputs_2d_flip.cuda()
                inputs_3d = inputs_3d.cuda()
                pre_text_tensor_valid = pre_text_tensor_valid.cuda()


            bs = 2
            total_batch = (inputs_3d.shape[0] + bs - 1) // bs

            for batch_cnt in range(total_batch):

                if (batch_cnt + 1) * bs > inputs_3d.shape[0]:
                    inputs_2d_single = inputs_2d[batch_cnt * bs:]
                    inputs_2d_flip_single = inputs_2d_flip[batch_cnt * bs:]
                    inputs_3d_single = inputs_3d[batch_cnt * bs:]
                    valid_frame_single = valid_frame[batch_cnt * bs:]
                    pre_text_tensor_valid_single = pre_text_tensor_valid[batch_cnt * bs:]
                else:
                    inputs_2d_single = inputs_2d[batch_cnt * bs:(batch_cnt+1) * bs]
                    inputs_2d_flip_single = inputs_2d_flip[batch_cnt * bs:(batch_cnt+1) * bs]
                    inputs_3d_single = inputs_3d[batch_cnt * bs:(batch_cnt+1) * bs]
                    valid_frame_single = valid_frame[batch_cnt * bs:(batch_cnt + 1) * bs]
                    pre_text_tensor_valid_single = pre_text_tensor_valid[batch_cnt * bs:(batch_cnt + 1) * bs]

                traj = inputs_3d_single[:, :, 14:15].clone()
                inputs_3d_single[:, :, 14] = 0

                predicted_3d_pos_single = model_eval(inputs_2d_single, inputs_3d_single, pre_text_tensor_valid_single, input_2d_flip=inputs_2d_flip_single)

                predicted_3d_pos_single[:, :, :, :, 14] = 0

                # P-Agg
                mean_pose = torch.mean(predicted_3d_pos_single, dim=2, keepdim=False)

                # P-Best
                b,t,h,f,_,_ = predicted_3d_pos_single.shape
                target = inputs_3d_single.unsqueeze(1).unsqueeze(1).repeat(1, t, h, 1, 1, 1)
                errors = torch.norm(predicted_3d_pos_single - target, dim=len(target.shape) - 1)
                from einops import rearrange
                errors_h = rearrange(errors, 'b t h f n  -> t h b f n', ).reshape(t, h, -1)
                errors_h = torch.mean(errors_h, dim=-1, keepdim=True)
                h_min_indices = torch.min(errors_h, dim=1, keepdim=True).indices #t,1,1
                h_min_indices = h_min_indices.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).repeat(b, 1, 1, f, j_sz, c_sz)
                h_min_pose = torch.gather(predicted_3d_pos_single, 2, h_min_indices).squeeze(2)

                # J-Best
                joint_min_indices = torch.min(errors, dim=2, keepdim=True).indices  # b, t, 1, f, n
                joint_min_indices = joint_min_indices.unsqueeze(-1).repeat(1, 1, 1, 1, 1, c_sz)
                joint_min_pose = torch.gather(predicted_3d_pos_single, 2, joint_min_indices).squeeze(2)

                # J-Agg
                inputs_traj_single_all = traj.unsqueeze(1).unsqueeze(1).repeat(1, t, h, 1, 1, 1)
                predicted_3d_pos_abs_single = predicted_3d_pos_single + inputs_traj_single_all
                predicted_3d_pos_abs_single = predicted_3d_pos_abs_single.reshape(b * t * h * f, j_sz, c_sz)
                if keys == "TS5" or keys == "TS6":
                    cam = cam_2.clone()
                    cam_data = cam_data_2.copy()
                    reproject_func = project_to_2d
                else:
                    cam = cam_1.clone()
                    cam_data = cam_data_1.copy()
                    reproject_func = project_to_2d_linear

                # J-Agg: all hypotheses reprojection
                cam_single_all = cam.unsqueeze(0).repeat(b * t * h * f, 1).cuda()
                reproj_2d = reproject_func(predicted_3d_pos_abs_single, cam_single_all)
                reproj_2d = reproj_2d.reshape(b, t, h, f, j_sz, 2)

                # J-Agg: gt reprojection
                cam_single_gt = cam.unsqueeze(0).repeat(b * f, 1).cuda()
                input_3d_reproj = inputs_3d_single + traj
                input_3d_reproj = input_3d_reproj.reshape(b * f, 17, 3)
                reproj_2d_gt = reproject_func(input_3d_reproj, cam_single_gt)
                reproj_2d_gt = reproj_2d_gt.reshape(b, f, 17, 2)

                target_2d = torch.from_numpy(image_coordinates(inputs_2d_single[..., :2].cpu().numpy(), w=cam_data[0], h=cam_data[1])).cuda()
                target_2d = target_2d.unsqueeze(1).unsqueeze(1).repeat(1, t, h, 1, 1, 1)
                errors_2d = torch.norm(reproj_2d - target_2d, dim=len(target_2d.shape) - 1)
                reproj_min_indices = torch.min(errors_2d, dim=2, keepdim=True).indices
                reproj_min_indices = reproj_min_indices.unsqueeze(-1).repeat(1, 1, 1, 1, 1, c_sz)
                reproj_min_pose = torch.gather(predicted_3d_pos_single, 2, reproj_min_indices).squeeze(2)

                if batch_cnt == 0:
                    all_3d_pose_pred = predicted_3d_pos_single.cpu().numpy()
                    mean_3d_pose_pred = mean_pose.cpu().numpy()
                    h_min_3d_pose_pred = h_min_pose.cpu().numpy()
                    joint_min_3d_pose_pred = joint_min_pose.cpu().numpy()
                    reproj_min_3d_pose_pred = reproj_min_pose.cpu().numpy()
                else:
                    all_3d_pose_pred = np.concatenate((all_3d_pose_pred, predicted_3d_pos_single.cpu().numpy()), axis=0)
                    mean_3d_pose_pred = np.concatenate((mean_3d_pose_pred, mean_pose.cpu().numpy()), axis=0)
                    h_min_3d_pose_pred = np.concatenate((h_min_3d_pose_pred, h_min_pose.cpu().numpy()), axis=0)
                    joint_min_3d_pose_pred = np.concatenate((joint_min_3d_pose_pred, joint_min_pose.cpu().numpy()), axis=0)
                    reproj_min_3d_pose_pred = np.concatenate((reproj_min_3d_pose_pred, reproj_min_pose.cpu().numpy()), axis=0)


                if return_predictions:
                    return predicted_3d_pos_single.squeeze().cpu().numpy()


                error = mpjpe_diffusion_3dhp(predicted_3d_pos_single, inputs_3d_single, valid_frame_single.type(torch.bool))
                error_mean = mpjpe_diffusion_3dhp(predicted_3d_pos_single, inputs_3d_single, valid_frame_single.type(torch.bool), mean_pos=True)

                epoch_loss_3d_pos += inputs_3d_single.shape[0] * inputs_3d_single.shape[1] * error.clone()
                epoch_loss_3d_pos_mean += inputs_3d_single.shape[0] * inputs_3d_single.shape[1] * error_mean.clone()

                N += inputs_3d_single.shape[0] * inputs_3d_single.shape[1]


            for ii in range(all_3d_pose_pred.shape[0]-1):
                data_inference_all[keys][:, :, ii*receptive_field:(ii+1)*receptive_field] = all_3d_pose_pred[ii]
            data_inference_all[keys][:, :, -receptive_field:] = all_3d_pose_pred[-1]
            data_inference_all[keys] = data_inference_all[keys].transpose(4,3,2,1,0)

            data_inference_mean = pose_post_process(mean_3d_pose_pred, data_inference_mean, keys, receptive_field)
            data_inference_h_min = pose_post_process(h_min_3d_pose_pred, data_inference_h_min, keys, receptive_field)
            data_inference_joint_min = pose_post_process(joint_min_3d_pose_pred, data_inference_joint_min, keys, receptive_field)
            data_inference_reproj_min = pose_post_process(reproj_min_3d_pose_pred, data_inference_reproj_min, keys, receptive_field)


            log_path = os.path.join(args.checkpoint, '3dhp_test_log_H%d_K%d.txt' %(args.num_proposals, args.sampling_timesteps))
            f = open(log_path, mode='a')
            if keys is None:
                print('----------')
            else:
                print('----'+keys+'----')
                f.write('----'+keys+'----\n')

            e1 = (epoch_loss_3d_pos / N)
            e1_mean = (epoch_loss_3d_pos_mean / N)

            print('Test time augmentation:', test_generator.augment_enabled())
            for ii in range(e1.shape[0]):
                print('step %d : Protocol #1 Error (MPJPE) P_Best:' % ii, e1[ii].item(), 'mm')
                f.write('step %d : Protocol #1 Error (MPJPE) P_Best: %f mm\n' % (ii, e1[ii].item()))
                print('step %d : Protocol #1 Error (MPJPE) P_Agg:' % ii, e1_mean[ii].item(), 'mm')
                f.write('step %d : Protocol #1 Error (MPJPE) P_Agg: %f mm\n' % (ii, e1_mean[ii].item()))

            print('----------')
            f.write('----------\n')

            f.close()

            if quickdebug:
                break


        mat_path_mean = os.path.join(args.checkpoint, 'inference_data_P_Agg.mat')
        scio.savemat(mat_path_mean, data_inference_mean)
        mat_path_h_min = os.path.join(args.checkpoint, 'inference_data_P_Best.mat')
        scio.savemat(mat_path_h_min, data_inference_h_min)
        mat_path_joint_min = os.path.join(args.checkpoint, 'inference_data_J_Best.mat')
        scio.savemat(mat_path_joint_min, data_inference_joint_min)
        mat_path_reproj_min = os.path.join(args.checkpoint, 'inference_data_J_Agg.mat')
        scio.savemat(mat_path_reproj_min, data_inference_reproj_min)

    return e1, e1_mean

if args.render:
    print('Rendering...')

    input_keypoints = keypoints[args.viz_subject][args.viz_action][args.viz_camera].copy()
    ground_truth = None
    if args.viz_subject in dataset.subjects() and args.viz_action in dataset[args.viz_subject]:
        if 'positions_3d' in dataset[args.viz_subject][args.viz_action]:
            ground_truth = dataset[args.viz_subject][args.viz_action]['positions_3d'][args.viz_camera].copy()
    if ground_truth is None:
        print('INFO: this action is unlabeled. Ground truth will not be rendered.')

    gen = UnchunkedGenerator_Seq(None, [ground_truth], [input_keypoints],
                             pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation,
                             kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right)
    prediction = evaluate(gen, return_predictions=True)
    if args.compare:
        from common.model_poseformer import PoseTransformer
        model_pf = PoseTransformer(num_frame=81, num_joints=17, in_chans=2, num_heads=8, mlp_ratio=2., qkv_bias=False, qk_scale=None,drop_path_rate=0.1)
        if torch.cuda.is_available():
            model_pf = nn.DataParallel(model_pf)
            model_pf = model_pf.cuda()
        prediction_pf = evaluate(gen, newmodel=model_pf, return_predictions=True)
        

    if ground_truth.shape[0] / receptive_field > ground_truth.shape[0] // receptive_field: 
        batch_num = (ground_truth.shape[0] // receptive_field) +1
        prediction2 = np.empty_like(ground_truth)
        for i in range(batch_num-1):
            prediction2[i*receptive_field:(i+1)*receptive_field,:,:] = prediction[i,:,:,:]
        left_frames = ground_truth.shape[0] - (batch_num-1)*receptive_field
        prediction2[-left_frames:,:,:] = prediction[-1,-left_frames:,:,:]
        prediction = prediction2
    elif ground_truth.shape[0] / receptive_field == ground_truth.shape[0] // receptive_field:
        prediction.reshape(ground_truth.shape[0], 17, 3)

    if args.viz_export is not None:
        print('Exporting joint positions to', args.viz_export)
        # Predictions are in camera space
        np.save(args.viz_export, prediction)

    if args.viz_output is not None:
        if ground_truth is not None:
            # Reapply trajectory
            trajectory = ground_truth[:, :1]
            ground_truth[:, 1:] += trajectory
            prediction += trajectory
            if args.compare:
                prediction_pf += trajectory

        # Invert camera transformation
        cam = dataset.cameras()[args.viz_subject][args.viz_camera]
        if ground_truth is not None:
            if args.compare:
                prediction_pf = camera_to_world(prediction_pf, R=cam['orientation'], t=cam['translation'])
            prediction = camera_to_world(prediction, R=cam['orientation'], t=cam['translation'])
            ground_truth = camera_to_world(ground_truth, R=cam['orientation'], t=cam['translation'])
        else:
            # If the ground truth is not available, take the camera extrinsic params from a random subject.
            # They are almost the same, and anyway, we only need this for visualization purposes.
            for subject in dataset.cameras():
                if 'orientation' in dataset.cameras()[subject][args.viz_camera]:
                    rot = dataset.cameras()[subject][args.viz_camera]['orientation']
                    break
            if args.compare:
                prediction_pf = camera_to_world(prediction_pf, R=rot, t=0)
                prediction_pf[:, :, 2] -= np.min(prediction_pf[:, :, 2])
            prediction = camera_to_world(prediction, R=rot, t=0)
            # We don't have the trajectory, but at least we can rebase the height
            prediction[:, :, 2] -= np.min(prediction[:, :, 2])
        
        if args.compare:
            anim_output = {'PoseFormer': prediction_pf}
            anim_output['Ours'] = prediction
        else:
            anim_output = {'Reconstruction': ground_truth + np.random.normal(loc=0.0, scale=0.1, size=[ground_truth.shape[0], 17, 3])}
        
        if ground_truth is not None and not args.viz_no_ground_truth:
            anim_output['Ground truth'] = ground_truth

        input_keypoints = image_coordinates(input_keypoints[..., :2], w=cam['res_w'], h=cam['res_h'])

        from common.visualization import render_animation
        render_animation(input_keypoints, keypoints_metadata, anim_output,
                        dataset.skeleton(), dataset.fps(), args.viz_bitrate, cam['azimuth'], args.viz_output,
                        limit=args.viz_limit, downsample=args.viz_downsample, size=args.viz_size,
                        input_video_path=args.viz_video, viewport=(cam['res_w'], cam['res_h']),
                        input_video_skip=args.viz_skip)

else:
    print('Evaluating...')
    all_actions = {}
    all_actions_flatten = []
    all_actions_by_subject = {}

    def run_evaluation_all_actions(actions, action_filter=None):
        errors_p1 = []
        errors_p1_mean = []


        gen = UnchunkedGenerator_Seq(None, out_poses_3d_test, out_poses_2d_test,
                                     pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation,
                                     kps_left=kps_left, kps_right=kps_right, joints_left=joints_left,
                                     joints_right=joints_right, valid_frame=valid_frame)
        e1, e1_mean = evaluate(gen)


        errors_p1.append(e1)
        errors_p1_mean.append(e1_mean)


    if not args.by_subject:
        run_evaluation_all_actions(all_actions_flatten, action_filter)

if not args.nolog:
    writer.close()