import os
import sys
import ssl
import json
import copy
import glob
import time
import pathlib
import random
import jsonlines

import nltk
import spacy
import argparse
import concurrent.futures
import openai
from concurrent.futures import ThreadPoolExecutor
from nltk.stem import *
from tqdm import tqdm


try:
    _create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
    pass
else:
    ssl._create_default_https_context = _create_unverified_https_context


nlp = spacy.load("en_core_web_trf")
lemma = nltk.wordnet.WordNetLemmatizer()

class Chat:
    def __init__(self, model="", timeout_sec=20, openai_apikey='', apibase=''):
        self.model = model
        self.timeout = timeout_sec
        openai.api_key = openai_apikey
        self.client = openai.AzureOpenAI(
            azure_endpoint=apibase,
            api_version="2023-03-15-preview",
            api_key=openai_apikey
        )

    def chat_completion(self, messages, temperature=0, top_p=1, max_tokens=512,
                        presence_penalty=0, frequency_penalty=0):

        # response = openai.ChatCompletion.create(
        #     model=self.model,
        #     messages=messages,
        #     temperature=temperature,
        #     top_p=top_p,
        #     max_tokens=max_tokens,
        #     presence_penalty=presence_penalty,
        #     frequency_penalty=frequency_penalty
        # )
        response = self.client.chat.completions.create(
            extra_headers={"X-TT-LOGID": "xxx"},
            model=self.model,
            messages=messages,
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            presence_penalty=presence_penalty,
            frequency_penalty=frequency_penalty,
            timeout=120,
        )

        return response



def parse_object_list(content):
    try:
        content = json.loads(content)
    except:
        if '```json' in content:
            try:
                content = json.loads(content.replace('```json', '').replace('```', '').strip())
            except:
                raise ValueError('Content is not json interpretable')
        elif '["' in content:
            try:
                content = json.loads(content.strip().split('\n')[-1])
            except:
                raise ValueError('Content is not json interpretable')
        else:
            raise ValueError('Content is not json interpretable')
    return content


prompt_template = """You are an expert in image objects extraction according to a question answer pair. We asked an examiner to answer a question about a picture.

[Start of Question]

<image> {question}

[End of Question]

[Start of Examiner's Answer]

{answer}

[End of Examiner's Answer]


Assume that the answer is correct, please identify all visible objects that are directly shown in the image. Please following the instructions in below:

1. You should only mention objects that are explicitly mentioned in the examiner's answer.
2. You should only extract the object names without the attributes of the objects.
3. You should not include the properties of the object, like the color, material, etc. as part of the object name in your result.
4. Make your answer precise. Present the results in a JSON list format: [\"object_1\", ..., \"object_n\"].
5. You should return an empty JSON list () if no visible objects can be found.
"""

def preprocess_coh_results(caps):
    new_caps = []
    for cap in caps:
        cap_text = cap["caption"]
        if "The following is a response without hallucination." in cap_text:
            new_cap_text = cap_text.split("The following is a response without hallucination.")[-1].strip()
        elif "The following is a response with hallucination." in cap_text:
            new_cap_text = cap_text.split("The following is a response with hallucination.")[0].strip()
        elif "Generate a response without errors." in cap_text:
            new_cap_text = cap_text.split("Generate a response without errors.")[-1].strip()
        elif "Generate a response with errors." in cap_text:
            new_cap_text = cap_text.split("Generate a response with errors.")[0].strip()
        else:
            new_cap_text = cap_text
        cap['caption'] = new_cap_text
        new_caps.append(cap)

    return new_caps

def combine_coco_captions(annotation_path):

    if not os.path.exists('%s/captions_%s2014.json' %(annotation_path, 'val')):
        raise Exception("Please download MSCOCO caption annotations for val set")
    if not os.path.exists('%s/captions_%s2014.json' %(annotation_path, 'train')):
        raise Exception("Please download MSCOCO caption annotations for train set")

    val_caps = json.load(open('%s/captions_%s2014.json' %(annotation_path, 'val')))
    train_caps = json.load(open('%s/captions_%s2014.json' %(annotation_path, 'train')))
    all_caps = {'info': train_caps['info'],
                'licenses': train_caps['licenses'],
                'images': val_caps['images'] + train_caps['images'],
                'annotations': val_caps['annotations'] + train_caps['annotations']}

    return all_caps

def combine_coco_instances(annotation_path):

    if not os.path.exists('%s/instances_%s2014.json' %(annotation_path, 'val')):
        raise Exception("Please download MSCOCO instance annotations for val set")
    if not os.path.exists('%s/instances_%s2014.json' %(annotation_path, 'train')):
        raise Exception("Please download MSCOCO instance annotations for train set")

    val_instances = json.load(open('%s/instances_%s2014.json' %(annotation_path, 'val')))
    train_instances = json.load(open('%s/instances_%s2014.json' %(annotation_path, 'train')))
    all_instances = {'info': train_instances['info'],
                     'licenses': train_instances['licenses'],
                     'type': train_instances['licenses'],
                     'categories': train_instances['categories'],
                     'images': train_instances['images'] + val_instances['images'],
                     'annotations': val_instances['annotations'] + train_instances['annotations']}

    return all_instances

class CHAIR(object):

    def __init__(self, imids, coco_path, openai_apikey, apibase=''):

        self.imid_to_objects = {imid: [] for imid in imids}

        self.coco_path = coco_path

        self.chat_model = Chat(model="gpt-4o-2024-05-13", timeout_sec=100, openai_apikey=openai_apikey, apibase=apibase)
        
        self.fail_limit=20


        #read in synonyms
        synonyms = open('/opt/tiger/vtcl/svco-eval/eval_vlm/playground/data/eval/objectHal/synonyms_refine.txt').readlines()
        synonyms = [s.strip().split(', ') for s in synonyms]
        self.mscoco_objects = [] #mscoco objects and *all* synonyms
        self.inverse_synonym_dict = {}
        for synonym in synonyms:
            new_synonym = [s.strip() for s in synonym]
            self.mscoco_objects.extend(new_synonym)
            for s in new_synonym:
                self.inverse_synonym_dict[s] = new_synonym[0]

        coco_double_words = [word for word in self.inverse_synonym_dict.keys() if len(word.strip().split(' ')) >= 2]
        coco_double_words += ['home plate', 'train track']
        print("double word count:", len(coco_double_words))

        animal_words = ['bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'animal', 'cub']
        vehicle_words = ['jet', 'train']

        self.double_word_dict = {}
        for double_word in coco_double_words:
            self.double_word_dict[double_word] = double_word
        for animal_word in animal_words:
            self.double_word_dict['baby %s' %animal_word] = animal_word
            self.double_word_dict['adult %s' %animal_word] = animal_word
        for vehicle_word in vehicle_words:
            self.double_word_dict['passenger %s' %vehicle_word] = vehicle_word
        self.double_word_dict['bow tie'] = 'tie'
        self.double_word_dict['toilet seat'] = 'toilet'
        self.double_word_dict['wine glas'] = 'wine glass'

    def _load_generated_captions_into_evaluator(self, cap_file, sample_num, org_dir=None):

        '''
        Meant to save time so imid_to_objects does not always need to be recomputed.
        '''
        #Read in captions
        self.caps, imids, self.metrics = load_generated_captions(cap_file, org_dir=org_dir)
        self.caps = list(self.caps)
        for index, cap in enumerate(self.caps):
            cap['index'] = index
        if sample_num != -1:
            self.caps = random.sample(self.caps, sample_num)
        print("cal cap num:", len(self.caps))

        assert imids == set(self.imid_to_objects.keys())

    def get_double_words_only(self, word_list):
        i = 0
        double_words = []
        idxs = []
        words = word_list
        while i < len(words):
           idxs.append(i)
           double_word = ' '.join(words[i:i+2])
           if double_word in self.double_word_dict:
               double_words.append(self.double_word_dict[double_word])
               i += 2
           else:
            #    double_words.append(words[i])
               i += 1
        words = double_words

        return words

    def caption_to_words(self, caption):

        '''
        Input: caption
        Output: MSCOCO words in the caption
        '''

        words = nltk.word_tokenize(caption.lower())
        words_2 = [lemma.lemmatize(w) for w in words]
        words = words_2

        #replace double words
        i = 0
        double_words = []
        idxs = []
        while i < len(words):
           idxs.append(i)
           double_word = ' '.join(words[i:i+2])
           if double_word in self.double_word_dict:
               double_words.append(self.double_word_dict[double_word])
               i += 2
           else:
               double_words.append(words[i])
               i += 1
        words = double_words

        #toilet seat is not chair (sentences like "the seat of the toilet" will fire for "chair" if we do not include this line)
        if ('toilet' in words) & ('seat' in words): words = [word for word in words if word != 'seat']

        #get synonyms for all words in the caption
        idxs = [idxs[idx] for idx, word in enumerate(words) \
                if word in set(self.mscoco_objects)]
        words = [word for word in words if word in set(self.mscoco_objects)]
        node_words = []
        for word in words:
            node_words.append([word, self.inverse_synonym_dict[word]])
        #return all the MSCOCO objects in the caption
        return words, node_words, idxs, double_words

    def caption_objects_to_coco_objects(self, words):
        idxs = list(range(len(words)))
        if ('toilet' in words) & ('seat' in words): words = [word for word in words if word != 'seat']
        #get synonyms for all words in the caption
        idxs = [idxs[idx] for idx, word in enumerate(words) \
                if word in set(self.mscoco_objects)]
        words = [word for word in words if word in set(self.mscoco_objects)]
        node_words = []
        for word in words:
            node_words.append([word, self.inverse_synonym_dict[word]])

        #return all the MSCOCO objects in the caption
        return words, node_words, idxs

    def get_annotations_from_segments(self):
        '''
        Add objects taken from MSCOCO segmentation masks
        '''

        coco_segments = combine_coco_instances(self.coco_path)
        segment_annotations = coco_segments['annotations']

        id_to_name = {} #dict with id to synsets
        for cat in coco_segments['categories']:
            id_to_name[cat['id']] = cat['name']


        for i, annotation in enumerate(segment_annotations):
            sys.stdout.write("\rGetting annotations for %d/%d segmentation masks"
                              %(i, len(segment_annotations)))
            imid = annotation['image_id']
            if imid in self.imid_to_objects:
                node_word = self.inverse_synonym_dict[id_to_name[annotation['category_id']]]
                self.imid_to_objects[imid].append(node_word)
        print("\n")
        for imid in self.imid_to_objects:
            self.imid_to_objects[imid] = set(self.imid_to_objects[imid])

    def get_annotations_from_captions(self):
        '''
        Add objects taken from MSCOCO ground truth captions
        '''

        coco_caps = combine_coco_captions(self.coco_path)
        caption_annotations = coco_caps['annotations']

        for i, annotation in enumerate(caption_annotations):
            sys.stdout.write('\rGetting annotations for %d/%d ground truth captions'
                            %(i, len(coco_caps['annotations'])))
            imid = annotation['image_id']
            if imid in self.imid_to_objects:
                _, node_words, _, _ = self.caption_to_words(annotation['caption'])
                self.imid_to_objects[imid].update([item[-1] for item in node_words])
        print("\n")

        for imid in self.imid_to_objects:
            self.imid_to_objects[imid] = set(self.imid_to_objects[imid])

    def get_annotations(self):

        '''
        Get annotations from both segmentation and captions.  Need both annotation types for CHAIR metric.
        '''

        self.get_annotations_from_segments()
        self.get_annotations_from_captions()

    def get_gpt_resp(self, data_item):
        prompt = copy.deepcopy(prompt_template)
        prompt = prompt.replace("{question}", data_item["question"])
        prompt = prompt.replace("{answer}", data_item["caption"])

        messages = [
            {"role": "system", "content": prompt}
        ]

        fail_cnt = 0
        used_tokens = {"total": 0, "input": 0, "output": 0}
        while True:
            if len(data_item["caption"].strip().split()) <= 3:
                data_item["extract_objs"] = []
                print(f"**[Short Answer]**@{data_item['caption']}@", data_item["extract_objs"])
                return data_item, used_tokens, {"total": 0, "input": 0, "output": 0}

            if fail_cnt == self.fail_limit:
                data_item["extract_objs"] = f'-1\n<no_response>'
                print("**[Wrong Return]**", data_item["extract_objs"])
                return data_item, used_tokens, {"total": 0, "input": 0, "output": 0}

            resp = None
            try:
                # print('input: ', messages)
                resp = self.chat_model.chat_completion(messages=messages)
                resp_dict = resp.model_dump() if hasattr(resp, 'model_dump') else resp.__dict__
                # print(resp["model"])

                # Logging consumption
                # used_tokens["total"] += resp['usage']["total_tokens"]
                # used_tokens["input"] += resp['usage']["prompt_tokens"]
                # used_tokens["output"] += resp['usage']["completion_tokens"]

                # Parsing ChatGPT response
                # content = resp["choices"][0]["message"]["content"]
                content = resp.choices[0].message.content
                content = parse_object_list(content)

                # API Rest
                time.sleep(1)

                data_item["extract_objs"] = content
                # print('output: ', content)
                data_item['input'] = messages
                data_item['output'] = resp_dict
                json.dumps(data_item, ensure_ascii=False)
                # success_tokens = {"total": resp['usage']['total_tokens'],
                #                   "input": resp['usage']['prompt_tokens'],
                #                   "output": resp['usage']['completion_tokens']}
                success_tokens = {"total": 0, "input": 0, "output": 0}
                return data_item, used_tokens, success_tokens
            except Exception as e:
                fail_cnt += 1
                print("Exception:", e, 'resp is ', resp)


                time.sleep(10 + fail_cnt)

    def gpt_caption_processor(self, max_workers=10):
        data_list = self.caps  #[:5]
        new_data = []
        all_used_tokens = {"total": 0, "input": 0, "output": 0}
        all_success_tokens = {"total": 0, "input": 0, "output": 0}
        with ThreadPoolExecutor(max_workers=max_workers) as executor:

            futures = [
                executor.submit(self.get_gpt_resp, data_item)
                for data_item in data_list
            ]

            pb = tqdm(total=len(futures))

            for i, future in enumerate(concurrent.futures.as_completed(futures)):
                pb.update(1)

                try:
                    new_data_item, used_tokens, success_tokens = future.result() # type = List
                    all_used_tokens = {key: all_used_tokens[key] + used_tokens[key] for key in all_used_tokens.keys()}
                    all_success_tokens = {key: all_success_tokens[key] + success_tokens[key] for key in all_success_tokens.keys()}
                    new_data.append(new_data_item)

                except Exception as e:
                    print(f"@@@ Exception: {e}\n")
        print(f'Done loop, waiting resource finalization', flush=True)

        return new_data, all_used_tokens, all_success_tokens

    def postagging(self, doc):
        obj_list = []
        temp_token = ""

        for token in doc:
            if token.tag_ in ["NNP", "NNPS", "NN", "NNS"]:
                temp_token += f" {token.lemma_}"
            else:
                if temp_token != "":
                    obj_list.append(temp_token.strip())
                    temp_token = ""
        if temp_token != "":
            obj_list.append(temp_token.strip())
            temp_token = ""
        return obj_list

    def get_pred_objs_match(self, caps):
        new_caps = []
        for idx, item in enumerate(caps, 1):
            print(idx, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
            caps_gpt_objs = item["extract_objs"]
            assert caps_gpt_objs != f'-1\n<no_response>'
            refined_objs = []
            for text in caps_gpt_objs:
                text = f"a {text}"
                doc = nlp(text)
                single_tokens = [token.lemma_ for token in doc]
                double_words_objs = self.get_double_words_only(single_tokens)

                if double_words_objs != []:
                    refined_objs += double_words_objs
                    continue

                postagging_objs = self.postagging(doc)
                refined_objs += postagging_objs

            new_item = copy.deepcopy(item)

            # only append unique word in the list
            new_item["objs"] = []
            for robj in refined_objs:
                if robj not in new_item["objs"]:
                    new_item["objs"].append(robj)

            new_caps.append(new_item)

        return new_caps

    def compute_chair(self, cap_file, sample_num, gpt_process=False, org_dir=None):

        '''
        Given ground truth objects and generated captions, determine which sentences have hallucinated words.
        '''

        self._load_generated_captions_into_evaluator(cap_file, sample_num, org_dir=org_dir)

        imid_to_objects = self.imid_to_objects
        caps = self.caps

        if gpt_process:
            caps, all_used_tokens, all_success_tokens = self.gpt_caption_processor()
            caps = self.get_pred_objs_match(caps)
        else:
            all_used_tokens = {}
            all_success_tokens = {}
        num_caps = 0.
        num_coco_caps = 0.
        num_hallucinated_caps = 0.
        hallucinated_word_count = 0.
        coco_word_count = 0.
        gt_word_count = 0.
        coco_obj_cls_count = 0.

        output = {'sentences': []}
        total_cap_word_num = 0
        for i, cap_eval in enumerate(caps):
            cap = cap_eval['caption']
            total_cap_word_num += len(cap.strip().split(" "))
            imid = cap_eval['image_id']

            #get all words in the caption, as well as corresponding node word
            if gpt_process:
                ext_objs = cap_eval["objs"]
                words, node_words, idxs = self.caption_objects_to_coco_objects(ext_objs)
                raw_words = ext_objs
            else:
                words, node_words, idxs, raw_words = self.caption_to_words(cap)

            gt_objects = imid_to_objects[imid]
            gt_word_count += len(gt_objects)
            cap_dict = {'image_id': cap_eval['image_id'],
                        'caption': cap, # org cap
                        'mscoco_hallucinated_words': [],
                        'mscoco_gt_words': list(gt_objects), # gt coco objs
                        'mscoco_generated_words': list(node_words), # gen mapped coco objs
                        'hallucination_idxs': [],
                        'words': raw_words, # gpt process -> map double words -> postagging results, or original text words lemmas
                        'input': cap_eval['input'],
                        'output': cap_eval['output'],
                        }

            cap_dict['metrics'] = {'CHAIRs': 0,
                                   'CHAIRi': 0}

            #count hallucinated words, if [word, coco_obj_cls] is unique, count as one prediction
            coco_word_count += len(node_words)
            caption_coco_obj_cls = []

            hallucinated = False
            for word, node_word, idx in zip(words, node_words, idxs):
                if node_word[-1] not in gt_objects:
                    hallucinated_word_count += 1
                    cap_dict['mscoco_hallucinated_words'].append((word, node_word))
                    cap_dict['hallucination_idxs'].append(idx)
                    hallucinated = True
                else:
                    caption_coco_obj_cls.append(node_word[-1])

            caption_coco_obj_cls = set(caption_coco_obj_cls)
            # print(caption_coco_obj_cls)
            coco_obj_cls_count += len(caption_coco_obj_cls)

            #count hallucinated caps
            num_caps += 1
            if hallucinated:
               num_hallucinated_caps += 1

            cap_dict['metrics']['CHAIRs'] = int(hallucinated)
            cap_dict['metrics']['CHAIRi'] = 0.
            if len(words) > 0:
                num_coco_caps += 1
                cap_dict['metrics']['CHAIRi'] = len(cap_dict['mscoco_hallucinated_words'])/float(len(words))

            output['sentences'].append(cap_dict)

        chair_s = (num_hallucinated_caps/num_caps)
        chair_s_refine = (num_hallucinated_caps/num_coco_caps)
        chair_i = (hallucinated_word_count/coco_word_count)
        avg_word_len = float(total_cap_word_num)/num_caps
        obj_rec = coco_obj_cls_count/gt_word_count
        output['overall_metrics'] = {'CHAIRs': chair_s,
                                     'CHAIRs_refine': chair_s_refine,
                                     'CHAIRi': chair_i,
                                     'obj_rec': obj_rec,
                                     'sentence_num': num_caps,
                                     'coco_sentence_num': num_coco_caps,
                                     'coco_word_count': coco_obj_cls_count, # predict coco object classes
                                     'gt_word_count': gt_word_count, # ground truth coco object classes
                                     'avg_word_len': avg_word_len,
                                     'all_gpt_used_tokens': all_used_tokens,
                                     'all_gpt_success_tokens': all_success_tokens,
                                     'correct_rate': 1 - chair_s_refine,
                                     'object_correct_rate': 1 - chair_i
                                     }

        return output

def read_jsonl(jsonl_file):
    data = []
    with open(jsonl_file, 'r', encoding='utf-8') as f1:
        for item in jsonlines.Reader(f1):
            data.append(item)
    return data

def load_generated_captions(cap_file, org_dir=None):
    if cap_file.endswith(".json"):
        #Read in captions
        caps = json.load(open(cap_file))
        try:
            metrics = caps['overall']
            caps = caps['imgToEval'].values()
            imids = set([cap['image_id'] for cap in caps])
        except:
            raise Exception("Expect caption file to consist of a dictionary with sentences correspdonding to the key 'imgToEval'")
    elif cap_file.endswith(".jsonl"):

        caps = read_jsonl(cap_file)

        if "image_id" not in caps[0].keys():
            try:
                assert org_dir != None and org_dir.strip() != ""
            except:
                raise Exception("Expect origin test input file directory for .jsonl cap file")
            cap_name = cap_file.split("/")[-1]
            org_name = cap_name.split("__")[0].replace("_answer", ".jsonl")

            if org_dir.endswith(".jsonl"):
                org_data_path = org_dir
            else:
                org_data_path = os.path.join(org_dir, org_name)
            org_data = read_jsonl(org_data_path)

        metrics = {}
        new_caps = []
        imids = []
        for i in range(len(caps)):
            if "image_id" not in caps[i].keys():
                imgid = int(org_data[i]["image_id"])
            else:
                imgid = int(caps[i]["image_id"])

            imids.append(imgid)

            if "prompt" in caps[i].keys():
                question = caps[i]["prompt"]
            elif "question" in caps[i].keys():
                question = caps[i]["question"]
            else:
                raise Exception("Expect 'question' or 'prompt' in generated file")

            if "text" in caps[i].keys():
                answer = caps[i]["text"].replace("Assistant:", "").strip()
            elif "response" in caps[i].keys():
                answer = caps[i]["response"].replace("Assistant:", "").strip()
            elif "answer" in caps[i].keys():
                answer = caps[i]["answer"].replace("Assistant:", "").strip()
            elif "model_answer" in caps[i].keys():
                answer = caps[i]["model_answer"].replace("Assistant:", "").strip()
            else:
                raise Exception("Expect 'answer' or 'text' in generated file")
            new_item = {"image_id": imgid, "question": question, "caption": answer}
            new_caps.append(new_item)
        caps = new_caps
        imids = set(imids)

    elif "." not in cap_file:
        caps = json.load(open(cap_file))
        try:
            assert 'raw_question' in caps[0].keys()
        except:
            raise Exception("Expect origin test input file directory for .jsonl cap file")
        imids = set([int(cap['question_id'].replace('.jpg')) for cap in caps])
        metrics = {}
        new_caps = []
        for item in caps:
            new_item = {
                "image_id": int(item["question_id"].replace('.jpg')),
                "question": item["raw_question"],
                "caption": item["answer"].replace("Assistant:", "").strip()
            }
            new_caps.append(new_item)
        caps = new_caps

    if "coh" in cap_file:
        caps = preprocess_coh_results(caps)

    return caps, imids, metrics

def print_metrics(hallucination_cap_dict, quiet=False):
    sentence_metrics = hallucination_cap_dict['overall_metrics']
    metric_string = "%0.001f\t%0.001f\t%0.001f\t%d\t%d\t%0.01f" %(
                                        sentence_metrics['CHAIRs']*100,
                                        sentence_metrics['CHAIRs_refine']*100,
                                        sentence_metrics['CHAIRi']*100,
                                        sentence_metrics['sentence_num'],
                                        sentence_metrics['coco_sentence_num'],
                                        sentence_metrics['avg_word_len'])

    if not quiet:
        print("CHAIRs\tCHAIRsr\tCHAIRi\tsent_num\tcoco_num\tavg_len")
        print(metric_string)

    else:
        return metric_string

if __name__ == '__main__':
    time_start = time.time()
    parser = argparse.ArgumentParser()
    parser.add_argument("--cap_file", type=str,default="")
    parser.add_argument("--q_file", type=str, default='/opt/tiger/vtcl/benchmark/objectHal/obj_halbench_300_with_image.jsonl')
    parser.add_argument("--coco_path", type=str, default='/opt/tiger/vtcl/benchmark/objectHal/annotations')
    parser.add_argument("--sample_num", type=int, default=-1)
    parser.add_argument("--openai_key", type=str, default='')
    parser.add_argument("--api_base", type=str, default="")
    args = parser.parse_args()

    print("=======load prediction=======")
    if args.cap_file.endswith('.json'):
        try:
            os.rename(args.cap_file, args.cap_file+'l')
        except:
            pass
        args.cap_file += 'l'
    print("load imgids file:", args.cap_file)
    _, imids, _ = load_generated_captions(args.cap_file, org_dir=args.q_file)
    # assert len(imids) == 300

    print("=======init evaluator=======")
    evaluator = CHAIR(imids, args.coco_path, args.openai_key, args.api_base)
    evaluator.get_annotations()

    print("========compute=========")
    tag = args.cap_file.split('/')[-1].replace(".jsonl", "")
    save_dir = pathlib.Path(args.cap_file).absolute().parent
    target_save_path_new = save_dir / f'hall_{tag}_{args.sample_num}.json'
    print('target_save_path_new: ', target_save_path_new)
    if target_save_path_new.exists():
        print("\teval file already exists!")
        exit(0)

    print("***do process***", flush=True)
    cap_dict = evaluator.compute_chair(args.cap_file, args.sample_num, gpt_process=True, org_dir=args.q_file)
    print(f'Done computing')

    with open(target_save_path_new, 'w', encoding='utf-8') as f:
        json.dump(cap_dict, f, indent=4, ensure_ascii=False)
    
    print(f'Done Saving')

    print_metrics(cap_dict)
    time_end = time.time()

    print("eval time:", time_end - time_start)
    print('target_save_path_new: ', target_save_path_new)