from openai import OpenAI
import openai
from tqdm import tqdm
import pdb
import time
import os
import shutil
import copy
import tiktoken
import base64
import requests
import json

with open('/data/instruct_blip_7b_triplets_new.json', 'r') as file:
    instruct_blip = json.load(file)

instruct_blip_score = []
for i, index in enumerate(instruct_blip.keys()):
    for j, instance in enumerate(instruct_blip[index]['instance']):
        try:
            instruct_blip_score.append(instance['instruct_blip_7b_triplets_judgements'].count('no') / len(instance['instruct_blip_7b_triplets_judgements']))
        except ZeroDivisionError:
            print(i,j)
            continue
        except KeyError:
            continue
print("question level hallucination rate: ", sum(instruct_blip_score)/len(instruct_blip_score))

instruct_blip_scores = []
for i, index in enumerate(instruct_blip.keys()):
    instruct_blip_score = []
    for j, instance in enumerate(instruct_blip[index]['instance']):
        try:
            instruct_blip_score.append(instance['instruct_blip_7b_triplets_judgements'].count('no') / len(instance['instruct_blip_7b_triplets_judgements']))
        except ZeroDivisionError:
            continue
    try:
        instruct_blip_scores.append(sum(instruct_blip_score)/len(instruct_blip_score))
    except ZeroDivisionError:
        print(i, j)
        continue

print("image level hallucination rate: ", sum(instruct_blip_scores) / len(instruct_blip_scores))

with open('/data/instruct_blip_response.json', 'r') as file:
    instruct_blip_judge = json.load(file)

instruct_blip_object_scores = []
instruct_blip_relation_scores = []
for i, index in enumerate(instruct_blip_judge.keys()):
    for j, judgements in enumerate(instruct_blip_judge[index]):
        objects = 0
        relations = 0

        for judgement in judgements:
            if ("my answer is 'no'" in judgement.lower()) or ("my answer is \"no\"" in judgement.lower()):
                if ("the error is related to 'object1'" in judgement.lower()) or (
                        "the error is related to 'object2'" in judgement.lower()):
                    objects += 1
                else:
                    relations += 1
        try:
            instruct_blip_object_scores.append(objects / len(judgements))
            instruct_blip_relation_scores.append(relations / len(judgements))
        except ZeroDivisionError:
            continue
print("question level object hallucination rate: ", sum(instruct_blip_object_scores)/len(instruct_blip_object_scores))
print("question level relation hallucination rate: ", sum(instruct_blip_relation_scores)/len(instruct_blip_relation_scores))

instruct_blip_object_scores = []
instruct_blip_relation_scores = []
for i, index in enumerate(instruct_blip_judge.keys()):
    instruct_blip_object_score = []
    instruct_blip_relation_score = []
    for j, judgements in enumerate(instruct_blip_judge[index]):
        objects = 0
        relations = 0

        for judgement in judgements:
            if ("my answer is 'no'" in judgement.lower()) or ("my answer is \"no\"" in judgement.lower()):
                if ("the error is related to 'object1'" in judgement.lower()) or (
                        "the error is related to 'object2'" in judgement.lower()):
                    objects += 1
                else:
                    relations += 1
        try:
            instruct_blip_object_score.append(objects / len(judgements))
            instruct_blip_relation_score.append(relations / len(judgements))
        except ZeroDivisionError:
            continue
    try:
        instruct_blip_object_scores.append(sum(instruct_blip_object_score) / len(instruct_blip_object_score))
        instruct_blip_relation_scores.append(sum(instruct_blip_relation_score) / len(instruct_blip_relation_score))
    except ZeroDivisionError:
        continue

print("image level object hallucination rate: ", sum(instruct_blip_object_scores)/len(instruct_blip_object_scores))
print("image level relation hallucination rate: ", sum(instruct_blip_relation_scores)/len(instruct_blip_relation_scores))