
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import time
import logging
import numpy as np
import torch

from outer_tools.lib.core.evaluate import accuracy,accuracy_kps,dist_acc
from outer_tools.lib.core.inference import get_final_preds
from outer_tools.lib.utils.transforms import flip_back
logger = logging.getLogger(__name__)


def train(config, train_loader, model, criterion, optimizer, epoch,
          output_dir, tb_log_dir, writer_dict):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()

    # switch to train mode
    model.train()
    end = time.time()
    for i, (input, target, target_weight, meta) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        input = input.cuda()
        target = target.cuda(non_blocking=True)
        target_weight = target_weight.cuda(non_blocking=True)

        # compute output
        outputs, _ = model(input)

        if isinstance(outputs, list):
            loss = criterion(outputs[0], target, target_weight)
            for output in outputs[1:]:
                loss += criterion(output, target, target_weight)
        else:
            output = outputs
            loss = criterion(output, target, target_weight)

        # compute gradient and do update step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        losses.update(loss.item(), input.size(0))

        _, avg_acc, cnt, pred = accuracy(output.detach().cpu().numpy(),
                                         target.detach().cpu().numpy())
        acc.update(avg_acc, cnt)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config.PRINT_FREQ == 0:
            msg = 'Epoch: [{0}/{1}][{2}/{3}]\t' \
                  'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                  'Speed {speed:.1f} samples/s\t' \
                  'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
                  'Loss {loss.val:.5f} ({loss.avg:.5f})\t' \
                  'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
                      epoch,config.TRAIN.END_EPOCH,i, len(train_loader), batch_time=batch_time,
                      speed=input.size(0)/batch_time.val,
                      data_time=data_time, loss=losses, acc=acc)
            logger.info(msg)

            writer = writer_dict['writer']
            global_steps = writer_dict['train_global_steps']
            writer.add_scalar('train_loss', losses.val, global_steps)
            writer.add_scalar('train_acc', acc.val, global_steps)
            writer_dict['train_global_steps'] = global_steps + 1


def validate(config, val_loader, val_dataset, model, criterion, output_dir,
             writer_dict=None, animalpose=False):
    batch_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()

    # switch to evaluate mode
    model.eval()

    num_samples = len(val_dataset)
    all_preds = np.zeros(
        (num_samples, config.MODEL.NUM_JOINTS, 3),
        dtype=np.float32
    )
    all_boxes = np.zeros((num_samples, 7)) if animalpose else np.zeros((num_samples, 6))
    image_path = []
    filenames = []
    imgnums = []
    idx = 0

    with torch.no_grad():
        end = time.time()
        for i, (input, target, target_weight, meta) in enumerate(val_loader):
            # compute output
            outputs = model(input)
            if isinstance(outputs, list):
                output = outputs[-1]
            else:
                output = outputs

            if config.TEST.FLIP_TEST:
                input_flipped = input.flip(3)
                outputs_flipped = model(input_flipped)

                if isinstance(outputs_flipped, list):
                    output_flipped = outputs_flipped[-1]
                else:
                    output_flipped = outputs_flipped

                output_flipped = flip_back(output_flipped.cpu().numpy(),
                                           val_dataset.flip_pairs)
                output_flipped = torch.from_numpy(output_flipped.copy()).cuda()

                # feature is not aligned, shift flipped heatmap for higher accuracy
                if config.TEST.SHIFT_HEATMAP:
                    output_flipped[:, :, :, 1:] = \
                        output_flipped.clone()[:, :, :, 0:-1]

                output = (output + output_flipped) * 0.5

            target = target.cuda(non_blocking=True)
            target_weight = target_weight.cuda(non_blocking=True)

            loss = criterion(output, target, target_weight)

            num_images = input.size(0)
            # measure accuracy and record loss
            losses.update(loss.item(), num_images)
            _, avg_acc, cnt, pred = accuracy(output.cpu().numpy(),
                                             target.cpu().numpy())

            acc.update(avg_acc, cnt)
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            c = meta['center'].numpy()
            s = meta['scale'].numpy()
            score = meta['score'].numpy()

            preds, maxvals = get_final_preds(
                config, output.clone().cpu().numpy(), c, s)

            all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
            all_preds[idx:idx + num_images, :, 2:3] = maxvals
            # double check this all_boxes parts
            all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2]
            all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2]
            all_boxes[idx:idx + num_images, 4] = np.prod(s*200, 1)
            all_boxes[idx:idx + num_images, 5] = score

            if animalpose:
                bbox_ids = meta['bbox_id'].numpy()
                all_boxes[idx:idx + num_images, 6] = bbox_ids

            image_path.extend(meta['image'])
            idx += num_images

            if i % config.PRINT_FREQ == 0:
                msg = 'Test: [{0}/{1}]\t' \
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
                      'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
                          i, len(val_loader), batch_time=batch_time,
                          loss=losses, acc=acc)
                logger.info(msg)

        name_values, perf_indicator = val_dataset.evaluate(
            config, all_preds, output_dir, all_boxes, image_path,
            filenames, imgnums
        )

        model_name = config.MODEL.NAME
        if isinstance(name_values, list):
            for name_value in name_values:
                _print_name_value(name_value, model_name)
        else:
            _print_name_value(name_values, model_name)

        if writer_dict:
            writer = writer_dict['writer']
            global_steps = writer_dict['valid_global_steps']
            writer.add_scalar(
                'valid_loss',
                losses.avg,
                global_steps
            )
            writer.add_scalar(
                'valid_acc',
                acc.avg,
                global_steps
            )
            if isinstance(name_values, list):
                for name_value in name_values:
                    writer.add_scalars(
                        'valid',
                        dict(name_value),
                        global_steps
                    )
            else:
                writer.add_scalars(
                    'valid',
                    dict(name_values),
                    global_steps
                )
            writer_dict['valid_global_steps'] = global_steps + 1
    logger.info(f"Average PCK: {acc.avg}")
    return perf_indicator,acc.avg


# evaluate both student and teacher models
def validate_mt(config, val_loader, val_dataset, model, model_ema, criterion, output_dir,
                tb_log_dir, writer_dict=None, animalpose=False):
    batch_time = AverageMeter()
    losses_sup = AverageMeter()
    losses_const = AverageMeter()
    acc = AverageMeter()
    acc_ema = AverageMeter()
    # switch to evaluate mode
    model.eval()
    model_ema.eval()

    num_samples = len(val_dataset)
    all_preds = np.zeros(
        (num_samples, config.MODEL.NUM_JOINTS, 3),
        dtype=np.float32
    )
    all_preds_ema = np.zeros(
        (num_samples, config.MODEL.NUM_JOINTS, 3),
        dtype=np.float32
    )
    all_boxes = np.zeros((num_samples, 7)) if animalpose else np.zeros((num_samples, 6))
    image_path = []
    filenames = []
    imgnums = []
    idx = 0
    with torch.no_grad():
        end = time.time()
        for i, (input, target, target_weight, meta) in enumerate(val_loader):
            # compute output
            outputs, _ = model(input)
            outputs_ema, _ = model_ema(input)
            if isinstance(outputs, list):
                output = outputs[-1]
                output_ema = outputs_ema[-1]
            else:
                output = outputs
                output_ema = outputs_ema
            if config.TEST.FLIP_TEST:
                input_flipped = input.flip(3)
                outputs_flipped, _ = model(input_flipped)
                outputs_flipped_ema, _ = model_ema(input_flipped)

                if isinstance(outputs_flipped, list):
                    output_flipped = outputs_flipped[-1]
                    output_flipped_ema = outputs_flipped_ema[-1]
                else:
                    output_flipped = outputs_flipped
                    output_flipped_ema = outputs_flipped_ema
                output_flipped = flip_back(output_flipped.cpu().numpy(),
                                           val_dataset.flip_pairs)
                output_flipped = torch.from_numpy(output_flipped.copy()).cuda()

                output_flipped_ema = flip_back(output_flipped_ema.cpu().numpy(),
                                               val_dataset.flip_pairs)
                output_flipped_ema = torch.from_numpy(output_flipped_ema.copy()).cuda()

                # feature is not aligned, shift flipped heatmap for higher accuracy
                if config.TEST.SHIFT_HEATMAP:
                    output_flipped[:, :, :, 1:] = \
                        output_flipped.clone()[:, :, :, 0:-1]

                    output_flipped_ema[:, :, :, 1:] = \
                        output_flipped_ema.clone()[:, :, :, 0:-1]
                output = (output + output_flipped) * 0.5
                output_ema = (output_ema + output_flipped_ema) * 0.5
            target = target.cuda(non_blocking=True)
            target_weight = target_weight.cuda(non_blocking=True)
            const_weight = torch.ones_like(target_weight).cuda()
            loss_sup = criterion(output, target, target_weight)
            loss_const =criterion(output, output_ema, const_weight)
            num_images = input.size(0)
            # measure accuracy and record loss
            losses_sup.update(loss_sup.item(), num_images)
            losses_const.update(loss_const.item(), num_images)
            _, avg_acc, cnt, pred = accuracy(output.cpu().numpy(),
                                             target.cpu().numpy())
            _, avg_acc_ema, cnt_ema, pred_ema = accuracy(output_ema.cpu().numpy(),
                                                         target.cpu().numpy())

            acc.update(avg_acc, cnt)
            acc_ema.update(avg_acc_ema, cnt_ema)
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            c = meta['center'].numpy()
            s = meta['scale'].numpy()
            score = meta['score'].numpy()

            preds, maxvals = get_final_preds(
                config, output.clone().cpu().numpy(), c, s)

            all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
            all_preds[idx:idx + num_images, :, 2:3] = maxvals
            # double check this all_boxes parts
            all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2]
            all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2]
            all_boxes[idx:idx + num_images, 4] = np.prod(s*200, 1)
            all_boxes[idx:idx + num_images, 5] = score

            preds_ema, maxvals_ema = get_final_preds(config, output_ema.clone().cpu().numpy(), c, s)
            all_preds_ema[idx:idx + num_images, :, 0:2] = preds_ema[:, :, 0:2]
            all_preds_ema[idx:idx + num_images, :, 2:3] = maxvals_ema

            if animalpose:
                bbox_ids = meta['bbox_id'].numpy()
                all_boxes[idx:idx + num_images, 6] = bbox_ids

            image_path.extend(meta['image'])

            idx += num_images

            if i % config.PRINT_FREQ == 0:
                msg = 'Test: [{0}/{1}]\t' \
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                      'Loss_sup {loss.val:.4f} ({loss.avg:.4f})\t' \
                      'Loss_const {loss_const.val:.4f} ({loss_const.avg:.4f})\t' \
                      'Accuracy {acc.val:.3f} ({acc.avg:.3f})\t' \
                      'Accuracy_ema {acc_ema.val:.3f} ({acc_ema.avg:.3f})'.format(
                          i, len(val_loader), batch_time=batch_time,
                          loss=losses_sup, loss_const=losses_const, acc=acc, acc_ema=acc_ema)
                logger.info(msg)

        name_values, perf_indicator = val_dataset.evaluate(
            config, all_preds, output_dir, all_boxes, image_path,
            filenames, imgnums
        )

        name_values_ema, perf_indicator_ema = val_dataset.evaluate(
            config, all_preds_ema, output_dir, all_boxes, image_path,
            filenames, imgnums
        )

        model_name = config.MODEL.NAME
        if isinstance(name_values, list):
            for name_value in name_values:
                _print_name_value(name_value, model_name)
        else:
            _print_name_value(name_values, model_name)

        if isinstance(name_values_ema, list):
            for name_value in name_values_ema:
                _print_name_value(name_value, model_name)
        else:
            _print_name_value(name_values_ema, model_name)

        if writer_dict:
            writer = writer_dict['writer']
            global_steps = writer_dict['valid_global_steps']
            writer.add_scalar(
                'valid_loss',
                losses_sup.avg,
                global_steps
            )
            writer.add_scalar(
                'valid_acc',
                acc.avg,
                global_steps
            )
            if isinstance(name_values, list):
                for name_value in name_values:
                    writer.add_scalars(
                        'valid',
                        dict(name_value),
                        global_steps
                    )
            else:
                writer.add_scalars(
                    'valid',
                    dict(name_values),
                    global_steps
                )
            writer_dict['valid_global_steps'] = global_steps + 1

    return perf_indicator_ema


def validate_kps(config, val_loader, val_dataset, model, criterion, output_dir,
                 writer_dict=None, animalpose=False, vis=False,return_pck=False,level=3):
    # for ap-10k
    batch_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()

    # switch to evaluate mode
    model.eval()

    num_samples = len(val_dataset)
    all_preds = np.zeros(
        (num_samples, config.MODEL.NUM_JOINTS, 3),
        dtype=np.float32
    )
    all_boxes = np.zeros((num_samples, 7)) if animalpose else np.zeros((num_samples, 6))
    image_path = []
    filenames = []
    imgnums = []
    idx = 0
    kps_dist = [[] for _ in range(17)]

    with torch.no_grad():
        end = time.time()
        for i, (input, target, target_weight, meta) in enumerate(val_loader):
            # compute output
            outputs = model(input)
            if isinstance(outputs, list):
                output = outputs[-1]
            else:
                output = outputs

            if config.TEST.FLIP_TEST:
                input_flipped = input.flip(3)
                outputs_flipped = model(input_flipped)

                if isinstance(outputs_flipped, list):
                    output_flipped = outputs_flipped[-1]
                else:
                    output_flipped = outputs_flipped

                output_flipped = flip_back(output_flipped.cpu().numpy(),
                                           val_dataset.flip_pairs)
                output_flipped = torch.from_numpy(output_flipped.copy()).cuda()

                # feature is not aligned, shift flipped heatmap for higher accuracy
                if config.TEST.SHIFT_HEATMAP:
                    output_flipped[:, :, :, 1:] = \
                        output_flipped.clone()[:, :, :, 0:-1]

                output = (output + output_flipped) * 0.5

            target = target.cuda(non_blocking=True)
            target_weight = target_weight.cuda(non_blocking=True)

            loss = criterion(output, target, target_weight)

            num_images = input.size(0)
            # measure accuracy and record loss
            losses.update(loss.item(), num_images)
            _, avg_acc, cnt, pred = accuracy(output.cpu().numpy(),
                                             target.cpu().numpy())
            dists_kps = accuracy_kps(output.cpu().numpy(),target.cpu().numpy())
            for j in range(len(kps_dist)):
                kps_dist[j].extend(dists_kps[j])
            acc.update(avg_acc, cnt)
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            c = meta['center'].numpy()
            s = meta['scale'].numpy()
            score = meta['score'].numpy()

            preds, maxvals = get_final_preds(
                config, output.clone().cpu().numpy(), c, s)

            all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
            all_preds[idx:idx + num_images, :, 2:3] = maxvals
            # double check this all_boxes parts
            all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2]
            all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2]
            all_boxes[idx:idx + num_images, 4] = np.prod(s*200, 1)
            all_boxes[idx:idx + num_images, 5] = score

            if animalpose:
                bbox_ids = meta['bbox_id'].numpy()
                all_boxes[idx:idx + num_images, 6] = bbox_ids

            image_path.extend(meta['image'])
            idx += num_images

            if i % config.PRINT_FREQ == 0:
                msg = 'Test: [{0}/{1}]\t' \
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
                      'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
                          i, len(val_loader), batch_time=batch_time,
                          loss=losses, acc=acc)
                logger.info(msg)

        name_values, perf_indicator = val_dataset.evaluate(
            config, all_preds, output_dir, all_boxes, image_path,
            filenames, imgnums
        )

        model_name = config.MODEL.NAME
        if isinstance(name_values, list):
            for name_value in name_values:
                _print_name_value(name_value, model_name)
        else:
            _print_name_value(name_values, model_name)

        if writer_dict:
            writer = writer_dict['writer']
            global_steps = writer_dict['valid_global_steps']
            writer.add_scalar(
                'valid_loss',
                losses.avg,
                global_steps
            )
            writer.add_scalar(
                'valid_acc',
                acc.avg,
                global_steps
            )
            if isinstance(name_values, list):
                for name_value in name_values:
                    writer.add_scalars(
                        'valid',
                        dict(name_value),
                        global_steps
                    )
            else:
                writer.add_scalars(
                    'valid',
                    dict(name_values),
                    global_steps
                )
            writer_dict['valid_global_steps'] = global_steps + 1
    logger.info(f"Average PCK: {acc.avg}")

    # Compute PCK of single keypoints
    kps_pck = np.zeros(len(kps_dist))
    for i in range(len(kps_dist)):
        kps_pck[i] = dist_acc(np.array(kps_dist[i]))

    kps = ["L_eye","R_eye","nose","neck","tail","L_F_hip","L_F_knee","L_F_paw","R_F_hip","R_F_knee","R_F_paw",
           "L_B_hip","L_B_knee","L_B_paw","R_B_hip","R_B_knee","R_B_paw"]
    for i in range(len(kps_pck)):
        logger.info("{} PCK : {}".format(kps[i],kps_pck[i]))

    # Compute PCK of group keypoints
    group_pck = np.zeros(5)
    # shared / exclusive / head / front / back
    group_dist = [[] for _ in range(5)]
    head_index = [0,1,2]
    front_index = [5,6,7,8,9,10]
    back_index = [4,11,12,13,14,15,16]

    shared_index_0 = [0,1,2,3,4]
    exclusive_index_0 = [5,6,7,8,9,10,11,12,13,14,15,16]

    shared_index_1 = [0,1,2,3,4,7,10,13,16]
    exclusive_index_1 = [5,6,8,9,11,12,14,15]

    shared_index_2 = [0,1,2,3,4,5,7,8,10,11,13,14,16]
    exclusive_index_2 = [6,9,12,15]

    shared_index_3 = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
    exclusive_index_3 = []

    shared_indices = [shared_index_0,shared_index_1,shared_index_2,shared_index_3]
    exclusive_indices = [exclusive_index_0,exclusive_index_1,exclusive_index_2,exclusive_index_3]
    shared_index = shared_indices[level]
    exclusive_index = exclusive_indices[level]

    for i in range(17):
        if i in shared_index:
            group_dist[3].extend(kps_dist[i])
            if i in head_index:
                group_dist[0].extend(kps_dist[i])
            elif i in front_index:
                group_dist[1].extend(kps_dist[i])
            elif i in back_index:
                group_dist[2].extend(kps_dist[i])
        elif i in exclusive_index:
            group_dist[4].extend(kps_dist[i])
        else:
            continue

    for i in range(len(group_dist)):
        group_pck[i] = max(dist_acc(np.array(group_dist[i])),0)

    logger.info("===================For Groups======================")
    groups = ["Head","Front","Back","Shared","Exclusive"]
    for i in range(len(group_pck)):
        logger.info("{} : {}".format(groups[i],group_pck[i]))

    if return_pck:
        return perf_indicator,acc.avg
    else:
        return perf_indicator


def validate_mix_kps(config, val_loader, val_dataset, model, criterion, output_dir,
                     writer_dict=None, animalpose=False, vis=False,return_pck=False):
    # for ap-10k  + animal pose
    batch_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()

    # switch to evaluate mode
    model.eval()

    num_samples = len(val_dataset)
    all_preds = np.zeros(
        (num_samples, config.MODEL.NUM_JOINTS, 3),
        dtype=np.float32
    )
    all_boxes = np.zeros((num_samples, 7)) if animalpose else np.zeros((num_samples, 6))
    image_path = []
    filenames = []
    imgnums = []
    idx = 0
    kps_dist = [[] for _ in range(21)]

    with torch.no_grad():
        end = time.time()
        for i, (input, target, target_weight, meta) in enumerate(val_loader):
            # compute output
            outputs = model(input)
            if isinstance(outputs, list):
                output = outputs[-1]
            else:
                output = outputs

            if config.TEST.FLIP_TEST:
                input_flipped = input.flip(3)
                outputs_flipped = model(input_flipped)

                if isinstance(outputs_flipped, list):
                    output_flipped = outputs_flipped[-1]
                else:
                    output_flipped = outputs_flipped

                output_flipped = flip_back(output_flipped.cpu().numpy(),
                                           val_dataset.flip_pairs)
                output_flipped = torch.from_numpy(output_flipped.copy()).cuda()

                # feature is not aligned, shift flipped heatmap for higher accuracy
                if config.TEST.SHIFT_HEATMAP:
                    output_flipped[:, :, :, 1:] = \
                        output_flipped.clone()[:, :, :, 0:-1]

                output = (output + output_flipped) * 0.5

            target = target.cuda(non_blocking=True)
            target_weight = target_weight.cuda(non_blocking=True)

            loss = criterion(output, target, target_weight)

            num_images = input.size(0)
            # measure accuracy and record loss
            losses.update(loss.item(), num_images)
            _, avg_acc, cnt, pred = accuracy(output.cpu().numpy(),
                                             target.cpu().numpy())
            dists_kps = accuracy_kps(output.cpu().numpy(),target.cpu().numpy())
            for j in range(len(kps_dist)):
                kps_dist[j].extend(dists_kps[j])
            acc.update(avg_acc, cnt)
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            c = meta['center'].numpy()
            s = meta['scale'].numpy()
            score = meta['score'].numpy()

            preds, maxvals = get_final_preds(
                config, output.clone().cpu().numpy(), c, s)

            all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
            all_preds[idx:idx + num_images, :, 2:3] = maxvals
            # double check this all_boxes parts
            all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2]
            all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2]
            all_boxes[idx:idx + num_images, 4] = np.prod(s*200, 1)
            all_boxes[idx:idx + num_images, 5] = score

            if animalpose:
                bbox_ids = meta['bbox_id'].numpy()
                all_boxes[idx:idx + num_images, 6] = bbox_ids

            image_path.extend(meta['image'])
            idx += num_images

            if i % config.PRINT_FREQ == 0:
                msg = 'Test: [{0}/{1}]\t' \
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
                      'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
                          i, len(val_loader), batch_time=batch_time,
                          loss=losses, acc=acc)
                logger.info(msg)

        name_values, perf_indicator = val_dataset.evaluate(
            config, all_preds, output_dir, all_boxes, image_path,
            filenames, imgnums
        )

        model_name = config.MODEL.NAME
        if isinstance(name_values, list):
            for name_value in name_values:
                _print_name_value(name_value, model_name)
        else:
            _print_name_value(name_values, model_name)

        if writer_dict:
            writer = writer_dict['writer']
            global_steps = writer_dict['valid_global_steps']
            writer.add_scalar(
                'valid_loss',
                losses.avg,
                global_steps
            )
            writer.add_scalar(
                'valid_acc',
                acc.avg,
                global_steps
            )
            if isinstance(name_values, list):
                for name_value in name_values:
                    writer.add_scalars(
                        'valid',
                        dict(name_value),
                        global_steps
                    )
            else:
                writer.add_scalars(
                    'valid',
                    dict(name_values),
                    global_steps
                )
            writer_dict['valid_global_steps'] = global_steps + 1
    logger.info(f"Average PCK: {acc.avg}")

    # Compute PCK of single keypoints
    kps_pck = np.zeros(len(kps_dist))
    for i in range(len(kps_dist)):
        kps_pck[i] = dist_acc(np.array(kps_dist[i]))

    kps = ["L_eye","R_eye","nose","neck","tail","L_F_hip","L_F_knee","L_F_paw","R_F_hip","R_F_knee","R_F_paw",
           "L_B_hip","L_B_knee","L_B_paw","R_B_hip","R_B_knee","R_B_paw","L_ear","R_ear","Throat","Wither"]
    for i in range(len(kps_pck)):
        logger.info("{} PCK : {}".format(kps[i],kps_pck[i]))

    # Compute PCK of group keypoints
    group_pck = np.zeros(5)
    # shared / exclusive / head / front / back
    group_dist = [[] for _ in range(5)]
    share_index = [0,1,2,4,5,6,7,8,9,10,11,12,13,14,15,16]
    head_index = [0,1,2]
    front_index = [5,6,7,8,9,10]
    back_index = [4,11,12,13,14,15,16]
    exclusive_index = [3,17,18,19,20]

    for i in range(21):
        if i in share_index:
            group_dist[3].extend(kps_dist[i])
            if i in head_index:
                group_dist[0].extend(kps_dist[i])
            elif i in front_index:
                group_dist[1].extend(kps_dist[i])
            elif i in back_index:
                group_dist[2].extend(kps_dist[i])
        elif i in exclusive_index:
            group_dist[4].extend(kps_dist[i])
        else:
            continue

    for i in range(len(group_dist)):
        group_pck[i] = max(dist_acc(np.array(group_dist[i])),0)

    groups = ["Head","Front","Back","Share","Exclusive"]
    logger.info("===================For Groups======================")
    for i in range(len(group_pck)):
        logger.info("{} PCK : {}".format(groups[i],group_pck[i]))

    if return_pck:
        return perf_indicator,acc.avg
    else:
        return perf_indicator


# markdown format output
def _print_name_value(name_value, full_arch_name):
    names = name_value.keys()
    values = name_value.values()
    num_values = len(name_value)
    logger.info(
        '| Arch ' +
        ' '.join(['| {}'.format(name) for name in names]) +
        ' |'
    )
    logger.info('|---' * (num_values+1) + '|')

    if len(full_arch_name) > 15:
        full_arch_name = full_arch_name[:8] + '...'
    logger.info(
        '| ' + full_arch_name + ' ' +
        ' '.join(['| {:.3f}'.format(value) for value in values]) +
         ' |'
    )


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count if self.count != 0 else 0
