import os
import time
import json
import torch
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter
from typing import List, Dict
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from utils import read_json, write_json, get_keywords, get_alphabet_choice, remove_boxed, last_boxed_only_string, is_math_equiv
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
# from agent import *
import contextlib
import gc
import os
import subprocess
import time
import re
from math import prod
from sklearn.linear_model import LinearRegression, LogisticRegression
# from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel
# import ray, gc, torch

# for sample, curr_answers in zip(train_samples, all_answers):
            # agg_prompt = (f"You have been provided with a set of responses from various open-source models to the latest user query. "
            #               f"Your task is to synthesize these responses into a single, high-quality response. "
            #               f"It is crucial to critically evaluate the information provided in these responses, "
            #               f"recognizing that some of it may be biased or incorrect. "
            #               f"Your response should not simply replicate the given answers but should offer a refined, "
            #               f"accurate, and comprehensive reply to the instruction. "
            #               f"Ensure your response is well-structured, coherent, and adheres" 
            #               f"to the highest standards of accuracy and reliability. "
            #               f"Responses from models:\n\n")
            # for idx, ans in enumerate(curr_answers[:-2]):
            #     agg_prompt += f"### Model {idx+1}'s response:\n{ans}\n\n"

            # agg_prompt = (f"Question: {sample['question']}\n"
            #                f"Provide your step-by-step reasoning first, and then print \"The answer is \\boxed{{X}}\", "
            #                f"where X is the final answer, at the end of your response.\n\n"
            #             )

            # train_prompts_reasoning.append(agg_prompt)
                
        # train_prompts_reasoning = [
        #     f"Provide your step-by-step reasoning first, and then print \"The answer is \\boxed{{X}}\", "
        #     f"where X is the final answer, at the end of your response.\n"
        #     # f"After providing your answer, you MUST provide a confidence score on a scale from 0 to 10, where:\n"
        #     # f"   - 0 means you are completely unsure\n"
        #     # f"   - 10 means you are absolutely certain\n"
        #     # f"   - 5 means you are somewhat confident\n"
        #     # f"Format your confidence score EXACTLY as: **Confidence Score: [your score]**\n"
        #     # f"Do not deviate from this format. The confidence score must be a number between 0 and 10.\n"
        #     # f"Example: **Confidence Score: 10**\n"
        #     # f"Example: **Confidence Score: 5**\n"
        #     # f"Example: **Confidence Score: 0**\n"
        #     # f"Do not deviate from this format. The confidence score must be a number between 0 and 10.\n"
        #     f"Question: {sample['question']}\n\n"
        #     for sample in train_samples
        # ]

MAX_TOKENS = 32768
N_SAMPLES = 1
TEMPERATURE = 0.7
THRESHOLD = 4000
TOP_P = 1.0
EPSILON = 0.01

agent_map = {
    "Llama": "/datasets/ai/llama3/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/", # meta-llama/Meta-Llama-3.1-8B-Instruct
    "Qwen": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-7B-Instruct/snapshots/a09a35458c702b33eeacc393d103063234e8bc28/", # Qwen/Qwen2.5-7B-Instruct
    "Mistral": "/datasets/ai/mixtral/hub/models--mistralai--Mistral-Nemo-Instruct-2407/snapshots/8aedd450f2583e9c67fae1929f6936b8fc5aef9c/", # mistralai/Mistral-Nemo-Instruct-2407
    "Phi": "/datasets/ai/phi/hub/models--microsoft--Phi-3.5-mini-instruct/snapshots/3145e03a9fd4cdd7cd953c34d9bbf7ad606122ca/", # microsoft/Phi-3.5-mini-instruct
    "Gemma": "/datasets/ai/gemma/hub/models--google--gemma-2-9b-it/snapshots/1937c70277fcc5f7fb0fc772fc5bc69378996e71/", # google/gemma-2-9b-it
    "GLM": "/datasets/ai/glm/hub/models--THUDM--glm-4-9b-chat/snapshots/bd8234fe5e0c09c48637a92abb0c797cb5fa0e73/", # THUDM/glm-4-9b-chat
    "Exaone": "/datasets/ai/lg/hub/models--LGAI-EXAONE--EXAONE-3.5-7.8B-Instruct/snapshots/0ff6b5ec7c13b049b253a16a889aa269e6b79a94/", # LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct
    "Granite": "/datasets/ai/ibm-granite/hub/models--ibm-granite--granite-3.1-8b-instruct/snapshots/3f05a1d007b2484bbf17593efe110bd5b9d67655/", # ibm-granite/granite-3.1-8b-instruct
    "QwenMath": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d/", # Qwen/Qwen2.5-Math-7B
    "QwenCode": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Coder-7B-Instruct/snapshots/c03e6d358207e414f1eca0bb1891e29f1db0e242/", #"Qwen/Qwen2.5-Coder-7B-Instruct",
    "DeepSeekMath": "/datasets/ai/deepseek/hub/models--deepseek-ai--deepseek-math-7b-instruct/snapshots/0a5828f800a36df0fd7f0ed581b983246c0677ff/", # deepseek-ai/deepseek-math-7b-instruct
    "QwenR1": "/datasets/ai/deepseek/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-7B/snapshots/6602cadec947dbb53e64f3d8d6425320b2197247/", # deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
    "LlamaR1": "/datasets/ai/deepseek/hub/models--deepseek-ai--DeepSeek-R1-Distill-Llama-8B/snapshots/74fbf131a939963dd1e244389bb61ad0d0440a4d/", # deepseek-ai/DeepSeek-R1-Distill-Llama-8B
    "InternLM": "/datasets/ai/internlm/hub/models--internlm--internlm3-8b-instruct/snapshots/28c99415adaf61767bd1c619f4f99f308fdfd223/", # internlm/internlm3-8b-instruct
    "Mathstral": "/datasets/ai/mixtral/hub/models--mistralai--Mathstral-7B-v0.1/snapshots/b6408c37979d6805935973ab06468089ff72ce95/", # mistralai/Mathstral-7B-v0.1
    "BioLlama": "/datasets/ai/contactdoctor/hub/models--ContactDoctor--Bio-Medical-Llama-3-8B/snapshots/b42b41f30767e43b6a636490bde14d82e5bad0c1/", # ContactDoctor/Bio-Medical-Llama-3-8B  
    "Qwen72B": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-72B-Instruct/snapshots/495f39366efef23836d0cfae4fbe635880d2be31/", # Qwen/Qwen2.5-72B-Instruct
    "Llama70B": "/datasets/ai/llama3/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b/", # meta-llama/Llama-3.3-70B-Instruct
    "NLlama": "/datasets/ai/nvidia/hub/models--nvidia--Llama-3.1-Nemotron-Nano-8B-v1/snapshots/a22e1c57330633cd3522903f9bb82480bf3192a6/", # nvidia/Llama-3.1-Nemotron-Nano-8B-v1
    "tinyQwen": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-1.5B-Instruct/snapshots/989aa7980e4cf806f80c7fef2b1adb7bc71aa306/", # Qwen/Qwen2.5-1.5B-Instruct
    "tinyQwenR1": "/datasets/ai/deepseek/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-1.5B/snapshots/530ca3e1ad39d440e182c2e4317aa40f012512fa/" # deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
}

MODELS = [
    "Llama", #
    "Qwen", #
    "Gemma", #
]

task_to_weights = {
    # "MATH500": {"Llama": 0.5464, "Qwen": 0.7528, "Gemma": 0.5179},
    "MATH500": {"tinyQwen": 0.54},
    # "MATH500": {"Qwen": 0.7528},
    # "MATH500": {"Qwen": 0.7528},
    # "AIME24": {"Llama": 0.065, "Qwen": 0.1243, "Gemma": 0.0141},
    # "AIME24": {"NLlama": 0.6921, "QwenR1": 0.5347},
    # "AIME24": {"NLlama": 0.5921, "QwenR1": 0.4347},
    # "AIME24": {"QwenR1": 0.5347},
    "AIME24": {"tinyQwenR1": 0.33},
    "AIME25": {"Llama": 0.065, "Qwen": 0.1243, "Gemma": 0.0141},
    # "GSM8K": {"Llama": 0.8723, "Qwen": 0.9445, "Gemma": 0.9133},
    # "GSM8K": {"Qwen": 0.9445},
    "GSM8K": {"tinyQwen": 0.81},
    # "MMLU_Pro": {"Llama": 0.3457, "Qwen": 0.5343, "Gemma": 0.5314},
    # "MMLU_Pro": {"Qwen": 0.5343, "Gemma": 0.5314},
    # "MMLU_Pro": {"Qwen": 0.5343},
    # "MMLU_Pro": {"LlamaR1": 0.5343},
    "MMLU_Pro": {"tinyQwen": 0.34},
    # "GPQA": {"QwenR1": 0.5343},
    "GPQA": {"tinyQwenR1": 0.31},
}

confidence_stats = {
    "AIME24": {
        "QwenR1": {
            "thresholds": [4780.0, 8571.5, 17034.75],   # Q1, Q2, Q3
            # "scores":     [0.929266, 0.768927, 0.442034, 0.216836],  # Q1..Q4
            "scores":     [0.829266, 0.668927, 0.342034, 0.116836],  # Q1..Q4
        },
        "NLlama": {
            "thresholds": [4529.75, 7919.0, 13886.25],   # Q1, Q2, Q3
            # "scores":     [0.957966, 0.899322, 0.709831, 0.332316],  # Q1..Q4
            "scores":     [0.857966, 0.799322, 0.609831, 0.232316],  # Q1..Q4
        },
    },
    "MATH500": {
        "Gemma": {
            "thresholds": [222.0, 319.0, 452.0],   # Q1, Q2, Q3
            # "scores":     [0.825434, 0.646975, 0.417618, 0.200922],  # Q1..Q4
            "scores":     [0.857966, 0.799322, 0.609831, 0.232316],  # Q1..Q4
        },
        "Qwen": {
            "thresholds": [342.0, 486.0, 721.0],   # Q1, Q2, Q3
            # "scores":     [0.963733, 0.889632, 0.751458, 0.488898],  # Q1..Q4
            "scores":     [0.863733, 0.789632, 0.651458, 0.388898],  # Q1..Q4
        },
        "Llama": {
            "thresholds": [259.0, 385.0, 647.0],   # Q1, Q2, Q3
            # "scores":     [0.897843, 0.793198, 0.592455, 0.300694],  # Q1..Q4
            "scores":     [0.797843, 0.693198, 0.492455, 0.200694],  # Q1..Q4
        },
    },
    "GSM8K": {
        "Qwen": {
            "thresholds": [205.0, 247.0, 300.0],   # Q1, Q2, Q3
            "scores":     [0.984622, 0.970318, 0.954515, 0.891398],  # Q1..Q4
        },
        "Gemma": {
            "thresholds": [117.0, 151.0, 190.0],   # Q1, Q2, Q3
            "scores":     [0.969756, 0.938133, 0.912407, 0.840706],  # Q1..Q4
        },
    },
    "MMLU_Pro": {
        "Qwen": {
            "thresholds": [362.0, 461.0, 586.0],   # Q1, Q2, Q3
            "scores":     [0.694305, 0.584615, 0.512616, 0.421854],  # Q1..Q4
        },
        "Gemma": {
            "thresholds": [220.0, 302.0, 398.0],   # Q1, Q2, Q3
            "scores":     [0.696693, 0.605562, 0.471659, 0.386842],  # Q1..Q4
        },
    },
    
}

os.environ["OMP_NUM_THREADS"] = "20"

def parse_args():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--task',
        type=str
    )
    parser.add_argument(
        '--gpus',
        type=int
    )
    parser.add_argument(
        '--seed',
        type=int,
        default=42
    )
    parser.add_argument(
        '--budget',
        type=int,
        default=4
    )
    parser.add_argument(
        '--threshold',
        type=float,
        default=0.98
    )
    parser.add_argument(
        '--n_sampling',
        type=int,
        default=2
    )
    parser.add_argument(
        '--n_run',
        type=int,
        default=1
    )
    return parser.parse_args()

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def extract_confidence_score(response: str) -> float:
    """
    Extracts the confidence score from an LLM response string.
    Looks for a substring formatted exactly as:
        **Confidence Score: X**
    where X is a number between 0 and 10 (integer or float).
    Returns:
        float: the extracted confidence score.
    Raises:
        ValueError: if no valid confidence score is found.
    """
    # This pattern matches **Confidence Score: 0**, …, **Confidence Score: 10**, 
    # and allows decimals like 7.5, 10.0, etc.
    pattern = r"\*\*Confidence Score:\s*((?:10(?:\.0+)?|[0-9](?:\.\d+)?))\*\*"
    match = re.search(pattern, response)
    if not match:
        raise ValueError("No valid confidence score found in response.")
    return float(match.group(1))

def compute_answer_scores(answers, confidences):
    """
    For each answer x:
      score(x) = (∏ c_i for i where answers[i] == x)
               * (∏ ((1 - c_j) / U) for j where answers[j] != x)

    And for None:
      score(None) = ∏ ((1 - c_i) / U)  over all confidences.

    Finally, normalize so that all scores sum to 1.
    """
    if len(answers) != len(confidences):
        raise ValueError("answers and confidences must be the same length")

    U = len(set(answers))  # number of unique answers
    scores = {}

    # compute for each observed answer
    # for x in set(answers):
    #     # confidences for x
    #     confs_x    = [c for a, c in zip(answers, confidences) if a == x]
    #     # confidences not for x
    #     others     = [c for a, c in zip(answers, confidences) if a != x]

    #     prod_correct = prod(confs_x) if confs_x else 1.0
    #     prod_penalty = prod((1 - c) / U for c in others) if others else 1.0

    #     scores[x] = prod_correct * prod_penalty

    equiv_groups: List[Tuple[str, List[str]]] = []
    for j, ans in enumerate(answers):
        placed = False
        for i, (rep, confs) in enumerate(equiv_groups):
            if is_math_equiv(rep, ans):
                confs.append(confidences[j])
                placed = True
                break
        if not placed:
            equiv_groups.append((ans, [confidences[j]]))
    
    for rep, confs in equiv_groups:
        prod_correct = prod(confs)
        others = [c for _rep, _confs in equiv_groups for c in _confs if _rep != rep]
        prod_penalty = prod((1 - c) / U for c in others) if others else 1.0
        scores[rep] = prod_correct * prod_penalty
    
    # special None‐case: penalty over *all* confidences
    scores[None] = prod((1 - c) / U for c in confidences) if confidences else 1.0

    # print(scores)
    
    # normalize
    total = sum(scores.values())
    if total > 0:
        for k in scores:
            scores[k] /= total
    else:
        # if somehow all zero, distribute uniformly
        uniform = 1.0 / len(scores)
        for k in scores:
            scores[k] = uniform

    return scores


def compute_mc_answer_scores(
    choices,
    answers,
    confidences
):
    """
    Compute normalized scores for a fixed set of multiple-choice options.
    
    For each choice x in choices:
      score(x) = (∏ c_i for i where answers[i] == x)
               * (∏ ((1 - c_j) / U) for j where answers[j] != x)
    And for None (no answer):
      score(None) = ∏ ((1 - c_i) / U) over all confidences.
    Finally, normalize so all scores sum to 1.
    
    Args:
        choices: List of all possible answer options.
        answers: List of observed answers (each must be in choices or None).
        confidences: Corresponding list of confidence scores (0.0–1.0).
    
    Returns:
        A dict mapping each choice and None to its normalized probability score.
    """
    if len(answers) != len(confidences):
        raise ValueError("answers and confidences must be the same length")

    U = len(choices)  # number of options
    scores = {}

    # Compute scores for each fixed choice
    for choice in choices:
        # confidences for when the model chose this option
        confs_for = [c for a, c in zip(answers, confidences) if is_math_equiv(a, choice)]
        # confidences for when the model chose other options
        confs_against = [c for a, c in zip(answers, confidences) if not is_math_equiv(a, choice)]
        
        prod_correct = prod(confs_for) if confs_for else 1.0
        prod_penalty = prod((1 - c) / U for c in confs_against) if confs_against else 1.0
        
        scores[choice] = prod_correct * prod_penalty

    # Normalize
    total = sum(scores.values())
    if total > 0:
        for k in scores:
            scores[k] /= total
    else:
        # uniform distribution if degenerate
        uniform = 1.0 / len(scores)
        for k in scores:
            scores[k] = uniform

    return scores



def predict_confidence(task: str, model: str, num_tokens: float) -> float:
    """
    Given a task, a model, and a token count, return the
    average confidence score based on precomputed quartiles.
    """
    # fetch the stats for this task/model
    stats = confidence_stats.get(task, {}).get(model)
    if stats is None:
        raise KeyError(f"No confidence stats found for task={task!r}, model={model!r}")

    q1, q2, q3 = stats["thresholds"]
    s1, s2, s3, s4 = stats["scores"]

    # map num_tokens → quartile score
    if num_tokens <= q1:
        return s1
    elif num_tokens <= q2:
        return s2
    elif num_tokens <= q3:
        return s3
    else:
        return s4
    
def fit_confidence_model(data):
    """
    Fits a linear regression model to predict confidence scores based on token counts.
    
    Args:
        data: pandas DataFrame with columns for token counts (x) and confidence scores (y)
        
    Returns:
        Fitted LinearRegression model
    """
    # Reshape data for sklearn (needs 2D array)
    X = data['num_tokens'].values.reshape(-1, 1)
    y = data['confidence_score'].values
    
    # Create and fit linear regression model
    model = LinearRegression()
    model.fit(X, y)
    
    # Ensure predictions are clipped between 0 and 1
    def predict(x):
        x = np.array(x).reshape(-1, 1)
        pred = model.predict(x)
        return np.clip(pred, 0.0, 1.0)
        
    model.bounded_predict = predict
    return model


def get_model_responses(task, tokenizer, llm, indices, seed, is_reasoning=False):
    
    train_samples = read_json(f"./Datasets/test/{task}_test.json")
    # train_samples = read_json(f"./Datasets/train/{task}_train.json")
    train_samples = [train_samples[i] for i in indices]
    if task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
        train_prompts_reasoning = []

        train_prompts_reasoning = [
            f"Question: {sample['question']}\n"
            f"Provide your step-by-step reasoning first, and then print \"The answer is \\boxed{{X}}\", "
            f"where X is the final answer, at the end of your response."
            for sample in train_samples
        ]

        train_prompts_non_reasoning = [
            f"Question: {sample['question']}\n"
            f"Directly provide the correct answer choice without explanation, formatted as: \"The answer is \\boxed{{X}}\", "
            f"where X is the final answer, at the end of your response."
            for sample in train_samples
        ]
    else:
        train_prompts_reasoning = [
            f"Question: {sample['question']}\n"
            f"Provide your step-by-step reasoning first, and then print \"The answer is (X)\", "
            f"where X is the answer choice (one capital letter), at the end of your response."
            for sample in train_samples
        ]

        train_prompts_non_reasoning = [
            f"Question: {sample['question']}\n"
            f"Directly provide the correct answer choice without explanation, formatted as: \"The answer is (X)\", "
            f"where X is the answer choice (one capital letter), at the end of your response."
            for sample in train_samples
        ]

    messages_reasoning = []
    messages_non_reasoning = []
    for i, p in enumerate(train_prompts_reasoning):
        msg_reasoning = [{"role": "user", "content": p}]
        msg_reasoning = tokenizer.apply_chat_template(
            msg_reasoning,
            tokenize=False,
            add_generation_prompt=True
        )
        messages_reasoning.append(msg_reasoning)
        msg_non_reasoning = [{"role": "user", "content": train_prompts_non_reasoning[i]}]
        msg_non_reasoning = tokenizer.apply_chat_template(
            msg_non_reasoning,
            tokenize=False,
            add_generation_prompt=True
        )
        messages_non_reasoning.append(msg_non_reasoning)
        
    # sampling_params = SamplingParams(temperature=TEMPERATURE, max_tokens=MAX_TOKENS, n=N_SAMPLES, top_p=TOP_P)
    sampling_params = SamplingParams(temperature=TEMPERATURE, max_tokens=MAX_TOKENS, n=N_SAMPLES, seed=seed)
    if is_reasoning:
        outputs_reasoning = llm.generate(messages_reasoning, sampling_params)
        responses_reasoning = [output.outputs[0].text for output in outputs_reasoning]
        n_tokens_reasoning = [len(output.outputs[0].token_ids) for output in outputs_reasoning]
    else:
        outputs_non_reasoning = llm.generate(messages_non_reasoning, sampling_params)
        responses_non_reasoning = [output.outputs[0].text for output in outputs_non_reasoning]
        n_tokens_non_reasoning = [len(output.outputs[0].token_ids) for output in outputs_non_reasoning]

    if is_reasoning:
        return responses_reasoning, n_tokens_reasoning
    else:
        return responses_non_reasoning, n_tokens_non_reasoning

if __name__ == "__main__":
    args = parse_args()
    seed_everything(args.seed)

    num_choice = 10 if args.task == "MMLU_Pro" else 4
    num_choice = 2 if args.task in ["MATH500", "AIME24", "AIME25", "GSM8K"] else num_choice
    choices = [chr(65 + i) for i in range(num_choice)]
    test_samples = read_json(f"./Datasets/test/{args.task}_test.json")
    # test_samples = read_json(f"./Datasets/train/{args.task}_train.json")

    conf_data = pd.read_json(f"./confidence_score_training/data/combined_confidence_score_data_tiny_s{args.seed}_{args.task}.jsonl", lines=True)
    # group_ids = conf_data.index // 10
    # conf_data = conf_data.groupby(group_ids).agg({
    #     'question': 'first',
    #     'response_with_confidence': 'first',
    #     'gold_answer': 'first',
    #     'confidence_score': 'first',
    #     'num_tokens': 'mean',
    #     'correctness': 'first'
    # }).reset_index(drop=True)
    # conf_data['num_tokens'] = conf_data['num_tokens'].round().astype(int)
    conf_predictor = fit_confidence_model(conf_data)

    model_id = agent_map.get(args.agent)
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    if args.agent in ["Phi", "Mistral"]:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  max_model_len = 16000,
                  tensor_parallel_size = args.gpus,
                    trust_remote_code = True)
    elif args.agent in ["DeepSeekMath"]:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  max_model_len = 4096,
                  tensor_parallel_size = args.gpus,
                  trust_remote_code = True)        
    else:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  tensor_parallel_size = args.gpus,
                  trust_remote_code = True)

    available_models = list(task_to_weights[args.task].keys())
    # Sort by weight (best to worst)
    k = 1
    model_index  = 0
    available_models.sort(key=lambda x: task_to_weights[args.task][x], reverse=True)

    # model = random.choices(list(task_to_weights[args.task].keys()), weights=list(task_to_weights[args.task].values()), k=1)[0]
    model = available_models[model_index]
    indices = np.arange(len(test_samples))
    num_llm_calls = len(indices)
    all_answers = [[] for _ in range(len(test_samples))]
    reasoning_responses, n_tokens_reasoning = get_model_responses(args.task, model, args.gpus, indices, args.seed, is_reasoning=True)
    all_answers = [[response] for response in reasoning_responses]
    all_n_tokens = [[n_tokens] for n_tokens in n_tokens_reasoning]
    if args.task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
        reasoning_final_answers = [remove_boxed(last_boxed_only_string(response)) for response in reasoning_responses]
    else:
        reasoning_final_answers = [get_alphabet_choice(response, num_choice=num_choice) for response in reasoning_responses]

    all_final_answers = [[ans] for ans in reasoning_final_answers]
    # all_confs = [[task_to_weights[args.task][model]] for _ in range(len(reasoning_final_answers))]
    # all_confs = [[predict_confidence(args.task, model, n_tokens)] for n_tokens in n_tokens_reasoning]
    all_confs = [[conf_predictor.bounded_predict(n_tokens)] for n_tokens in n_tokens_reasoning]
    # for i in range(len(test_samples)):
    #     try:
    #         conf_score = extract_confidence_score(reasoning_responses[i]) / 10.0
    #     except ValueError as e:
    #         print(f"Response {i} error: {e}")
    #         conf_score = 0.8

    #     all_confs[i].append((conf_score + task_to_weights[args.task][model]) / 2)

    if args.task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
        choices_2_scores = [compute_answer_scores(curr_final_ans, curr_confs) for curr_final_ans, curr_confs in zip(all_final_answers, all_confs)]
    else:
        choices_2_scores = [compute_mc_answer_scores(choices, curr_final_ans, curr_confs) for curr_final_ans, curr_confs in zip(all_final_answers, all_confs)]

    for t in range(args.budget - 1):
        mask = [max([val for val in list(choices_2_scores[i].values())]) < args.threshold for i in range(len(choices_2_scores))]
        indices = np.nonzero(np.array(mask))[0]
        if len(indices) == 0:
            break
        num_llm_calls += len(indices)
        if k == args.n_sampling:
            model_index += 1
            k = 0
        model = available_models[model_index]
        k += 1
        # reasoning_responses, n_tokens_reasoning = get_model_responses(args.task, model, args.gpus, indices, np.array(all_answers)[indices], is_reasoning=True)
        reasoning_responses, n_tokens_reasoning = get_model_responses(args.task, model, args.gpus, indices, args.seed, is_reasoning=True)

        if args.task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
            reasoning_final_answers = [remove_boxed(last_boxed_only_string(response)) for response in reasoning_responses]
        else:
            reasoning_final_answers = [get_alphabet_choice(response, num_choice=num_choice) for response in reasoning_responses]

        for i, index in enumerate(indices):
            all_final_answers[index].append(reasoning_final_answers[i])
            # try:
            #     conf_score = extract_confidence_score(reasoning_responses[i]) / 10.0
            # except ValueError as e:
            #     print(f"Response {i} error: {e}")
            #     conf_score = 0.8
            # all_confs[index].append((conf_score + task_to_weights[args.task][model]) / 2)
            # all_confs[index].append(predict_confidence(args.task, model, n_tokens_reasoning[i]))
            all_confs[index].append(conf_predictor.bounded_predict(n_tokens_reasoning[i]))

            if args.task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
                choices_2_scores[index] = compute_answer_scores(all_final_answers[index], all_confs[index])
            else:
                choices_2_scores[index] = compute_mc_answer_scores(choices, all_final_answers[index], all_confs[index])

    
    preds = []
    for answer_scores in choices_2_scores:
        # not_none_choices = [(answer_scores[key][0], answer_scores[key][1]) for key in answer_scores.keys() if answer_scores[key][0] is not None]
        not_none_choices = [(key, answer_scores[key]) for key in answer_scores.keys() if key is not None]
        values = [val[1] for val in not_none_choices]
        preds.append(not_none_choices[np.argmax(values)][0])

    df = pd.DataFrame(test_samples)
    correctness = [is_math_equiv(final_answer, str(df['gold_answer'][i])) for i, final_answer in enumerate(preds)]
    acc = round(sum(correctness) / len(correctness) * 100, 2)
    df['pred'] = preds
    df['correctness'] = correctness
    df.to_csv(f"./Results/skills_2/{args.task}/v6_early_stop_seed{args.seed}_budget{args.budget}_{round(num_llm_calls / len(choices_2_scores) ,2)}_acc{acc}_model{model}_threshold{args.threshold}.csv", index=False)
    # df.to_csv(f"./Results/skills/{args.task}/parameter_tuning_early_stop_seed{args.seed}_budget{args.budget}_{round(num_llm_calls / len(choices_2_scores) ,2)}_acc{acc}_model{model}_threshold{args.threshold}.csv", index=False)

    print("Resutls of V5: A linear regression model is used to predict the confidence score.")    
    print(f"EarlyStop completed with accuracy: {acc}%")
    print(f"Average LLM calls per sample: {num_llm_calls / len(test_samples):.2f}")
    print(f"Total LLM calls: {num_llm_calls}")

    destroy_model_parallel()
    destroy_distributed_environment()
    del llm.llm_engine.model_executor
    if hasattr(llm, 'engine'):
        del llm.engine

    if 'llm' in locals():
        del llm
    gc.collect()
    torch.cuda.empty_cache()
    import torch.distributed as dist
    if dist.is_initialized():
        dist.destroy_process_group()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    time.sleep(10)
    subprocess.run(["nvidia-smi"], check=True)
