import json
import numpy as np
import matplotlib.pyplot as plt
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
import seaborn as sns
import os


def plot_pr_curves(gt_json_paths, pred_json_paths, labels=None, iou_thresh=0.5):
    """
    Plot PR curves for multiple prediction files with corresponding ground truth files.
    
    Args:
        gt_json_paths (list[str]): List of paths to COCO-format ground truth JSONs.
        pred_json_paths (list[str]): List of prediction JSON file paths.
        labels (list[str], optional): List of labels for the curves. Defaults to filenames.
        iou_thresh (float): IoU threshold to evaluate.
    """
    # Ensure we have matching ground truth for each prediction
    if isinstance(gt_json_paths, str):
        gt_json_paths = [gt_json_paths] * len(pred_json_paths)
    elif len(gt_json_paths) != len(pred_json_paths):
        raise ValueError("Number of ground truth files must match number of prediction files")

    plt.figure(figsize=(8, 6))
    sns.set_style("whitegrid")

    if labels is None:
        labels = [f"Model {i+1}" for i in range(len(pred_json_paths))]

    # Keep track of temporary files for cleanup
    temp_files = []
    
    # Process each pair of ground truth and prediction files
    for gt_path, pred_path, label in zip(gt_json_paths, pred_json_paths, labels):
        print(f"Processing: {label} with GT: {gt_path}")
        
        # Load and ensure ground truth has required fields
        with open(gt_path, 'r') as f:
            gt_data = json.load(f)
        
        # Add required fields if missing
        if 'info' not in gt_data:
            gt_data['info'] = {'description': 'ForestPersons Dataset'}
        if 'licenses' not in gt_data:
            gt_data['licenses'] = [{'url': '', 'id': 1, 'name': 'none'}]
        
        # Write the modified ground truth to a temporary file
        gt_temp_path = f'temp_gt_{label}.json'
        temp_files.append(gt_temp_path)
        with open(gt_temp_path, 'w') as f:
            json.dump(gt_data, f)

        # Initialize COCO with ground truth
        coco_gt = COCO(gt_temp_path)
        
        # Print ground truth stats
        print(f"\nGround Truth Stats for {label}:")
        print(f"Number of images: {len(coco_gt.imgs)}")
        print(f"Number of annotations: {len(coco_gt.anns)}")
        print(f"Categories: {coco_gt.cats}")
        
        # Load and check predictions
        with open(pred_path, 'r') as f:
            pred_data = json.load(f)
        print(f"\nPrediction Stats for {label}:")
        print(f"Number of predictions: {len(pred_data)}")
        if len(pred_data) > 0:
            print(f"Sample prediction: {pred_data[0]}")
            # Check confidence scores
            scores = [pred.get('score', 0) for pred in pred_data]
            if scores:
                print(f"Score range: {min(scores):.3f} to {max(scores):.3f}")
                print(f"Number of predictions with score > 0.1: {sum(1 for s in scores if s > 0.1)}")
        
        coco_dt = coco_gt.loadRes(pred_path)
        
        coco_eval = COCOeval(coco_gt, coco_dt, iouType='bbox')
        coco_eval.params.iouThrs = np.array([iou_thresh])
        coco_eval.evaluate()
        coco_eval.accumulate()

        pr_array = coco_eval.eval['precision'][0, :, 0, 0, 2]  # iou=0.5, class=0, area=all, maxDets=100
        recall = coco_eval.params.recThrs
        
        # Print evaluation stats
        print(f"\nEvaluation Stats:")
        print(f"Number of true positives: {np.sum(coco_eval.eval['scores'][0, :, 0, 0, 2] > -1)}")
        print(f"Max precision: {np.max(pr_array):.3f}")
        print(f"Max recall: {np.max(recall):.3f}")

        # if the value pr_array is less than 0.1, do not plot
        # pr_array[pr_array < 1e-5] = np.nan 
        # pr_array[pr_array == 0] = np.nan 

        plt.plot(recall, pr_array, label=label, linewidth=2)

    plt.xlabel('Recall', fontsize=25)
    plt.ylabel('Precision', fontsize=25)

    plt.xticks(fontsize=25)
    plt.yticks(fontsize=25)

    plt.title(f'Precision-Recall Curves', fontsize=30, pad=20)
    plt.grid(True)
    plt.legend(fontsize=20, handlelength=2)
    plt.tight_layout()
    # plt.show()

    # Draw the borderline of figure
    ax = plt.gca()
    for spine in ax.spines.values():
        spine.set_edgecolor('black')
        spine.set_linewidth(2)
    plt.savefig(f'pr_curve_iou_{iou_thresh}.png')
    
    # Clean up all temporary files
    for temp_file in temp_files:
        if os.path.exists(temp_file):
            os.remove(temp_file)


# Example usage:
# Using same ground truth for all predictions:
# gt_path = '/mnt/home2/annonymous/ForestPersons_v3/annotations/test.json'
# gt_paths = [gt_path] * 3  # Or specify individual GT files for each prediction
# 
# prediction_paths = [
#     '/mnt/home/annonymous/neurips2025/mmdet/model1_predictions.json',
#     '/mnt/home/annonymous/neurips2025/mmdet/model2_predictions.json',
#     '/mnt/home/annonymous/neurips2025/mmdet/model3_predictions.json'
# ]
#     "/mnt/home2/annonymous/model_checkpoints/benchmark/dino_r50_4scale_12ep_all/coco_instances_results.json"
# ]
# labels = [
#     'Yolov3', 
#     'Yolox', 
#     'Faster R-CNN', 
#     'RetinaNet', 
#     'SSD',
#     "DETR",
#     "DINO"
# ]  # optional

ground_truth_paths = [
    '/mnt/home2/annonymous/ForestPersons_v3/annotations/test.json',
    '/mnt/home2/annonymous/ForestPersons_v3/annotations/test.json',
    '/mnt/home2/annonymous/ForestPersons_v3/annotations/test.json',
    '/mnt/home2/annonymous/ForestPersons_v3/annotations/test.json',
    '/mnt/home2/annonymous/ForestPersons_v3/annotations/test.json',
    # '/mnt/home/annonymous/neurips2025/yolov11_forestpersons/data/test.json',
    '/mnt/home2/annonymous/ForestPersons_v3/annotations/test.json',
    '/mnt/home2/annonymous/ForestPersons_v3/annotations/test.json',
    '/mnt/home2/annonymous/ForestPersons_v3/annotations/test.json',
    '/mnt/home2/annonymous/ForestPersons_v3/annotations/test.json',
    '/mnt/home2/annonymous/ForestPersons_v3/annotations/test.json',
]
prediction_paths = [
    "/mnt/home2/annonymous/model_checkpoints/benchmark/dino_r50_4scale_12ep_all/coco_instances_results.json",
    # "/mnt/home/annonymous/neurips2025/drone_detectron2/DroneDetectron2/outputs_FPN_CROP_forestpersons_fix/inference/inference/coco_instances_results.json",
    "converted_coco_instances_results_all.json",
    
    '/mnt/home/annonymous/neurips2025/mmdet/250506_missing_person_yolo3.bbox.json', 
    "/mnt/home/annonymous/neurips2025/mmdet/250506_missing_person_yolox.bbox.json",
    "/mnt/home/annonymous/neurips2025/yolov11_forestpersons/converted_predictions_test_val25.json",

    '/mnt/home/annonymous/neurips2025/mmdet/250506_missing_person_faster_rcnn.bbox.json',
    '/mnt/home/annonymous/neurips2025/mmdet/250506_missing_person_retinanet.bbox.json',
    "/mnt/home/annonymous/neurips2025/mmdet/missing_person_iclr2026/original_results.json.bbox.json",

    "/mnt/home/annonymous/neurips2025/mmdet/250503_missing_person_ssd_forestperson_v3.bbox.json",
    "/mnt/home/annonymous/neurips2025/mmdet/250508_missing_person_detr_forestperson_v3.bbox.json",
]
labels = [
    "DINO",
    "CZ Det",
    'YOLOv3', 
    'YOLOX', 
    'YOLOv11n',
    'Faster R-CNN', 
    'RetinaNet', 
    'Deformable R-CNN', 
    'SSD',
    "DETR",
]  # optional

plot_pr_curves(ground_truth_paths, prediction_paths, labels=labels, iou_thresh=0.5)



# # Example usage:
# gt_file = 
# pred_file = 
# plot_pr_curve(gt_file, pred_file, iou_thresh=0.5)
