from openai import OpenAI
import argparse
import json
from tqdm import tqdm
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

client = OpenAI(
    api_key="",
    base_url=""
)

def judge(question: str, target: str, gen_opts) -> int:
    if len(gen_opts) == 0:
        return 0

    prompt = f"""You are an intelligent chatbot designed to evaluate the correctness of given options. 
You will be provided with a question, a reference answer, and a set of options. 
Without relying on external knowledge, determine whether the correct answer is included in the provided options, and count the number of correct answers in the options.
```
Question: {question} 
Reference Answer: {target} 
Options: {gen_opts} 
``` 
Please directly return the number of correct options, without any additional text. If no option is correct, return 0.
""".strip()

    response = client.chat.completions.create(
        # model="gpt-4o-2024-08-06",
        # model="gpt-5-mini-2025-08-07",
        model="gpt-4.1-2025-04-14",
        messages=[
            {"role": "user", "content": prompt}
        ],
        temperature=0,
        max_tokens=5,
        timeout=20
    )
    answer = response.choices[0].message.content.strip()
    return int(answer.lower())


def process_item(item):
    while True:
        try:
            c = judge(item['ques'], item['gt'], item['gen_opts'])
            item['correct'] = c
            break
        except Exception as e:
            print(e, item)
            time.sleep(1)
    return item

parser = argparse.ArgumentParser(description='Run GPT judging on Generated Options')
parser.add_argument('--num_workers', type=int, default = 16)
parser.add_argument('--input_dir', type=str, default = "/mnt/bn/wzr/code/VTR-VLM/exp/main/4_generated_opts/llava/videoevalpro/2025_07_16_18_31_51")
args = parser.parse_args()


judge_results = {}
if os.path.exists(os.path.join(args.input_dir, "gen_opts_check_results.json")):
    with open(os.path.join(args.input_dir, "gen_opts_check_results.json"), "r") as f:
        for line in f.readlines():
            item = json.loads(line)
            judge_results[item['id']] = item
print("Already Judged : ", len(judge_results))

data_to_judge = []
with open(os.path.join(args.input_dir, "results.json"), "r") as f:
    for line in f.readlines():
        item = json.loads(line)
        if item['QA']['id'] in judge_results:
            continue
        data_to_judge.append({
            "id" : item['QA']['id'],
            "ques" : item['QA']['ques'],
            "gt" : item['QA']['opts'][ord(item['QA']['gt_ans'][0])-ord('A')],
            "gen_opts": item['QA']['gen_opts'],
        })

print("Remain To Judge : ", len(data_to_judge))

futures = []
with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
    pbar = tqdm(total=len(data_to_judge), desc="Judge")
    for i in range(len(data_to_judge)):
        future = executor.submit(process_item, data_to_judge[i])
        future.add_done_callback(lambda _: pbar.update(1))
        futures.append(future)
    for future in as_completed(futures):
        item = future.result()
        judge_results[item['id']] = item
        # print(item)
        with open(os.path.join(args.input_dir, "gen_opts_check_results.json"), "a") as f:
            f.write(json.dumps(item, ensure_ascii=False)+"\n")
            f.flush()
        
        with open(os.path.join(args.input_dir, "gen_opts_check_metrics.json"), "w") as f:
            correct_answer_cover = []
            total_count, true_count = 0,0
            for x in judge_results.values():
                if x['correct'] > 0:
                    true_count += 1
                    correct_answer_cover.append(x['correct']/len(x['gen_opts']))
                total_count += 1
            accuracy = true_count/total_count if total_count > 0 else 0
            xxxxx = [
                {"accuracy": accuracy, 
                 "total_count": total_count, 
                 "true_count": true_count}, 
                {"correct_answer_cover": sum(correct_answer_cover)/len(correct_answer_cover) 
                 if len(correct_answer_cover) > 0 else 0}]
            json.dump(xxxxx, f, indent=4, ensure_ascii=False)
