import json
import numpy as np
from sklearn import metrics

def calculate_iou(box1, box2):
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    intersection = max(0, x2 - x1) * max(0, y2 - y1)
    area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union = area_box1 + area_box2 - intersection
    iou = intersection / union if union > 0 else 0
    return iou

def calculate_acc(image_data, iou_threshold=0.5):
    pred_label = np.zeros(len(image_data.items()))
    gt_label = np.zeros(len(image_data.items()))
    for idx, item in (enumerate(image_data.items())):
        image_id, item_data = item
        detected_boxes = item_data['pred_boxes']
        ground_truth = item_data['gt_boxes']
        assigned_gt = [False] * len(ground_truth)
        assigned_pred = [False] * len(detected_boxes)
        if len(ground_truth) == 0:
            gt_label[idx] = 0
            if len(detected_boxes) == 0:
                pred_label[idx] = 0
            else:
                pred_label[idx] = 1
        else:
            gt_label[idx] = 1
            for i, pred_box in enumerate(detected_boxes):
                ious = [calculate_iou(pred_box, gt_box) for gt_box in ground_truth]
                max_iou_idx = np.argmax(ious)
                if ious[max_iou_idx] >= iou_threshold and not assigned_gt[max_iou_idx]:
                    assigned_gt[max_iou_idx] = True
                    assigned_pred[i] = True
            pred_label[idx] = 1 if all(assigned_gt) and all(assigned_pred) else 0
    return gt_label, pred_label


def calculate_precision_recall(image_data, iou_threshold=0.5):
    tp = 0
    fp = 0
    fn = 0
    for _, item in image_data.items():
        detected_boxes = item['pred_boxes']
        ground_truth = item['gt_boxes']
        assigned_gt = [False] * len(ground_truth)
        if len(ground_truth) == 0:
            fp += len(detected_boxes)
        else:
            for pred_box in detected_boxes:
                ious = [calculate_iou(pred_box, gt_box) for gt_box in ground_truth]
                max_iou_idx = np.argmax(ious)
                if ious[max_iou_idx] >= iou_threshold and not assigned_gt[max_iou_idx]:
                    tp += 1
                    assigned_gt[max_iou_idx] = True
                else:
                    fp += 1

        fn += sum(1 for assigned in assigned_gt if not assigned)

    precision = tp / (tp + fp) if tp + fp > 0 else 0
    recall = tp / (tp + fn) if tp + fn > 0 else 0

    return precision, recall

def filter_info(data):
    for key, result in data.items():
        scores = result['scores']
        pred_boxes = result['pred_boxes']

        new_scores = []
        new_pred_boxes = []
        for score, box in zip(scores, pred_boxes):
            if score >= 0.1:
                new_scores.append(score)
                new_pred_boxes.append(box)
        result['scores'] = new_scores
        result['pred_boxes'] = new_pred_boxes
    return data

json_file = 'test.json'

with open(json_file, "r") as f:
    image_data = json.load(f)

image_data = filter_info(image_data)
for iou_threshold in [0.5, 0.75]:
    dataset_precision, dataset_recall = calculate_precision_recall(image_data, iou_threshold)
    true_label, pred_label = calculate_acc(image_data, iou_threshold)
    acc = metrics.accuracy_score(true_label, pred_label)
    print(f"Precision: {dataset_precision:.2f}   Recall: {dataset_recall:.2f}   Acc: {acc:.2f}")