import numpy as np

from .common import voc_ap, viou


def compute_detection_scores_per_class(gt_actions, pred_actions, viou_threshold):
    pred_actions = sorted(pred_actions, key=lambda x: x['score'], reverse=True)
    gt_detected = np.zeros((len(gt_actions),), dtype=bool)
    hit_scores = np.ones((len(pred_actions))) * -np.inf
    for pred_idx, pred_action in enumerate(pred_actions):
        ov_max = -float('Inf')
        k_max = -1
        for gt_idx, gt_action in enumerate(gt_actions):
            if not gt_detected[gt_idx] \
                    and pred_action['id'] == gt_action['id']:
                ov = viou(pred_action['trajectory'], pred_action['duration'],
                          gt_action['trajectory'], gt_action['duration'])
                if ov >= viou_threshold and ov > ov_max:
                    ov_max = ov
                    k_max = gt_idx
        if k_max >= 0:
            hit_scores[pred_idx] = pred_action['score']
            gt_detected[k_max] = True
    tp = np.isfinite(hit_scores)
    fp = ~tp
    cum_tp = np.cumsum(tp).astype(np.float32)
    cum_fp = np.cumsum(fp).astype(np.float32)
    rec = cum_tp / np.maximum(len(gt_actions), np.finfo(np.float32).eps)
    prec = cum_tp / np.maximum(cum_tp + cum_fp, np.finfo(np.float32).eps)
    return prec, rec, hit_scores


def evaluate(groundtruth, prediction, viou_threshold=0.5):
    """
    evaluate action detection
    """
    gt_classes = set()
    for tracks in groundtruth.values():
        for traj in tracks:
            gt_classes.add(traj['category'])
    gt_class_num = len(gt_classes)

    prediction_actions = dict()
    for vid, tracks in prediction.items():
        for traj in tracks:
            pred_action = {
                "id": vid,
                "score": traj['score'],
                "duration": traj['duration'],
                "trajectory": traj['trajectory']
            }
            if traj['category'] not in prediction_actions.keys():
                prediction_actions[traj['category']] = [pred_action]
            else:
                prediction_actions[traj['category']].append(pred_action)

    ap_class = dict()
    print('Computing average precision AP over {} classes...'.format(gt_class_num))

    for each_action in gt_classes:
        if each_action not in prediction_actions.keys():
            ap_class[each_action] = 0.
            continue

        groundtruth_actions = dict()
        for each_vid in groundtruth:
            # get groundtruth actions
            for each_gt_traj in groundtruth[each_vid]:
                if each_gt_traj['category'] == each_action:
                    gt_action = {
                        "id": each_vid,
                        "duration": each_gt_traj['duration'],
                        "trajectory": each_gt_traj['trajectory']
                    }
                    if each_action not in groundtruth_actions.keys():
                        groundtruth_actions[each_action] = [gt_action]
                    else:
                        groundtruth_actions[each_action].append(gt_action)

        pred_actions = prediction_actions[each_action]
        gt_actions = groundtruth_actions[each_action]

        det_prec, det_rec, det_scores = compute_detection_scores_per_class(
                gt_actions, pred_actions, viou_threshold)

        ap_class[each_action] = voc_ap(det_rec, det_prec)

    # compute mean ap and print
    print('=' * 30)
    ap_class_l = sorted(ap_class.items(), key=lambda ap_class: ap_class[0])
    for i, (category, ap) in enumerate(ap_class_l):
        print('{:>2}{:>20}\t{:.4f}'.format(i+1, category, ap))
    mean_ap = np.mean(list(ap_class.values()))
    print('=' * 30)
    print('{:>22}\t{:.4f}'.format('mean AP', mean_ap))

    return mean_ap, ap_class


if __name__ == '__main__':
    """
    You can directly run this script from the parent directory, e.g.,
    python -m evaluation.action_detection val_action_groundtruth.json val_action_prediction.json
    """
    import json
    from argparse import ArgumentParser

    parser = ArgumentParser(description='Action detection evaluation.')
    parser.add_argument('groundtruth', type=str, help='A ground truth json file generated by yourself')
    parser.add_argument('prediction', type=str, help='A prediction file')
    args = parser.parse_args()

    print('Loading ground truth from {}'.format(args.groundtruth))
    with open(args.groundtruth, 'r') as fp:
        gt = json.load(fp)
    print('Number of videos in ground truth: {}'.format(len(gt)))

    print('Loading prediction from {}'.format(args.prediction))
    with open(args.prediction, 'r') as fp:
        pred = json.load(fp)
    print('Number of videos in prediction: {}'.format(len(pred['results'])))

    mean_ap, ap_class = evaluate(gt, pred['results'])
