import json
import argparse
import os
import random
import sys
import ssl
import copy
import openai
import time
import pathlib
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix
import nltk
from nltk.stem import WordNetLemmatizer
import spacy
from tqdm import tqdm
import warnings
import concurrent.futures
from nltk.stem import *
from concurrent.futures import ThreadPoolExecutor

nlp = spacy.load("en_core_web_lg")
warnings.filterwarnings("ignore", category=UserWarning)
lemma = nltk.wordnet.WordNetLemmatizer()

from eval.gpt4_grpc import Chat


template = '''Please act as an impartial and objective judge and evaluate the quality of the response provided by a Large Multimodal Model (LMM) to the user question. Your evaluation should be mainly based on whether the response is informative, and whether the response contains any hallucination. Hallucination, in this context, refers to a situation where the LMM generates a response that includes information not present or implied in the image or previous conversation. A hallucination could be a false claim about an object, action, emotion, or any other detail that is not grounded in the image.
For clarity, consider these examples:
### Example 1:
- Image Contents: Lynx, Cat
- Question: How many cats can you see in this image?
- Standard Human-Generated Answer: There are two cats visible in the image.
- LMM Response to Evaluate: There is one cat in the image.
- Analysis: The LMM's response is false because it claims there is only one cat in the image, when in fact there are two.
- Hallucination: Yes.
### Example 2:
- Image Contents: Tree, Plant, Panda, Animal
- Question: What can you see from this photo?
- Standard Human-Generated Answer: In this photo we can see branches. We can see a panda on a branch. The background is blurry.
- LMM Response to Evaluate: In the photo, I can see a large black and white panda sitting on a tree branch. The image captures the panda's unique appearance and its natural habitat, showcasing the beauty and rarity of these animals.
- Analysis: The LMM's response is not false because it does not make any false claims about the image contents. For example, it does not claim that the panda is eating bamboo.
- Hallucination: No.
### Example 3:
- Image Contents: Fashion accessory, Mammal, Dog, Umbrella
- Question: Describe this image in details.
- Standard Human-Generated Answer: In this picture I can see a dog laying under a yellow umbrella on a beach, and there is a rope on the right side. In the background I can see few persons talking or lying. At the top of the image, I can see the blue and clear sky.
- LMM Response to Evaluate: The image features a brown dog laying on the sand at the beach, underneath a yellow umbrella. The dog is positioned near the water, and there are people in the background. The beach is sandy, and the sun is shining, creating a pleasant atmosphere. The dog is enjoying the shade provided by the umbrella, which is a common way to protect pets from the sun's harmful rays. The presence of people in the background suggests that the beach is a popular spot for leisure activities, such as sunbathing, swimming, and relaxing.
- Analysis: The LMM's response is false because it claims that the dog is positioned near the water, but there is no water in the image.
- Hallucination: Yes.
### Example 4:
- Image Contents: Clothing, Luggage and bags, Man, Building, Woman, Person
- Question: Where is this photo taken?
- Standard Human-Generated Answer: From the slogan "Welcome to Houston, Texas", we can infer that this photo is taken at an airport in Houston, Texas.
- LMM Response to Evaluate: This photo is taken at an airport, specifically at the entrance of a terminal. However, the specific location of this airport cannot be determined from the image.
- Analysis: Although the LMM's response does not explicitly mention the location "Houston, Texas" and does not perfectly answer the question, it makes no false claims about the image contents. Therefore, this response is not considered hallucination. It is OK if the LMM fails to directly answer the question, expresses uncertainty, or cannot include all the objects or details in the human-generated answer.
- Hallucination: No.
### Example 5:
- Image Contents: Tree, Tower
- Question: What environmental concerns are involved in this photo?
- Standard Human-Generated Answer: This photo shows a smokestack emitting smoke. It may cause severe air pollution which threatens human health.
- LMM Response to Evaluate: In the photo, there are several environmental concerns related to the smokestack emitting smoke. The smoke from the smokestack is a byproduct of industrial processes, which can contribute to air pollution, climate change, and negative impacts on human health. The smoke contains harmful pollutants, such as particulate matter, sulfur dioxide, nitrogen oxides, and carbon monoxide, which can lead to respiratory problems, heart disease, and other health issues. Additionally, the smoke contributes to the greenhouse effect, which can lead to global warming and climate change. The smokestack's emissions also affect the environment, as they can harm wildlife, vegetation, and ecosystems. Therefore, it is essential to address these environmental concerns by implementing measures to reduce emissions and promote sustainable practices in industrial processes.
- Analysis: Although the LMM's response is significantly longer than the standard human-generated answer, it does not contain any false claims about the image contents. Instead, it provides additional general information about the environmental concerns, which can be inferred from the smoke emission. Such detailed analysis or reasoning should be considered as a positive aspect, as long as it contains no false claims.
- Hallucination: No.
With these examples in mind, please help me evaluate whether the response by the LMM is informative, and whether hallucination exists in it, based on the comparison between the LMM's response and the factual information provided in the image contents, question, and the standard human-generated answer below.
Please note that the standard human-generated answer may only contain factual information but may not give a detailed analysis. Also, the standard human-generated answer may not be completely comprehensive in describing all the objects and their attributes, so please be a bit more cautious during evalutation. LMM's detailed analysis or reasoning should be encouraged.
To evaluate the LMM responses, first, begin your evaluation by providing a short explanation. Second, after providing your explanation, you must rate the response by choosing from the following options:
- Rating: 6, very informative with good analysis or reasoning, no hallucination
- Rating: 5, very informative, no hallucination
- Rating: 4, somewhat informative, no hallucination
- Rating: 3, not informative, no hallucination
- Rating: 2, very informative, with hallucination
- Rating: 1, somewhat informative, with hallucination
- Rating: 0, not informative, with hallucination
### Image Contents
{}
### Question
{}
### Standard Human-Generated Answer
{}
### LMM Response to Evaluate
{}
'''

# import nltk
# nltk.download('punkt_tab')
# nltk.download('averaged_perceptron_tagger_eng')

eval_type_dict = {
    "Perception": ["existence", "count", "position", "color", "posters", "celebrity", "scene", "landmark", "artwork", "OCR"],
    "Cognition": ["commonsense_reasoning", "numerical_calculation", "text_translation", "code_reasoning"]
}
from utils import str2bool, set_random_seed, setup_logger, log_all_args

def parse_args():
    """Parse command line arguments for the script."""
    parser = argparse.ArgumentParser(description="Perform Visual-Question-Answering Task using LLAVA.")
    
    # Experiment Setting
    parser.add_argument('--record_root', type=str, help='Path to record root')
    parser.add_argument('--experiment_id', type=str, help='Unique Experiment ID')
    parser.add_argument('--evaluate_file', type=str, help='Path to evaluate file')
    parser.add_argument('--record_path', type=str, help='Path to record')
    
    parser.add_argument('--dataset', type=str, default='beaf', choices=['beaf', 'pope_adversarial', 'MME', 'AMBER', 'AMBER_G', 'MMH', 'obj'], help='Dataset')
    parser.add_argument('--model', type=str, default='llava-1.5v-7b', help='Model')
    parser.add_argument("--coco_path", type=str, default='../../Datasets/coco2014/annotations')

    # GPT-4 evaluation MMHal-Bench
    parser.add_argument('--api_key', type=str, 
                        default='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',
                        help='OpenAI API Key')
    parser.add_argument('--gpt_model', type=str, default='gpt-4-0613')
    
    # Seed Parameters
    parser.add_argument('--data_seed', type=int, default=0, help='Random seed')
    parser.add_argument('--seed', type=int, default=1234, help='Random seed')
    
    return parser.parse_args()

def check_synonyms_word(word1, word2, similarity_score):
    token1 = nlp(word1)
    token2 = nlp(word2)
    similarity = token1.similarity(token2)
    return similarity > similarity_score

def extract_nouns(text):
    lemmatizer = WordNetLemmatizer()
    tokens = nltk.word_tokenize(text)
    tagged = nltk.pos_tag(tokens)
    nouns = [lemmatizer.lemmatize(word) for word, pos in tagged if pos.startswith('NN')]
    return nouns

def init_metrics(metrics_path):
    metrics = {}
    with open(metrics_path, "r") as file:
        lines = file.readlines()

    for line in lines:
        parts = line.strip().split('=')
        if len(parts) == 2:
            variable_name = parts[0].strip()
            variable_value = eval(parts[1].strip())
            metrics[variable_name] = variable_value
            
    return metrics

def evaluate_AMBER(inference_data, annotation_path, word_association_path, safe_words_path, similarity_score, metrics_path, logger):
    metrics = init_metrics(metrics_path)
    association = json.load(open(word_association_path, 'r', encoding='utf-8'))
    
    hallucination_words = []
    for word1 in association.keys():
        hallucination_words.append(word1)
        for word2 in association[word1]:
            hallucination_words.append(word2)

    global_safe_words = []
    with open(safe_words_path, 'r', encoding='utf-8') as safe_file:
        for line in safe_file:
            line = line.split('\n')[0]
            global_safe_words.append(line)

    inference_data = json.load(open(inference_data, 'r', encoding='utf-8'))
    ground_truth = json.load(open(annotation_path, 'r', encoding='utf-8'))

    logger.info("Evaluation Type: AMBER")
    
    dimension = {'g': False,'de': False, 'da': False, 'dr': False}
    for key in dimension.keys():
            dimension[key] = True
            
    for i in tqdm(range(len(inference_data))):
        
        id = inference_data[i]['id']
        
        if ground_truth[id-1]['type'] == 'generative':
            nouns = extract_nouns(inference_data[i]['response'])
            after_process_nouns = []
            for noun in nouns:
                if noun in hallucination_words:
                    after_process_nouns.append(noun)
            
            safe_words = []
            safe_list = []
            for idx, word in enumerate(ground_truth[id-1]['truth']):
                safe_words += association[word]
                safe_list += [idx] * len(association[word])
                
            ha_words = []
            ha_list = []
            for idx, word in enumerate(ground_truth[id-1]['hallu']):
                ha_words += association[word]
                ha_list += [idx] * len(association[word])
            
            safe_words += ground_truth[id-1]['truth']
            safe_len = len(ground_truth[id-1]['truth'])
            safe_list += [0] * safe_len
            safe_flag_list = [0] * len(after_process_nouns)
            
            ha_words += ground_truth[id-1]['hallu']
            ha_len = len(ground_truth[id-1]['hallu'])
            ha_list += [0] * ha_len
            
            for idx, noun in enumerate(after_process_nouns):
                if noun in global_safe_words:
                    continue
                
                if noun in safe_words:
                    for j in range(len(safe_words)):
                        if noun == safe_words[j]:
                            if j < (len(safe_list) - safe_len):
                                safe_list[safe_list[j] + len(safe_list) - safe_len] = 1
                            else:
                                safe_list[j] = 1
                            break
                    continue
                
                if noun in ha_words:
                    for j in range(len(ha_words)):
                        if noun == ha_words[j]:
                            if j < (len(ha_list) - ha_len):
                                ha_list[ha_list[j] + len(ha_list) - ha_len] = 1
                            else:
                                ha_list[j] = 1
                            break
                
                for j, check_word in enumerate(ha_words):
                    if check_synonyms_word(noun, check_word, similarity_score):
                        if j < (len(ha_list) - ha_len):
                                ha_list[ha_list[j] + len(ha_list) - ha_len] = 1
                        else:
                            ha_list[j] = 1
                        break
                
                flag = False
                for j, check_word in enumerate(safe_words):
                    if check_synonyms_word(noun, check_word, similarity_score):
                        flag = True
                        if j < (len(safe_list) - safe_len):
                                safe_list[safe_list[j] + len(safe_list) - safe_len] = 1
                        else:
                            safe_list[j] = 1
                        break
                if flag == True:
                    continue
            
                safe_flag_list[idx] = 1

            metrics['chair_score'] += sum(safe_flag_list)
            metrics['chair_num'] += len(safe_flag_list)
            metrics['safe_cover_score'] += sum(safe_list[-safe_len:])
            metrics['safe_cover_num'] += len(safe_list[-safe_len:])
            metrics['hallu_cover_score'] += sum(ha_list[-ha_len:])
            metrics['hallu_cover_num'] += len(ha_list[-ha_len:])
            if sum(safe_flag_list) == 0:
                metrics['non_hallu_score'] += 1
            metrics['non_hallu_num'] += 1
        
        else:
            metrics['qa_correct_num'] += 1
            if ground_truth[id-1]['type'] == 'discriminative-attribute-state':
                metrics['as_qa_correct_num'] += 1
            elif ground_truth[id-1]['type'] == 'discriminative-attribute-number':
                metrics['an_qa_correct_num'] += 1
            elif ground_truth[id-1]['type'] == 'discriminative-attribute-action':
                metrics['aa_qa_correct_num'] += 1
            elif ground_truth[id-1]['type'] == 'discriminative-hallucination':
                metrics['ha_qa_correct_num'] += 1
            else:
                metrics['asso_qa_correct_num'] += 1
            
            truth = ground_truth[id-1]['truth']
            response = inference_data[i]['response']
            if truth == 'yes':
                if 'yes' in response.lower():
                    metrics['qa_correct_score'] += 1
                    if ground_truth[id-1]['type'] == 'discriminative-attribute-state':
                        metrics['as_qa_correct_score'] += 1
                    elif ground_truth[id-1]['type'] == 'discriminative-attribute-number':
                        metrics['an_qa_correct_score'] += 1
                    elif ground_truth[id-1]['type'] == 'discriminative-attribute-action':
                        metrics['aa_qa_correct_score'] += 1
                    elif ground_truth[id-1]['type'] == 'discriminative-hallucination':
                        metrics['ha_qa_correct_score'] += 1
                    else:
                        metrics['asso_qa_correct_score'] += 1
            else:
                metrics['qa_no_num'] += 1
                if ground_truth[id-1]['type'] == 'discriminative-attribute-state':
                    metrics['as_qa_no_num'] += 1
                elif ground_truth[id-1]['type'] == 'discriminative-attribute-number':
                    metrics['an_qa_no_num'] += 1
                elif ground_truth[id-1]['type'] == 'discriminative-attribute-action':
                    metrics['aa_qa_no_num'] += 1
                elif ground_truth[id-1]['type'] == 'discriminative-hallucination':
                    metrics['ha_qa_no_num'] += 1
                else:
                    metrics['asso_qa_no_num'] += 1
                
                if 'no' in response.lower():
                    metrics['qa_correct_score'] += 1
                    metrics['qa_no_score'] += 1
                    if ground_truth[id-1]['type'] == 'discriminative-attribute-state':
                        metrics['as_qa_correct_score'] += 1
                        metrics['as_qa_no_score'] += 1
                    elif ground_truth[id-1]['type'] == 'discriminative-attribute-number':
                        metrics['an_qa_correct_score'] += 1
                        metrics['an_qa_no_score'] += 1
                    elif ground_truth[id-1]['type'] == 'discriminative-attribute-action':
                        metrics['aa_qa_correct_score'] += 1
                        metrics['aa_qa_no_score'] += 1
                    elif ground_truth[id-1]['type'] == 'discriminative-hallucination':
                        metrics['ha_qa_correct_score'] += 1
                        metrics['ha_qa_no_score'] += 1
                    else:
                        metrics['asso_qa_correct_score'] += 1
                        metrics['asso_qa_no_score'] += 1
            
            if 'no' in response.lower():
                metrics['qa_ans_no_num'] += 1
                if ground_truth[id-1]['type'] == 'discriminative-attribute-state':
                    metrics['as_qa_ans_no_num'] += 1
                elif ground_truth[id-1]['type'] == 'discriminative-attribute-number':
                    metrics['an_qa_ans_no_num'] += 1
                elif ground_truth[id-1]['type'] == 'discriminative-attribute-action':
                    metrics['aa_qa_ans_no_num'] += 1
                elif ground_truth[id-1]['type'] == 'discriminative-hallucination':
                    metrics['ha_qa_ans_no_num'] += 1
                else:
                    metrics['asso_qa_ans_no_num'] += 1
                if truth == 'no':
                    metrics['qa_ans_no_score'] += 1
                    if ground_truth[id-1]['type'] == 'discriminative-attribute-state':
                        metrics['as_qa_ans_no_score'] += 1
                    elif ground_truth[id-1]['type'] == 'discriminative-attribute-number':
                        metrics['an_qa_ans_no_score'] += 1
                    elif ground_truth[id-1]['type'] == 'discriminative-attribute-action':
                        metrics['aa_qa_ans_no_score'] += 1
                    elif ground_truth[id-1]['type'] == 'discriminative-hallucination':
                        metrics['ha_qa_ans_no_score'] += 1
                    else:
                        metrics['asso_qa_ans_no_score'] += 1

    if dimension['g']:
        CHAIR = round(metrics['chair_score'] / metrics['chair_num'] * 100, 1)
        Cover = round(metrics['safe_cover_score'] / metrics['safe_cover_num'] * 100, 1)
        Ha = round(metrics['hallu_cover_score'] / metrics['hallu_cover_num'] * 100, 1)
        Ha_p = round(100 - metrics['non_hallu_score'] / metrics['non_hallu_num'] * 100, 1)
        logger.info("Generative Task:")
        logger.info("CHAIR:\t\t %.2f" % CHAIR)
        logger.info("Cover:\t\t %.2f" % Cover)
        logger.info("Hal:\t\t %.2f" % Ha_p)
        logger.info("Cog:\t\t %.2f" % Ha)
    
    if dimension['de'] and dimension['da'] and dimension['dr']:
        Accuracy = round(metrics['qa_correct_score'] / metrics['qa_correct_num'] * 100, 1)
        Precision = round(metrics['qa_ans_no_score'] / metrics['qa_ans_no_num'] * 100, 1)
        Recall = round(metrics['qa_no_score'] / metrics['qa_no_num'] * 100, 1)
        F1 = round(2 * (Precision/100) * (Recall/100) / ((Precision/100) + (Recall/100) + 0.0001) * 100, 1)
        logger.info("Descriminative Task:")
        logger.info("Accuracy:\t %.2f" % Accuracy)
        logger.info("Precision:\t %.2f" % Precision)
        logger.info("Recall:\t\t %.2f" % Recall)
        logger.info("F1:\t\t %.2f" % F1)
    
    if dimension['de']:
        hallucination_Accuracy = round(metrics['ha_qa_correct_score'] / metrics['ha_qa_correct_num'] * 100, 1)
        hallucination_Precision = round(metrics['ha_qa_ans_no_score'] / metrics['ha_qa_ans_no_num'] * 100, 1)
        hallucination_Recall = round(metrics['ha_qa_no_score'] / metrics['ha_qa_no_num'] * 100, 1)
        hallucination_F1 = round(2 * (hallucination_Precision/100) * (hallucination_Recall/100) / ((hallucination_Precision/100) + (hallucination_Recall/100) + 0.001) * 100, 1)
        logger.info("Exsitence:")
        logger.info("Accuracy:\t %.2f" % hallucination_Accuracy)
        logger.info("Precision:\t %.2f" % hallucination_Precision)
        logger.info("Recall:\t\t %.2f" % hallucination_Recall)
        logger.info("F1:\t\t %.2f" % hallucination_F1)
    
    if dimension['da']:
        attr_Accuracy = round((metrics['as_qa_correct_score'] + metrics['an_qa_correct_score'] + metrics['aa_qa_correct_score']) / (metrics['as_qa_correct_num'] + metrics['an_qa_correct_num'] + metrics['aa_qa_correct_num']) * 100, 1)
        attr_Precision = round((metrics['as_qa_ans_no_score'] + metrics['an_qa_ans_no_score'] + metrics['aa_qa_ans_no_score']) / (metrics['as_qa_ans_no_num'] + metrics['an_qa_ans_no_num'] + metrics['aa_qa_ans_no_num']) * 100, 1)
        attr_Recall = round((metrics['as_qa_no_score'] + metrics['an_qa_no_score'] + metrics['aa_qa_no_score']) / (metrics['as_qa_no_num'] + metrics['an_qa_no_num'] + metrics['aa_qa_no_num']) * 100, 1)
        attr_F1 = round(2 * (attr_Precision/100) * (attr_Recall/100) / ((attr_Precision/100) + (attr_Recall/100) + 0.0001) * 100, 1)
        state_Accuracy = round(metrics['as_qa_correct_score'] / metrics['as_qa_correct_num'] * 100, 1)
        state_Precision = round(metrics['as_qa_ans_no_score'] / metrics['as_qa_ans_no_num'] * 100, 1)
        state_Recall = round(metrics['as_qa_no_score'] / metrics['as_qa_no_num'] * 100, 1)
        state_F1 = round(2 * (state_Precision/100) * (state_Recall/100) / ((state_Precision/100) + (state_Recall/100) + 0.0001) * 100, 1)
        number_Accuracy = round(metrics['an_qa_correct_score'] / metrics['an_qa_correct_num'] * 100, 1)
        number_Precision = round(metrics['an_qa_ans_no_score'] / metrics['an_qa_ans_no_num'] * 100, 1)
        number_Recall = round(metrics['an_qa_no_score'] / metrics['an_qa_no_num'] * 100, 1)
        number_F1 = round(2 * (number_Precision/100) * (number_Recall/100) / ((number_Precision/100) + (number_Recall/100) + 0.0001) * 100, 1)
        action_Accuracy = round(metrics['aa_qa_correct_score'] / metrics['aa_qa_correct_num'] * 100, 1)
        action_Precision = round(metrics['aa_qa_ans_no_score'] / metrics['aa_qa_ans_no_num'] * 100, 1)
        action_Recall = round(metrics['aa_qa_no_score'] / metrics['aa_qa_no_num'] * 100, 1)
        action_F1 = round(2 * (action_Precision/100) * (action_Recall/100) / ((action_Precision/100) + (action_Recall/100) + 0.0001) * 100, 1)
        logger.info("Attribute:")
        logger.info("Accuracy:\t %.2f" % attr_Accuracy)
        logger.info("Precision:\t %.2f" % attr_Precision)
        logger.info("Recall:\t\t %.2f" % attr_Recall)
        logger.info("F1:\t\t %.2f" % attr_F1)
        logger.info("State:")
        logger.info("Accuracy:\t %.2f" % state_Accuracy)
        logger.info("Precision:\t %.2f" % state_Precision)
        logger.info("Recall:\t\t %.2f" % state_Recall)
        logger.info("F1:\t\t %.2f" % state_F1)
        logger.info("Number:")
        logger.info("Accuracy:\t %.2f" % number_Accuracy)
        logger.info("Precision:\t %.2f" % number_Precision)
        logger.info("Recall:\t\t %.2f" % number_Recall)
        logger.info("F1:\t\t %.2f" % number_F1)
        logger.info("Action:")
        logger.info("Accuracy:\t %.2f" % action_Accuracy)
        logger.info("Precision:\t %.2f" % action_Precision)
        logger.info("Recall:\t\t %.2f" % action_Recall)
        logger.info("F1:\t\t %.2f" % action_F1)
    
    if dimension['dr']:
        relation_Accuracy = round(metrics['asso_qa_correct_score'] / metrics['asso_qa_correct_num'] * 100, 1)
        relation_Precision = round(metrics['asso_qa_ans_no_score'] / metrics['asso_qa_ans_no_num'] * 100, 1)
        relation_Recall = round(metrics['asso_qa_no_score'] / metrics['asso_qa_no_num'] * 100, 1)
        relation_F1 = round(2 * (relation_Precision/100) * (relation_Recall/100) / ((relation_Precision/100) + (relation_Recall/100) + 0.0001) * 100, 1)
        logger.info("Relation:")
        logger.info("Accuracy:\t %.2f" % relation_Accuracy)
        logger.info("Precision:\t %.2f" % relation_Precision)
        logger.info("Recall:\t\t %.2f" % relation_Recall)
        logger.info("F1:\t\t %.2f" % relation_F1)

    return metrics

def load_json(path):
    with open(path, 'r') as f:
        return json.load(f)

def answer_check(beaf_qna):
    orig_pairs = {}
    total_qna = []
    yes_gt_count = 0
    for item in tqdm(beaf_qna, desc="Processing items"):
        try:
            if 'yes' in item['answer'].lower():
                answer = 'yes'
            elif 'no' in item['answer'].lower():
                answer = 'no'
            else:
                print(item['answer'])
                continue
            
            if 'yes' in item['gt'].lower():
                gt = 'yes'
                yes_gt_count += 1
            elif 'no' in item['gt'].lower():
                gt = 'no'
            else:
                continue
            
            if gt == 'yes' and answer == 'yes':
                item['answer'] = 'TP'
            elif gt == 'no' and answer == 'no':
                item['answer'] = 'TN'
            elif gt == 'yes' and answer == 'no':
                item['answer'] = 'FN'
            elif gt == 'no' and answer == 'yes':
                item['answer'] = 'FP'
            
            if item['answer'] not in ['TP', 'TN', 'FN', 'FP']:
                print(answer)
                print(gt)
                raise ValueError(f"Invalid answer: {item['answer']}")

            if item['orig_img']:
                if orig_pairs.get(item['image']) is None:
                    orig_pairs[item['image']] = {}
                orig_pairs[item['image']][item['question']] = item['answer']
            # total_qna = beaf_qna.copy()
            total_qna.append(item)
        except:
            raise ValueError(f"Dataset {item['id']} not supported")
            continue
    return orig_pairs, total_qna, yes_gt_count

def part_answer_check(beaf_qna):
    orig_pairs = {}
    total_qna = []
    yes_gt_count = 0
    for item in tqdm(beaf_qna, desc="Processing items"):
        try:
            if 'yes' in item['answer'].lower():
                answer = 'yes'
            elif 'no' in item['answer'].lower():
                answer = 'no'
            else:
                print(item['answer'])
                continue
            
            if 'yes' in item['gt'].lower():
                gt = 'yes'
                yes_gt_count += 1
            elif 'no' in item['gt'].lower():
                gt = 'no'
            else:
                continue
            
            if gt == 'yes' and answer == 'yes':
                item['answer'] = 'TP'
            elif gt == 'no' and answer == 'no':
                item['answer'] = 'TN'
            elif gt == 'yes' and answer == 'no':
                item['answer'] = 'FN'
            elif gt == 'no' and answer == 'yes':
                item['answer'] = 'FP'
            
            if item['answer'] not in ['TP', 'TN', 'FN', 'FP']:
                print(answer)
                print(gt)
                raise ValueError(f"Invalid answer: {item['answer']}")

            total_qna.append(item)
        except:
            raise ValueError(f"Dataset {item['id']} not supported")
            continue
    return orig_pairs, total_qna, yes_gt_count

def answer_check_MME(qna):
    orig_pairs = {}
    total_qna = []
    yes_gt_count = 0
    for item in tqdm(qna, desc="Processing items"):
        try:
            if 'yes' in item['answer'].lower():
                answer = 'yes'
            elif 'no' in item['answer'].lower():
                answer = 'no'
            else:
                print(item['answer'])
                continue
            
            if 'yes' in item['gt'].lower():
                gt = 'yes'
                yes_gt_count += 1
            elif 'no' in item['gt'].lower():
                gt = 'no'
            else:
                continue
            
            if gt == 'yes' and answer == 'yes':
                item['answer'] = 'TP'
            elif gt == 'no' and answer == 'no':
                item['answer'] = 'TN'
            elif gt == 'yes' and answer == 'no':
                item['answer'] = 'FN'
            elif gt == 'no' and answer == 'yes':
                item['answer'] = 'FP'
            
            if item['answer'] not in ['TP', 'TN', 'FN', 'FP']:
                print(answer)
                print(gt)
                raise ValueError(f"Invalid answer: {item['answer']}")
            total_qna.append(item)
        except:
            raise ValueError(f"Dataset {item['id']} not supported")
            continue
    return orig_pairs, total_qna, yes_gt_count

def metric(orig_pairs, total_qna):
    results_per_image = []

    for image, qa_pairs in orig_pairs.items():
        cnt = {'TP': 0, 'FP': 0, 'TN': 0, 'FN': 0, 'TU': 0, 'IG': 0, 'SBp': 0, 'SBn': 0, 'ID': 0}
        conv = {'TPTN': 'TU', 'FNFP': 'IG', 'TPFP': 'SBp', 'FNTN': 'SBn'}
        id_tot = 0
        for tot in total_qna:
            if tot['image'][:25] != image[:25]:
                continue

            cnt[tot['answer']] += 1
            if not tot['orig_img']:
                name = tot['image'][:-7] + '.jpg'
                try:
                    ori_ans = orig_pairs[name][tot['question']]
                    if tot['removed_q']:
                        if conv.get(ori_ans + tot['answer']) is not None:
                            key = conv[ori_ans + tot['answer']]
                            cnt[key] += 1
                    else:
                        id_tot += 1
                        if ori_ans[0] != tot['answer'][0]:
                            cnt['ID'] += 1
                except:
                    continue

        Filter_R_True = cnt['TU'] + cnt['IG'] + cnt['SBp'] + cnt['SBn']
        
        acc = (cnt['TP'] + cnt['TN']) / (cnt['TP'] + cnt['FP'] + cnt['TN'] + cnt['FN']) * 100 if (cnt['TP'] + cnt['FP'] + cnt['TN'] + cnt['FN']) != 0 else 0
        precision = cnt['TP'] / (cnt['TP'] + cnt['FP']) * 100 if (cnt['TP'] + cnt['FP']) != 0 else 0
        recall = cnt['TP'] / (cnt['TP'] + cnt['FN']) * 100 if (cnt['TP'] + cnt['FN']) != 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0

        tu = cnt['TU'] / Filter_R_True * 100 if Filter_R_True != 0 else 0
        ig = cnt['IG'] / Filter_R_True * 100 if Filter_R_True != 0 else 0
        sbp = cnt['SBp'] / Filter_R_True * 100 if Filter_R_True != 0 else 0
        sbn = cnt['SBn'] / Filter_R_True * 100 if Filter_R_True != 0 else 0
        id_ = cnt['ID'] / id_tot * 100 if id_tot != 0 else 0
        f1_tuid = 2 * tu * (100 - id_) / (tu + (100 - id_)) if (tu + (100 - id_)) != 0 else 0
                
        results_per_image.append({
            'image': image,
            'accuracy': acc,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'TU': tu,
            'IG': ig,
            'SBp': sbp,
            'SBn': sbn,
            'ID': id_,
            'F1_TUID': f1_tuid
        })

    return results_per_image

def metric_all_data(orig_pairs, total_qna, yes_gt_count):
    cnt = {'TP':0, 'FP':0, 'TN':0, 'FN':0,
           'TU':0, 'IG':0, 'SBp':0, 'SBn':0, 'ID':0}
    conv = {'TPTN': 'TU', 'FNFP': 'IG', 'TPFP': 'SBp', 'FNTN': 'SBn'}

    id_tot = 0
    for tot in total_qna:
        cnt[tot['answer']] += 1
        if not tot['orig_img']:
            name = tot['image'][:-7] + '.jpg'
            try:
                ori_ans = orig_pairs[name][tot['question']]
                # for TU, IG, SBp, SBn
                if tot['removed_q']:
                    if conv.get(ori_ans + tot['answer']) is not None:
                        key = conv[ori_ans + tot['answer']]
                        cnt[key] += 1
                # for ID
                else:
                    id_tot += 1
                    if ori_ans[0] != tot['answer'][0]:
                        cnt['ID'] += 1
            except:
                continue
    
    Filter_R_True = cnt['TU'] + cnt['IG'] + cnt['SBp'] + cnt['SBn']
    
    acc = (cnt['TP'] + cnt['TN']) / (cnt['TP'] + cnt['FP'] + cnt['TN'] + cnt['FN']) * 100
    precision = cnt['TP'] / (cnt['TP'] + cnt['FP']) * 100
    recall = cnt['TP'] / (cnt['TP'] + cnt['FN']) * 100
    f1 = 2 * precision * recall / (precision + recall)

    tu = cnt['TU'] / Filter_R_True * 100
    ig = cnt['IG'] / Filter_R_True * 100
    sbp = cnt['SBp'] / Filter_R_True * 100
    sbn = cnt['SBn'] / Filter_R_True * 100
    id_ = cnt['ID'] / id_tot * 100
    f1_tuid = 2 * tu * (100 - id_) / (tu + (100 - id_))
    
    yes_gt_ratio = yes_gt_count / len(total_qna) if len(total_qna) > 0 else 0
    
    return acc, precision, recall, f1, tu, ig, sbp, sbn, id_, f1_tuid, cnt['TP'], cnt['FP'], cnt['TN'], cnt['FN'], yes_gt_ratio

def evaluate(evaluate_file, save_path):
    beaf_qna = load_json(evaluate_file)
    orig_pairs, total_qna = answer_check(beaf_qna)
    
    results = metric(orig_pairs, total_qna)
    
    with open(save_path, "w") as json_file:
        json.dump(results, json_file, indent=4)

    print(f"Results saved to {save_path}")
    
def evaluate_all(evaluate_file, save_path, logger):
    """
    Evaluate metrics from the evaluation file and save results to the specified path.
    
    Parameters:
    - evaluate_file (str): Path to the evaluation file.
    - save_path (str): Path where results will be saved.
    - logger (logging.Logger): Logger for logging information.
    
    Returns:
    - tuple: Contains evaluation metrics (accuracy, precision, recall, etc.).
    """
    
    # Load and process data
    beaf_qna = load_json(evaluate_file)
    orig_pairs, total_qna, yes_gt_count = answer_check(beaf_qna)

    # Compute metrics
    acc, precision, recall, f1, tu, ig, sbp, sbn, id_, f1_tuid, TP, FP, TN, FN, yes_gt_ratio  = metric_all_data(orig_pairs, total_qna, yes_gt_count)
    
    # Log results
    logger.info(f"Accuracy: {acc}")
    logger.info(f"Precision: {precision}")
    logger.info(f"Recall: {recall}")
    logger.info(f"F1: {f1}")
    logger.info(f"TU: {tu}")
    logger.info(f"IG: {ig}")
    logger.info(f"SBp: {sbp}")
    logger.info(f"SBn: {sbn}")
    logger.info(f"ID: {id_}")
    logger.info(f"F1_TUID: {f1_tuid}")
    
    logger.info(f"True Positive: {TP}")
    logger.info(f"False Positive: {FP}")
    logger.info(f"True Negative: {TN}")
    logger.info(f"False Negative: {FN}")
    
    # Calculate the proportion of 'yes' answers based on TP and FP
    total = TP + FP + TN + FN
    yes_ratio = (TP + FP) / total if total > 0 else 0  # Avoid division by zero
    logger.info(f"Proportion of 'yes' answers: {yes_ratio:.4f}")
    
    logger.info(f"Proportion of 'yes' in ground truth: {yes_gt_ratio:.4f}")
    
    return acc, precision, recall, f1, tu, ig, sbp, sbn, id_, f1_tuid

def metric_all_data_part(orig_pairs, total_qna, yes_gt_count):
    cnt = {'TP':0, 'FP':0, 'TN':0, 'FN':0,
           'TU':0, 'IG':0, 'SBp':0, 'SBn':0, 'ID':0}

    # for tot in total_qna:
    #     print(tot['answer'])
    for tot in total_qna:
        cnt[tot['answer']] += 1
      
    acc = (cnt['TP'] + cnt['TN']) / (cnt['TP'] + cnt['FP'] + cnt['TN'] + cnt['FN']) * 100
    precision = cnt['TP'] / (cnt['TP'] + cnt['FP']) * 100
    recall = cnt['TP'] / (cnt['TP'] + cnt['FN']) * 100
    f1 = 2 * precision * recall / (precision + recall)
    
    yes_gt_ratio = yes_gt_count / len(total_qna) if len(total_qna) > 0 else 0

    return acc, precision, recall, f1, cnt['TP'], cnt['FP'], cnt['TN'], cnt['FN'], yes_gt_ratio

def evaluate_part(evaluate_file, logger):
    # Load and process data
    beaf_qna = load_json(evaluate_file)
    orig_pairs, total_qna, yes_gt_count = part_answer_check(beaf_qna)
    
    # Compute metrics
    acc, precision, recall, f1, TP, FP, TN, FN, yes_gt_ratio = metric_all_data_part(orig_pairs, total_qna, yes_gt_count)
    
    # Calculate the proportion of 'yes' answers based on TP and FP
    total = TP + FP + TN + FN
    yes_ratio = (TP + FP) / total if total > 0 else 0  # Avoid division by zero
    
    # Log results
    logger.info(f"Accuracy: {acc}")
    logger.info(f"Precision: {precision}")
    logger.info(f"Recall: {recall}")
    logger.info(f"F1: {f1}")
    
    logger.info(f"True Positive: {TP}")
    logger.info(f"False Positive: {FP}")
    logger.info(f"True Negative: {TN}")
    logger.info(f"False Negative: {FN}")
    
    logger.info(f"Proportion of 'yes' answers: {yes_ratio:.4f}")
    logger.info(f"Proportion of 'yes' in ground truth: {yes_gt_ratio:.4f}")
    
    return acc, precision, recall, f1, yes_ratio

class calculate_metrics:
    def divide_chunks(self, l, n=2):
        # looping till length l
        for i in range(0, len(l), n): 
            yield l[i:i + n]
        return 

    def parse_pred_ans(self, pred_ans):
        pred_label = None
        if pred_ans in ["yes", "no"]:
            pred_label = pred_ans
        else:
            # prefix_pred_ans = pred_ans[:4]
            prefix_pred_ans = pred_ans
            if "yes" in prefix_pred_ans:
                pred_label = "yes"
            elif "no" in prefix_pred_ans:
                pred_label = "no"
            elif "not" in prefix_pred_ans:
                pred_label = "no"
            else:
                pred_label = "other"
                print(f"Other answer: {pred_ans}")
        return pred_label

    # def compute_metric(self, gts, preds):
    #     assert len(gts) == len(preds)

    #     label_map = {
    #         "yes": 1,
    #         "no": 0,
    #         "other": -1,
    #     }

    #     gts = [label_map[x] for x in gts]
    #     preds = [label_map[x] for x in preds]

    #     acc = accuracy_score(gts, preds) 

    #     clean_gts = []
    #     clean_preds = []
    #     other_num = 0 
    #     for gt, pred in zip(gts, preds):
    #         if pred == -1:
    #             other_num += 1
    #             continue
    #         clean_gts.append(gt)
    #         clean_preds.append(pred)

    #     conf_mat = confusion_matrix(clean_gts, clean_preds, labels=[1,0])
    #     precision = precision_score(clean_gts, clean_preds, average='binary')
    #     recall = recall_score(clean_gts, clean_preds, average='binary')
    #     tp, fn = conf_mat[0]
    #     fp, tn = conf_mat[1]

    #     metric_dict = dict()
    #     metric_dict = {
    #         "TP": tp,
    #         "FN": fn,
    #         "TN": tn,
    #         "FP": fp,
    #         "precision": precision,
    #         "recall": recall,
    #         "other_num": other_num,
    #         "acc": acc,
    #     }

    #     return metric_dict
    def compute_metric(self, gts, preds):
            assert len(gts) == len(preds)

            # Manually map labels
            label_map = {
                "yes": 1,
                "no": 0,
                "other": -1,
            }

            gts = [label_map[x] for x in gts]
            preds = [label_map[x] for x in preds]

            # Initialize counters
            tp = fn = tn = fp = 0
            other_num = 0
            correct_count = 0

            # Calculate TP, FN, TN, FP and handle "other" answers
            for gt, pred in zip(gts, preds):
                if pred == -1:
                    other_num += 1
                    continue  # Skip "other" answers
                if gt == pred:
                    correct_count += 1
                if gt == 1 and pred == 1:  # True Positive
                    tp += 1
                elif gt == 1 and pred == 0:  # False Negative
                    fn += 1
                elif gt == 0 and pred == 0:  # True Negative
                    tn += 1
                elif gt == 0 and pred == 1:  # False Positive
                    fp += 1

            # Calculate accuracy manually
            acc = correct_count / len(gts)

            # Precision, Recall, and F1 score calculations
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0

            # Return metrics as a dictionary
            metric_dict = {
                "TP": tp,
                "FN": fn,
                "TN": tn,
                "FP": fp,
                "precision": precision,
                "recall": recall,
                "other_num": other_num,
                "acc": acc,
            }

            return metric_dict

    def evaluate_MME(self, evaluate_file, logger):
        # Load the evaluation file (JSON format)
        qna = load_json(evaluate_file)

        task_score_dict = dict()
        task_other_ans_num = 0
        acc_plus_correct_num = 0
        img_num = len(qna)
        gt_yes_count = 0
        pred_yes_count = 0

        # Group by category
        category_results = {}

        for item in qna:
            try:
                # Extract details from each item
                img_name = item['image']
                question = item['question']
                gt_ans = item['gt'].lower()
                pred_ans = item['answer'].lower()
                category = item['category']  # Assuming this field exists in the dataset

                # Validate ground truth and prediction
                assert gt_ans in ["yes", "no"], "Ground truth does not contain 'yes' or 'no'."
                pred_ans = self.parse_pred_ans(pred_ans)
                assert pred_ans in ["yes", "no", "other"], "Prediction must be 'yes', 'no', or 'other'."

                # Store the answers for later metric calculation
                if category not in category_results:
                    category_results[category] = {'gts': [], 'preds': [], 'correct': 0, 'other': 0, 'yes_count': 0, 'pred_yes_count': 0, 'count': 0}
                
                category_results[category]['gts'].append(gt_ans)
                category_results[category]['preds'].append(pred_ans)

                # Count how many 'yes' in gt and pred for each category
                if gt_ans == "yes":
                    category_results[category]['yes_count'] += 1
                if pred_ans == "yes":
                    category_results[category]['pred_yes_count'] += 1

                if gt_ans == pred_ans:
                    category_results[category]['correct'] += 1

                if pred_ans not in ["yes", "no"]:
                    category_results[category]['other'] += 1
                category_results[category]['count'] += 1
            except Exception as e:
                # Handle error without breaking the loop
                logger.info(f"Error processing dataset item {item['id']}: {e}")
                continue

        total_score = 0
        yes_count_num = 0
        count_num = 0
        other_answer_count = 0
        # Now, calculate metrics for each category
        for category, results in category_results.items():
            gts = results['gts']
            preds = results['preds']
            correct = results['correct']
            other = results['other']
            other_answer_count += results['other']
            yes_count = results['yes_count']
            pred_yes_count = results['pred_yes_count']

            # Calculate metrics (TP, FP, precision, recall, etc.)
            metric_dict = self.compute_metric(gts, preds)

            # Calculate acc_plus (percentage of correct answers in each image)
            acc_plus = correct / img_num
            metric_dict["acc_plus"] = acc_plus

            # Log the results for this category
            logger.info(f"========Evaluation results for category '{category}':========")
            # logger.info(f"Accuracy with 'yes/no' predictions: {metric_dict['acc'] * 100:.2f}%")
            # logger.info(f"Accuracy with correct 'yes/no' predictions: {acc_plus * 100:.2f}%")
            logger.info(f"Other answers count: {other}")
            # logger.info(f"Confusion Matrix (TP, FN, TN, FP): {metric_dict['TP']} {metric_dict['FN']} {metric_dict['TN']} {metric_dict['FP']}")

            # Calculate the ratio of predicted 'yes' over the total 'yes' in the ground truth
            if yes_count > 0:
                pred_yes_ratio = pred_yes_count / results['count']
            else:
                pred_yes_ratio = 0  # Avoid division by zero if there are no 'yes' in the ground truth
            yes_count_num += pred_yes_count
            count_num += results['count']
            logger.info(f"Ground truth 'yes' count: {yes_count}")
            logger.info(f"Predicted 'yes' count: {pred_yes_count}")
            logger.info(f"Yes Ratio(%): {pred_yes_ratio * 100:.2f}%")

            # Store the category score (adding the acc to the total score for this category)
            category_score = metric_dict["acc"] * 100
            task_score_dict[category] = category_score
            total_score += category_score

        # Log the final total score
        yes_ratio = yes_count_num / count_num if count_num > 0 else 0
        logger.info(f"========Final Rating========")
        logger.info(f"Answer Count: {count_num}")
        logger.info(f"Inaccurate Response Count: {other_answer_count}")
        logger.info(f"Final Yes Ratio(%): {yes_ratio * 100:.2f}%")
        logger.info(f"Final total score: {total_score:.2f}")
        
        return task_score_dict
    
def eval_MMHal(inference_data, api_key, save_gpt_4_path, logger):
    records = load_json(inference_data)
    assert len(records) == 96
    from openai import OpenAI
    client = OpenAI(api_key=api_key)
    # ask GPT-4 to evaluate
    responses = []
    for i, record in enumerate(records):
        image_content = ', '.join(record['image_content'])
        input_text = template.format(image_content, record['question'], record['gt_answer'], record['response'])
        # print(input_text)

        response = None
        while response is None:
            try:
                response = client.responses.create(
                    model=args.gpt_model,
                    input=input_text,
                    temperature=0.0,
                )
            except Exception as e:
                print(e)
                print('retrying...')
                time.sleep(10)
                continue

        print(i, response.output_text, flush=True)
        responses.append({
            "index": i,
            "output": response.output_text
        })
        time.sleep(1)
        
    # save responses
    with open(save_gpt_4_path, 'w') as f:
        json.dump(responses, f, indent=2)

    # analyze responses
    scores = []
    for i, response in enumerate(responses):
        response = response["output"]
        scores_found = []
        for s in range(7):
            if f'rating: {s}' in response.lower():
                scores_found.append(s)
        if len(scores_found) == 1:
            scores.append(scores_found[0])
        else:
            print('Warning: multiple or zero scores found')
            print(i, response)
            scores.append(0)

    hallucination = []
    for s in scores:
        if s >= 3:
            hallucination.append(0)
        else:
            hallucination.append(1)

    scores_each = [[] for _ in range(8)]
    # assuming order of 96 questions is not changed
    for i in range(96):
        question_type = i % 8
        scores_each[question_type].append(scores[i])
        
    logger.info('Average score: {:.2f}'.format(sum(scores) / len(scores)))
    logger.info('Hallucination rate: {:.2f}'.format(sum(hallucination) / len(hallucination)))
    logger.info('Average score for each question type: ' + 
                ', '.join([str(round(sum(scores_each[i]) / len(scores_each[i]), 2)) for i in range(8)]))

    # logger.info('Average score: {:.2f}'.format(sum(scores) / len(scores)))
    # logger.info('Hallucination rate: {:.2f}'.format(sum(hallucination) / len(hallucination)))
    # logger.info('Average score for each question type:', ','.join([str(round(sum(scores_each[i]) / len(scores_each[i]), 2)) for i in range(8)]), flush=True)
    return None    
    
    # openai.api_key = api_key
    
    # # ask GPT-4 to evaluate
    # responses = []
    # for i, record in enumerate(records):
    #     image_content = ', '.join(record['image_content'])
    #     input_text = template.format(image_content, record['question'], record['gt_answer'], record['response'])
    #     # print(input_text)

    #     response = None
    #     while response is None:
    #         try:
    #             response = openai.ChatCompletion.create(
    #                 model=args.gpt_model,
    #                 messages=[
    #                     {"role": "user", "content": input_text}
    #                 ],
    #                 temperature=0.0,
    #             )
    #         except Exception as e:
    #             print(e)
    #             print('retrying...')
    #             time.sleep(10)
    #             continue

    #     print(i, response['choices'][0]['message']['content'], flush=True)
    #     responses.append(response)
    #     time.sleep(1)

    # # save responses
    # with open(save_gpt_4_path, 'w') as f:
    #     json.dump(responses, f, indent=2)

    # # analyze responses
    # scores = []
    # for i, response in enumerate(responses):
    #     response = response['choices'][0]['message']['content']
    #     scores_found = []
    #     for s in range(7):
    #         if f'rating: {s}' in response.lower():
    #             scores_found.append(s)
    #     if len(scores_found) == 1:
    #         scores.append(scores_found[0])
    #     else:
    #         print('Warning: multiple or zero scores found')
    #         print(i, response)
    #         scores.append(0)

    # hallucination = []
    # for s in scores:
    #     if s >= 3:
    #         hallucination.append(0)
    #     else:
    #         hallucination.append(1)

    # scores_each = [[] for _ in range(8)]
    # # assuming order of 96 questions is not changed
    # for i in range(96):
    #     question_type = i % 8
    #     scores_each[question_type].append(scores[i])
        
    # logger.info('Average score: {:.2f}'.format(sum(scores) / len(scores)))
    # logger.info('Hallucination rate: {:.2f}'.format(sum(hallucination) / len(hallucination)))
    # logger.info('Average score for each question type: ' + 
    #             ', '.join([str(round(sum(scores_each[i]) / len(scores_each[i]), 2)) for i in range(8)]))

    # # logger.info('Average score: {:.2f}'.format(sum(scores) / len(scores)))
    # # logger.info('Hallucination rate: {:.2f}'.format(sum(hallucination) / len(hallucination)))
    # # logger.info('Average score for each question type:', ','.join([str(round(sum(scores_each[i]) / len(scores_each[i]), 2)) for i in range(8)]), flush=True)
    # return None

def load_generated_outputs(cap_file):
    data = json.load(open(cap_file, 'r', encoding='utf-8'))

    output = []
    for item in data:
        try:
            out_item = {
                "id": int(item.get("id", -1)),
                "org_idx": int(item["org_idx"]),
                "image_id": int(item["image_id"]),
                "image": item["image"],
                "question": item["question"],
                "response": item["response"]
            }
            output.append(out_item)
        except KeyError as e:
            raise Exception(f"Missing key in input file: {e}")

    imids = set([item["image_id"] for item in output])
    metrics = {}
    
    return output, imids, metrics

def combine_coco_instances(annotation_path):

    if not os.path.exists('%s/instances_%s2014.json' %(annotation_path, 'val')):
        raise Exception("Please download MSCOCO instance annotations for val set")
    if not os.path.exists('%s/instances_%s2014.json' %(annotation_path, 'train')):
        raise Exception("Please download MSCOCO instance annotations for train set")

    val_instances = json.load(open('%s/instances_%s2014.json' %(annotation_path, 'val')))
    train_instances = json.load(open('%s/instances_%s2014.json' %(annotation_path, 'train')))
    all_instances = {'info': train_instances['info'],
                     'licenses': train_instances['licenses'],
                     'type': train_instances['licenses'],
                     'categories': train_instances['categories'],
                     'images': train_instances['images'] + val_instances['images'],
                     'annotations': val_instances['annotations'] + train_instances['annotations']}

    return all_instances

def combine_coco_captions(annotation_path):

    if not os.path.exists('%s/captions_%s2014.json' %(annotation_path, 'val')):
        raise Exception("Please download MSCOCO caption annotations for val set")
    if not os.path.exists('%s/captions_%s2014.json' %(annotation_path, 'train')):
        raise Exception("Please download MSCOCO caption annotations for train set")

    val_caps = json.load(open('%s/captions_%s2014.json' %(annotation_path, 'val')))
    train_caps = json.load(open('%s/captions_%s2014.json' %(annotation_path, 'train')))
    all_caps = {'info': train_caps['info'],
                'licenses': train_caps['licenses'],
                'images': val_caps['images'] + train_caps['images'],
                'annotations': val_caps['annotations'] + train_caps['annotations']}

    return all_caps


def parse_object_list(content):
    try:
        content = json.loads(content)
    except:
        if '["' in content:
            try:
                content = json.loads(content.strip().split('\n')[-1])
            except:
                raise ValueError('Content is not json interpretable')
        else:
            raise ValueError('Content is not json interpretable')
    return content

class CHAIR(object):

    def __init__(self, imids, coco_path, openai_apikey):

        self.imid_to_objects = {imid: [] for imid in imids}

        self.coco_path = coco_path
        
        self.chat_model = Chat(model="gpt-4o-2024-05-13", timeout_sec=100, openai_apikey=openai_apikey)
        
        self.fail_limit=20


        #read in synonyms
        synonyms = open('./eval/synonyms_refine.txt').readlines()
        synonyms = [s.strip().split(', ') for s in synonyms]
        self.mscoco_objects = [] #mscoco objects and *all* synonyms
        self.inverse_synonym_dict = {}
        for synonym in synonyms:
            new_synonym = [s.strip() for s in synonym]
            self.mscoco_objects.extend(new_synonym)
            for s in new_synonym:
                self.inverse_synonym_dict[s] = new_synonym[0]

        coco_double_words = [word for word in self.inverse_synonym_dict.keys() if len(word.strip().split(' ')) >= 2]
        coco_double_words += ['home plate', 'train track']
        print("double word count:", len(coco_double_words))

        animal_words = ['bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'animal', 'cub']
        vehicle_words = ['jet', 'train']

        self.double_word_dict = {}
        for double_word in coco_double_words:
            self.double_word_dict[double_word] = double_word
        for animal_word in animal_words:
            self.double_word_dict['baby %s' %animal_word] = animal_word
            self.double_word_dict['adult %s' %animal_word] = animal_word
        for vehicle_word in vehicle_words:
            self.double_word_dict['passenger %s' %vehicle_word] = vehicle_word
        self.double_word_dict['bow tie'] = 'tie'
        self.double_word_dict['toilet seat'] = 'toilet'
        self.double_word_dict['wine glas'] = 'wine glass'

    def _load_generated_captions_into_evaluator(self, cap_file, sample_num,):

        '''
        Meant to save time so imid_to_objects does not always need to be recomputed.
        '''
        #Read in captions
        self.caps, imids, self.metrics = load_generated_outputs(cap_file)
        self.caps = list(self.caps)
        for index, cap in enumerate(self.caps):
            cap['index'] = index
        if sample_num != -1:
            self.caps = random.sample(self.caps, sample_num)
        print("cal cap num:", len(self.caps))

        assert imids == set(self.imid_to_objects.keys())

    def get_double_words_only(self, word_list):
        i = 0
        double_words = []
        idxs = []
        words = word_list
        while i < len(words):
           idxs.append(i)
           double_word = ' '.join(words[i:i+2])
           if double_word in self.double_word_dict:
               double_words.append(self.double_word_dict[double_word])
               i += 2
           else:
            #    double_words.append(words[i])
               i += 1
        words = double_words

        return words

    def caption_to_words(self, caption):

        '''
        Input: caption
        Output: MSCOCO words in the caption
        '''

        words = nltk.word_tokenize(caption.lower())
        words_2 = [lemma.lemmatize(w) for w in words]
        words = words_2

        #replace double words
        i = 0
        double_words = []
        idxs = []
        while i < len(words):
           idxs.append(i)
           double_word = ' '.join(words[i:i+2])
           if double_word in self.double_word_dict:
               double_words.append(self.double_word_dict[double_word])
               i += 2
           else:
               double_words.append(words[i])
               i += 1
        words = double_words

        #toilet seat is not chair (sentences like "the seat of the toilet" will fire for "chair" if we do not include this line)
        if ('toilet' in words) & ('seat' in words): words = [word for word in words if word != 'seat']

        #get synonyms for all words in the caption
        idxs = [idxs[idx] for idx, word in enumerate(words) \
                if word in set(self.mscoco_objects)]
        words = [word for word in words if word in set(self.mscoco_objects)]
        node_words = []
        for word in words:
            node_words.append([word, self.inverse_synonym_dict[word]])
        #return all the MSCOCO objects in the caption
        return words, node_words, idxs, double_words

    def caption_objects_to_coco_objects(self, words):
        idxs = list(range(len(words)))
        if ('toilet' in words) & ('seat' in words): words = [word for word in words if word != 'seat']
        #get synonyms for all words in the caption
        idxs = [idxs[idx] for idx, word in enumerate(words) \
                if word in set(self.mscoco_objects)]
        words = [word for word in words if word in set(self.mscoco_objects)]
        node_words = []
        for word in words:
            node_words.append([word, self.inverse_synonym_dict[word]])

        #return all the MSCOCO objects in the caption
        return words, node_words, idxs

    def get_annotations_from_segments(self):
        '''
        Add objects taken from MSCOCO segmentation masks
        '''

        coco_segments = combine_coco_instances(self.coco_path)
        segment_annotations = coco_segments['annotations']

        id_to_name = {} #dict with id to synsets
        for cat in coco_segments['categories']:
            id_to_name[cat['id']] = cat['name']


        for i, annotation in enumerate(segment_annotations):
            sys.stdout.write("\rGetting annotations for %d/%d segmentation masks"
                              %(i, len(segment_annotations)))
            imid = annotation['image_id']
            if imid in self.imid_to_objects:
                node_word = self.inverse_synonym_dict[id_to_name[annotation['category_id']]]
                self.imid_to_objects[imid].append(node_word)
        print("\n")
        for imid in self.imid_to_objects:
            self.imid_to_objects[imid] = set(self.imid_to_objects[imid])

    def get_annotations_from_captions(self):
        '''
        Add objects taken from MSCOCO ground truth captions
        '''

        coco_caps = combine_coco_captions(self.coco_path)
        caption_annotations = coco_caps['annotations']

        for i, annotation in enumerate(caption_annotations):
            sys.stdout.write('\rGetting annotations for %d/%d ground truth captions'
                            %(i, len(coco_caps['annotations'])))
            imid = annotation['image_id']
            if imid in self.imid_to_objects:
                _, node_words, _, _ = self.caption_to_words(annotation['caption'])
                self.imid_to_objects[imid].update([item[-1] for item in node_words])
        print("\n")

        for imid in self.imid_to_objects:
            self.imid_to_objects[imid] = set(self.imid_to_objects[imid])

    def get_annotations(self):

        '''
        Get annotations from both segmentation and captions.  Need both annotation types for CHAIR metric.
        '''

        self.get_annotations_from_segments()
        self.get_annotations_from_captions()

    def get_gpt_resp(self, data_item):
        prompt_template = """You are an expert in image objects extraction according to a question answer pair. We asked an examiner to answer a question about a picture.
            [Start of Question]

            <image> {question}

            [End of Question]

            [Start of Examiner's Answer]

            {answer}

            [End of Examiner's Answer]


            Assume that the answer is correct, please identify all visible objects that are directly shown in the image. Please following the instructions in below:

            1. You should only mention objects that are explicitly mentioned in the examiner's answer.
            2. You should only extract the object names without the attributes of the objects.
            3. You should not include the properties of the object, like the color, material, etc. as part of the object name in your result.
            4. Make your answer precise. Present the results in a JSON list format: [\"object_1\", ..., \"object_n\"].
            5. You should return an empty JSON list () if no visible objects can be found.
        """
        prompt = copy.deepcopy(prompt_template)
        prompt = prompt.replace("{question}", data_item["question"])
        prompt = prompt.replace("{answer}", data_item["response"])

        messages = [
            {"role": "system", "content": prompt}
        ]

        fail_cnt = 0
        used_tokens = {"total": 0, "input": 0, "output": 0}
        while True:
            if len(data_item["response"].strip().split()) <= 3:
                data_item["extract_objs"] = []
                print(f"**[Short Answer]**@{data_item['response']}@", data_item["extract_objs"])
                return data_item, used_tokens, {"total": 0, "input": 0, "output": 0}

            if fail_cnt == self.fail_limit:
                data_item["extract_objs"] = f'-1\n<no_response>'
                print("**[Wrong Return]**", data_item["extract_objs"])
                return data_item, used_tokens, {"total": 0, "input": 0, "output": 0}

            resp = None
            # try:
            # print('input: ', messages)
            resp = self.chat_model.chat_completion(messages=messages)
            # print(resp["model"])

            # Logging consumption
            used_tokens["total"] += resp['usage']["total_tokens"]
            used_tokens["input"] += resp['usage']["prompt_tokens"]
            used_tokens["output"] += resp['usage']["completion_tokens"]

            # Parsing ChatGPT response
            content = resp["choices"][0]["message"]["content"]
            
            content = parse_object_list(content)

            # API Rest
            time.sleep(1)

            data_item["extract_objs"] = content
            # print('output: ', content)
            data_item['input'] = messages
            data_item['output'] = resp
            success_tokens = {"total": resp['usage']['total_tokens'],
                                "input": resp['usage']['prompt_tokens'],
                                "output": resp['usage']['completion_tokens']}
            return data_item, used_tokens, success_tokens
            # except Exception as e:
            #     fail_cnt += 1
            #     print("Exception:", e, 'resp is ', resp)

    def gpt_caption_processor(self, max_workers=2):
        data_list = self.caps  #[:5]
        print(self.caps)
        exit()
        new_data = []
        all_used_tokens = {"total": 0, "input": 0, "output": 0}
        all_success_tokens = {"total": 0, "input": 0, "output": 0}
        with ThreadPoolExecutor(max_workers=max_workers) as executor:

            futures = [
                executor.submit(self.get_gpt_resp, data_item)
                for data_item in data_list
            ]

            pb = tqdm(total=len(futures))

            for i, future in enumerate(concurrent.futures.as_completed(futures)):
                pb.update(1)
                # try:
                new_data_item, used_tokens, success_tokens = future.result() # type = List
                all_used_tokens = {key: all_used_tokens[key] + used_tokens[key] for key in all_used_tokens.keys()}
                all_success_tokens = {key: all_success_tokens[key] + success_tokens[key] for key in all_success_tokens.keys()}
                new_data.append(new_data_item)
                # except Exception as e:
                #     print(f"@@@ Exception: {e}\n")
        print(f'Done loop, waiting resource finalization', flush=True)

        return new_data, all_used_tokens, all_success_tokens

    def postagging(self, doc):
        obj_list = []
        temp_token = ""

        for token in doc:
            if token.tag_ in ["NNP", "NNPS", "NN", "NNS"]:
                temp_token += f" {token.lemma_}"
            else:
                if temp_token != "":
                    obj_list.append(temp_token.strip())
                    temp_token = ""
        if temp_token != "":
            obj_list.append(temp_token.strip())
            temp_token = ""
        return obj_list

    def get_pred_objs_match(self, caps):
        new_caps = []
        for idx, item in enumerate(caps, 1):
            print(idx, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
            caps_gpt_objs = item["extract_objs"]
            assert caps_gpt_objs != f'-1\n<no_response>'
            refined_objs = []
            for text in caps_gpt_objs:
                text = f"a {text}"
                doc = nlp(text)
                single_tokens = [token.lemma_ for token in doc]
                double_words_objs = self.get_double_words_only(single_tokens)

                if double_words_objs != []:
                    refined_objs += double_words_objs
                    continue

                postagging_objs = self.postagging(doc)
                refined_objs += postagging_objs

            new_item = copy.deepcopy(item)

            # only append unique word in the list
            new_item["objs"] = []
            for robj in refined_objs:
                if robj not in new_item["objs"]:
                    new_item["objs"].append(robj)

            new_caps.append(new_item)

        return new_caps

    def compute_chair(self, cap_file, sample_num, gpt_process=True):

        '''
        Given ground truth objects and generated captions, determine which sentences have hallucinated words.
        '''

        self._load_generated_captions_into_evaluator(cap_file, sample_num)

        imid_to_objects = self.imid_to_objects
        caps = self.caps

        if gpt_process:
            caps, all_used_tokens, all_success_tokens = self.gpt_caption_processor()
            caps = self.get_pred_objs_match(caps)
        else:
            all_used_tokens = {}
            all_success_tokens = {}
        num_caps = 0.
        num_coco_caps = 0.
        num_hallucinated_caps = 0.
        hallucinated_word_count = 0.
        coco_word_count = 0.
        gt_word_count = 0.
        coco_obj_cls_count = 0.

        output = {'sentences': []}
        total_cap_word_num = 0
        for i, cap_eval in enumerate(caps):
            cap = cap_eval['response']
            total_cap_word_num += len(cap.strip().split(" "))
            imid = cap_eval['image_id']

            #get all words in the caption, as well as corresponding node word
            if gpt_process:
                ext_objs = cap_eval["objs"]
                words, node_words, idxs = self.caption_objects_to_coco_objects(ext_objs)
                raw_words = ext_objs
            else:
                words, node_words, idxs, raw_words = self.caption_to_words(cap)

            gt_objects = imid_to_objects[imid]
            gt_word_count += len(gt_objects)
            cap_dict = {'image_id': cap_eval['image_id'],
                        'response': cap, # org cap
                        'mscoco_hallucinated_words': [],
                        'mscoco_gt_words': list(gt_objects), # gt coco objs
                        'mscoco_generated_words': list(node_words), # gen mapped coco objs
                        'hallucination_idxs': [],
                        'words': raw_words, # gpt process -> map double words -> postagging results, or original text words lemmas
                        # 'input': cap_eval['input'],
                        # 'output': cap_eval['output'],
                        }

            cap_dict['metrics'] = {'CHAIRs': 0,
                                   'CHAIRi': 0}

            #count hallucinated words, if [word, coco_obj_cls] is unique, count as one prediction
            coco_word_count += len(node_words)
            caption_coco_obj_cls = []

            hallucinated = False
            for word, node_word, idx in zip(words, node_words, idxs):
                if node_word[-1] not in gt_objects:
                    hallucinated_word_count += 1
                    cap_dict['mscoco_hallucinated_words'].append((word, node_word))
                    cap_dict['hallucination_idxs'].append(idx)
                    hallucinated = True
                else:
                    caption_coco_obj_cls.append(node_word[-1])

            caption_coco_obj_cls = set(caption_coco_obj_cls)
            # print(caption_coco_obj_cls)
            coco_obj_cls_count += len(caption_coco_obj_cls)

            #count hallucinated caps
            num_caps += 1
            if hallucinated:
               num_hallucinated_caps += 1

            cap_dict['metrics']['CHAIRs'] = int(hallucinated)
            cap_dict['metrics']['CHAIRi'] = 0.
            if len(words) > 0:
                num_coco_caps += 1
                cap_dict['metrics']['CHAIRi'] = len(cap_dict['mscoco_hallucinated_words'])/float(len(words))

            output['sentences'].append(cap_dict)

        chair_s = (num_hallucinated_caps/num_caps)
        chair_s_refine = (num_hallucinated_caps/num_coco_caps)
        chair_i = (hallucinated_word_count/coco_word_count)
        avg_word_len = float(total_cap_word_num)/num_caps
        obj_rec = coco_obj_cls_count/gt_word_count
        output['overall_metrics'] = {'CHAIRs': chair_s,
                                     'CHAIRs_refine': chair_s_refine,
                                     'CHAIRi': chair_i,
                                     'obj_rec': obj_rec,
                                     'sentence_num': num_caps,
                                     'coco_sentence_num': num_coco_caps,
                                     'coco_word_count': coco_obj_cls_count, # predict coco object classes
                                     'gt_word_count': gt_word_count, # ground truth coco object classes
                                     'avg_word_len': avg_word_len,
                                     'all_gpt_used_tokens': all_used_tokens,
                                     'all_gpt_success_tokens': all_success_tokens,
                                     'correct_rate': 1 - chair_s_refine,
                                     'object_correct_rate': 1 - chair_i
                                     }

        return output

def print_metrics(hallucination_cap_dict, quiet=False):
    sentence_metrics = hallucination_cap_dict['overall_metrics']
    metric_string = "%0.001f\t%0.001f\t%0.001f\t%d\t%d\t%0.01f" %(
                                        sentence_metrics['CHAIRs']*100,
                                        sentence_metrics['CHAIRs_refine']*100,
                                        sentence_metrics['CHAIRi']*100,
                                        sentence_metrics['sentence_num'],
                                        sentence_metrics['coco_sentence_num'],
                                        sentence_metrics['avg_word_len'])

    if not quiet:
        print("CHAIRs\tCHAIRsr\tCHAIRi\tsent_num\tcoco_num\tavg_len")
        print(metric_string)

    else:
        return metric_string

if __name__ == "__main__":
    args= parse_args()
    set_random_seed(args.seed)
    read_record =  os.path.join(args.record_root, str(args.experiment_id), args.model, args.dataset)
    record_path = os.path.join(args.record_path, str(args.experiment_id), args.model, args.dataset)
    
    os.makedirs(record_path, exist_ok=True)

    logger = setup_logger(args.dataset, os.path.join(record_path, args.evaluate_file[:-5]+'.log'))
    save_gpt_4_path = os.path.join(record_path, args.evaluate_file[:-5]+'_gpt4.json')
    log_all_args(args, logger)

    if args.dataset == 'beaf':
        evaluate_all(os.path.join(read_record, args.evaluate_file), os.path.join(record_path, args.evaluate_file), logger)
    elif args.dataset == 'MME':
        cal = calculate_metrics()
        total_score = cal.evaluate_MME(os.path.join(read_record, args.evaluate_file), logger)
        logger.info(f"Total score: {total_score}")
    elif args.dataset == 'AMBER' or args.dataset == 'AMBER_G':
        evaluate_AMBER(inference_data=os.path.join(read_record, args.evaluate_file), 
                        word_association_path='eval/AMBER_data/relation.json',
                        safe_words_path='eval/AMBER_data/safe_words.txt',
                        similarity_score=0.8,
                        metrics_path='eval/AMBER_data/metrics.txt',
                        annotation_path='eval/AMBER_data/annotations.json',
                        logger=logger
                        )
    elif args.dataset == 'MMH':
        eval_MMHal(inference_data=os.path.join(read_record, args.evaluate_file),
                   api_key=args.api_key,
                   save_gpt_4_path=save_gpt_4_path,
                   logger=logger
                   )
    elif args.dataset == 'obj':
        _, imids, _ = load_generated_outputs(os.path.join(read_record, args.evaluate_file))
        evaluator = CHAIR(imids, args.coco_path, args.api_key)
        evaluator.get_annotations()
        temp_caps, temp_imids, _ = load_generated_outputs(os.path.join(read_record, args.evaluate_file))
        cap_dict = evaluator.compute_chair(os.path.join(read_record, args.evaluate_file), -1, gpt_process=False)
        print_metrics(cap_dict)
    else:
        evaluate_part(os.path.join(read_record, args.evaluate_file), logger)
