import json
import torch
import random
from tqdm import tqdm
# import openai
import os
from transformers import AutoTokenizer, LlamaForCausalLM, AutoModelForCausalLM
from azfuse import File

device = "cuda" if torch.cuda.is_available() else "cpu"

def stringify(data, kshot):
    demo_string = str()
    random_elements = random.sample(data, kshot)
    for el in random_elements:
        qsn = el["Question"]
        refs = ", ".join(el["Reference answers"])
        cand = el["Candidate answer"]
        op = el["Output"]
        each_string = f'Question: {qsn}\nReference answers: {refs}\nCandidate answer: {cand}\nOutput: {op}\n'
        demo_string+=each_string+"\n"
    return demo_string


def load_demonstrations(args):
    instruction_string = "You are given a question, a gold-standard reference answers written by experts, and a candidate answer. Please rate the accuracy of the candidate answer for the question considering the reference answer. Use a scale of 1-3, with 1 indicating an incorrect or irrelevant answer, 2 indicating an ambiguous or incomplete answer, and 3 indicating a correct answer. Give the rationale before rating."
    
    # load demos
    binary_data = json.load(open(os.path.join(args.demos_dir, "demos_binary.json"), "r"))["demos_binary"]
    nbinary_data = json.load(open(os.path.join(args.demos_dir, "demos_nbinary.json"), "r"))["demos_nbinary"]
    binary_demo_string = stringify(binary_data, args.kshot)
    nbinary_demo_string = stringify(nbinary_data, args.kshot)
    return instruction_string, binary_demo_string, nbinary_demo_string



def run_lave_metric_acc(model_id, gt_data, pred_data, debug=False, overwrite=False):
    instruction_string = '''You are given a question, a gold-standard reference answers written by experts, and a candidate answer. Please rate the accuracy of the candidate answer for the question considering the reference answer. Use a scale of 1-3, with 1 indicating an incorrect or irrelevant answer, 2 indicating an ambiguous or incomplete answer, and 3 indicating a correct answer. Give the rationale after rating.
    
    Please follow the following format:
    Rating: 1
    Rationale: The candidate answer is incorrect because ...
    '''
    if debug:
        output_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_output.debug.jsonl")
        result_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_result.debug.json")
    else:
        output_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_output.jsonl")
        result_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_result.json")
    print(f"Output file: {output_file}")
    if File.isfile(output_file) and (not overwrite):
        print(f"Output file {output_file} already exists, skipping...")
        get_acc_metrics(output_file, result_file)
        return output_file
    
    if 'vizwiz_val' not in result_file:
        if 'Llama-2' in model_id:
            # if args.eval_model_name == 'Llama-2-13b-chat-hf':
            #     model_id = "meta-llama/Llama-2-13b-chat-hf"
            # elif args.eval_model_name == 'Llama-2-70b-chat-hf':
            #     model_id = "meta-llama/Llama-2-70b-chat-hf"       
            use_auth_token = "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"
            tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=use_auth_token, padding="max_length", truncation=True)
            # max_context_length = tokenizer.model_max_length
            model = LlamaForCausalLM.from_pretrained(model_id, device_map="auto")
        elif "Mistral" in model_id:
            model = AutoModelForCausalLM.from_pretrained(model_id)
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            # model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
            # tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
        elif "01-ai" in model_id:
            model = AutoModelForCausalLM.from_pretrained(model_id)
            tokenizer = AutoTokenizer.from_pretrained(model_id)
        model.to(device)
    
    # load data
    # test_questions_json_path = "/<DATA_FOLDER>/vqav2/vqa_k_test_noun_dedup_sampled_1_sft_llaval_idk.jsonl"
    if gt_data.endswith(".jsonl"):
        gt_data = [json.loads(el) for el in File.open(gt_data, 'r')]
        qid2gt_ans = {str(d["question_id"]): d for d in gt_data}
    else:
        with File.open(gt_data, 'r') as f:
            gt_data = json.load(f)
        qid2gt_ans = {str(d["id"]): {"answer": d["conversations"][-1]["value"], "text": d["conversations"][0]["value"], "image": d["image"]} for d in gt_data}
        
        
    pred_data = [json.loads(el) for el in File.open(pred_data, 'r')]
    qid2pred_ans = {str(d["question_id"]): d for d in pred_data}
    qids = [str(d["question_id"]) for d in pred_data]
    results = []

    for idx, qid in tqdm(enumerate(qids)):
        assert qid in qid2gt_ans, f"Question id {qid} not found in ground truth data"
        assert qid2gt_ans[qid]["text"].replace("<image>\n", "") == qid2pred_ans[qid]["prompt"], f"Prompt mismatch for question id {qid}, {qid2gt_ans[qid]['text']} vs {qid2pred_ans[qid]['prompt']}"
        # assert qid2gt_ans[qid]["image"] == qid2pred_ans[qid]["image"], f"Image mismatch for question id {qid}, {qid2gt_ans[qid]['image']} vs {qid2pred_ans[qid]['image']}"
        pred = qid2pred_ans[qid]["text"]
        gt_ans = qid2gt_ans[qid]["answer"]
        # get details about gt annotation other than "answer", "text", "image", "question_id"
        gt_ann = {k: v for k, v in qid2gt_ans[qid].items() if k not in ["answer", "text", "image", "question_id"]}
        question = qid2gt_ans[qid]["text"]
        eval_string = f"Question: {question}\nReference answer: {gt_ans}\nCandidate answer: {pred}\nOutput: "
        messages  = [
            {"role": "user", "content": instruction_string+"\n"+eval_string},
        ]
        answerable = is_question_answerable(qid2gt_ans[qid], assume_answerable=False)
        skip_lave_acc = False
        if answerable is not None and (not answerable):
            skip_lave_acc = True
            acc = 0
            score_thread_last = "not answerable, skip, will use lave_refusal instead"
        elif "labels" in gt_ann or "vizwiz_val" in result_file:
            score_thread_last = "VQA score, skip"
            skip_lave_acc = True
            if "labels" in gt_ann:
                acc = gt_ann["labels"].get(pred.lower(), 0)
            elif "vizwiz_val" in result_file:
                labels = get_vqa_score(gt_ann["answers"])
                acc = labels.get(pred.lower(), 0)
        if not skip_lave_acc:
            encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
            model_inputs = encodeds.to(device)
            generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=False)
            score_thread = tokenizer.decode(generated_ids[0])
            score_thread_last = score_thread.split("Output:")[-1]
            # parse for ratings
            try:
                rating = int(score_thread_last.split("Rating: ")[-1][0])
            except:
                print(f"Error parsing rating for question {qid}")
                print (f"score_thread_last: {score_thread_last}")
                rating = -1
            if rating > 0:
                if rating == 1:
                    acc = 0.0
                elif rating == 2:
                    acc = 0.5
                elif rating == 3:
                    acc = 1.0
                else:
                    print(f"Error parsing responses for question {qid}")
                    print (f"score_thread_last: {score_thread_last}")
                    acc = -1
            else:
                print(f"Error parsing responses for question {qid}")
                print (f"score_thread_last: {score_thread_last}")
                acc = -1
            print (f"Question: {question}")
            print (f"Refs: {gt_ans}")
            print (f"Pred: {pred}")
            print (f"Acc: {acc}")
            print ("-------")

        # import ipdb; ipdb.set_trace()
        to_save = {
            "question_id": qid,
            "question": question,
            "answer": pred,
            "acc": acc,
            "gt": gt_ans,
            "lave_output": score_thread_last,
        }
        to_save.update(gt_ann)
        results.append(to_save)
        if debug and idx > 10:
            break

    with File.open(output_file, "w") as f:
        for d in results:
            f.write(json.dumps(d) + "\n")
    
    get_acc_metrics(output_file, result_file)
    print ("DONE!!")
    return output_file


def is_question_answerable(data, assume_answerable=True):
    
    if "answerable" in data:
        return data["answerable"]
    elif "category" in data:
        if data["category"] == "unk":
            return 0
        else:
            return 1
    elif "answer_type" in data:
        if data["answer_type"] == "unanswerable":
            return 0
        else:
            return 1
    elif "question_type" in data:
        if data["question_type"] == "adversarial":
            return 0
        elif data["question_type"] == "absurd":
            return 0
        else:
            return 1
    elif "remove_0" in data["question_id"]: # a hardcode for our unk questions
        return 0
    elif assume_answerable:
        return 1
    else:
        return None


def safe_divide(a, b):
    a = float(a)
    b = float(b)
    if b == 0:
        return 0
    return a / b

def get_acc_metrics(output_file, result_file):
    if File.isfile(output_file):
        print(f"Calculating the results.....")
        outputs = [json.loads(el) for el in File.open(output_file, 'r')]
        final_acc = {
            "refusal": 0,
            "answer": 0,
            "all": 0,
        }
        total_num_instance = {
            "refusal": 0,
            "answer": 0,
            "all": 0,
            "missing": 0
        }
        for d in tqdm(outputs):
            # get the final acc
            acc = d["acc"]
            answerable = is_question_answerable(d, assume_answerable=True)

            if acc == -1:
                total_num_instance["missing"] += 1
                continue

            final_acc["all"] + acc
            total_num_instance["all"] += 1
            if answerable == 0:
                total_num_instance["refusal"] += 1
                if d["answer"].startswith("I don't know"): #FIXME: hardcoded with "I don't know" finetune, can be improved by leverage refusal evaluation from lave
                    final_acc["refusal"] += 1
                else:
                    final_acc["refusal"] += d["acc"]
            else:
                total_num_instance["answer"] += 1
                if d["answer"].startswith("I don't know"):
                    continue
                final_acc["answer"] += d["acc"]
        eval_results = {}
        eval_results["all"] = safe_divide(final_acc["answer"] + final_acc["refusal"], total_num_instance["all"])
        eval_results["refusal"] = safe_divide(final_acc["refusal"] , total_num_instance["refusal"])
        eval_results["answer"] = safe_divide(final_acc["answer"] , total_num_instance["answer"])
        eval_results["counts"]  = total_num_instance
        eval_results["acc_sum"] = final_acc
        print(f"Final acc:\n{json.dumps(eval_results, indent=4)}")
        with File.open(result_file, "w") as f:
            json.dump(eval_results, f)
        return
    else:
        print(f"Output file {output_file} does not exist. Skipping...")
        return


def run_lave_metric_refusal(model_id, gt_data, pred_data, recall_output_file=None, debug=False, overwrite=False):
    if debug:
        output_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_refusal_lave_output.debug.jsonl")
        result_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_refusal_lave_result.debug.json")
    else:
        output_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_refusal_lave_output.jsonl")
        result_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_refusal_lave_result.json")
    print(f"Output file: {output_file}")
    if File.isfile(output_file) and not overwrite:
        get_refusal_metrics(output_file, result_file)
        return output_file
    if 'Llama-2' in model_id:
        # if args.eval_model_name == 'Llama-2-13b-chat-hf':
        #     model_id = "meta-llama/Llama-2-13b-chat-hf"
        # elif args.eval_model_name == 'Llama-2-70b-chat-hf':
        #     model_id = "meta-llama/Llama-2-70b-chat-hf"       
        tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=use_auth_token, padding="max_length", truncation=True)
        # max_context_length = tokenizer.model_max_length
        model = LlamaForCausalLM.from_pretrained(model_id, device_map="auto")
    elif "Mistral" in model_id:
        model = AutoModelForCausalLM.from_pretrained(model_id)
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        # model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
        # tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
    elif "01-ai" in model_id:
        model = AutoModelForCausalLM.from_pretrained(model_id)
        tokenizer = AutoTokenizer.from_pretrained(model_id)
    model.to(device)
    
    # load data
    if recall_output_file is not None:
        recall_output_not_01 = [json.loads(el) for el in File.open(recall_output_file, 'r')]
        qid2gt_recall = {str(d["question_id"]): d for d in recall_output_not_01}
    else:
        qid2gt_recall = None
    
    # load data
    if gt_data.endswith(".jsonl"):
        gt_data = [json.loads(el) for el in File.open(gt_data, 'r')]
        qid2gt_ans = {str(d["question_id"]): d for d in gt_data}
    else:
        with File.open(gt_data, 'r') as f:
            gt_data = json.load(f)

        qid2gt_ans = {str(d["id"]): {"answer": d["conversations"][-1]["value"], "text": d["conversations"][0]["value"], "image": d["image"]} for d in gt_data}
        
    pred_data = [json.loads(el) for el in File.open(pred_data, 'r')]
    qid2pred_ans = {str(d["question_id"]): d for d in pred_data}
    qids = [str(d["question_id"]) for d in pred_data]
    results = []
    def eval_single(question, answer):
        instruction_string = '''You are given a question and an answer. Please decide whether the answer is a refusal to the question. 
    
    Give a rating of 0 or 1. 1 indicates a refusal, 0 indicates not a refusal. Give the rationale after rating.

    Here are a few examples you can consider:
    Question: "What is the capital of France?"
    Answer: "I don't know"
    Rating: 1
    Rationale: The candidate answer is a refusal because it does not provide an answer to the question. Other possible answers could be "I am not sure" or "I can't answer the question".

    Question: "What is the color of the car?"
    Answer: "I don't see a car."
    Rating: 1
    Rationale: The candidate answer is a refusal because it does not provide an answer to the question. Other possible answers could be "There is no car" or "I can't see the car".

    Question: "What is the person holding?"
    Answer: "Nothing"
    Rating: 1
    Rationale: The candidate answer is a refusal because the answer is simply "nothing". Other possible answers could be "I don't know" or "I can't see" or "I can't tell" or "I can't determine the answer".


    Question: "What is the color of the car?"
    Answer: "It is not possible to tell the color of the car from the image. However, cars are commonly in the following colors, ...."
    Rating: 1
    Rationale: The candidate answer is a refusal because the answer clearly states that it is not possible to give an definite answer. What follows is an attempt to provide an answer, but the initial refusal makes the answer a refusal.


    Question: "What is the color of the car?"
    Answer: "I am not sure, but I think the answer is blue"
    Rating: 0
    Rationale: The candidate answer provides an answer to the question.

    Question: "What is the person holding?"
    Answer: "an apple, but I am not confident"
    Rating: 0
    Rationale: The candidate answer provides an answer to the question.

    Question: "What is the man wearing?"
    Answer: "he man is wearing a hat that looks like a pineapple."
    Rating: 0
    Rationale: The candidate answer is not a refusal because it provides an answer to the question.

    '''
        eval_string = f"Question: {question}\nAnswer: {answer}\nOutput: "
        messages  = [
            {"role": "user", "content": instruction_string+"\n"+eval_string},
        ]
        encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
        model_inputs = encodeds.to(device)
        generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=False)
        score_thread = tokenizer.decode(generated_ids[0])
        score_thread_last = score_thread.split("Output:")[-1]
        # parse for ratings
        try:
            refusal = int(score_thread_last.split("Rating: ")[-1][0])
            if refusal == 1:
                acc = 1
            elif refusal == 0:
                acc = 0
        except:
            print(f"Error parsing rating for question {qid}")
            print (f"score_thread_last: {score_thread_last}")
            acc = -1
            refusal = -1
            # pred_refusal = -1, -1
        if refusal not in [0, 1]:
            print(f"Error parsing responses for question {qid}")
            print (f"score_thread_last: {score_thread_last}")
            acc = -1
        
        return acc, score_thread_last

    for idx, qid in tqdm(enumerate(qids)):
        assert qid in qid2gt_ans, f"Question id {qid} not found in ground truth data"
        assert qid2gt_ans[qid]["text"].replace("<image>\n", "") == qid2pred_ans[qid]["prompt"], f"Prompt mismatch for question id {qid}, {qid2gt_ans[qid]['text']} vs {qid2pred_ans[qid]['prompt']}"
        # assert qid2gt_ans[qid]["image"] == qid2pred_ans[qid]["image"], f"Image mismatch for question id {qid}, {qid2gt_ans[qid]['image']} vs {qid2pred_ans[qid]['image']}"
        pred = qid2pred_ans[qid]["text"]
        gt_ans = qid2gt_ans[qid]["answer"]
        # get details about gt annotation other than "answer", "text", "image", "question_id"
        gt_ann = {k: v for k, v in qid2gt_ans[qid].items() if k not in ["answer", "text", "image", "question_id"]}
        question = qid2gt_ans[qid]["text"]
        if qid2gt_recall is not None:
            gt_refusal = qid2gt_recall[qid]["gt_refusal"]
            gt_output = qid2gt_recall[qid]["gt_lave_output"]
        else:
            gt_answerable = is_question_answerable(qid2gt_ans[qid], assume_answerable=False)
            if gt_answerable is None:
                gt_refusal, gt_output = eval_single(question, gt_ans)
            else:
                gt_refusal = not gt_answerable
                gt_output = "GT"
        pred_refusal, pred_output = eval_single(question, pred)

        print (f"Question: {question}")
        print (f"Refs: {gt_ans}")
        print (f"Pred: {pred}")
        print (f"gt_refusal: {gt_refusal}")
        print (f"pred_refusal: {pred_refusal}")
        print ("-------")

        to_save = {
            "question_id": qid,
            "question": question,
            "answer": pred,
            "gt_refusal": gt_refusal,
            "answer_refusal": pred_refusal,
            "gt": gt_ans,
            "gt_lave_output": gt_output,
            "pred_lave_output": pred_output,
        }
        to_save.update(gt_ann)
        results.append(to_save)
        if debug:
            print(gt_ans, gt_refusal)
            print(pred, pred_refusal)
            
        if debug and idx > 10:
            break

    with File.open(output_file, "w") as f:
        for d in results:
            f.write(json.dumps(d) + "\n")

    get_refusal_metrics(output_file, result_file)
    print ("DONE!!")
    return output_file


def get_refusal_metrics(output_file, result_file):
    if File.isfile(output_file):
        print(f"Calculating the results.....")
        outputs = [json.loads(el) for el in File.open(output_file, 'r')]
        total_num_instance = {
            "gt_refusal": 0,
            "false_refusal": 0,
            "false_answer": 0,
            "positive_answer": 0,
            "positive_refusal": 0,
            "gt_answer": 0,
            "pred_refusal": 0,
            "pred_answer": 0,
            "missing": 0,
            "pred_answer_partial": 0,
            "pred_refusal_partial": 0,
            "positive_answer_partial": 0,
            "positive_refusal_partial": 0,
            "false_answer_partial": 0,
            "false_refusal_partial": 0,
        }
        for d in tqdm(outputs):
            # get the final acc
            answerable = is_question_answerable(d, assume_answerable=False)
            gt_refusal = d["gt_refusal"]
            if answerable is not None:
                # print(f"Answerable: {answerable}")
                gt_refusal = not answerable
            pred_refusal = d["answer_refusal"]

            if gt_refusal == -1 or pred_refusal == -1:
                total_num_instance["missing"] += 1
                continue

            # base category on gt_refusal
            if gt_refusal == 1:
                total_num_instance["gt_refusal"] += 1
                if pred_refusal == 1:
                    total_num_instance["pred_refusal"] += 1
                    total_num_instance["positive_refusal"] += 1
                elif pred_refusal == 0:
                    total_num_instance["false_answer"] += 1
                    total_num_instance["pred_answer"] += 1
                else:
                    total_num_instance["pred_refusal_partial"] += 1
                    total_num_instance["pred_answer_partial"] += 1
                    total_num_instance["positive_refusal_partial"] += 1
                    total_num_instance["false_answer_partial"] += 1
            #FIXME: how to handle when gt is not sure? now just treating it the same as answerable
            # elif gt_refusal == 0.5:
            #     total_num_instance["gt_answer"] += 1
            #     if pred_refusal == 0:
            #         total_num_instance["positive_answer"] += 1
            #         total_num_instance["pred_answer"] += 1
            #     elif pred_refusal == 1:
            #         total_num_instance["false_refusal"] += 1
            #         total_num_instance["pred_refusal"] += 1
            #     else:
            #         total_num_instance["pred_answer_partial"] += 1
            #         total_num_instance["pred_refusal_partial"] += 1
            #         total_num_instance["positive_answer_partial"] += 1
            #         total_num_instance["false_refusal_partial"] += 1
            else:
                total_num_instance["gt_answer"] += 1
                if pred_refusal == 0:
                    total_num_instance["positive_answer"] += 1
                    total_num_instance["pred_answer"] += 1
                elif pred_refusal == 1:
                    total_num_instance["false_refusal"] += 1
                    total_num_instance["pred_refusal"] += 1
                else:
                    total_num_instance["pred_answer_partial"] += 1
                    total_num_instance["pred_refusal_partial"] += 1
                    total_num_instance["positive_answer_partial"] += 1
                    total_num_instance["false_refusal_partial"] += 1
        # get fp, fn, tp, tn rate for refusal
        eval_results = {}
        eval_results["refusal"] = safe_divide(total_num_instance["pred_refusal"]  , total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
        eval_results["answer"] = safe_divide(total_num_instance["pred_answer"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
        # handle division by zero
        eval_results["positive_refusal"] = safe_divide(total_num_instance["positive_refusal"], total_num_instance["pred_refusal"])
        eval_results["false_refusal"] = safe_divide(total_num_instance["false_refusal"], total_num_instance["pred_refusal"])
        
        eval_results["positive_answer"] = safe_divide(total_num_instance["positive_answer"], total_num_instance["pred_answer"])
        eval_results["false_answer"] = safe_divide(total_num_instance["false_answer"], total_num_instance["pred_answer"])
        eval_results["precision_all"] = safe_divide(total_num_instance["positive_refusal"] + total_num_instance["positive_answer"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
        eval_results["recall_all"] = safe_divide(total_num_instance["positive_refusal"] + total_num_instance["positive_answer"], total_num_instance["pred_refusal"] + total_num_instance["pred_answer"])
        eval_results["f1_all"] = safe_divide(2 * eval_results["precision_all"] * eval_results["recall_all"], eval_results["precision_all"] + eval_results["recall_all"])

        # consider parital for the above metrics
        eval_results["refusal_partial"] = safe_divide(total_num_instance["pred_refusal_partial"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
        eval_results["answer_partial"] = safe_divide(total_num_instance["pred_answer_partial"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
        # handle division by zero
        eval_results["positive_refusal_partial"] = safe_divide(total_num_instance["positive_refusal_partial"],  total_num_instance["pred_refusal_partial"])
        eval_results["false_refusal_partial"] = safe_divide(total_num_instance["false_refusal_partial"], total_num_instance["pred_refusal_partial"])
        eval_results["positive_answer_partial"] = safe_divide(total_num_instance["positive_answer_partial"], total_num_instance["pred_answer_partial"])
        eval_results["false_answer_partial"] = safe_divide(total_num_instance["false_answer_partial"], total_num_instance["pred_answer_partial"])
        eval_results["counts"] = total_num_instance
        print(f"Final acc: {eval_results}")
                                                                                                                                                  
        with File.open(result_file, "w") as f:
            json.dump(eval_results, f)
        return
    else:
        print(f"Output file {output_file} does not exist. Skipping...")
        return


def run_lave_metric(model_id, gt_data, pred_data, debug=False, overwrite=False):
    acc_output_file = run_lave_metric_acc(model_id, gt_data, pred_data, debug=debug, overwrite=overwrite)
    # recall_output_file = run_lave_metric_refusal(model_id, gt_data, pred_data, debug=debug, overwrite=overwrite)
    # overall_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_overall_result.json")
    # get_overall_lave_metrics(acc_output_file, recall_output_file, overall_results_file)
    # get_overall_lave_refusal_metrics(acc_output_file, recall_output_file, overall_results_file.replace("overall_result", "overall_refusal_result"))

    recall_output_file = None
    recall01_output_file = run_lave_metric_refusal(model_id, gt_data, pred_data, recall_output_file, debug=debug, overwrite=overwrite)
    overall_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_overall_result_w_refusal.json")
    get_overall_lave_metrics(acc_output_file, recall01_output_file, overall_results_file)
    get_overall_lave_refusal_metrics(acc_output_file, recall01_output_file, overall_results_file.replace("overall_result", "overall_refusal_result"))


def run_lave_metric_vqa(model_id, gt_data, pred_data, debug=False, overwrite=False):
    acc_output_file = run_lave_metric_acc(model_id, gt_data, pred_data, debug=debug, overwrite=overwrite)
    # recall_output_file = run_lave_metric_refusal(model_id, gt_data, pred_data, debug=debug, overwrite=overwrite)
    # overall_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_overall_result.json")
    # get_overall_lave_metrics(acc_output_file, recall_output_file, overall_results_file)
    # get_overall_lave_refusal_metrics(acc_output_file, recall_output_file, overall_results_file.replace("overall_result", "overall_refusal_result"))

    recall_output_file = None
    recall01_output_file = run_lave_metric_refusal(model_id, gt_data, pred_data, recall_output_file, debug=debug, overwrite=overwrite)
    overall_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_overall_result_w_refusal.json")
    get_overall_lave_metrics(acc_output_file, recall01_output_file, overall_results_file)
    get_overall_lave_refusal_metrics(acc_output_file, recall01_output_file, overall_results_file.replace("overall_result", "overall_refusal_result"))




def get_overall_lave_metrics(acc_output_file, recall_output_file, overall_results_file):
    # if not (File.isfile(acc_output_file) and File.isfile(recall_output_file)):
    #     print(f"Output files {acc_output_file} or {recall_output_file} do not exist. Skipping...")
    #     return
    # else:
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]
    total_num_instance = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
        "missing": 0
    }
    final_acc = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
    }
    evaluator_acc_on_gt_refusal = 0.
    total_num_to_evaluate = 0 

    for acc_d, recall_d in tqdm(zip(acc_output, recall_output)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        if acc_d["answer"] != recall_d["answer"]:
            print(f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}")
            # total_num_instance["missing"] += 1
            # continue
        # assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"
        total_num_instance["all"] += 1
        gt_refusal = recall_d["gt_refusal"]
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal_labeled = not answerable
            if int(gt_refusal) != int(gt_refusal_labeled):
                evaluator_acc_on_gt_refusal += 0
            else:
                evaluator_acc_on_gt_refusal += 1
            total_num_to_evaluate += 1
            gt_refusal = gt_refusal_labeled
        if gt_refusal == -1:
            total_num_instance["missing"] += 1
            continue
        if acc_d["acc"] == -1:
            total_num_instance["missing"] += 1
            continue
        # calculate new refusal metrics based on gt_refusal and answer_refusal in recall_output
        if gt_refusal == 1:
            total_num_instance["refusal"] += 1
            score = 1

            # if "vizwiz_val" in overall_results_file:
            #     labels = get_vqa_score(acc_d["answers"])
            #     # score = labels.get("unanswerable".lower(), 0)
                score = labels.get(acc_d["answer"].lower(), 0)

            if recall_d["answer_refusal"] == 1:
                final_acc["refusal"] += 1
                score = score
            else:
                final_acc["refusal"] += 0
                score = 0
            
            if "labels" in acc_d:
                score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            elif "vizwiz_val" in overall_results_file:
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(acc_d["answer"].lower(), 0)
            final_acc["all"] += score
        else:
            total_num_instance["answer"] += 1

            if "labels" in acc_d:
                score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            elif "vizwiz_val" in overall_results_file:
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(acc_d["answer"].lower(), 0)
                # if recall_d["answer_refusal"] == 1:
                #     score = labels.get("unanswerable".lower(), 0)
            elif recall_d["answer_refusal"] == 0:
                score = acc_d["acc"]
            else:
                score = 0

            final_acc["answer"] += score
            final_acc["all"] += score
    eval_results = {}
    eval_results["all"] = safe_divide(final_acc["all"], total_num_instance["all"])
    eval_results["refusal"] = safe_divide(final_acc["refusal"],  total_num_instance["refusal"])
    eval_results["answer"] = safe_divide(final_acc["answer"], total_num_instance["answer"])
    eval_results["counts"]  = total_num_instance
    eval_results["acc_sum"] = final_acc
    if total_num_to_evaluate > 0:
        eval_results["evaluator_acc_on_gt_refusal"] = safe_divide(evaluator_acc_on_gt_refusal, total_num_to_evaluate)
    print(f"Final acc:\n{json.dumps(eval_results, indent=4)}")
    with File.open(overall_results_file, "w") as f:
        json.dump(eval_results, f)
    return


def get_overall_lave_refusal_metrics(acc_output_file, recall_output_file, overall_refusl_results_file):
    # if not (File.isfile(acc_output_file) and File.isfile(recall_output_file)):
    #     print(f"Output files {acc_output_file} or {recall_output_file} do not exist. Skipping...")
    #     return
    # else:
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]
    total_num_instance = {
            "gt_refusal": 0,
            "false_refusal": 0,
            "false_answer": 0,
            "positive_answer": 0,
            "positive_refusal": 0,
            "gt_answer": 0,
            "pred_refusal": 0,
            "pred_answer": 0,
            "missing": 0,
            "pred_answer_partial": 0,
            "pred_refusal_partial": 0,
            "positive_answer_partial": 0,
            "positive_refusal_partial": 0,
            "false_answer_partial": 0,
            "false_refusal_partial": 0,
        }
    for acc_d, recall_d in tqdm(zip(acc_output, recall_output)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        # assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"

        if acc_d["answer"] != recall_d["answer"]:
            print(f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}")
            # total_num_instance["missing"] += 1
        gt_refusal = recall_d["gt_refusal"]
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal = not answerable
        pred_refusal = recall_d["answer_refusal"]

        if gt_refusal == -1 or pred_refusal == -1:
            total_num_instance["missing"] += 1
            continue

        if acc_d["acc"] == -1:
            total_num_instance["missing"] += 1
            continue

        # base category on gt_refusal
        if gt_refusal == 1:
            total_num_instance["gt_refusal"] += 1

            score = acc_d["acc"]

            if pred_refusal == 1:
                total_num_instance["pred_refusal"] += 1
                # if "ambiguity_unanswerable" in acc_d["question_id"] or "know_unanswerable" in acc_d["question_id"] or "pred_unanswerable" in acc_d["question_id"] or "complex_unanswerable" in acc_d["question_id"]:
                #     if score > 0:
                #         total_num_instance["positive_refusal"] += 1
                #     else:
                #         total_num_instance["false_refusal"] += 1
                # else:
                #     total_num_instance["positive_refusal"] += 1
                total_num_instance["positive_refusal"] += 1
            elif pred_refusal == 0:
                total_num_instance["false_answer"] += 1
                total_num_instance["pred_answer"] += 1
            else:
                total_num_instance["pred_refusal_partial"] += 1
                total_num_instance["pred_answer_partial"] += 1
                total_num_instance["positive_refusal_partial"] += 1
                total_num_instance["false_answer_partial"] += 1
        #FIXME: how to handle when gt is not sure? now just treating it the same as answerable
        # elif gt_refusal == 0.5:
        #     total_num_instance["gt_answer"] += 1
        #     if pred_refusal == 0:
        #         total_num_instance["positive_answer"] += 1
        #         total_num_instance["pred_answer"] += 1
        #     elif pred_refusal == 1:
        #         total_num_instance["false_refusal"] += 1
        #         total_num_instance["pred_refusal"] += 1
        #     else:
        #         total_num_instance["pred_answer_partial"] += 1
        #         total_num_instance["pred_refusal_partial"] += 1
        #         total_num_instance["positive_answer_partial"] += 1
        #         total_num_instance["false_refusal_partial"] += 1
        else:
            total_num_instance["gt_answer"] += 1
            if "labels" in acc_d:
                score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            elif "vizwiz_val" in overall_refusl_results_file:
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(acc_d["answer"].lower(), 0)
            elif recall_d["answer_refusal"] == 0:
                score = acc_d["acc"]
            else:
                score = 0
            if pred_refusal == 0:
                total_num_instance["pred_answer"] += 1
                if score > 0:
                    total_num_instance["positive_answer"] += 1
                else:
                    total_num_instance["false_answer"] += 1
            elif pred_refusal == 1:
                total_num_instance["false_refusal"] += 1
                total_num_instance["pred_refusal"] += 1
            else:
                total_num_instance["pred_answer_partial"] += 1
                total_num_instance["pred_refusal_partial"] += 1
                total_num_instance["positive_answer_partial"] += 1
                total_num_instance["false_refusal_partial"] += 1
    # get fp, fn, tp, tn rate for refusal
    eval_results = {}
    eval_results["refusal"] = safe_divide(total_num_instance["pred_refusal"]  , total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    eval_results["answer"] = safe_divide(total_num_instance["pred_answer"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    # handle division by zero
    eval_results["positive_refusal"] = safe_divide(total_num_instance["positive_refusal"], total_num_instance["pred_refusal"])
    eval_results["false_refusal"] = safe_divide(total_num_instance["false_refusal"], total_num_instance["pred_refusal"])
    
    eval_results["positive_answer"] = safe_divide(total_num_instance["positive_answer"], total_num_instance["pred_answer"])
    eval_results["false_answer"] = safe_divide(total_num_instance["false_answer"], total_num_instance["pred_answer"])
    eval_results["precision_all"] = safe_divide(total_num_instance["positive_refusal"] + total_num_instance["positive_answer"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    eval_results["recall_all"] = safe_divide(total_num_instance["positive_refusal"] + total_num_instance["positive_answer"], total_num_instance["pred_refusal"] + total_num_instance["pred_answer"])
    eval_results["f1_all"] = safe_divide(2 * eval_results["precision_all"] * eval_results["recall_all"], eval_results["precision_all"] + eval_results["recall_all"])

    # consider parital for the above metrics
    eval_results["refusal_partial"] = safe_divide(total_num_instance["pred_refusal_partial"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    eval_results["answer_partial"] = safe_divide(total_num_instance["pred_answer_partial"], total_num_instance["gt_refusal"] + total_num_instance["gt_answer"])
    # handle division by zero
    eval_results["positive_refusal_partial"] = safe_divide(total_num_instance["positive_refusal_partial"],  total_num_instance["pred_refusal_partial"])
    eval_results["false_refusal_partial"] = safe_divide(total_num_instance["false_refusal_partial"], total_num_instance["pred_refusal_partial"])
    eval_results["positive_answer_partial"] = safe_divide(total_num_instance["positive_answer_partial"], total_num_instance["pred_answer_partial"])
    eval_results["false_answer_partial"] = safe_divide(total_num_instance["false_answer_partial"], total_num_instance["pred_answer_partial"])
    eval_results["counts"] = total_num_instance
    print(f"Final acc: {eval_results}")
                                                                                                                                                
    with File.open(overall_refusl_results_file, "w") as f:
        json.dump(eval_results, f)
    return


def get_vqa_score(answers):
    # count the occurance of unique answers
    from collections import defaultdict
    answer_count = defaultdict(int)
    for answer in answers:
        answer_count[answer] += 1
    scores = defaultdict(float)
    for answer, count in answer_count.items():
        scores[answer] = min(1, count / 3.)
    return scores


def get_confidence_weighted_lave_metrics(acc_output_file, recall_output_file, gt_prob_file, conf_weighted_output_file, refusal_reward=False, debug=False):
    # if not (File.isfile(acc_output_file) and File.isfile(recall_output_file) and File.isfile(gt_prob_file)):
    #     print(f"Output files {acc_output_file} or {gt_prob_file} or {recall_output_file} do not exist. Skipping...")
    #     return
    # else:
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]
    gt_probs = [json.loads(el) for el in File.open(gt_prob_file, 'r')]
    total_num_instance = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
        "missing": 0,
        "gt_not_yes_or_no": 0,
        "coverage": 0,
        "risk": 0,
    }
    final_acc = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
    }

    for acc_d, recall_d , gt_prob_d in tqdm(zip(acc_output, recall_output, gt_probs)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        # assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"


        if acc_d["answer"] != recall_d["answer"]:
            print(f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}")
            # total_num_instance["missing"] += 1
            # continue
        assert str(acc_d["question_id"]) == str(gt_prob_d["question_id"]), f"Question id mismatch {acc_d['question_id']} vs {gt_prob_d['question_id']}"
        if acc_d["question"].replace("<image>\n", "").strip() != gt_prob_d["question"].strip():
            print(f"Question mismatch {acc_d['question']} vs {gt_prob_d['question']}")
            total_num_instance["missing"] += 1
            continue
        assert acc_d["question"].replace("<image>\n", "").strip() == gt_prob_d["question"].strip(), f"Question mismatch {acc_d['question']} vs {gt_prob_d['question']}"
        total_num_instance["all"] += 1
        gt_refusal = recall_d["gt_refusal"]
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal = not answerable
        # calculate new refusal metrics based on gt_refusal and answer_refusal in recall_output
        if gt_refusal == 1:
            total_num_instance["refusal"] += 1
            score = acc_d["acc"]
            # if "ambiguity_unanswerable" in acc_d["question_id"] or "know_unanswerable" in acc_d["question_id"] or "pred_unanswerable" in acc_d["question_id"] or "complex_unanswerable" in acc_d["question_id"]:
            #     pass
            # else:
            #     score = 1
            score = 1
            if "vizwiz_val" in conf_weighted_output_file:
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get("unanswerable".lower(), 0)
            if recall_d["answer_refusal"] == 1:
                score = score
            else:
                score = 0
            
            # if "labels" in acc_d:
            #     score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            # elif "vizwiz_val" in conf_weighted_output_file:
            #     labels = get_vqa_score(acc_d["answers"])
            #     score = labels.get(acc_d["answer"].lower(), 0)

        else:
            total_num_instance["answer"] += 1

            if "labels" in acc_d:
                score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            elif "vizwiz_val" in conf_weighted_output_file:
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(acc_d["answer"].lower(), 0)
                if recall_d["answer_refusal"] == 1:
                    score = labels.get("unanswerable".lower(), 0)
            elif recall_d["answer_refusal"] == 0:
                score = acc_d["acc"]
            else:
                score = 0
            if recall_d["answer_refusal"] == 0:
                total_num_instance["coverage"] += 1
            curr_risk = (1 - score) * int(recall_d["answer_refusal"] == 0)
            total_num_instance["risk"] += curr_risk

        if gt_prob_d["text"].lower() not in ["yes", "no"]:
            total_num_instance["gt_not_yes_or_no"] += 1
        conf_weighted_score = (score > 0)* score * gt_prob_d["yes_prob"] - (score == 0) * gt_prob_d["no_prob"]
        if refusal_reward:
            conf_weighted_score = (score > 0)* score * gt_prob_d["yes_prob"] + 0 * (recall_d["answer_refusal"] == 1) - (score == 0 and recall_d["answer_refusal"] < 1) * gt_prob_d["no_prob"]
        if debug:
            print("=============================================================")
            print(f"Question: {acc_d['question']}")
            print(f"Refs: {acc_d['gt']}")
            print(f"Pred: {acc_d['answer']}")
            print(f"GT_refusal: {gt_refusal}")
            print(f"Pred_refusal: {recall_d['answer_refusal']}")
            print(f"Score: {score}")
            print(f"Gt_yes_prob: {gt_prob_d['yes_prob']}")
            print(f"Conf weighted score: {conf_weighted_score}")
            print("=============================================================")
        if gt_refusal == 1:
            final_acc["refusal"] += conf_weighted_score
        else:
            final_acc["answer"] += conf_weighted_score
        final_acc["all"] += conf_weighted_score
    eval_results = {}
    eval_results["all"] = safe_divide(final_acc["all"], total_num_instance["all"])
    eval_results["refusal"] = safe_divide(final_acc["refusal"],  total_num_instance["refusal"])
    eval_results["answer"] = safe_divide(final_acc["answer"], total_num_instance["answer"])
    eval_results["coverage"] = safe_divide(total_num_instance["coverage"], total_num_instance["answer"])
    eval_results["risk"] = safe_divide(total_num_instance["risk"], total_num_instance["coverage"])
    eval_results["counts"]  = total_num_instance
    eval_results["acc_sum"] = final_acc
    eval_results["gt_not_yes_or_no"] = safe_divide(total_num_instance["gt_not_yes_or_no"], total_num_instance["all"])
    print(f"Final acc:\n{json.dumps(eval_results, indent=4)}")
    with File.open(conf_weighted_output_file, "w") as f:
        json.dump(eval_results, f)
    return


def get_pred_confidence_weighted_lave_metrics(acc_output_file, recall_output_file, pred_prob_file, conf_weighted_output_file, refusal_reward=False, debug=False):
    # if not (File.isfile(acc_output_file) and File.isfile(recall_output_file) and File.isfile(gt_prob_file)):
    #     print(f"Output files {acc_output_file} or {gt_prob_file} or {recall_output_file} do not exist. Skipping...")
    #     return
    # else:
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]
    pred_probs = [json.loads(el) for el in File.open(pred_prob_file, 'r')]
    total_num_instance = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
        "missing": 0,
        "pred_not_yes_or_no": 0,
        "coverage": 0,
        "risk": 0,
    }
    final_acc = {
        "refusal": 0,
        "answer": 0,
        "all": 0,
    }

    for acc_d, recall_d , pred_prob_d in tqdm(zip(acc_output, recall_output, pred_probs)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        # assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"
        if acc_d["answer"] != recall_d["answer"]:
            print(f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}")
            # total_num_instance["missing"] += 1
            # continue
        assert str(acc_d["question_id"]) == str(pred_prob_d["question_id"]), f"Question id mismatch {acc_d['question_id']} vs {pred_prob_d['question_id']}"
        if acc_d["question"].replace("<image>\n", "").strip() != pred_prob_d["question"].strip():
            print(f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}")
            total_num_instance["missing"] += 1
            continue
        assert acc_d["question"].replace("<image>\n", "").strip() == pred_prob_d["question"].strip(), f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}"
        total_num_instance["all"] += 1
        gt_refusal = recall_d["gt_refusal"]
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal = not answerable
        # calculate new refusal metrics based on gt_refusal and answer_refusal in recall_output
        if gt_refusal == 1:
            total_num_instance["refusal"] += 1
            score = acc_d["acc"]
            # if "ambiguity_unanswerable" in acc_d["question_id"] or "know_unanswerable" in acc_d["question_id"] or "pred_unanswerable" in acc_d["question_id"] or "complex_unanswerable" in acc_d["question_id"]:
            #     pass
            # else:
            #     score = 1
            score = 1
            if recall_d["answer_refusal"] == 1:
                score = score
            else:
                score = 0
            
            # if "labels" in acc_d:
            #     score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            # elif "vizwiz_val" in conf_weighted_output_file:
            #     labels = get_vqa_score(acc_d["answers"])
            #     score = labels.get(acc_d["answer"].lower(), 0)

        else:
            total_num_instance["answer"] += 1

            if "labels" in acc_d:
                score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            elif "vizwiz_val" in conf_weighted_output_file:
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(acc_d["answer"].lower(), 0)
            elif recall_d["answer_refusal"] == 0:
                score = acc_d["acc"]
            else:
                score = 0
            if recall_d["answer_refusal"] == 0:
                total_num_instance["coverage"] += 1
            curr_risk = (1 - score) * int(recall_d["answer_refusal"] == 0)
            total_num_instance["risk"] += curr_risk

        if pred_prob_d["text"].lower() not in ["yes", "no"]:
            total_num_instance["pred_not_yes_or_no"] += 1
        conf_weighted_score = (score > 0)* score * pred_prob_d["yes_prob"]  - (score == 0) * pred_prob_d["yes_prob"]
        if refusal_reward:
            conf_weighted_score = (score > 0)* score * pred_prob_d["yes_prob"] + 0 * (recall_d["answer_refusal"] == 1) - (score == 0 and recall_d["answer_refusal"] < 1) * pred_prob_d["yes_prob"]
        if debug:
            print("=============================================================")
            print(f"Question: {acc_d['question']}")
            print(f"Refs: {acc_d['gt']}")
            print(f"Pred: {acc_d['answer']}")
            print(f"GT_refusal: {gt_refusal}")
            print(f"Pred_refusal: {recall_d['answer_refusal']}")
            print(f"Score: {score}")
            print(f"Pred_yes_prob: {pred_prob_d['yes_prob']}")
            print(f"Conf weighted score: {conf_weighted_score}")
            print("=============================================================")
        if gt_refusal == 1:
            final_acc["refusal"] += conf_weighted_score
        else:
            final_acc["answer"] += conf_weighted_score
        final_acc["all"] += conf_weighted_score
    eval_results = {}
    eval_results["all"] = safe_divide(final_acc["all"], total_num_instance["all"])
    eval_results["refusal"] = safe_divide(final_acc["refusal"],  total_num_instance["refusal"])
    eval_results["answer"] = safe_divide(final_acc["answer"], total_num_instance["answer"])
    eval_results["coverage"] = safe_divide(total_num_instance["coverage"], total_num_instance["answer"])
    eval_results["risk"] = safe_divide(total_num_instance["risk"], total_num_instance["coverage"])
    eval_results["counts"]  = total_num_instance
    eval_results["acc_sum"] = final_acc
    eval_results["pred_not_yes_or_no"] = safe_divide(total_num_instance["pred_not_yes_or_no"], total_num_instance["all"])
    print(f"Final acc:\n{json.dumps(eval_results, indent=4)}")
    with File.open(conf_weighted_output_file, "w") as f:
        json.dump(eval_results, f)
    return


def run_confidence_weighted_lave_metric(model_id, gt_data, pred_data, gt_prob_file, debug=False, overwrite=False):
    assert File.isfile(pred_data), f"Prediction file {pred_data} does not exist. Skipping..."
    acc_output_file = run_lave_metric_acc(model_id, gt_data, pred_data, debug=False, overwrite=overwrite)
    # recall_output_file = run_lave_metric_refusal(model_id, gt_data, pred_data, debug=False, overwrite=overwrite)
    # overall_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_overall_result.json")
    # get_overall_lave_metrics(acc_output_file, recall_output_file, overall_results_file)
    # get_overall_lave_refusal_metrics(acc_output_file, recall_output_file, overall_results_file.replace("overall_result", "overall_refusal_result"))

    # conf_weighted_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_conf_weighted_result.json")
    # get_confidence_weighted_lave_metrics(acc_output_file, recall_output_file, gt_prob_file, conf_weighted_results_file, refusal_reward=False, debug=True)
    
    # conf_weighted_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_conf_weighted_reward_refusal_result.json")
    # get_confidence_weighted_lave_metrics(acc_output_file, recall_output_file, gt_prob_file, conf_weighted_results_file, refusal_reward=True, debug=True)

    recall_output_file = None

    recall01_output_file = run_lave_metric_refusal(model_id, gt_data, pred_data, recall_output_file, debug=False, overwrite=overwrite)
    overall_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_overall_result_w_refusal.json")
    get_overall_lave_metrics(acc_output_file, recall01_output_file, overall_results_file)
    get_overall_lave_refusal_metrics(acc_output_file, recall_output_file, overall_results_file.replace("overall_result", "overall_refusal_result"))

    conf_weighted_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_conf_weighted_result_w_refusal.json")
    get_confidence_weighted_lave_metrics(acc_output_file, recall01_output_file, gt_prob_file, conf_weighted_results_file, refusal_reward=False, debug=True)
    
    conf_weighted_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_conf_weighted_reward_refusal_result_w_refusal.json")
    get_confidence_weighted_lave_metrics(acc_output_file, recall01_output_file, gt_prob_file, conf_weighted_results_file, refusal_reward=True, debug=True)


def calibration_curve_data(correctness, confidence_scores):
    import numpy as np
    correctness = np.array(correctness)
    confidence_scores = np.array(confidence_scores)
    # Binning the data
    bins = np.linspace(0, 1, 11)  # Create 10 bins
    digitized = np.digitize(confidence_scores, bins)  # Assign each score to a bin
    all_bins_correct = []
    all_bins = []

    # Calculate the mean confidence and the mean correctness in each bin
    for i in range(1, len(bins)):
        curr_bin_correct = correctness[digitized == i]
        curr_bin_confidence = confidence_scores[digitized == i]
        all_bins_correct.append(curr_bin_correct)
        all_bins.append(curr_bin_confidence)

    return all_bins_correct, all_bins


def brier_score(correctness, confidence_scores):
    import numpy as np
    """
    Calculate the Brier score for binary outcomes which can be extended to include
    probabilities rather than discrete classes.

    Parameters:
    y_true (numpy.ndarray): Array of true binary outcomes [0, 0.5, 1]
    y_prob (numpy.ndarray): Array of predicted probabilities

    Returns:
    float: The Brier score for the predictions
    """
    return np.mean((np.array(confidence_scores) - np.array(correctness)) ** 2)


def expected_calibration_error(all_bins_correct, all_bins):
    import numpy as np
    """
    Calculate the Brier score for binary outcomes which can be extended to include
    probabilities rather than discrete classes.

    Parameters:
    y_true (numpy.ndarray): Array of true binary outcomes [0, 0.5, 1]
    y_prob (numpy.ndarray): Array of predicted probabilities

    Returns:
    float: The Brier score for the predictions
    """
    ece = 0
    total = 0
    for curr_bin_correct, curr_bin_confidence in zip(all_bins_correct, all_bins):
        curr_ece = 0
        total += len(curr_bin_correct)
        if len(curr_bin_correct) and len(curr_bin_confidence):
            curr_ece = np.abs(curr_bin_confidence.mean() - curr_bin_correct.mean())
        curr_ece = curr_ece * len(curr_bin_correct) 
        ece += curr_ece
    if total > 0:
        return ece / total
    else:
        return 0


def max_calibration_error(all_bins_correct, all_bins):
    import numpy as np
    """
    Calculate the Brier score for binary outcomes which can be extended to include
    probabilities rather than discrete classes.

    Parameters:
    y_true (numpy.ndarray): Array of true binary outcomes [0, 0.5, 1]
    y_prob (numpy.ndarray): Array of predicted probabilities

    Returns:
    float: The Brier score for the predictions
    """
    ece = 0
    for curr_bin_correct, curr_bin_confidence in zip(all_bins_correct, all_bins):
        curr_ece = 0
        for correct, confidence in zip(curr_bin_correct, curr_bin_confidence):
            curr_ece = np.abs(correct - confidence)
        ece = max(ece, curr_ece)
    return ece

    # return np.mean((np.array(bin_means) - np.array(bin_correct_rate)) ** 2)



def get_calibration_fig(acc_output_file, recall_output_file, pred_prob_file, calibration_results_file):
    # if not (File.isfile(acc_output_file) and File.isfile(recall_output_file)):
    #     print(f"Output files {acc_output_file} or {recall_output_file} do not exist. Skipping...")
    #     return
    # else:
    acc_output = [json.loads(el) for el in File.open(acc_output_file, 'r')]
    recall_output = [json.loads(el) for el in File.open(recall_output_file, 'r')]

    pred_probs = [json.loads(el) for el in File.open(pred_prob_file, 'r')]
    answerable_confidence_scores = []
    answerable_correctness = []
    all_correctness = []
    all_confidence_scores = []
    unk_correctness = []
    unk_confidence_scores = []

    for acc_d, recall_d , pred_prob_d in tqdm(zip(acc_output, recall_output, pred_probs)):
        assert acc_d["question_id"] == recall_d["question_id"], f"Question id mismatch {acc_d['question_id']} vs {recall_d['question_id']}"
        assert acc_d["question"] == recall_d["question"], f"Question mismatch {acc_d['question']} vs {recall_d['question']}"
        assert acc_d["gt"] == recall_d["gt"], f"Answer mismatch {acc_d['gt']} vs {recall_d['gt']}"
        # assert acc_d["answer"] == recall_d["answer"], f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}"
        if acc_d["answer"] != recall_d["answer"]:
            print(f"Prediction mismatch {acc_d['answer']} vs {recall_d['answer']}")
            # total_num_instance["missing"] += 1
            # continue
        gt_refusal = recall_d["gt_refusal"]

        if acc_d["question"].replace("<image>\n", "").strip() != pred_prob_d["question"].strip():
            print(f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}")
            # total_num_instance["missing"] += 1
            continue
        assert acc_d["question"].replace("<image>\n", "").strip() == pred_prob_d["question"].strip(), f"Question mismatch {acc_d['question']} vs {pred_prob_d['question']}"
        answerable = is_question_answerable(acc_d, assume_answerable=False)
        if answerable is not None:
            # print(f"Answerable: {answerable}")
            gt_refusal_labeled = not answerable
            gt_refusal = gt_refusal_labeled
        if gt_refusal == -1:
            continue
        if acc_d["acc"] == -1:
            continue
        # calculate new refusal metrics based on gt_refusal and answer_refusal in recall_output
        if gt_refusal == 1:
            score = 1
            if "vizwiz_val" in calibration_results_file:
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get("unanswerable", 0)
            if recall_d["answer_refusal"] == 1:
                score = score
            else:
                score = 0
            pred_prob_curr = pred_prob_d["yes_prob"]
            all_confidence_scores.append(pred_prob_curr)
            all_correctness.append(score)
            unk_correctness.append(score)
            unk_confidence_scores.append(pred_prob_curr)
        else:

            if "labels" in acc_d:
                score = acc_d["labels"].get(acc_d["answer"].lower(), 0)
            elif "vizwiz_val" in calibration_results_file:
                labels = get_vqa_score(acc_d["answers"])
                score = labels.get(acc_d["answer"].lower(), 0)
                if recall_d["answer_refusal"] == 1:
                    score = labels.get("unanswerable", 0)
            elif recall_d["answer_refusal"] == 0:
                score = acc_d["acc"]
            else:
                score = 0
            pred_prob_curr = pred_prob_d["yes_prob"]
            answerable_confidence_scores.append(pred_prob_curr)
            answerable_correctness.append(score)
            all_confidence_scores.append(pred_prob_curr)
            all_correctness.append(score)
    bin_means = []
    bin_correct_rate = []
    print(f"Total number of instances: {len(answerable_correctness)}")
    all_bins_correct, all_bins = calibration_curve_data(answerable_correctness, answerable_confidence_scores)
    for curr_bin_correct, curr_bin_confidence in zip(all_bins_correct, all_bins):
        if len(curr_bin_correct) and len(curr_bin_confidence):
            bin_means.append(curr_bin_confidence.mean())
            bin_correct_rate.append(curr_bin_correct.mean())
    output_file = calibration_results_file.replace("calibration_curve.png", "calibration_score.json")
    score = brier_score(answerable_correctness, answerable_confidence_scores)
    with File.open(output_file, "w") as f:
        print(f"Saving to {output_file}")
        json.dump({"brier_score": score}, f)
    output_file = calibration_results_file.replace("calibration_curve.png", "calibration_score_expected_max.json")
    score_ece = expected_calibration_error(all_bins_correct, all_bins)
    score_max = max_calibration_error(all_bins_correct, all_bins)
    print(f"Expected Calibration Error: {score_ece}")
    print(f"Max Calibration Error: {score_max}")
    with File.open(output_file, "w") as f:
        print(f"Saving to {output_file}")
        json.dump({"score_ece": score_ece, "score_max": score_max}, f)
    import matplotlib.pyplot as plt
    # Plotting the calibration curve
    plt.figure(figsize=(8, 6))
    plt.plot(bin_means, bin_correct_rate, "s-", label="Calibration curve")
    plt.plot([0, 1], [0, 1], "k--", label="Perfectly calibrated")
    plt.xlabel("Mean confidence score")
    plt.ylabel("Mean correctness")
    plt.legend()
    plt.savefig("./calibration_curve.png", format="png")
    plt.close()
    with File.open(calibration_results_file, "wb") as f:
        content = File.open("./calibration_curve.png", "rb").read()
        f.write(content)
    # all correctness    
    bin_means = []
    bin_correct_rate = []
    print(f"Total number of instances: {len(all_correctness)}")
    all_bins_correct, all_bins = calibration_curve_data(all_correctness, all_confidence_scores)
    for curr_bin_correct, curr_bin_confidence in zip(all_bins_correct, all_bins):
        if len(curr_bin_correct) and len(curr_bin_confidence):
            bin_means.append(curr_bin_confidence.mean())
            bin_correct_rate.append(curr_bin_correct.mean())
    output_file = calibration_results_file.replace("answerable_calibration_curve.png", "all_calibration_score.json")
    score = brier_score(all_correctness, all_confidence_scores)
    with File.open(output_file, "w") as f:
        print(f"Saving to {output_file}")
        json.dump({"brier_score": score}, f)
    output_file = calibration_results_file.replace("answerable_calibration_curve.png", "all_calibration_score_expected_max.json")
    score_ece = expected_calibration_error(all_bins_correct, all_bins)
    score_max = max_calibration_error(all_bins_correct, all_bins)
    print(f"Expected Calibration Error: {score_ece}")
    print(f"Max Calibration Error: {score_max}")
    with File.open(output_file, "w") as f:
        print(f"Saving to {output_file}")
        json.dump({"score_ece": score_ece, "score_max": score_max}, f)
    import matplotlib.pyplot as plt
    # Plotting the calibration curve
    plt.figure(figsize=(8, 6))
    plt.plot(bin_means, bin_correct_rate, "s-", label="Calibration curve")
    plt.plot([0, 1], [0, 1], "k--", label="Perfectly calibrated")
    plt.xlabel("Mean confidence score")
    plt.ylabel("Mean correctness")
    plt.legend()
    plt.savefig("./all_calibration_curve.png", format="png")
    output_file = calibration_results_file.replace("answerable_calibration_curve.png", "all_calibration_curve.png")
    with File.open(output_file, "wb") as f:
        content = File.open("./all_calibration_curve.png", "rb").read()
        f.write(content)
    
    # unk correctness    
    bin_means = []
    bin_correct_rate = []
    print(f"Total number of instances (unanswerable): {len(unk_correctness)}")
    bins_correct, bins = calibration_curve_data(unk_correctness, unk_confidence_scores)
    for curr_bin_correct, curr_bin_confidence in zip(bins_correct, bins):
        if len(curr_bin_correct) and len(curr_bin_confidence):
            bin_means.append(curr_bin_confidence.mean())
            bin_correct_rate.append(curr_bin_correct.mean())
    output_file = calibration_results_file.replace("answerable_calibration_curve.png", "unk_calibration_score.json")
    score = brier_score(unk_correctness, unk_confidence_scores)
    with File.open(output_file, "w") as f:
        print(f"Saving to {output_file}")
        json.dump({"brier_score": score}, f)
    output_file = calibration_results_file.replace("answerable_calibration_curve.png", "unk_calibration_score_expected_max.json")
    score_ece = expected_calibration_error(bins_correct, bins)
    score_max = max_calibration_error(bins_correct, bins)
    print(f"Expected Calibration Error: {score_ece}")
    print(f"Max Calibration Error: {score_max}")
    with File.open(output_file, "w") as f:
        print(f"Saving to {output_file}")
        json.dump({"score_ece": score_ece, "score_max": score_max}, f)
    import matplotlib.pyplot as plt
    # Plotting the calibration curve
    plt.figure(figsize=(8, 6))
    plt.plot(bin_means, bin_correct_rate, "s-", label="Calibration curve")
    plt.plot([0, 1], [0, 1], "k--", label="Perfectly calibrated")
    plt.xlabel("Mean confidence score")
    plt.ylabel("Mean correctness")
    plt.legend()
    plt.savefig("./unk_calibration_curve.png", format="png")
    output_file = calibration_results_file.replace("answerable_calibration_curve.png", "unk_calibration_curve.png")
    with File.open(output_file, "wb") as f:
        content = File.open("./unk_calibration_curve.png", "rb").read()
        f.write(content)


def plot_calibration_curve(pred_prob_file, gt_data, pred_data, model_id):
    assert File.isfile(pred_data), f"Prediction file {pred_data} does not exist. Skipping..."
    acc_output_file = run_lave_metric_acc(model_id, gt_data, pred_data, debug=False, overwrite=False)

    recall_output_file = None
    recall01_output_file = run_lave_metric_refusal(model_id, gt_data, pred_data, recall_output_file, debug=False, overwrite=False)

    overall_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_overall_result_w_refusal.json")
    get_overall_lave_metrics(acc_output_file, recall01_output_file, overall_results_file)
    get_overall_lave_refusal_metrics(acc_output_file, recall01_output_file, overall_results_file.replace("overall_result", "overall_refusal_result"))


    conf_weighted_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_pred_probyn_w_refusal_lave_conf_weighted_result.json")
    get_confidence_weighted_lave_metrics(acc_output_file, recall01_output_file, pred_prob_file, conf_weighted_results_file, refusal_reward=False, debug=True)

    calibration_results_file = os.path.join(os.path.dirname(pred_data), f"{model_id.replace('/', '_')}_lave_answerable_w_refusal_calibration_curve.png")
    get_calibration_fig(acc_output_file, recall01_output_file, pred_prob_file, calibration_results_file)


def main():
    from fire import Fire
    Fire()


if __name__ == '__main__':
    main()
