import re
import os
import copy
import json
import numpy as np
import argparse
from tqdm import tqdm
from collections import defaultdict

template = [
    {"role": "system", "content": "You are a helpful assistant who receives question-answer pairs and extracts the corresponding option letter from the answer. Your response should be a single uppercase letter (A, B, C, D, ...) corresponding to the option the answer selects explicitly or implicitly. If the answer is empty or invalid, return 'Other'. Do not include any additional text."},
    {"role": "user", "content": "Question: {question}\nAnswer: {answer}\nExtract the corresponding option letter from the answer."},
]
eval_model, eval_tokenizer = None, None

def eval_cvbench(answer:str, pred:str, prompt:str):
    answer = answer.strip(" \n,;.?!()[]").lower()
    # get last alpha letter of pred
    match = re.findall(r'\(([A-Z])\)', pred)
    use_lm = False
    if len(pred) < 2:
        pred = pred
    elif len(match) > 0:
        pred = match[-1]
    else:
        global eval_model, eval_tokenizer
        from llava.eval.eval_utils import load_eval_model, query_model
        if eval_model is None:
            eval_model, eval_tokenizer = load_eval_model("Qwen/Qwen2.5-3B-Instruct")
        use_lm = True
        messages = copy.deepcopy(template)
        messages[1]['content'] = messages[1]['content'].format(question=prompt, answer=pred)
        pred = query_model(eval_model, eval_tokenizer, messages, max_new_tokens=3)
        pred = re.findall(r'[A-Z]', pred)[-1]
    valid = pred.lower() in ['a', 'b', 'c', 'd', 'e', 'f']
    return answer == pred.lower(), valid, use_lm

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--question-file", type=str)
    parser.add_argument("--result-file", type=str)
    parser.add_argument("--result-save", type=str)
    args = parser.parse_args()

    questions = [json.loads(line) for line in open(args.question_file)]
    questions = {question['question_id']: question for question in questions}
    answers = [json.loads(q) for q in open(args.result_file)]
    predictions = {pred['question_id']: pred for pred in answers}

    
    all_results = {
        'by_task': defaultdict(list),
        'by_source': defaultdict(list),
        'valid': [],
        'lm_use': [],
    }
    for question_id, question in tqdm(questions.items(), desc="Evaluating"):
        prompt = predictions.get(question_id, {}).get('prompt', '')
        prediction = predictions.get(question_id, {}).get('text', '')
        answer = question.get('answer', )
        result, valid, use_lm = eval_cvbench(answer, prediction, prompt)
        all_results['by_task'][question['task']].append(result)
        all_results['by_source'][question['source']].append(result)
        all_results['valid'].append(valid)
        all_results['lm_use'].append(use_lm)

    all_results['by_task'] = {k: np.mean(v)*100 for k, v in all_results['by_task'].items()}
    all_results['by_source'] = {k: np.mean(v)*100 for k, v in all_results['by_source'].items()}
    all_results['valid'] = np.mean(all_results['valid']) * 100
    all_results['lm_use'] = np.mean(all_results['lm_use']) * 100
    all_results['2d'] = (all_results['by_source']["ADE20K"] + all_results['by_source']["COCO"]) / 2
    all_results['3d'] = all_results['by_source']["Omni3D"]
    all_results['all'] = (all_results['2d'] + all_results['3d']) / 2
    
    print(f"CV-Bench Accuracy: {all_results['all']:.4f}")
    print(f"2D Accuracy: {all_results['2d']:.4f}")
    print(f"3D Accuracy: {all_results['3d']:.4f}")
    print(f"Answer Validity: {all_results['valid']:.4f}, LM Use: {all_results['lm_use']:.4f}")
    print(f"Task Accuracies:")
    for k, v in all_results['by_task'].items():
        print(f"  {k}: {v:.4f}")
    print(f"Source Accuracies:")
    for k, v in all_results['by_source'].items():
        print(f"  {k}: {v:.4f}")
    results = {
        "name": args.result_file,
        "all": all_results['all'],
        "2d": all_results['2d'],
        "3d": all_results['3d'],
        "valid": all_results['valid'],
        "lm_use": all_results['lm_use'],
        "by_task": all_results['by_task'],
        "by_source": all_results['by_source'],
    }
    os.makedirs(os.path.dirname(args.result_save), exist_ok=True)
    with open(args.result_save, 'w') as f:
        json.dump(results, f,indent=4)
