import os
import time
import json
import torch
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter
from typing import List, Dict
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from utils import read_json, write_json, get_keywords, get_alphabet_choice, remove_boxed, last_boxed_only_string, is_math_equiv
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
# from agent import *
import contextlib
import gc
import os
import subprocess
import time
import re
from math import prod
import ast
from copy import deepcopy

# from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel
# import ray, gc, torch

MAX_TOKENS = 32768
N_SAMPLES = 1
TEMPERATURE = 1.0
THRESHOLD = 4000
TOP_P = 1.0
EPSILON = 0.01

agent_map = {
    "Llama": "/datasets/ai/llama3/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/", # meta-llama/Meta-Llama-3.1-8B-Instruct
    "Qwen": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-7B-Instruct/snapshots/a09a35458c702b33eeacc393d103063234e8bc28/", # Qwen/Qwen2.5-7B-Instruct
    "Mistral": "/datasets/ai/mixtral/hub/models--mistralai--Mistral-Nemo-Instruct-2407/snapshots/8aedd450f2583e9c67fae1929f6936b8fc5aef9c/", # mistralai/Mistral-Nemo-Instruct-2407
    "Phi": "/datasets/ai/phi/hub/models--microsoft--Phi-3.5-mini-instruct/snapshots/3145e03a9fd4cdd7cd953c34d9bbf7ad606122ca/", # microsoft/Phi-3.5-mini-instruct
    "Gemma": "/datasets/ai/gemma/hub/models--google--gemma-2-9b-it/snapshots/1937c70277fcc5f7fb0fc772fc5bc69378996e71/", # google/gemma-2-9b-it
    "GLM": "/datasets/ai/glm/hub/models--THUDM--glm-4-9b-chat/snapshots/bd8234fe5e0c09c48637a92abb0c797cb5fa0e73/", # THUDM/glm-4-9b-chat
    "Exaone": "/datasets/ai/lg/hub/models--LGAI-EXAONE--EXAONE-3.5-7.8B-Instruct/snapshots/0ff6b5ec7c13b049b253a16a889aa269e6b79a94/", # LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct
    "Granite": "/datasets/ai/ibm-granite/hub/models--ibm-granite--granite-3.1-8b-instruct/snapshots/3f05a1d007b2484bbf17593efe110bd5b9d67655/", # ibm-granite/granite-3.1-8b-instruct
    "QwenMath": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d/", # Qwen/Qwen2.5-Math-7B
    "QwenCode": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Coder-7B-Instruct/snapshots/c03e6d358207e414f1eca0bb1891e29f1db0e242/", #"Qwen/Qwen2.5-Coder-7B-Instruct",
    "DeepSeekMath": "/datasets/ai/deepseek/hub/models--deepseek-ai--deepseek-math-7b-instruct/snapshots/0a5828f800a36df0fd7f0ed581b983246c0677ff/", # deepseek-ai/deepseek-math-7b-instruct
    "QwenR1": "/datasets/ai/deepseek/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-7B/snapshots/6602cadec947dbb53e64f3d8d6425320b2197247/", # deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
    "LlamaR1": "/datasets/ai/deepseek/hub/models--deepseek-ai--DeepSeek-R1-Distill-Llama-8B/snapshots/74fbf131a939963dd1e244389bb61ad0d0440a4d/", # deepseek-ai/DeepSeek-R1-Distill-Llama-8B
    "InternLM": "/datasets/ai/internlm/hub/models--internlm--internlm3-8b-instruct/snapshots/28c99415adaf61767bd1c619f4f99f308fdfd223/", # internlm/internlm3-8b-instruct
    "Mathstral": "/datasets/ai/mixtral/hub/models--mistralai--Mathstral-7B-v0.1/snapshots/b6408c37979d6805935973ab06468089ff72ce95/", # mistralai/Mathstral-7B-v0.1
    "BioLlama": "/datasets/ai/contactdoctor/hub/models--ContactDoctor--Bio-Medical-Llama-3-8B/snapshots/b42b41f30767e43b6a636490bde14d82e5bad0c1/", # ContactDoctor/Bio-Medical-Llama-3-8B  
    "Qwen72B": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-72B-Instruct/snapshots/495f39366efef23836d0cfae4fbe635880d2be31/", # Qwen/Qwen2.5-72B-Instruct
    "Llama70B": "/datasets/ai/llama3/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b/", # meta-llama/Llama-3.3-70B-Instruct
    "NLlama": "/datasets/ai/nvidia/hub/models--nvidia--Llama-3.1-Nemotron-Nano-8B-v1/snapshots/a22e1c57330633cd3522903f9bb82480bf3192a6/", # nvidia/Llama-3.1-Nemotron-Nano-8B-v1
}

MODELS = [
    "Llama", #
    "Qwen", #
    "Gemma", #
]

task_to_weights = {
    "MATH500": {"Llama": 0.5464, "Qwen": 0.7528, "Gemma": 0.5179},
    # "MATH500": {"Qwen": 0.7528},
    # "MATH500": {"Qwen": 0.7528},
    # "AIME24": {"Llama": 0.065, "Qwen": 0.1243, "Gemma": 0.0141},
    "AIME24": {"NLlama": 0.6921, "QwenR1": 0.5347},
    # "AIME24": {"NLlama": 0.5921, "QwenR1": 0.4347},
    # "AIME24": {"NLlama": 0.5921},
    "AIME25": {"Llama": 0.065, "Qwen": 0.1243, "Gemma": 0.0141},
    "GSM8K": {"Llama": 0.8723, "Qwen": 0.9445, "Gemma": 0.9133},
    # "MMLU_Pro": {"Llama": 0.3457, "Qwen": 0.5343, "Gemma": 0.5314},
    "MMLU_Pro": {"Qwen": 0.5343, "Gemma": 0.5314},
}

confidence_stats = {
    "AIME24": {
        "QwenR1": {
            "thresholds": [4780.0, 8571.5, 17034.75],   # Q1, Q2, Q3
            # "scores":     [0.929266, 0.768927, 0.442034, 0.216836],  # Q1..Q4
            "scores":     [0.829266, 0.668927, 0.342034, 0.116836],  # Q1..Q4
        },
        "NLlama": {
            "thresholds": [4529.75, 7919.0, 13886.25],   # Q1, Q2, Q3
            # "scores":     [0.957966, 0.899322, 0.709831, 0.332316],  # Q1..Q4
            "scores":     [0.857966, 0.799322, 0.609831, 0.232316],  # Q1..Q4
        },
    },
    "MATH500": {
        "Gemma": {
            "thresholds": [222.0, 319.0, 452.0],   # Q1, Q2, Q3
            # "scores":     [0.825434, 0.646975, 0.417618, 0.200922],  # Q1..Q4
            "scores":     [0.857966, 0.799322, 0.609831, 0.232316],  # Q1..Q4
        },
        "Qwen": {
            "thresholds": [342.0, 486.0, 721.0],   # Q1, Q2, Q3
            # "scores":     [0.963733, 0.889632, 0.751458, 0.488898],  # Q1..Q4
            "scores":     [0.863733, 0.789632, 0.651458, 0.388898],  # Q1..Q4
        },
        "Llama": {
            "thresholds": [259.0, 385.0, 647.0],   # Q1, Q2, Q3
            # "scores":     [0.897843, 0.793198, 0.592455, 0.300694],  # Q1..Q4
            "scores":     [0.797843, 0.693198, 0.492455, 0.200694],  # Q1..Q4
        },
    },
    "GSM8K": {
        "Qwen": {
            "thresholds": [205.0, 247.0, 300.0],   # Q1, Q2, Q3
            "scores":     [0.984622, 0.970318, 0.954515, 0.891398],  # Q1..Q4
        },
        "Gemma": {
            "thresholds": [117.0, 151.0, 190.0],   # Q1, Q2, Q3
            "scores":     [0.969756, 0.938133, 0.912407, 0.840706],  # Q1..Q4
        },
    },
    "MMLU_Pro": {
        "Qwen": {
            "thresholds": [362.0, 461.0, 586.0],   # Q1, Q2, Q3
            "scores":     [0.694305, 0.584615, 0.512616, 0.421854],  # Q1..Q4
        },
        "Gemma": {
            "thresholds": [220.0, 302.0, 398.0],   # Q1, Q2, Q3
            "scores":     [0.696693, 0.605562, 0.471659, 0.386842],  # Q1..Q4
        },
    },
    
}

os.environ["OMP_NUM_THREADS"] = "20"

def parse_args():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--task',
        type=str
    )
    parser.add_argument(
        '--gpus',
        type=int
    )
    parser.add_argument(
        '--seed',
        type=int,
        default=42
    )
    parser.add_argument(
        '--budget',
        type=int,
        default=4
    )
    parser.add_argument(
        '--threshold',
        type=float,
        default=0.98
    )
    parser.add_argument(
        '--n_sampling',
        type=int,
        default=2
    )
    parser.add_argument(
        '--models',
        type=str,
        default='',
        help="Models to use for prediction"
    )
    parser.add_argument(
        '--fixed_window',
        type=int,
        default=2,
        help="Use fixed window size for early stopping"
    )
    return parser.parse_args()

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def extract_confidence_score(response: str) -> float:
    """
    Extracts the confidence score from an LLM response string.
    Looks for a substring formatted exactly as:
        **Confidence Score: X**
    where X is a number between 0 and 10 (integer or float).
    Returns:
        float: the extracted confidence score.
    Raises:
        ValueError: if no valid confidence score is found.
    """
    # This pattern matches **Confidence Score: 0**, …, **Confidence Score: 10**, 
    # and allows decimals like 7.5, 10.0, etc.
    pattern = r"\*\*Confidence Score:\s*((?:10(?:\.0+)?|[0-9](?:\.\d+)?))\*\*"
    match = re.search(pattern, response)
    if not match:
        raise ValueError("No valid confidence score found in response.")
    return float(match.group(1))

def compute_answer_scores(answers, confidences):
    """
    For each answer x:
      score(x) = (∏ c_i for i where answers[i] == x)
               * (∏ ((1 - c_j) / U) for j where answers[j] != x)

    And for None:
      score(None) = ∏ ((1 - c_i) / U)  over all confidences.

    Finally, normalize so that all scores sum to 1.
    """
    if len(answers) != len(confidences):
        raise ValueError("answers and confidences must be the same length")

    U = len(set(answers))  # number of unique answers
    scores = {}

    # compute for each observed answer
    # for x in set(answers):
    #     # confidences for x
    #     confs_x    = [c for a, c in zip(answers, confidences) if a == x]
    #     # confidences not for x
    #     others     = [c for a, c in zip(answers, confidences) if a != x]

    #     prod_correct = prod(confs_x) if confs_x else 1.0
    #     prod_penalty = prod((1 - c) / U for c in others) if others else 1.0

    #     scores[x] = prod_correct * prod_penalty

    equiv_groups: List[Tuple[str, List[str]]] = []
    for j, ans in enumerate(answers):
        placed = False
        for i, (rep, confs) in enumerate(equiv_groups):
            if is_math_equiv(rep, ans):
                confs.append(confidences[j])
                placed = True
                break
        if not placed:
            equiv_groups.append((ans, [confidences[j]]))
    
    for rep, confs in equiv_groups:
        prod_correct = prod(confs)
        others = [c for _rep, _confs in equiv_groups for c in _confs if _rep != rep]
        prod_penalty = prod((1 - c) / U for c in others) if others else 1.0
        scores[rep] = prod_correct * prod_penalty
    
    # special None‐case: penalty over *all* confidences
    scores[None] = prod((1 - c) / U for c in confidences) if confidences else 1.0

    # print(scores)
    
    # normalize
    total = sum(scores.values())
    if total > 0:
        for k in scores:
            scores[k] /= total
    else:
        # if somehow all zero, distribute uniformly
        uniform = 1.0 / len(scores)
        for k in scores:
            scores[k] = uniform

    return scores



def predict_confidence(task: str, model: str, num_tokens: float) -> float:
    """
    Given a task, a model, and a token count, return the
    average confidence score based on precomputed quartiles.
    """
    # fetch the stats for this task/model
    stats = confidence_stats.get(task, {}).get(model)
    if stats is None:
        raise KeyError(f"No confidence stats found for task={task!r}, model={model!r}")

    q1, q2, q3 = stats["thresholds"]
    s1, s2, s3, s4 = stats["scores"]

    # map num_tokens → quartile score
    if num_tokens <= q1:
        return s1
    elif num_tokens <= q2:
        return s2
    elif num_tokens <= q3:
        return s3
    else:
        return s4

def get_model_responses(prompts, agent, n_gpus, n_sampling):

    model_id = agent_map.get(agent)
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

    messages = []
    for p in prompts:
        msg = [{"role": "user", "content": p}]
        msg = tokenizer.apply_chat_template(
            msg,
            tokenize=False,
            add_generation_prompt=True
        )
        messages.append(msg)
    
    if agent in ["Phi", "Mistral"]:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  max_model_len = 16000,
                  tensor_parallel_size = n_gpus,
                    trust_remote_code = True)
    elif agent in ["DeepSeekMath"]:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  max_model_len = 4096,
                  tensor_parallel_size = n_gpus,
                  trust_remote_code = True)        
    else:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  tensor_parallel_size = n_gpus,
                  trust_remote_code = True)
        
    sampling_params = SamplingParams(temperature=TEMPERATURE, max_tokens=MAX_TOKENS, n=n_sampling, top_p=TOP_P)

    outputs = llm.generate(messages, sampling_params)
    responses = [[val.text for val in output.outputs] for output in outputs]
    n_tokens = [[len(val.token_ids) for val in output.outputs] for output in outputs]

    destroy_model_parallel()
    destroy_distributed_environment()
    del llm.llm_engine.model_executor
    if hasattr(llm, 'engine'):
        del llm.engine

    if 'llm' in locals():
        del llm
    gc.collect()
    torch.cuda.empty_cache()
    import torch.distributed as dist
    if dist.is_initialized():
        dist.destroy_process_group()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    time.sleep(10)
    subprocess.run(["nvidia-smi"], check=True)

    return responses, n_tokens

if __name__ == "__main__":
    args = parse_args()
    seed_everything(args.seed)

    num_choice = 10 if args.task == "MMLU_Pro" else 4
    num_choice = 2 if args.task in ["MATH500", "AIME24", "AIME25", "GSM8K"] else num_choice
    test_samples = read_json(f"./Datasets/test/{args.task}_test.json")

    available_models = list(task_to_weights[args.task].keys())
    # Sort by weight (best to worst)
    k = 1
    model_index  = 0
    # available_models.sort(key=lambda x: task_to_weights[args.task][x], reverse=True)
    available_models = ast.literal_eval(args.models)

    # model = random.choices(list(task_to_weights[args.task].keys()), weights=list(task_to_weights[args.task].values()), k=1)[0]
    model = available_models[0]
    indices = np.arange(len(test_samples))
    # num_llm_calls = len(indices) * args.n_sampling
    num_llm_calls = len(indices) * args.fixed_window
    test_samples = [test_samples[i] for i in indices]
    
    if args.task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
        prompts = [
            f"Question: {sample['question']}\n"
            f"Provide your step-by-step reasoning first, and then print \"The answer is \\boxed{{X}}\", "
            f"where X is the final answer, at the end of your response."
            for sample in test_samples
        ]
    else:
        prompts = [
            f"Question: {sample['question']}\n"
            f"Provide your step-by-step reasoning first, and then print \"The answer is (X)\", "
            f"where X is the answer choice (one capital letter), at the end of your response."
            for sample in test_samples
        ]

    responses, n_tokens = get_model_responses(prompts, model, args.gpus, args.fixed_window)
    
    if args.task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
        final_answers = [[remove_boxed(last_boxed_only_string(response)) for response in one_sample_responses] for one_sample_responses in responses]
    else:
        final_answers = [[get_alphabet_choice(response, num_choice=num_choice) for response in one_sample_responses] for one_sample_responses in responses]

    confs = [[predict_confidence(args.task, model, curr_n_tokens) for curr_n_tokens in one_sample_n_tokens] for one_sample_n_tokens in n_tokens]

    all_responses = deepcopy(responses)
    all_n_tokens = deepcopy(n_tokens)
    all_final_answers = deepcopy(final_answers)

    choices_2_scores = [compute_answer_scores(curr_final_ans, curr_confs) for curr_final_ans, curr_confs in zip(all_final_answers, confs)]

    # for model in available_models[1:]:
    for model in available_models[1:] * (args.budget - args.fixed_window):
        mask = [max([val for val in list(choices_2_scores[i].values())]) < args.threshold for i in range(len(choices_2_scores))]
        indices = np.nonzero(np.array(mask))[0]
        if len(indices) == 0:
            break
        # num_llm_calls += len(indices) * args.n_sampling
        num_llm_calls += len(indices)
        agg_prompt = (f"You have been provided with a set of responses from various open-source models to the latest user query. "
            f"Your task is to synthesize these responses into a single, high-quality response. "
            f"It is crucial to critically evaluate the information provided in these responses, "
            f"recognizing that some of it may be biased or incorrect. "
            f"Your response should not simply replicate the given answers but should offer a refined, "
            f"accurate, and comprehensive reply to the instruction. "
            f"Ensure your response is well-structured, coherent, and adheres" 
            f"to the highest standards of accuracy and reliability. "
            f"Responses from models:\n\n")

        prompts = []
        for i, index in enumerate(indices):
            prompt = ""
            prompt += agg_prompt
            for k, res in enumerate(all_responses[index][:-args.fixed_window]):
                prompt += f"### Model {k+1}'s response:\n{res.split('</think>')[0]}\n\n"
            if args.task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
                prompt += (f"Question: {test_samples[index]['question']}\n"
                    f"Provide your step-by-step reasoning first, and then print \"The answer is \\boxed{{X}}\", "
                    f"where X is the final answer, at the end of your response."
                )
            else:
                prompt += (f"Question: {test_samples[index]['question']}\n"
                    f"Provide your step-by-step reasoning first, and then print \"The answer is (X)\", "
                    f"where X is the answer choice (one capital letter), at the end of your response."
                )
            prompts.append(prompt)

        responses, n_tokens = get_model_responses(prompts, model, args.gpus, 1)



        if args.task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
            final_answers = [[remove_boxed(last_boxed_only_string(response)) for response in one_sample_responses] for one_sample_responses in responses]
        else:
            final_answers = [[get_alphabet_choice(response, num_choice=num_choice) for response in one_sample_responses] for one_sample_responses in responses]

        for i, index in enumerate(indices):
            all_responses[index].extend(responses[i])
            all_n_tokens[index].extend(n_tokens[i])
            all_final_answers[index].extend(final_answers[i])
            confs[index].extend([
                predict_confidence(args.task, model, curr_n_tokens) for curr_n_tokens in n_tokens[i]
            ])

            choices_2_scores[index] = compute_answer_scores(all_final_answers[index], confs[index])

    
    preds = []
    for answer_scores in choices_2_scores:
        # not_none_choices = [(answer_scores[key][0], answer_scores[key][1]) for key in answer_scores.keys() if answer_scores[key][0] is not None]
        not_none_choices = [(key, answer_scores[key]) for key in answer_scores.keys() if key is not None]
        values = [val[1] for val in not_none_choices]
        preds.append(not_none_choices[np.argmax(values)][0])

    df = pd.DataFrame(test_samples)
    correctness = [is_math_equiv(final_answer, str(df['gold_answer'][i])) for i, final_answer in enumerate(preds)]
    acc = round(sum(correctness) / len(correctness) * 100, 2)
    df['pred'] = preds
    df['correctness'] = correctness
    # df.to_csv(f"./Results/skills/{args.task}/SEQ_v2_early_stop_seed{args.seed}_budget{args.budget}_{round(num_llm_calls / len(choices_2_scores) ,2)}_acc{acc}_model{model}_threshold{args.threshold}.csv", index=False)
    df.to_csv(f"./Results/skills/{args.task}/SEQ_v2_early_stop_seed{args.seed}_budget{args.budget}_{round(num_llm_calls / len(choices_2_scores) ,2)}_acc{acc}_model{model}_threshold{args.threshold}.csv", index=False)

    print(f"EarlyStop completed with accuracy: {acc}%")
    print(f"Average LLM calls per sample: {num_llm_calls / len(test_samples):.2f}")
    print(f"Total LLM calls: {num_llm_calls}")

