from openai import OpenAI
import openai
from tqdm import tqdm
import pdb
import time
import os
import shutil
import copy
import tiktoken
import base64
import requests
import json

## Import OpenAI Key
openai.organization = "your openai organization"
chatgpt_client = OpenAI(api_key='your openai key')

## Import data
with open('/data/instruct_blip_7b_triplets.json', 'r') as file:
    triplets = json.load(file)

## Instructions for GPT-4 evaluation
evaluate_instruction = 'Given a list of reference triplets ("object1", "relation", "object2") extracted from the scene graph of an image, along with a list of objects observed in this image, your task is:\n\n' \
'Task 1. Determine if a claim triplet ("object1", "relation", "object2") is directly supported by any single triplet in the reference, or can be logically inferred from multiple reference triplets and the list of objects. Follow these steps when finishing the task:\n\n' \
'1. Answer "yes" if the claim appears in the reference.\n\n' \
'2. Answer "yes" if the claim can be logically inferred from one or more triplets in the reference. Consider:\n\n' \
'a. General Inferences: Assess common associations or implications.\n' \
'b. Conditional Phrases: Note phrases like "could be", "might", "suggests", which allow broader inferences.\n' \
'c. Equivalence of Objects: In your judgment, treat objects of the same kind as equal. For example, "woman", "man" should be considered under the general category of "person".\n' \
'd. Support from Object List: If the claim is not directly supported or inferable from the triplets, assess whether the list of objects provides additional evidence to support or infer the claim.\n\n' \
'3. Answer "no" if the claim neither directly matches any triplet in the reference nor can be reasonably inferred from the triplets and the object list.\n\n' \
'Task 2: Error categorization.\n\n' \
'If your answer to the previous task is "no", determine whether the not supported/inferred part in the claim is "object1" or "object2" or "relation".\n\n' \
'Reference:\n{}\n\n' \
'List of Objects:\n{}\n\n' \
'Claim:\n{}\n\n' \
'Please output your answer to the first task only in the format of "My answer is \'yes\'/\'no\'". If your answer is "no", output your answer to the second task only in the format of "The error is related to \'object1\'/\'object2\'/\'relation\'".'


## Calculate the number of tokens in the input
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"):
    """Returns the number of tokens used by a list of messages."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        encoding = tiktoken.get_encoding("cl100k_base")

    if model == "gpt-3.5-turbo-0613":  # note: future models may deviate from this
        num_tokens = 0
        for message in messages:
            num_tokens += 4  # every message follows <im_start>{role/name}\n{content}<im_end>\n
            for key, value in message.items():
                num_tokens += len(encoding.encode(value))
                if key == "name":  # if there's a name, the role is omitted
                    num_tokens += -1  # role is always required and always 1 token
        num_tokens += 2  # every reply is primed with <im_start>assistant
        return num_tokens
    else:
        raise NotImplementedError(f"""num_tokens_from_messages() is not presently implemented for model {model}.
    See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""")

## Start to evaluate
model_responses = {}

for index in tqdm(triplets.keys()):

    reference = triplets[index]['triplets']
    object_list = triplets[index]['all_object']

    new_reference = [tuple(item.strip('()').split(', ')) for item in reference]
    model_response = []

    for i, instance in tqdm(enumerate(triplets[index]['instance'])):
        judgements = []
        ori_responses = []
        for claim in instance['instruct_blip_7b_triplets']:
            judge_prompt = copy.deepcopy(evaluate_instruction).format(new_reference, object_list,
                                                                      tuple(claim), '{}', '{}', '{}')

            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": judge_prompt}]
            print('token number:', num_tokens_from_messages(messages))
            response = chatgpt_client.chat.completions.create(
                model="gpt-4-1106-preview",
                # response_format={ "type": "json_object" },
                messages=messages,
                # temperature=0.7,
            )
            ori_responses.append(response.choices[0].message.content)

            ## judge the result
            if ("my answer is 'yes'" in response.choices[0].message.content.lower()) or (
                    "my answer is \"yes\"" in response.choices[0].message.content.lower()):
                judgements.append('yes')
            elif ("my answer is 'no'" in response.choices[0].message.content.lower()) or (
                    "my answer is \"no\"" in response.choices[0].message.content.lower()):
                judgements.append('no')
            else:
                judgements.append('null')
        model_response.append(ori_responses)
        instance['instruct_blip_7b_triplets_judgements'] = judgements

    model_responses[index] = model_response

    ## Store the results
    with open(
            '/data/instruct_blip_triplets_new.json',
            'w') as file:
        json.dump(triplets, file)

    with open(
            '/data/instruct_blip_response.json',
            'w') as file:
        json.dump(model_responses, file)
