import json
# read file 1 /data/linxi/workspace/detr/plot_results/scores_boxes.json
# read std_file 2 /data/linxi/workspace/POPE/llava_qa/obj/I4_pred_obj_detr.json

# match the same image and output data 1 in the same format as data 2
# output data 1 to a new file
def match_results(org_file, save_pred_file, std_file):
    # with open(org_file, "r") as f:
    #     data = json.load(f)
    data = [json.loads(line) for line in open(org_file, 'r')]

    # with open(std_file, "r") as f:
    #     std_data = json.load(f)
    std_data = [json.loads(line) for line in open(std_file, 'r')]
    # data[0] = "COCO_val2014_000000047940.jpg": {"classes": ["person", "book", "laptop", "book"], "scores": [0.9914, 0.7218, 0.983, 0.9093]},
    # std_data[0] = {"id": 1, "image": "COCO_val2014_000000144305.jpg", "objects": ["tv", "keyboard", "laptop"]}
    if len(data) != len(std_data):
        print("data and std_data have different length")
        return
    if len(data) == 1 or len(std_data) == 1:
        print("data and std_data have length 1")
        return
    print(len(data), data[0])
    print(len(std_data), std_data[0])
    results = []
    # for item in std_data:
    #     image = item["image"]
    # for i, image in enumerate(image_ls):
    obj_len = 0
    for i in range(len(data)):
        result = dict()
        result["id"] = i
        result["image"] = list(data[i].keys())[0]
        print(data[i])
        result["objects"] = data[i][list(data[i].keys())[0]]["classes"]
        # set
        result["objects"] = list(set(result["objects"]))
        obj_len += len(result["objects"])
        results.append(result)

    # save to file
    with open(save_pred_file, "w+") as f:
        json.dump(results, f)
    print("avg obj len: ", obj_len/len(data))
    print(f"save to {save_pred_file}")
    
def match_results2(org_file, save_pred_file, std_file):
    with open(org_file, "r") as f:
        data = json.load(f)

        
    with open(std_file, "r") as f:
        std_data = json.load(f)
    # data[0] = "COCO_val2014_000000047940.jpg": {"classes": ["person", "book", "laptop", "book"], "scores": [0.9914, 0.7218, 0.983, 0.9093]},
    # std_data[0] = {"id": 1, "image": "COCO_val2014_000000144305.jpg", "objects": ["tv", "keyboard", "laptop"]}
    print(len(data), data[0])
    print(len(std_data), std_data[0])
    results = []
    for item in std_data:
        image = item["image"]
        for d in data:
            if d["image"] == image:
                result = dict()
                result["id"] = item["id"]
                result["image"] = image
                result["objects"] = d["gt_objects"]
                # set
                result["objects"] = list(set(result["objects"]))
                results.append(result)

    # save to file
    with open(save_pred_file, "w+") as f:
        json.dump(results, f)
    print(f"save to {save_pred_file}")
        


def test_pred_acc(gt_file, pred_file):
    with open(gt_file, 'r') as f:
        gt_data = json.load(f)
    with open(pred_file, 'r') as f:
        pred_data = json.load(f)
        
    gt_labels_ls = []
    pred_labels_ls = []
    for gt, pred in zip(gt_data, pred_data):
        gt_labels = gt['objects']
        pred_labels = pred['objects']
        all_labels = set(gt_labels + pred_labels)

        # turn labels into one-hot vectors
        gt_labels_ls.extend([1 if label in gt_labels else 0 for label in all_labels])
        pred_labels_ls.extend([1 if label in pred_labels else 0 for label in all_labels])

        # compute accuracy, precision, recall using sklearn evaluation metrics
    from sklearn.metrics import accuracy_score, precision_score, recall_score
    accuracy = accuracy_score(gt_labels_ls, pred_labels_ls)
    precision = precision_score(gt_labels_ls, pred_labels_ls)
    recall = recall_score(gt_labels_ls, pred_labels_ls)
    # num of detected labels
    num_detected = sum(pred_labels_ls)
    num_gt = sum(gt_labels_ls)

    print(f"accuracy: {accuracy}, precision: {precision}, recall: {recall}")
    print(f"num_detected: {num_detected}, num_gt: {num_gt}")
    
    
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--type", type=str, default="I4")
    parser.add_argument("--th", type=float, default=0.5)
    args = parser.parse_args()
    
    gt_file = "/data/linxi/workspace/POPE/data/lables_val2014_2.json"
    # gt_file = "/data/linxi/workspace/POPE/llava_qa/obj/I4_gt_obj.json"
    th_ls = [0.95, 0.5]
    for th in th_ls:
        org_file =f"/data/linxi/workspace/detr/plot_results/scores_boxes_{args.type}_th{th}.json" 
        save_pred_file = f"/data/linxi/workspace/POPE/llava_qa/obj/{args.type}_pred_obj_detr_th{th}.json"
        save_gt_file = f"/data/linxi/workspace/POPE/llava_qa/obj/{args.type}_gt_obj_detr_th{th}.json"
        # std_file="/data/linxi/workspace/POPE/llava_qa/obj/I4_pred_obj_detr_0.5.json"
        print(f"org_file: {org_file}")
        print(f"th: {th}")
        match_results(org_file, save_pred_file, std_file=org_file)
        # match_results2(gt_file, save_gt_file, std_file=save_pred_file)
        # test_pred_acc(save_gt_file, save_pred_file)
