import json
import argparse
import re
import os
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
import pandas as pd


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--model-output', type=str, nargs='+', help='model output files')
    parser.add_argument('--punish-hallucination', action='store_true', help='punish hallucinations')
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('--micro', action='store_true', help='micro average (set-level)')
    group.add_argument('--macro', action='store_true', help='macro average (image-level)')
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('--annotations', type=str, nargs='+', help='post-processed annotation data in json')
    group.add_argument('--box-ground-truth', type=str, help='ground truth bounding boxes data')
    group.add_argument('--consolidated-annotations', type=str, help='consolidated annotation data in jsonl')

    args = parser.parse_args()

    y_true = {}

    if args.box_ground_truth is not None:
        # Load the ground truth bounding boxes
        with open(args.box_ground_truth, "r") as bbox_file:
            bboxes = json.load(bbox_file)

        for qid in bboxes.keys():
            y_true[qid] = [box[-1] for box in bboxes[qid]]
    elif args.annotations is not None:
        # Loop through the annotation files and load the bounding boxes
        for annotation_file in args.annotations:
            with open(annotation_file, "r") as f:
                annotations = json.load(f)
            # Construct the y_true dictionary
            for entry in annotations[list(annotations.keys())[0]]:
                bboxes_data = entry["bounding_box_labels"]
                y_true[entry["qid"]] = [box[-1] for box in bboxes_data]
    elif args.consolidated_annotations is not None:
        with open(args.consolidated_annotations, "r") as f:
            # load the jsonl file line by line
            for line in f:
                if line.strip():
                    d = json.loads(line)
                    qid = d["qid"]
                    bboxes_data = d["bounding_box_labels"]
                    y_true[qid] = [box[-1] for box in bboxes_data]

    # Initialize lists to store results
    output_list = [['Model Name', 'Size', 'Precision', 'Recall', 'F1']]

    # Iterate through model outputs
    for file_path in args.model_output:
        basename = os.path.basename(file_path)
        model_str = basename.split('_')[-1].rsplit('.', 1)[0]
        model_size = re.findall(r'(\d+)b', model_str, re.IGNORECASE)
        if model_size:
            model_size = int(model_size[0])
        else:
            model_size = float('inf')
        _y_true = y_true.copy()
        model_out = {}
        y_pred = {}
        # Load model outputs and parse the boxes mentioned
        with open(file_path, "r") as fp:
            for line in fp:
                if line.strip():
                    d = json.loads(line)
                    qid = d["qid"]
                    # If the qid is not in y_true, skip it
                    if qid not in y_true.keys():
                        continue
                    # Convert all bounding boxes references in the form of `Rn`
                    output_text_standardized = re.sub(
                        r"`region (\d+)`|`r(\d+)`|region (\d+)|r(\d+)|'r(\d+)'|box (\d+)",
                        lambda m: f"`R{m.group(1) or m.group(2) or m.group(3) or m.group(4) or m.group(5) or m.group(6)}`",
                        d['output_text'], flags=re.IGNORECASE)
                    model_out[qid] = output_text_standardized
                    # Get the list of all bounding boxes mentioned by the model
                    bbox_matches = re.findall(r"`R(\d+)`", output_text_standardized)
                    bbox_matches = sorted(list(set([int(match) - 1 for match in bbox_matches])))
                    pred = [0] * len(y_true[qid])
                    for idx in bbox_matches:
                        if idx < len(pred):
                            pred[idx] = 1
                        if args.punish_hallucination:
                            # If the model hallucinated a bounding box, we create new entries to both y_true and y_pred
                            # so that we can punish it
                            if idx >= len(y_true[qid]):
                                pred.append(1)
                                _y_true[qid].append(0)
                    y_pred[qid] = pred

        # If micro average is selected, we need to flatten the y_true and y_pred
        if args.micro:
            y_true_flattened = []
            y_pred_flattened = []
            for key in y_true.keys():
                y_true_flattened += y_true[key]
                y_pred_flattened += y_pred[key]
            assert len(y_true_flattened) == len(y_pred_flattened), "Length mismatch between y_true and y_pred"
            y_true_flattened = np.array(y_true_flattened)
            y_pred_flattened = np.array(y_pred_flattened)
            metrics = precision_recall_fscore_support(
                y_true_flattened, y_pred_flattened, pos_label=1, average="binary", zero_division=0
            )
            output_list.append([model_str, model_size, metrics[0] * 100, metrics[1] * 100, metrics[2] * 100])
        else:
            # If macro average is selected, we need to keep the original structure
            precision_ttl, recall_ttl, f1_ttl = 0.0, 0.0, 0.0
            for key in y_true.keys():
                # Here each key in y_true must have a corresponding key in y_pred
                y_true_flattened = np.array(y_true[key])
                y_pred_flattened = np.array(y_pred[key])

                assert len(y_true_flattened) == len(y_pred_flattened), "Length mismatch between y_true and y_pred"
                metrics = precision_recall_fscore_support(
                    y_true_flattened, y_pred_flattened, pos_label=1, average="binary", zero_division=0
                )
                precision_ttl += metrics[0]
                recall_ttl += metrics[1]
                f1_ttl += metrics[2]
            precision_ttl /= len(y_true)
            recall_ttl /= len(y_true)
            f1_ttl /= len(y_true)
            output_list.append([model_str, model_size, precision_ttl * 100, recall_ttl * 100, f1_ttl * 100])

    # Convert the output list to a DataFrame
    output_df = pd.DataFrame(output_list[1:], columns=output_list[0])
    output_df = output_df.sort_values(by=['Size'], ascending=[True])
    print(output_df.to_markdown(index=False))


if __name__ == "__main__":
    main()