import os
import time
import json
import torch
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter
from typing import List, Dict
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from utils import read_json, write_json, get_keywords, get_alphabet_choice, remove_boxed, last_boxed_only_string, is_math_equiv
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
import contextlib
import gc
import subprocess
import math

MAX_TOKENS = 32768
N_SAMPLES = 1
TEMPERATURE = 0.7
THRESHOLD = 4000
TOP_P = 1.0

agent_map = {
    "Llama": "/datasets/ai/llama3/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/", # meta-llama/Meta-Llama-3.1-8B-Instruct
    "Qwen": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-7B-Instruct/snapshots/a09a35458c702b33eeacc393d103063234e8bc28/", # Qwen/Qwen2.5-7B-Instruct
    "Mistral": "/datasets/ai/mixtral/hub/models--mistralai--Mistral-Nemo-Instruct-2407/snapshots/8aedd450f2583e9c67fae1929f6936b8fc5aef9c/", # mistralai/Mistral-Nemo-Instruct-2407
    "Phi": "/datasets/ai/phi/hub/models--microsoft--Phi-3.5-mini-instruct/snapshots/3145e03a9fd4cdd7cd953c34d9bbf7ad606122ca/", # microsoft/Phi-3.5-mini-instruct
    "Gemma": "/datasets/ai/gemma/hub/models--google--gemma-2-9b-it/snapshots/1937c70277fcc5f7fb0fc772fc5bc69378996e71/", # google/gemma-2-9b-it
    "GLM": "/datasets/ai/glm/hub/models--THUDM--glm-4-9b-chat/snapshots/bd8234fe5e0c09c48637a92abb0c797cb5fa0e73/", # THUDM/glm-4-9b-chat
    "Exaone": "/datasets/ai/lg/hub/models--LGAI-EXAONE--EXAONE-3.5-7.8B-Instruct/snapshots/0ff6b5ec7c13b049b253a16a889aa269e6b79a94/", # LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct
    "Granite": "/datasets/ai/ibm-granite/hub/models--ibm-granite--granite-3.1-8b-instruct/snapshots/3f05a1d007b2484bbf17593efe110bd5b9d67655/", # ibm-granite/granite-3.1-8b-instruct
    "QwenMath": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d/", # Qwen/Qwen2.5-Math-7B
    "QwenCode": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Coder-7B-Instruct/snapshots/c03e6d358207e414f1eca0bb1891e29f1db0e242/", #"Qwen/Qwen2.5-Coder-7B-Instruct",
    "DeepSeekMath": "/datasets/ai/deepseek/hub/models--deepseek-ai--deepseek-math-7b-instruct/snapshots/0a5828f800a36df0fd7f0ed581b983246c0677ff/", # deepseek-ai/deepseek-math-7b-instruct
    "QwenR1": "/datasets/ai/deepseek/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-7B/snapshots/6602cadec947dbb53e64f3d8d6425320b2197247/", # deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
    "LlamaR1": "/datasets/ai/deepseek/hub/models--deepseek-ai--DeepSeek-R1-Distill-Llama-8B/snapshots/74fbf131a939963dd1e244389bb61ad0d0440a4d/", # deepseek-ai/DeepSeek-R1-Distill-Llama-8B
    "InternLM": "/datasets/ai/internlm/hub/models--internlm--internlm3-8b-instruct/snapshots/28c99415adaf61767bd1c619f4f99f308fdfd223/", # internlm/internlm3-8b-instruct
    "Mathstral": "/datasets/ai/mixtral/hub/models--mistralai--Mathstral-7B-v0.1/snapshots/b6408c37979d6805935973ab06468089ff72ce95/", # mistralai/Mathstral-7B-v0.1
    "BioLlama": "/datasets/ai/contactdoctor/hub/models--ContactDoctor--Bio-Medical-Llama-3-8B/snapshots/b42b41f30767e43b6a636490bde14d82e5bad0c1/", # ContactDoctor/Bio-Medical-Llama-3-8B  
    "Qwen72B": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-72B-Instruct/snapshots/495f39366efef23836d0cfae4fbe635880d2be31/", # Qwen/Qwen2.5-72B-Instruct
    "Llama70B": "/datasets/ai/llama3/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b/", # meta-llama/Llama-3.3-70B-Instruct
    "NLlama": "/datasets/ai/nvidia/hub/models--nvidia--Llama-3.1-Nemotron-Nano-8B-v1/snapshots/a22e1c57330633cd3522903f9bb82480bf3192a6/", # nvidia/Llama-3.1-Nemotron-Nano-8B-v1
    "tinyQwen": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-1.5B-Instruct/snapshots/989aa7980e4cf806f80c7fef2b1adb7bc71aa306/", # Qwen/Qwen2.5-1.5B-Instruct
    "tinyQwenR1": "/datasets/ai/deepseek/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-1.5B/snapshots/530ca3e1ad39d440e182c2e4317aa40f012512fa/" # deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
}

MODELS = [
    "Llama", #
    "Qwen", #
    "Gemma", #
]

# task_to_weights = {
#     # "MATH500": {"Llama": 0.5464, "Qwen": 0.7528, "Gemma": 0.5179},
#     # "MATH500": {"Qwen": 0.7528},
#     "MATH500": {"tinyQwen": 0.7528},
#     # "AIME24": {"Llama": 0.065, "Qwen": 0.1243, "Gemma": 0.0141},
#     # "AIME24": {"NLlama": 0.6921, "QwenR1": 0.5367},
#     # "AIME24": {"NLlama": 0.6921},
#     "AIME24": {"QwenR1": 0.5367},
#     "AIME25": {"Llama": 0.065, "Qwen": 0.1243, "Gemma": 0.0141},
#     # "GSM8K": {"Llama": 0.8723, "Qwen": 0.9445, "Gemma": 0.9133},
#     "GSM8K": {"Qwen": 0.9445},
#     "MMLU_Pro": {"Llama": 0.3457, "Qwen": 0.5343, "Gemma": 0.5314},
#     "GPQA": {"QwenR1": 0.5343},
# }

task_to_weights = {
    # "MATH500": {"Llama": 0.5464, "Qwen": 0.7528, "Gemma": 0.5179},
    # "MATH500": {"tinyQwen": 0.54},
    "MATH500": {"Qwen": 0.7528},
    # "MATH500": {"Qwen": 0.7528},
    # "AIME24": {"Llama": 0.065, "Qwen": 0.1243, "Gemma": 0.0141},
    # "AIME24": {"NLlama": 0.6921, "QwenR1": 0.5347},
    # "AIME24": {"NLlama": 0.5921, "QwenR1": 0.4347},
    "AIME24": {"QwenR1": 0.5347},
    # "AIME24": {"tinyQwenR1": 0.33},
    # "AIME25": {"Llama": 0.065, "Qwen": 0.1243, "Gemma": 0.0141},
    # "GSM8K": {"Llama": 0.8723, "Qwen": 0.9445, "Gemma": 0.9133},
    "GSM8K": {"Qwen": 0.9445},
    # "GSM8K": {"tinyQwen": 0.81},
    # "MMLU_Pro": {"Llama": 0.3457, "Qwen": 0.5343, "Gemma": 0.5314},
    # "MMLU_Pro": {"Qwen": 0.5343, "Gemma": 0.5314},
    "MMLU_Pro": {"Qwen": 0.5343},
    # "MMLU_Pro": {"LlamaR1": 0.5343},
    # "MMLU_Pro": {"tinyQwen": 0.34},
    "GPQA": {"QwenR1": 0.5343},
    # "GPQA": {"tinyQwenR1": 0.31},
}

os.environ["OMP_NUM_THREADS"] = "20"

def parse_args():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--task',
        type=str
    )
    parser.add_argument(
        '--gpus',
        type=int
    )
    parser.add_argument(
        '--seed',
        type=int,
        default=42
    )
    parser.add_argument(
        '--n_sampling',
        type=int,
        default=5,
        help="Number of samples to get from each model"
    )
    parser.add_argument(
        '--n_run',
        type=int,
        default=1
    )
    return parser.parse_args()

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def get_model_responses(task, agent, n_gpus, indices, n_sampling, seed, is_reasoning=False):
    
    train_samples = read_json(f"./Datasets/test/{task}_test.json")
    train_samples = [train_samples[i] for i in indices]
    if task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
        train_prompts_reasoning = [
            f"Question: {sample['question']}\n"
            f"Provide your step-by-step reasoning first, and then print \"The answer is \\boxed{{X}}\", "
            f"where X is the final answer, at the end of your response."
            for sample in train_samples
        ]

        train_prompts_non_reasoning = [
            f"Question: {sample['question']}\n"
            f"Directly provide the correct answer choice without explanation, formatted as: \"The answer is \\boxed{{X}}\", "
            f"where X is the final answer, at the end of your response."
            for sample in train_samples
        ]
    else:
        train_prompts_reasoning = [
            f"Question: {sample['question']}\n"
            f"Provide your step-by-step reasoning first, and then print \"The answer is (X)\", "
            f"where X is the answer choice (one capital letter), at the end of your response."
            for sample in train_samples
        ]

        train_prompts_non_reasoning = [
            f"Question: {sample['question']}\n"
            f"Directly provide the correct answer choice without explanation, formatted as: \"The answer is (X)\", "
            f"where X is the answer choice (one capital letter), at the end of your response."
            for sample in train_samples
        ]

    model_id = agent_map.get(agent)
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

    messages_reasoning = []
    messages_non_reasoning = []
    for i, p in enumerate(train_prompts_reasoning):
        msg_reasoning = [{"role": "user", "content": p}]
        msg_reasoning = tokenizer.apply_chat_template(
            msg_reasoning,
            tokenize=False,
            add_generation_prompt=True
        )
        messages_reasoning.append(msg_reasoning)
        msg_non_reasoning = [{"role": "user", "content": train_prompts_non_reasoning[i]}]
        msg_non_reasoning = tokenizer.apply_chat_template(
            msg_non_reasoning,
            tokenize=False,
            add_generation_prompt=True
        )
        messages_non_reasoning.append(msg_non_reasoning)
    
    if agent in ["Phi", "Mistral"]:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  max_model_len = 16000,
                  tensor_parallel_size = n_gpus,
                    trust_remote_code = True)
    elif agent in ["DeepSeekMath"]:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  max_model_len = 4096,
                  tensor_parallel_size = n_gpus,
                  trust_remote_code = True)        
    else:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  tensor_parallel_size = n_gpus,
                  trust_remote_code = True)
        
    # sampling_params = SamplingParams(temperature=TEMPERATURE, max_tokens=MAX_TOKENS, n=n_sampling, top_p=TOP_P)
    sampling_params = SamplingParams(temperature=TEMPERATURE, max_tokens=MAX_TOKENS, n=n_sampling, seed=seed, logprobs=1)
    if is_reasoning:
        outputs_reasoning = llm.generate(messages_reasoning, sampling_params)
        responses_reasoning = [[val.text for val in output.outputs] for output in outputs_reasoning]
        n_tokens_reasoning = [[len(val.token_ids) for val in output.outputs] for output in outputs_reasoning]
        logprobs_reasoning = [[val.logprobs for val in output.outputs] for output in outputs_reasoning]
    else:
        outputs_non_reasoning = llm.generate(messages_non_reasoning, sampling_params)
        responses_non_reasoning = [[val.text for val in output.outputs] for output in outputs_non_reasoning]
        n_tokens_non_reasoning = [[len(val.token_ids) for val in output.outputs] for output in outputs_non_reasoning]
        logprobs_non_reasoning = [[val.logprobs for val in output.outputs] for output in outputs_non_reasoning]

    destroy_model_parallel()
    destroy_distributed_environment()
    del llm.llm_engine.model_executor
    if hasattr(llm, 'engine'):
        del llm.engine

    if 'llm' in locals():
        del llm
    gc.collect()
    torch.cuda.empty_cache()
    import torch.distributed as dist
    if dist.is_initialized():
        dist.destroy_process_group()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    time.sleep(10)
    subprocess.run(["nvidia-smi"], check=True)

    if is_reasoning:
        return responses_reasoning, n_tokens_reasoning, logprobs_reasoning
    else:
        return responses_non_reasoning, n_tokens_non_reasoning, logprobs_non_reasoning

def calculate_consistency_ratio(answers):
    """Calculate consistency ratio: most common / number of all current answers"""
    if not answers:
        return 0.0
    counter = Counter(answers)
    most_common_count = counter.most_common(1)[0][1]
    return most_common_count / len(answers)

def most_common_math(
    answers,
    n = None
):
    """
    Count and group answers up to math-equivalence *after* normalization.
    Returns a list of (representative, count), sorted descending.
    """
    groups = []

    for ans in answers:
        placed = False
        for i, (rep, cnt) in enumerate(groups):
            if is_math_equiv(rep, ans):
                groups[i] = (rep, cnt + 1)
                placed = True
                break
        if not placed:
            # start a new group; use the normalized form as the rep
            groups.append((ans, 1))

    groups.sort(key=lambda x: x[1], reverse=True)
    return groups if n is None else groups[:n]

if __name__ == "__main__":
    args = parse_args()
    seed_everything(args.seed)

    num_choice = 10 if args.task == "MMLU_Pro" else 4
    test_samples = read_json(f"./Datasets/test/{args.task}_test.json")
    
    # Get the list of models in order from best to worst based on task_to_weights
    available_models = list(task_to_weights[args.task].keys())
    # Sort by weight (best to worst)
    available_models.sort(key=lambda x: task_to_weights[args.task][x], reverse=True)
    
    # Initialize data structures to store all answers for each sample
    all_answers = [[] for _ in range(len(test_samples))]
    all_predictions = [[] for _ in range(len(test_samples))]
    num_llm_calls = 0
    
    # Start with all samples
    current_indices = np.arange(len(test_samples))
    
    # Iterate over models (like early_stop.py)
    for model_idx, model in enumerate(available_models):
        # if len(current_indices) == 0:
        #     break
            
        print(f"Processing model {model} ({model_idx + 1}/{len(available_models)}) with {len(current_indices)} samples")
        
        # Get responses from current model for remaining samples
        reasoning_responses, n_tokens_reasoning, logprobs_reasoning = get_model_responses(args.task, model, args.gpus, current_indices, args.n_sampling, args.seed, is_reasoning=True)
        num_llm_calls += len(current_indices) * args.n_sampling
        
        # Extract final answers using same method as early_stop.py
        final_answers = []
        for i, responses in enumerate(reasoning_responses):
            if args.task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
                current_final_answers = [remove_boxed(last_boxed_only_string(response)) for response in responses]
            else:
                current_final_answers = [get_alphabet_choice(response, num_choice=num_choice) for response in responses]
            final_answers.append(current_final_answers)
        
        # Update all_answers and all_predictions for current indices
        for i, (sample_idx, curr_responses, curr_final_answers) in enumerate(zip(current_indices, reasoning_responses, final_answers)):
            all_answers[sample_idx].extend(curr_final_answers)
            all_predictions[sample_idx].extend(curr_responses)
        
        # Calculate which samples need more models (consistency < threshold)
        # next_indices = []
        # for i, sample_idx in enumerate(current_indices):
        #     consistency_ratio = calculate_consistency_ratio(all_answers[sample_idx])
        #     if consistency_ratio < args.consistency_threshold:
        #         next_indices.append(sample_idx)
        
        # current_indices = np.array(next_indices)
        print(f"After model {model}: {len(current_indices)} samples still need more models")
    
    # Calculate final predictions using most frequent answer
    final_predictions = []
    for answers in all_answers:
        if answers:
            # counter = Counter(answers)
            # final_predictions.append(counter.most_common(1)[0][0])
            final_predictions.append(most_common_math(answers)[0][0])
        else:
            final_predictions.append("")
    
    # Calculate accuracy
    df = pd.DataFrame(test_samples)
    correctness = [is_math_equiv(final_answer, str(df['gold_answer'][i])) for i, final_answer in enumerate(final_predictions)]
    acc = round(sum(correctness) / len(correctness) * 100, 2)

    modified_logprobs = []
    for logprob in logprobs_reasoning:
        curr_query_logprobs = []
        for curr_sample_logprobs in logprob:
            curr_query_logprobs.append([
                math.exp(list(val.values())[0].logprob)
                for val in curr_sample_logprobs
            ])
        modified_logprobs.append(curr_query_logprobs)
    
    # Save results
    df['pred'] = final_predictions
    df['correctness'] = correctness
    df['all_answers'] = all_answers
    df['logprobs'] = modified_logprobs
    df['all_predictions'] = all_predictions
    
    # Create output directory
    # os.makedirs(f"./Results/skills/{args.task}", exist_ok=True)
    
    # df.to_csv(f"./Results/skills/{args.task}/model_switch_seed{args.seed}_budget{args.n_sampling*len(available_models)}_{round(num_llm_calls / len(test_samples), 2)}_consistency{args.consistency_threshold}_acc{acc}.csv", index=False)
    df.to_csv(f"./Results/skills_3/{args.task}/self_con_seed{args.seed}_budget{args.n_sampling*len(available_models)}_acc{acc}_models{available_models}_with_logprobs.csv", index=False)

    # print(f"ModelSwitch completed with accuracy: {acc}%")
    print(f"Self-Consistency completed with accuracy: {acc}%")
    print(f"Average LLM calls per sample: {num_llm_calls / len(test_samples):.2f}")
    print(f"Total LLM calls: {num_llm_calls}")
    print(f"Models used: {available_models}") 