import argparse
import json
import os
import pickle

import numpy as np


def parse_args():
    parser = argparse.ArgumentParser(
        description="Train the Frequenct Prior For RelDN."
    )
    parser.add_argument(
        "--overlap", action="store_true", help="Only count overlap boxes."
    )
    parser.add_argument(
        "--json-path",
        type=str,
        default="~/.mxnet/datasets/visualgenome",
        help="Only count overlap boxes.",
    )
    args = parser.parse_args()
    return args


args = parse_args()
use_overlap = args.overlap
PATH_TO_DATASETS = os.path.expanduser(args.json_path)
path_to_json = os.path.join(PATH_TO_DATASETS, "rel_annotations_train.json")

# format in y1y2x1x2
def with_overlap(boxA, boxB):
    xA = max(boxA[2], boxB[2])
    xB = min(boxA[3], boxB[3])

    if xB > xA:
        yA = max(boxA[0], boxB[0])
        yB = min(boxA[1], boxB[1])

        if yB > yA:
            return 1

    return 0


def box_ious(boxes):
    n = len(boxes)
    res = np.zeros((n, n))
    for i in range(n - 1):
        for j in range(i + 1, n):
            iou_val = with_overlap(boxes[i], boxes[j])
            res[i, j] = iou_val
            res[j, i] = iou_val
    return res


with open(path_to_json, "r") as f:
    tmp = f.read()
    train_data = json.loads(tmp)

fg_matrix = np.zeros((150, 150, 51), dtype=np.int64)
bg_matrix = np.zeros((150, 150), dtype=np.int64)

for _, item in train_data.items():
    gt_box_to_label = {}
    for rel in item:
        sub_bbox = rel["subject"]["bbox"]
        ob_bbox = rel["object"]["bbox"]
        sub_class = rel["subject"]["category"]
        ob_class = rel["object"]["category"]
        rel_class = rel["predicate"]

        sub_node = tuple(sub_bbox)
        ob_node = tuple(ob_bbox)
        if sub_node not in gt_box_to_label:
            gt_box_to_label[sub_node] = sub_class
        if ob_node not in gt_box_to_label:
            gt_box_to_label[ob_node] = ob_class

        fg_matrix[sub_class, ob_class, rel_class + 1] += 1

    if use_overlap:
        gt_boxes = [*gt_box_to_label]
        gt_classes = np.array([*gt_box_to_label.values()])
        iou_mat = box_ious(gt_boxes)
        cols, rows = np.where(iou_mat)
        if len(cols) and len(rows):
            for col, row in zip(cols, rows):
                bg_matrix[gt_classes[col], gt_classes[row]] += 1
        else:
            all_possib = np.ones_like(iou_mat, dtype=np.bool)
            np.fill_diagonal(all_possib, 0)
            cols, rows = np.where(all_possib)
            for col, row in zip(cols, rows):
                bg_matrix[gt_classes[col], gt_classes[row]] += 1
    else:
        for b1, l1 in gt_box_to_label.items():
            for b2, l2 in gt_box_to_label.items():
                if b1 == b2:
                    continue
                bg_matrix[l1, l2] += 1


eps = 1e-3
bg_matrix += 1
fg_matrix[:, :, 0] = bg_matrix
pred_dist = np.log(fg_matrix / (fg_matrix.sum(2)[:, :, None] + eps) + eps)


if use_overlap:
    with open("freq_prior_overlap.pkl", "wb") as f:
        pickle.dump(pred_dist, f)
else:
    with open("freq_prior.pkl", "wb") as f:
        pickle.dump(pred_dist, f)
