import json
def get_category_dict():
    with open('./data/annotations/category_dict.json') as f:
        category_dict = json.load(f)
    return category_dict

def get_keep_labels_scores(keep_dict):
    category_dict = get_category_dict()

    keep_labels = []
    for i in range(len(keep_dict["keep_ls"])):
        keep_labels.append([])
        for j in range(len(keep_dict["keep_ls"][i])):
            if keep_dict["keep_ls"][i][j] == True:
                keep_labels[i].append(category_dict[str(keep_dict["labels"][i][j])])
    keep_scores = [keep_dict["scores"][i][j] for i in range(len(keep_dict["scores"])) for j in range(len(keep_dict["scores"][i])) if keep_dict["keep_ls"][i][j] == True]
    return keep_labels, keep_scores

import json
import os
import shutil
def get_interest_obj_ls():
    # read file
    file = "../POPE/output/coco/minival2014_pope_obj_dict.json"
    file_copy = file.replace(".json", "_copy.json")
    if not os.path.exists(file_copy):
        shutil.copy(file, file_copy)
    with open(file_copy, 'r') as f:
        interest_obj_ls = json.load(f)
    if len(interest_obj_ls) == 0:
        shutil.copy(file, file_copy)
        with open(file_copy, 'r') as f:
            interest_obj_ls = json.load(f)  
    interest_obj = interest_obj_ls.pop(list(interest_obj_ls.keys())[0])
    with open(file_copy, 'w') as f:
        json.dump(interest_obj_ls, f)
    return interest_obj

def found_scores_given_labels(keep_dict, obj_ls):
    if isinstance(obj_ls, str):
        obj_ls = [obj_ls]
    
    category_dict = get_category_dict()
    
    found_scores = {}

    for i in range(len(keep_dict["labels"])):
        for j in range(len(keep_dict["labels"][i])):
            obj = category_dict[str(keep_dict["labels"][i][j])]
            if obj in obj_ls:
                if obj not in found_scores:
                    found_scores[obj] = []
                found_scores[obj].append(round(keep_dict["scores"][i][j],4))
    for obj in obj_ls:
        if obj not in found_scores.keys():
            found_scores[obj] = [0]
    found_scores_given_obj = {obj: max(scores) for obj, scores in found_scores.items()}

    return found_scores, found_scores_given_obj

    # bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
    # return probas[keep], bboxes_scaled

def plot_results(pil_img, prob, boxes):
    category_dict = get_category_dict()

    import matplotlib.pyplot as plt
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), COLORS * 100):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        cl = p.argmax()
        text = f'{category_dict[cl]}: {p[cl]:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()
    save_path = "/data/linxi/workspace/detr/detr_demo.png"
    plt.savefig(save_path)
    print(f"save to {save_path}")


if __name__ == "__main__":
    with open('./llava/backbone/results/keep_COCO_val2014_000000144305.json') as f:
        keep_dict = json.load(f)

    keep_labels, keep_scores = get_keep_labels_scores(keep_dict)
    print(keep_labels, keep_scores)

    found_scores, found_scores_given_obj = found_scores_given_labels(keep_dict, ["keyboard","mouse","people","table","tv","chair","laptop","keyboard","cell phone","book","clock"])
    print()



    # what I need?
    # when evaluating, I need to know keep_dict
    # keep_dict includes keep, scores, labels

    # detailed evaluation, I pay attention to keep_labels, keep_scores, found_scores_given_labels, people_scores
    # highlight: the distribution!