import json
import random
from PIL import Image
from matplotlib import pyplot as plt
import os
from tqdm import tqdm
from utils import *


image_file_path = ""

def generate_question(sample,image_token = None):
    question = sample["question"]

    if image_token:
        return image_token + "\n QUESTION: "+ question + "\n ANSWER: "
    return "Question: "+ question + "\n Answer: "

def image_id_2_image_path(image_id):
    file_parts = [".JPEG",".jpg",".png"]

    for file_part in file_parts:
        image_file = image_file_path + image_id + file_part
        if os.path.exists(image_file):
            return image_file
        
    
    # raise ValueError("No file named {} found !".format(image_id))
    return None



if __name__ == '__main__':

    # {'predict': 'capybara', 'gold': 'Capybara', 'prompt': 'Question: what animal is presented in the image?\n Answer: '} 

    # print(image_id_2_image_path("oven_00623976"))

    # answer_entity = []
    
    with open("answer_BLIPOPT_entity.json") as f:
        answer_entity = json.loads(f.read())

    for img in answer_entity:

        predict = answer_entity[img]["predict"]
        gold = answer_entity[img]["gold"]

        if predict.strip().lower() == gold.lower():
            print(answer_entity[img],image_id_2_image_path(img))


    # with open("/group/30105/zhouyang/MMFact/triples.json") as f :
    #     triples = json.loads(f.read())

    
    # answer_entities = []
    # for t in triples:
    #     answer_entities.append(t["entity_text"])
    
    # print(len(answer_entities))
    # print(len(set(answer_entities)))
        

