import json
import requests
import numpy as np
import random
import argparse
import os
from tqdm import tqdm
from datasets import load_dataset
# Import vLLM
from vllm import LLM, SamplingParams

random.seed(42)


def args_parse():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_1", type=str, required=True)
    parser.add_argument("--gpu_1", type=int, default=0,
                        help="GPU index for the model when not using server")
    parser.add_argument("--use_server", action="store_true", default=False,
                        help="Whether to use vLLM server API instead of direct integration")
    parser.add_argument("--cluster2", action="store_true", default=False,
                        help="Whether to load model from directory instead of HuggingFace")
    parser.add_argument(
        "--cluster1",
        action="store_true",
        default=False,
        help="Whether to use cluster1 format for the output"
    )
    parser.add_argument("--port", default=8003, type=int)
    parser.add_argument("--round", default=1, type=int)
    parser.add_argument("--num_query", default=100, type=int)
    parser.add_argument("--temp", default=0.7, type=float)
    parser.add_argument(
        "--cot",
        default=False,
        action='store_true',
        help="If this is True, you can use Chain-of-Thought during inference."
    )
    parser.add_argument(
        "--self_consistency",
        default=False,
        action='store_true',
        help="If this is True, apply self-consistency by generating multiple reasoning paths."
    )
    parser.add_argument(
        "--self_refinement",
        default=False,
        action='store_true',
        help="If this is True, apply self-refinement by giving iterative self-feedback."
    )
    parser.add_argument(
        "--num_samples",
        default=5,
        type=int,
        help="Number of reasoning paths to sample for self-consistency."
    )
    parser.add_argument(
        "--refinement_rounds",
        default=16,
        type=int,
        help="Number of refinement rounds for self-refinement."
    )
    parser.add_argument(
        "--dataset",
        type=str,
        required=True,
        choices=["gsm8k", "math500", "collegemath", "aime2024", "aime2025"],
        help="Dataset to use for inference"
    )
    parser.add_argument(
        "--data_topic",
        type=str,
        default="algebra",
        choices=["algebra", "calculus", "precalculus", "differential_equation",
                 "linear_algebra", "probability", "vector_calculus"],
        help="Topic for CollegeMath dataset (ignored for other datasets)"
    )
    parser.add_argument(
        "--output_dir",
        default="results",
        type=str,
        help="Base directory to save the result file"
    )
    return parser.parse_args()

# Function to format each model's prompt


def prompt_formatting(instruction, cot, prompt_dict, model_path):
    model_list = prompt_dict.keys()
    model_name = model_path.split("/")[-1].split("-")[0].lower()
    for model_key in model_list:
        if model_key in model_name:
            model = model_key
            break
    if model == "alpaca" or model == "orca":
        formatted_prompt = prompt_dict[model]["prompt_no_input"].format(
            instruction=instruction)
    else:
        formatted_prompt = prompt_dict[model]["prompt"].format(
            instruction=instruction)

    if cot:
        formatted_prompt += " Let's think step by step."

    return formatted_prompt


def refinement_prompt(question, current_response, round, prompt_dict, model_path, dataset):
    # Load the prompt template
    model_list = prompt_dict.keys()
    model_name = model_path.split("/")[-1].split("-")[0].lower()
    for model_key in model_list:
        if model_key in model_name:
            model = model_key
            break

    # Dataset-specific instructions
    if dataset == 'gsm8k':
        instruction = f"Now here is the given math problem: \n {question}\n"
        instruction += f"Here are your response from the previous round:\n"
        instruction += f"Round {round} response : {current_response}\n"
        instruction += f"Please review your previous solution carefully. Provide your feedback and update your solution if needed. Your final answer should be a single numerical number, in the form \\boxed{{answer}}, at the end of your response. Round {round + 1}.\n"
    elif dataset == 'aime2024':
        instruction = f"Now here is the given math problem: \n {question}\n"
        instruction += f"Here are your response from the previous round:\n"
        instruction += f"Round {round} response : {current_response}\n"
        instruction += f"Please review your previous solution carefully. Provide your feedback and update your solution if needed. Your final answer should be a single integer, in the form \\boxed{{answer}}, at the end of your response. Round {round + 1}.\n"
    else:  # math500 and collegemath
        instruction = f"Now here is the given math problem: \n {question}\n"
        instruction += f"Here are your response from the previous round:\n"
        instruction += f"Round {round} response : {current_response}\n"
        instruction += f"Please review your previous solution carefully. Provide your feedback and update your solution if needed. Your final answer should be in the form \\boxed{{answer}}, at the end of your response. Please use latex format for your answer (e.g. \\frac{1}{2} over 1/2). Round {round + 1}.\n"

    # Formatting
    formatted_prompt = prompt_dict[model]["prompt_format"].format(
        instruction=instruction)

    return formatted_prompt

# Load vLLM model directly (non-server mode)


def load_vllm_model(model_path, gpu_idx, args):
    print(f"Loading {model_path} on GPU {gpu_idx}...")
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_idx)

    # Configure cache directories for cluster1 computing if enabled
    if args.cluster1:
        # Set cache directories for cluster1 computing
        user = os.environ.get("USER", "user1")
        cache_dir = f"/{user}/hf_cache"
        os.environ["TRANSFORMERS_CACHE"] = cache_dir
        os.environ["HF_DATASETS_CACHE"] = cache_dir
        print(f"Using cluster1 computing cache directory: {cache_dir}")
    else:
        cache_dir = None

    model = LLM(
        model=model_path,
        tensor_parallel_size=1,  # Single GPU per model
        trust_remote_code=True,
        max_model_len=16384,
        dtype="bfloat16",
        download_dir=cache_dir if args.cluster1 else None,
    )

    return model

# Function to generate using direct vLLM instance


def generate_vllm_direct(model, prompt, args):
    # Set max_tokens based on dataset
    if args.dataset in ["gsm8k"]:
        max_tokens = 1024
    else:  # MATH500 and CollegeMath use 2048
        max_tokens = 2048

    sampling_params = SamplingParams(
        temperature=args.temp,
        top_p=0.9,
        max_tokens=max_tokens,
        n=args.num_samples if args.self_consistency else 1
    )

    outputs = model.generate([prompt], sampling_params)
    return [choice.text.strip() for choice in outputs[0].outputs]

# Function to call vLLM API and generate response


def generate_answer_vllm(model_key, model_name, prompt, args, VLLM_API_URLS, vllm_model=None):
    if args.use_server:
        headers = {"Content-Type": "application/json"}
        # Set max_tokens based on dataset
        if args.dataset in ["gsm8k", "aime2024"]:
            max_tokens = 1024
        else:  # MATH500 and CollegeMath use 2048
            max_tokens = 2048

        payload = {
            "model": model_name,
            "prompt": prompt,
            "max_tokens": max_tokens,
            "temperature": args.temp,
            "top_p": 0.9,
            "n": args.num_samples if args.self_consistency else 1
        }
        try:
            response = requests.post(
                VLLM_API_URLS[model_key], json=payload, headers=headers)
            response_json = response.json()
            return [choice["text"].strip() for choice in response_json["choices"]]
        except Exception as e:
            print(f"vLLM request failed for {model_name}: {e}")
            return ["Error generating response."]
    else:
        # Use direct vLLM model
        try:
            return generate_vllm_direct(vllm_model, prompt, args)
        except Exception as e:
            print(f"vLLM direct generation failed for {model_name}: {e}")
            return ["Error generating response."]

# Load JSON data


def read_jsonl(path: str):
    with open(path, "r") as fh:
        return [json.loads(line) for line in fh.readlines() if line]


if __name__ == "__main__":
    args = args_parse()
    model_names = {
        "model_1": args.model_1
    }

    # Load vLLM model directly if not using server
    vllm_model = None
    if not args.use_server:
        model_path = args.model_1
        if args.cluster2:
            model_name = args.model_1.split("/")[-1]
            model_path = f"/models/{model_name}"
        vllm_model = load_vllm_model(model_path, args.gpu_1, args)

    # vLLM API endpoints for different models
    VLLM_API_URLS = {
        "model_1": f"http://localhost:{args.port}/v1/completions"
    }

    # Set output directory based on dataset
    output_dir = f"{args.output_dir}/{args.dataset}"

    # naming section
    type = "1_" + args.model_1.split("/")[-1]
    folder_name = f"{type}_q_{args.num_query}"

    # Add data_topic for CollegeMath
    if args.dataset == "collegemath":
        folder_name += f"_topic_{args.data_topic}"

    if args.cot:
        folder_name += '_cot'
    if args.self_consistency:
        num_samples = args.num_samples
        folder_name += f'_sc_{num_samples}'
    if args.self_refinement:
        folder_name += f'_sr_{args.refinement_rounds}'
        args.num_samples = 1
    if not args.use_server:
        folder_name += '_direct'
    if args.cluster2:
        folder_name += '_cluster2'
    if args.cluster1:
        folder_name += '_cluster1'

    assert not (args.self_refinement == True & args.self_consistency ==
                True), "Self-refinement and self-consistency should be used together."

    if args.use_server:
        print(
            f"✅ Using vLLM API for inference with model {args.model_1} on {args.dataset} dataset.")
    else:
        print(
            f"✅ Using direct vLLM integration with model {args.model_1} on {args.dataset} dataset.")

    # Load dataset based on dataset argument
    if args.dataset == "gsm8k":
        dataset = load_dataset("openai/gsm8k", "main", split="test")
        questions = [{"question": item["question"],
                      "answer": item["answer"]} for item in dataset]
    elif args.dataset == "math500":
        dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
        questions = [{"question": item["problem"], "answer": item["solution"],
                      "level": item["level"]} for item in dataset]
    elif args.dataset == "collegemath":
        dataset = load_dataset("di-zhang-fdu/College_Math_Test", split="test")
        dataset = dataset.filter(
            lambda x: args.data_topic == x["data_topic"].split(".")[-1])
        questions = [{"question": item["question"], "answer": item["answer"],
                      "data_topic": item["data_topic"]} for item in dataset]
    elif args.dataset == "aime2024":
        dataset = load_dataset("Maxwell-Jia/AIME_2024", split="train")
        questions = [{"question": item["Problem"],
                      "answer": item["Answer"]} for item in dataset]
    elif args.dataset == "aime2025":
        dataset = load_dataset("yentinglin/aime_2025", split="train")
        questions = [{"question": item["problem"],
                      "answer": item["answer"]} for item in dataset]

    # Load prompt templates for math reasoning
    prompt_path = "src/prompt_template.json"
    with open(prompt_path, "r") as prompt_file:
        prompt_dict = json.load(prompt_file)

    generated_description = []
    for idx in tqdm(range(len(questions[:args.num_query]))):
        question = questions[idx]["question"]
        answer = questions[idx]["answer"]

        # Construct prompt based on dataset
        if args.dataset == 'gsm8k':
            instruction = f"Now here is the given math problem: \n {question}\n Give your reasoning first and then provide your final answer. Your final answer should be a single numerical number, in the form \\boxed{{answer}}, at the end of your response."
        elif args.dataset in ['aime2024', 'aime2025']:
            instruction = f"Now here is the given math problem: \n {question}\n Give your reasoning first and then provide your final answer. Your final answer should be a single integer, in the form \\boxed{{answer}}, at the end of your response."
        else:
            instruction = f"Now here is the given math problem: \n {question}\n Give your reasoning first and then provide your final answer. Your final answer should be in the form \\boxed{{answer}}, at the end of your response. Please use latex format for your answer (e.g. \\frac{1}{2} over 1/2)"

        agent_contexts = {
            "model_1": [{"model": model_names["model_1"], "content": prompt_formatting(instruction, args.cot, prompt_dict, model_names["model_1"])}]
        }

        # Generate responses
        for model_key in agent_contexts.keys():
            if args.self_refinement:
                for round in range(args.refinement_rounds):
                    if round == 0:
                        formatted_prompt = agent_contexts[model_key][-1]["content"]
                    else:
                        current_response = agent_contexts[model_key][-1]["content"]
                        formatted_prompt = refinement_prompt(
                            question, current_response, round, prompt_dict, model_names[model_key], args.dataset)
                    response = generate_answer_vllm(
                        model_key, model_names[model_key], formatted_prompt, args, VLLM_API_URLS, vllm_model)
                    agent_contexts[model_key].append(
                        {"model": model_names[model_key], "content": response})
            else:
                formatted_prompt = agent_contexts[model_key][-1]["content"]
                response = generate_answer_vllm(
                    model_key, model_names[model_key], formatted_prompt, args, VLLM_API_URLS, vllm_model)
                agent_contexts[model_key].append(
                    {"model": model_names[model_key], "content": response})

        # Save all responses
        models_response = {
            args.model_1+'1': [agent_contexts["model_1"][i]["content"] for i in range(1, len(agent_contexts["model_1"]))]
        }

        # Create response object with different fields based on dataset
        response_obj = {
            "question_id": idx,
            "question": question,
            "agent_response": models_response,
            "answer": answer
        }

        # Add dataset-specific fields
        if args.dataset == "math500":
            response_obj["level"] = questions[idx]["level"]
        elif args.dataset == "collegemath":
            response_obj["data_topic"] = questions[idx]["data_topic"].split(
                ".")[-1]

        generated_description.append(response_obj)

    # Save results
    output_path = f"{output_dir}/{folder_name}"
    os.makedirs(output_path, exist_ok=True)
    print(f"The result file '{output_path}/conversation.json' is saving...")
    with open(f"{output_path}/conversation.json", "w") as f:
        json.dump(generated_description, f, indent=4)

    print("✅ All done!!")
