import os
import time
import json
import torch
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter
from typing import List, Dict
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from utils import read_json, write_json, get_keywords, get_alphabet_choice, remove_boxed, last_boxed_only_string, is_math_equiv
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
# from agent import *
import contextlib
import gc
import os
import subprocess
import time

# from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel
# import ray, gc, torch

MAX_TOKENS = 32768
N_SAMPLES = 1
TEMPERATURE = 1.0
THRESHOLD = 4000
TOP_P = 1.0

agent_map = {
    "Llama": "/datasets/ai/llama3/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/", # meta-llama/Meta-Llama-3.1-8B-Instruct
    "Qwen": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-7B-Instruct/snapshots/a09a35458c702b33eeacc393d103063234e8bc28/", # Qwen/Qwen2.5-7B-Instruct
    "Mistral": "/datasets/ai/mixtral/hub/models--mistralai--Mistral-Nemo-Instruct-2407/snapshots/8aedd450f2583e9c67fae1929f6936b8fc5aef9c/", # mistralai/Mistral-Nemo-Instruct-2407
    "Phi": "/datasets/ai/phi/hub/models--microsoft--Phi-3.5-mini-instruct/snapshots/3145e03a9fd4cdd7cd953c34d9bbf7ad606122ca/", # microsoft/Phi-3.5-mini-instruct
    "Gemma": "/datasets/ai/gemma/hub/models--google--gemma-2-9b-it/snapshots/1937c70277fcc5f7fb0fc772fc5bc69378996e71/", # google/gemma-2-9b-it
    "GLM": "/datasets/ai/glm/hub/models--THUDM--glm-4-9b-chat/snapshots/bd8234fe5e0c09c48637a92abb0c797cb5fa0e73/", # THUDM/glm-4-9b-chat
    "Exaone": "/datasets/ai/lg/hub/models--LGAI-EXAONE--EXAONE-3.5-7.8B-Instruct/snapshots/0ff6b5ec7c13b049b253a16a889aa269e6b79a94/", # LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct
    "Granite": "/datasets/ai/ibm-granite/hub/models--ibm-granite--granite-3.1-8b-instruct/snapshots/3f05a1d007b2484bbf17593efe110bd5b9d67655/", # ibm-granite/granite-3.1-8b-instruct
    "QwenMath": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d/", # Qwen/Qwen2.5-Math-7B
    "QwenCode": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Coder-7B-Instruct/snapshots/c03e6d358207e414f1eca0bb1891e29f1db0e242/", #"Qwen/Qwen2.5-Coder-7B-Instruct",
    "DeepSeekMath": "/datasets/ai/deepseek/hub/models--deepseek-ai--deepseek-math-7b-instruct/snapshots/0a5828f800a36df0fd7f0ed581b983246c0677ff/", # deepseek-ai/deepseek-math-7b-instruct
    "QwenR1": "/datasets/ai/deepseek/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-7B/snapshots/6602cadec947dbb53e64f3d8d6425320b2197247/", # deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
    "LlamaR1": "/datasets/ai/deepseek/hub/models--deepseek-ai--DeepSeek-R1-Distill-Llama-8B/snapshots/74fbf131a939963dd1e244389bb61ad0d0440a4d/", # deepseek-ai/DeepSeek-R1-Distill-Llama-8B
    "InternLM": "/datasets/ai/internlm/hub/models--internlm--internlm3-8b-instruct/snapshots/28c99415adaf61767bd1c619f4f99f308fdfd223/", # internlm/internlm3-8b-instruct
    "Mathstral": "/datasets/ai/mixtral/hub/models--mistralai--Mathstral-7B-v0.1/snapshots/b6408c37979d6805935973ab06468089ff72ce95/", # mistralai/Mathstral-7B-v0.1
    "BioLlama": "/datasets/ai/contactdoctor/hub/models--ContactDoctor--Bio-Medical-Llama-3-8B/snapshots/b42b41f30767e43b6a636490bde14d82e5bad0c1/", # ContactDoctor/Bio-Medical-Llama-3-8B  
    "Qwen72B": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-72B-Instruct/snapshots/495f39366efef23836d0cfae4fbe635880d2be31/", # Qwen/Qwen2.5-72B-Instruct
    "Llama70B": "/datasets/ai/llama3/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b/", # meta-llama/Llama-3.3-70B-Instruct
    "NLlama": "/datasets/ai/nvidia/hub/models--nvidia--Llama-3.1-Nemotron-Nano-8B-v1/snapshots/a22e1c57330633cd3522903f9bb82480bf3192a6/", # nvidia/Llama-3.1-Nemotron-Nano-8B-v1
}

MODELS = [
    "NLlama", #
    "QwenR1", #
    # "Gemma", #
]

os.environ["OMP_NUM_THREADS"] = "20"

def parse_args():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--task',
        type=str
    )
    parser.add_argument(
        '--gpus',
        type=int
    )
    parser.add_argument(
        '--seed',
        type=int,
        default=42
    )
    return parser.parse_args()

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def get_model_responses(task, agent, seed, n_gpus, is_reasoning=False):
    
    train_samples = read_json(f"./Datasets/train/{task}_train.json")
    if task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
        train_prompts_reasoning = [
            f"Question: {sample['question']}\n"
            f"Provide your step-by-step reasoning first, and then print \"The answer is \\boxed{{X}}\", "
            f"where X is the final answer, at the end of your response."
            for sample in train_samples
        ]

        train_prompts_non_reasoning = [
            f"Question: {sample['question']}\n"
            f"Directly provide the correct answer choice without explanation, formatted as: \"The answer is \\boxed{{X}}\", "
            f"where X is the final answer, at the end of your response."
            for sample in train_samples
        ]
    else:
        train_prompts_reasoning = [
            f"Question: {sample['question']}\n"
            f"Provide your step-by-step reasoning first, and then print \"The answer is (X)\", "
            f"where X is the answer choice (one capital letter), at the end of your response."
            for sample in train_samples
        ]

        train_prompts_non_reasoning = [
            f"Question: {sample['question']}\n"
            f"Directly provide the correct answer choice without explanation, formatted as: \"The answer is (X)\", "
            f"where X is the answer choice (one capital letter), at the end of your response."
            for sample in train_samples
        ]

    model_id = agent_map.get(agent)
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

    messages_reasoning = []
    messages_non_reasoning = []
    for i, p in enumerate(train_prompts_reasoning):
        msg_reasoning = [{"role": "user", "content": p}]
        msg_reasoning = tokenizer.apply_chat_template(
            msg_reasoning,
            tokenize=False,
            add_generation_prompt=True
        )
        messages_reasoning.append(msg_reasoning)
        msg_non_reasoning = [{"role": "user", "content": train_prompts_non_reasoning[i]}]
        msg_non_reasoning = tokenizer.apply_chat_template(
            msg_non_reasoning,
            tokenize=False,
            add_generation_prompt=True
        )
        messages_non_reasoning.append(msg_non_reasoning)
    
    if agent in ["Phi", "Mistral"]:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  max_model_len = 16000,
                  tensor_parallel_size = n_gpus,
                    trust_remote_code = True)
    elif agent in ["DeepSeekMath"]:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  max_model_len = 4096,
                  tensor_parallel_size = n_gpus,
                  trust_remote_code = True)        
    else:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  tensor_parallel_size = n_gpus,
                  trust_remote_code = True)
        
    sampling_params = SamplingParams(temperature=TEMPERATURE, max_tokens=MAX_TOKENS, n=N_SAMPLES, top_p=TOP_P)
    if is_reasoning:
        outputs_reasoning = llm.generate(messages_reasoning, sampling_params)
        responses_reasoning = [output.outputs[0].text for output in outputs_reasoning]
        n_tokens_reasoning = [len(output.outputs[0].token_ids) for output in outputs_reasoning]
    else:
        outputs_non_reasoning = llm.generate(messages_non_reasoning, sampling_params)
        responses_non_reasoning = [output.outputs[0].text for output in outputs_non_reasoning]
        n_tokens_non_reasoning = [len(output.outputs[0].token_ids) for output in outputs_non_reasoning]

    destroy_model_parallel()
    destroy_distributed_environment()
    del llm.llm_engine.model_executor
    if hasattr(llm, 'engine'):
        del llm.engine

    if 'llm' in locals():
        del llm
    gc.collect()
    torch.cuda.empty_cache()
    import torch.distributed as dist
    if dist.is_initialized():
        dist.destroy_process_group()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    time.sleep(10)
    subprocess.run(["nvidia-smi"], check=True)

    if is_reasoning:
        return responses_reasoning, n_tokens_reasoning
    else:
        return responses_non_reasoning, n_tokens_non_reasoning




if __name__ == "__main__":
    args = parse_args()
    seed_everything(args.seed)

    num_choice = 10 if args.task == "MMLU_Pro" else 4
    for model in MODELS:
        reasoning_responses, n_tokens_reasoning = get_model_responses(args.task, model, args.seed, args.gpus, is_reasoning=True)
        if args.task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
            reasoning_final_answers = [remove_boxed(last_boxed_only_string(response)) for response in reasoning_responses]
        else:
            reasoning_final_answers = [get_alphabet_choice(response, num_choice=num_choice) for response in reasoning_responses]

        train_samples = read_json(f"./Datasets/train/{args.task}_train.json")
        df = pd.DataFrame(train_samples)
        df['responses'] = reasoning_responses
        df['n_tokens'] = n_tokens_reasoning
        correctness = [is_math_equiv(final_answer, str(df['gold_answer'][i])) for i, final_answer in enumerate(reasoning_final_answers)]
        acc = round(sum(correctness) / len(correctness) * 100, 2)
        df['correctness'] = correctness
        df.to_csv(f"./Results/skills/{args.task}/{model}_responses_train_samples_seed{args.seed}_{acc}.csv", index=False)
        

