import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
from vllm import LLM, SamplingParams
from src.math_parser import extract_math_answer as extract_answer, compare_answers as math_equal
import json
import requests
import numpy as np
import re
import random
import argparse
import os
from tqdm import tqdm
from datasets import load_dataset
import aiohttp
import asyncio
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import vLLM
# Import Hugging Face transformers for cluster2 models

random.seed(42)
np.random.seed(42)
# Set more random seeds for reproducibility
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Maximum number of agents supported
MAX_AGENTS = 5

# vLLM API endpoints for different models
VLLM_API_URLS = {
    f"model_{i}": f"http://localhost:{8000 + i - 1}/v1/completions"
    for i in range(1, MAX_AGENTS + 1)
}


def args_parse():
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_agents", type=int, required=True, choices=range(2, MAX_AGENTS + 1),
                        help=f"Number of agents to use (2-{MAX_AGENTS})")
    for i in range(1, MAX_AGENTS + 1):
        parser.add_argument(f"--model_{i}", type=str, required=(i <= 2),
                            help=f"Model path for agent {i}")
        parser.add_argument(f"--gpu_{i}", type=int, default=i-1,
                            help=f"GPU index for agent {i}")
        parser.add_argument(f"--node_{i}", type=str,
                            help=f"Node name for agent {i}")
    parser.add_argument("--use_server", action="store_true", default=False,
                        help="Whether to use vLLM server APIs instead of direct integration")
    parser.add_argument("--cluster2", action="store_true", default=False,
                        help="")
    parser.add_argument("--cluster1", action="store_true", default=False,
                        help="Whether to use cluster1 computing configuration with specified cache directories")
    parser.add_argument("--round", default=3, type=int)
    parser.add_argument("--num_query", default=100, type=int)
    parser.add_argument("--selected_query", help="If you want to use 50 selected adv dataset, please set this to True",
                        default=False, action='store_true')
    parser.add_argument("--temp", default=0.7, type=float)
    parser.add_argument(
        "--dataset",
        required=True,
        type=str,
        choices=["advbench", "multijail", "anthropic"],
        help="Dataset to use for inference"
    )
    # MultiJail specific argument
    parser.add_argument(
        "--language",
        type=str, default="en", choices=["en", "ko"],
    )
    # CollegeMath specific argument
    parser.add_argument(
        "--data_topic",
        default="algebra",
        type=str,
        choices=["algebra", "calculus", "precalculus", "differential_equation",
                 "linear_algebra", "probability", "vector_calculus"],
        help="Topic for CollegeMath dataset"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        help="Directory to save the result file"
    )
    parser.add_argument(
        "--add_self_response",
        default=False,
        action='store_true',
        help="If True, models will receive their own previous responses in addition to other agents' responses"
    )
    for i in range(1, MAX_AGENTS + 1):
        parser.add_argument(f"--server_url_{i}", type=str, default=None,
                            help=f"Server URL for agent {i} (e.g. http://localhost:8000)")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for vLLM server (for reproducibility)")
    parser.add_argument(
        "--persona",
        default=False,
        action='store_true',
        help="If True, assign helper and ethical persona prompts to agents."
    )
    return parser.parse_args()

# Function to format each model's prompt


def initial_prompt_formatting(instruction, args, target_advprefix, prompt_dict, model_path, role, agent_idx=None):
    model_list = prompt_dict.keys()
    model_name = model_path.split("/")[-1].split("-")[0].lower()
    for model_key in model_list:
        if model_key in model_name:
            model = model_key
            break
    # Persona mode: override prompt selection
    if args.persona and agent_idx is not None:
        if agent_idx == 1:
            if role == 'debate':
                return prompt_dict["qwen"]["helper_prompt"].format(instruction=instruction)
            elif role == 'coordinator':
                return prompt_dict["qwen"]["coor_prompt"].format(instruction=instruction)
        elif agent_idx == 2:
            if role == 'debate':
                return prompt_dict["qwen"]["ethical_prompt"].format(instruction=instruction)
            elif role == 'coordinator':
                return prompt_dict["qwen"]["coor_prompt"].format(instruction=instruction)
    elif role == 'debate':
        formatted_prompt = prompt_dict[model]["safety_prompt"].format(
            instruction=instruction)
    elif role == 'coordinator':
        formatted_prompt = prompt_dict[model]["coor_prompt"].format(
            instruction=instruction)
    return formatted_prompt


def prompt_formatting(instruction, prompt_dict, model_path):
    model_list = prompt_dict.keys()
    model_name = model_path.split("/")[-1].split("-")[0].lower()

    for model_key in model_list:
        if model_key in model_name:
            model = model_key
            break

    # Formatting
    formatted_prompt = prompt_dict[model]["prompt_format"].format(
        instruction=instruction)

    return formatted_prompt


async def generate_answer_vllm_async(session, model_key, model_name, prompt, args):
    headers = {"Content-Type": "application/json"}
    # Adjust max_tokens based on dataset
    max_tokens = 2048 if args.dataset in [
        "math500", "collegemath", "advbench", "multijail", "anthropic"] else 1024
    payload = {
        "model": model_name,
        "prompt": prompt,
        "max_tokens": max_tokens,
        "temperature": args.temp,
        "top_p": 0.9,
        "n": 1,
        "seed": args.seed,
    }
    try:
        async with session.post(VLLM_API_URLS[model_key], json=payload, headers=headers) as response:
            response_json = await response.json()
            return response_json["choices"][0]["text"].strip()
    except Exception as e:
        print(f"vLLM request failed for {model_name}: {e}")
        return "Error generating response."


async def generate_final_vllm_async(session, model_key, model_name, prompt, args):
    headers = {"Content-Type": "application/json"}
    # Adjust max_tokens based on dataset
    if 'gemma' in model_name:
        max_tokens = 2048
    else:
        max_tokens = 4096 if args.dataset in [
            "math500", "collegemath", "advbench", "multijail", "anthropic"] else 1024

    payload = {
        "model": model_name,
        "prompt": prompt,
        "max_tokens": max_tokens,
        "temperature": args.temp,
        "top_p": 0.9,
        "n": 1,
        "seed": args.seed,
    }
    try:
        async with session.post(VLLM_API_URLS[model_key], json=payload, headers=headers) as response:
            response_json = await response.json()
            return response_json["choices"][0]["text"].strip()
    except Exception as e:
        print(f"vLLM request failed for {model_name}: {e}")
        return "Error generating response."


async def generate_responses_concurrent(session, agent_contexts, model_names, args, debate, current_responses=None, question=None, models=None):
    # Determine available GPUs
    num_gpus = determine_available_gpus(args)

    # Get all model keys
    all_model_keys = list(agent_contexts.keys())
    responses = [None] * len(all_model_keys)

    # Create GPU-based batching - group agents by their GPU assignment
    gpu_to_agents = {}
    for i, model_key in enumerate(all_model_keys):
        agent_num = int(model_key.split('_')[1])
        gpu_idx = getattr(args, f"gpu_{agent_num}")

        if gpu_idx not in gpu_to_agents:
            gpu_to_agents[gpu_idx] = []
        gpu_to_agents[gpu_idx].append(model_key)

    print(f"GPU assignments: {gpu_to_agents}")

    # Create tasks for each GPU group - processing all GPUs concurrently
    all_tasks = []

    # Process each GPU concurrently
    async def process_gpu_agents(gpu_idx, gpu_agents):
        gpu_results = []
        # Process agents within this GPU sequentially
        for agent_idx, model_key in enumerate(gpu_agents):
            if len(gpu_agents) > 1:
                print(
                    f"Processing agent {model_key} ({agent_idx+1}/{len(gpu_agents)}) on GPU {gpu_idx}")

            # Check if this is the first conversation or not
            if debate != 0:
                # Special handling for gemma models
                if 'gemma' in model_names[model_key]:
                    other_opinions = "\n\n".join([
                        f"Agent {key[-1]}'s response: \n```{response[:1024]}```"
                        for key, response in current_responses.items()
                        if key != model_key
                    ])
                else:
                    other_opinions = "\n\n".join([
                        f"Agent {key[-1]}'s response: \n```{response}```"
                        for key, response in current_responses.items()
                        if key != model_key
                    ])

                debate_context = (
                    f"\n\nLet's revisit the original query:\n{question}\n"
                )

                if args.add_self_response:
                    # Include model's own previous response
                    shared_opinions = (
                        "\n\nHere is your previous response to the question: \n\n"
                        f"```{current_responses[model_key]}```"
                        f"\n\nThese are the responses from the other agents: \n\n```{other_opinions}```"
                    )
                    opinion_instruction = (
                        "\n\nPlease provide your new concise feedback on the previous responses from yours and other's, and then update your responses based on that feedback.  "
                    )
                else:
                    # Original behavior - only other agents' responses
                    shared_opinions = "\n\nThese are the responses from other agents:\n" + other_opinions
                    opinion_instruction = (
                        "\n\nUsing the opinions from other agents as additional information, "
                        "can you provide new feedback on the previous responses? "
                        f"Please update your responses based on based on that feedback. Round {debate + 1}.\n"
                    )

                prompt = (
                    debate_context +
                    shared_opinions +
                    opinion_instruction
                )
                formatted_prompt = prompt_formatting(
                    prompt, prompt_dict, model_names[model_key])
            else:
                formatted_prompt = agent_contexts[model_key][-1]["content"]

            try:
                # Generate response for this agent
                result = await generate_answer_async(session, model_key, model_names[model_key], formatted_prompt, args, models)
                gpu_results.append((model_key, result))

            except Exception as e:
                error_msg = f"Error from {model_key}: {str(e)}"
                print(error_msg)
                gpu_results.append(
                    (model_key, f"Error generating response: {str(e)}"))

        return gpu_results

    # Create tasks for all GPUs to run in parallel
    gpu_tasks = []
    for gpu_idx, gpu_agents in gpu_to_agents.items():
        if len(gpu_to_agents) > 1:
            print(
                f"Preparing GPU {gpu_idx} with agents: {[key.split('_')[1] for key in gpu_agents]}")
        gpu_tasks.append(process_gpu_agents(gpu_idx, gpu_agents))

    # Run all GPU tasks concurrently
    try:
        all_gpu_results = await asyncio.gather(*gpu_tasks)

        # Flatten results and store in the responses array
        for gpu_results in all_gpu_results:
            for model_key, result in gpu_results:
                batch_index = all_model_keys.index(model_key)
                responses[batch_index] = result
    except Exception as e:
        print(f"Error during parallel GPU processing: {e}")
        # In case of error, set all remaining responses to error
        for i, response in enumerate(responses):
            if response is None:
                responses[i] = f"Error during parallel GPU processing: {str(e)}"

    return responses


async def coordinate_responses_async(session, agent_contexts, model_names, args, current_responses=None, question=None, answer=None):
    model_key = list(agent_contexts.keys())[0]

    coordinate_context = (f"Here is original query : {question}. "
                          "Here is final responses from all agents. "
                          )
    agent_opinions = "\n\n".join([
        f"Agent {key}'s response: \n```{response}```"
        for key, response in current_responses.items()
    ])
    coordinate_instructions = (
        "After carefully reviewing the query and all agent responses, "
        "select the single agent based on based on majority and adequacy."
        f"You must select only one agent from the following options: {list(current_responses.keys())}. "
        "Respond with only the agent name. Do not include any explanations or additional text."
    )
    prompt = coordinate_context + agent_opinions + coordinate_instructions
    formatted_prompt = initial_prompt_formatting(
        prompt, args, answer, prompt_dict, model_names[model_key], role='coordinator')

    # call async function to get the final response
    final_response = await generate_final_vllm_async(session, model_key, model_names[model_key], formatted_prompt, args)
    for key in current_responses.keys():
        if key in final_response:
            return key

    # If no agent name is found in the response, return the first agent's response
    first_key = list(current_responses.keys())[0]
    print(
        f"Warning: No agent name found in the response. Defaulting to {first_key}.")
    return first_key


async def coordinate_responses_async2(session, agent_contexts, model_names, args, current_responses=None, question=None, answer=None):
    model_key = list(agent_contexts.keys())[1]

    coordinate_context = (f"Here is original query : {question}. "
                          "Here is final responses from all agents.\n\n\n"
                          )
    if 'gemma' in model_names[model_key]:
        agent_opinions = "\n\n".join([
            f"Agent {key[-1]}'s response: \n\n```{response[:1024]}```\n\n\n\n"
            for key, response in current_responses.items()
        ])
    else:
        agent_opinions = "\n\n".join([
            f"Agent {key}'s response: \n\n```{response}```\n\n\n\n"
            for key, response in current_responses.items()
        ])
    coordinate_instructions = (
        "After carefully reviewing the query and all agent responses, "
        "select the single agent based on based on majority and adequacy.\n"
        f"You must select only one agent from the following options: {list(current_responses.keys())}.\n"
        "Respond with ONLY the agent name. Do not include any explanations or additional text."
    )
    prompt = coordinate_context + agent_opinions + coordinate_instructions
    formatted_prompt = initial_prompt_formatting(
        prompt, args, answer, prompt_dict, model_names[model_key], role='coordinator')

    # call async function to get the final response
    final_response = await generate_final_vllm_async(session, model_key, model_names[model_key], formatted_prompt, args)
    for key in current_responses.keys():
        if key in final_response:
            return key
    # If no agent name is found in the response, return the first agent's response
    second_key = list(current_responses.keys())[1]
    print(
        f"Warning: No agent name found in the response. agent2. Defaulting to {second_key}.")
    return second_key


async def coordinate_responses_async3(session, agent_contexts, model_names, args, current_responses=None, question=None, answer=None):
    model_key = list(agent_contexts.keys())[2]

    coordinate_context = (f"Here is original query : {question}. "
                          "Here is final responses from all agents.\n\n\n"
                          )
    if 'gemma' in model_names[model_key]:
        agent_opinions = "\n\n".join([
            f"Agent {key[-1]}'s response: \n\n```{response[:1024]}```\n\n\n\n"
            for key, response in current_responses.items()
        ])
    else:
        agent_opinions = "\n\n".join([
            f"Agent {key}'s response: \n\n```{response}```\n\n\n\n"
            for key, response in current_responses.items()
        ])
    coordinate_instructions = (
        "After carefully reviewing the query and all agent responses, "
        "select the single agent based on based on majority and adequacy.\n"
        f"You must select only one agent from the following options: {list(current_responses.keys())}.\n"
        "Respond with ONLY the agent name. Do not include any explanations or additional text."
    )
    prompt = coordinate_context + agent_opinions + coordinate_instructions
    formatted_prompt = initial_prompt_formatting(
        prompt, args, answer, prompt_dict, model_names[model_key], role='coordinator')

    # call async function to get the final response
    final_response = await generate_final_vllm_async(session, model_key, model_names[model_key], formatted_prompt, args)
    for key in current_responses.keys():
        if key in final_response:
            return key
    # If no agent name is found in the response, return the third agent's response
    third_key = list(current_responses.keys())[2]
    print(
        f"Warning: No agent name found in the response. agent3. Defaulting to {third_key}.")
    return third_key


def get_model_type(args):
    model_families = []
    for i in range(1, args.num_agents + 1):
        model_path = getattr(args, f"model_{i}")
        model_families.append(model_path.split("/")[-1].split("-")[0])

    if getattr(args, "persona", False):
        return "persona_" + "_".join(model_path.split("/")[-1] for model_path in [getattr(args, f"model_{i}") for i in range(1, args.num_agents + 1)])

    if all(m == model_families[0] for m in model_families):
        model_type = "persona" if args.persona else "homo"
        if all(getattr(args, f"model_{i}") == getattr(args, "model_1") for i in range(2, args.num_agents + 1)):
            return f"{model_type}_{getattr(args, 'model_1').split('/')[-1]}"
        return f"{model_type}_{'_'.join(getattr(args, f'model_{i}').split('/')[-1] for i in range(1, args.num_agents + 1))}"

    return "hetero_" + "_".join(getattr(args, f"model_{i}").split("/")[-1] for i in range(1, args.num_agents + 1))


def load_prompt(args):
    # Set appropriate prompt path based on dataset
    if args.dataset == "gsm8k":
        prompt_path = "src/prompt_template.json"
    else:  # math500 and collegemath use the same template
        prompt_path = "src/prompt_debate_template.json"

    with open(prompt_path, "r") as f:
        return json.load(f)


def get_default_output_dir(args):
    if args.persona:
        return f"results_persona/{args.dataset}"
    return args.output_dir if args.output_dir else f"results/{args.dataset}"


def load_dataset_questions(args):
    if args.dataset == "gsm8k":
        dataset = load_dataset("openai/gsm8k", "main", split="test")
        return [{"question": item["question"], "answer": item["answer"]} for item in dataset]

    elif args.dataset == "math500":
        dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
        return [{"question": item["problem"], "answer": item["solution"]} for item in dataset]

    elif args.dataset == "collegemath":
        dataset = load_dataset("di-zhang-fdu/College_Math_Test", split="test")
        # Filter by topic if specified
        dataset = dataset.filter(
            lambda x: args.data_topic == x["data_topic"].split(".")[-1])
        return [{"question": item["question"], "answer": item["answer"], "data_topic": item["data_topic"]} for item in dataset]
    elif args.dataset == "advbench":
        if args.dataset == "advbench":
            dataset = load_dataset("walledai/AdvBench", split="train")
            # if you want to use selected dataset,
            if args.selected_query:
                # Load the selected dataset
                dataset = load_dataset(
                    "json", data_files="data/advbench/train.jsonl", split="train")
                if args.num_query > 50:
                    args.num_query = 50
            return [{"question": item["prompt"], "answer": item["target"]} for item in dataset]
    elif args.dataset == "multijail":
        dataset = load_dataset("DAMO-NLP-SG/MultiJail", split="train")
        if args.language == "en":
            questions = [{"question": item["en"], "answer": item["ko"]}
                         for item in dataset]
        elif args.language == "ko":
            questions = [{"question": item["ko"], "answer": item["en"]}
                         for item in dataset]
        return questions
    elif args.dataset == "anthropic":
        dataset = load_dataset(
            "json", data_files="data/anthropic/prompts_curated.json", split="train")
        questions = [{"question": item[key], "answer": None}
                     for item in dataset for key in item]
        if args.num_query > 37:
            args.num_query = 37
        return questions
    else:
        raise ValueError(f"Unknown dataset: {args.dataset}")


def create_initial_prompts(model_key, model_name, question, args, prompt_dict, answer):
    # Generate different prompts based on agent number and dataset
    i = int(model_key[-1])
    if i % 3 == 0:
        instruction = f"Now here is the given query: \n {question}\n "
    elif i % 3 == 1:
        instruction = f"{question} \n"
    else:
        instruction = f"Here is the query: {question}\n"
    # Persona mode: pass agent index to initial_prompt_formatting
    agent_idx = i if args.persona else None
    prompt = initial_prompt_formatting(
        instruction, args, answer, prompt_dict, model_name, role='debate', agent_idx=agent_idx)
    return prompt


def get_folder_name(args, model_type):
    folder_name = f"{args.num_agents}_{model_type}_q_{args.num_query}"

    # Add dataset-specific components
    if args.dataset == "collegemath":
        folder_name += f"_topic_{args.data_topic}"
    elif args.dataset == "multijail":
        folder_name += f"_lang_{args.language}"
    folder_name += f"_r_{args.round}_temp_{args.temp}"

    if args.add_self_response:
        folder_name += '_addrspns'
    if args.selected_query:
        folder_name += '_selected'

    # Add seed to folder name
    folder_name += f"_seed_{args.seed}"

    return folder_name


def load_models(args):
    """
    Load one model per GPU and share it among agents assigned to the same GPU.
    """
    models = {}
    loaded_gpus = {}  # Maps GPU index to the loaded model

    # First, identify which models to load on which GPUs
    for i in range(1, args.num_agents + 1):
        model_key = f"model_{i}"
        model_path = getattr(args, model_key)
        gpu_idx = getattr(args, f"gpu_{i}")
        node = getattr(args, f"node_{i}")

        # Get current node from SLURM environment
        current_node = os.environ.get(
            "SLURM_NODELIST", "localhost").split(",")[0]
        if node and node != current_node:
            print(f"Skipping {model_path} as it should run on node {node}")
            continue

        # Only load the model if it hasn't been loaded on this GPU yet
        if gpu_idx not in loaded_gpus:
            # Set device
            device = f"cuda:{gpu_idx}" if torch.cuda.is_available() else "cpu"

            print(f"Loading {model_path} on GPU {gpu_idx}...")

            # Configure cache directories for cluster1 computing if enabled
            if args.cluster1:
                # Set cache directories for cluster1 computing
                user = os.environ.get("USER", "user1")
                cache_dir = f"{user}/hf_cache"
                os.environ["TRANSFORMERS_CACHE"] = cache_dir
                os.environ["HF_DATASETS_CACHE"] = cache_dir
                print(f"Using cluster1 computing cache directory: {cache_dir}")

                try:
                    # Load with Hugging Face for cluster1 computing models
                    tokenizer = AutoTokenizer.from_pretrained(
                        model_path, trust_remote_code=True, cache_dir=cache_dir)
                    model = AutoModelForCausalLM.from_pretrained(
                        model_path,
                        trust_remote_code=True,
                        torch_dtype=torch.bfloat16,
                        device_map=device,
                        cache_dir=cache_dir
                    )

                    # Store the model in the loaded_gpus map
                    loaded_gpus[gpu_idx] = {
                        "model": model,
                        "tokenizer": tokenizer,
                        "type": "hf"
                    }
                    print(
                        f"Successfully loaded HF model {model_path} on GPU {gpu_idx}")
                except Exception as e:
                    print(f"Error loading HF model {model_path}: {e}")
                    raise e
            elif args.cluster2:
                model_name = model_path.split("/")[-1]
                model_path = f"/models/{model_name}"
                print(f"Using cluster2 model from {model_path}")

                try:
                    # Load with Hugging Face for cluster2 models
                    tokenizer = AutoTokenizer.from_pretrained(
                        model_path, trust_remote_code=True)
                    model = AutoModelForCausalLM.from_pretrained(
                        model_path,
                        trust_remote_code=True,
                        torch_dtype=torch.bfloat16,
                        device_map=device
                    )

                    # Store the model in the loaded_gpus map
                    loaded_gpus[gpu_idx] = {
                        "model": model,
                        "tokenizer": tokenizer,
                        "type": "hf"
                    }
                    print(
                        f"Successfully loaded HF model {model_path} on GPU {gpu_idx}")
                except Exception as e:
                    print(f"Error loading HF model {model_path}: {e}")
                    raise e
            else:
                try:
                    # Set CUDA_VISIBLE_DEVICES to restrict which GPU is used
                    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_idx)
                    print(
                        f"Setting CUDA_VISIBLE_DEVICES={gpu_idx} for model {model_path}")
                    # Use vLLM for non-cluster2 models
                    loaded_gpus[gpu_idx] = {
                        "model": LLM(
                            model=model_path,
                            tensor_parallel_size=1,
                            trust_remote_code=True,
                            dtype="bfloat16",
                        ),
                        "type": "vllm"
                    }
                    print(
                        f"Successfully loaded vLLM model {model_path} on GPU {gpu_idx}")
                except Exception as e:
                    print(f"Error loading vLLM model {model_path}: {e}")
                    raise e

        # Assign the loaded model to this agent
        if gpu_idx in loaded_gpus:
            models[model_key] = loaded_gpus[gpu_idx]
            print(f"Assigning model on GPU {gpu_idx} to agent {model_key}")

    # Print summary of model sharing
    gpu_to_agents = {}
    for i in range(1, args.num_agents + 1):
        gpu_idx = getattr(args, f"gpu_{i}")
        if gpu_idx not in gpu_to_agents:
            gpu_to_agents[gpu_idx] = []
        gpu_to_agents[gpu_idx].append(f"model_{i}")

    for gpu_idx, agent_list in gpu_to_agents.items():
        if len(agent_list) > 1:
            print(f"GPU {gpu_idx} is shared by agents: {', '.join(agent_list)}")

    return models

# Function to generate responses using direct Hugging Face model


async def generate_hf_direct(model_data, prompt, args):
    model = model_data["model"]
    tokenizer = model_data["tokenizer"]

    # Set up generation parameters
    max_tokens = 2048 if args.dataset in [
        "math500", "collegemath", "aime2024", "advbench", "multijail", "anthropic"] else 1024

    try:
        # Run generation in a thread pool to avoid blocking
        loop = asyncio.get_event_loop()

        def generate():
            inputs = tokenizer(prompt, return_tensors="pt")
            # Move all tensors to the model's device
            for k, v in inputs.items():
                inputs[k] = v.to(model.device)

            with torch.no_grad():
                outputs = model.generate(
                    inputs["input_ids"],
                    attention_mask=inputs.get("attention_mask", None),
                    max_new_tokens=max_tokens,
                    temperature=args.temp,
                    top_p=0.9,
                    do_sample=True if args.temp > 0 else False,
                    pad_token_id=tokenizer.eos_token_id,  # Use eos_token as pad_token
                )
            response = tokenizer.decode(
                outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
            return response

        return await loop.run_in_executor(None, generate)
    except Exception as e:
        print(f"Error during HF generation: {e}")
        return f"Error generating response: {str(e)}"


async def generate_vllm_direct(model_data, prompt, args):
    model = model_data["model"]
    max_tokens = 2048 if args.dataset in [
        "math500", "collegemath", "advbench", "multijail", "anthropic"] else 1024
    sampling_params = SamplingParams(
        temperature=args.temp,
        top_p=0.9,
        max_tokens=max_tokens,
    )
    try:
        loop = asyncio.get_event_loop()
        outputs = await loop.run_in_executor(None, lambda: model.generate([prompt], sampling_params))
        return outputs[0].outputs[0].text.strip()
    except Exception as e:
        print(f"Error during vLLM generation: {e}")
        return f"Error generating response: {str(e)}"


async def generate_answer_async(session, model_key, model_name, prompt, args, models=None):
    if args.use_server:
        headers = {"Content-Type": "application/json"}
        # Adjust max_tokens based on dataset
        max_tokens = 2048 if args.dataset in [
            "math500", "collegemath", "aime2024", "advbench", "multijail", "anthropic"] else 1024
        payload = {
            "model": model_name,
            "prompt": prompt,
            "max_tokens": max_tokens,
            "temperature": args.temp,
            "top_p": 0.9,
            "n": 1,
            "seed": args.seed,
        }
        try:
            async with session.post(VLLM_API_URLS[model_key], json=payload, headers=headers) as response:
                response_json = await response.json()
                return response_json["choices"][0]["text"].strip()
        except Exception as e:
            print(f"vLLM request failed for {model_name}: {e}")
            return "Error generating response."
    else:
        # Use direct model generation based on model type
        try:
            if models[model_key]["type"] == "hf":
                return await generate_hf_direct(models[model_key], prompt, args)
            else:  # vllm type
                return await generate_vllm_direct(models[model_key], prompt, args)
        except Exception as e:
            print(f"Model generation failed for {model_name}: {e}")
            return "Error generating response."


def determine_available_gpus(args):
    """
    Determine the number of unique GPUs available from the arguments.
    """
    gpu_indices = set()
    for i in range(1, args.num_agents + 1):
        gpu_idx = getattr(args, f"gpu_{i}")
        if gpu_idx is not None:
            gpu_indices.add(gpu_idx)

    # If no GPUs are specified or if using server mode, use num_agents as default
    if not gpu_indices or args.use_server:
        return args.num_agents

    return len(gpu_indices)


if __name__ == "__main__":
    args = args_parse()
    model_names = {f"model_{i}": getattr(
        args, f"model_{i}") for i in range(1, args.num_agents + 1)}
    VLLM_API_URLS = {
        f"model_{i}": f"http://localhost:{getattr(args, f'port_{i}')}/v1/completions" for i in range(1, args.num_agents + 1)}

    base_dir = get_default_output_dir(args)
    output_dir = f"{base_dir}/{args.dataset}"

    type_prefix = "1_"
    if args.persona == "ethical":
        type_prefix += "ethical_"
    elif args.persona == "helper":
        type_prefix += "helper_"
    type_prefix += args.model_1.split("/")[-1]

    folder_name = f"{type_prefix}_q_{args.num_query}"

    if args.dataset == "collegemath":
        folder_name += f"_topic_{args.data_topic}"
    elif args.dataset == "multijail":
        folder_name += f"_lang_{args.language}"

    if args.advprefix:
        folder_name += '_advprefix'
    if args.advprefix2:
        folder_name += '_advprefix2'
    if args.self_consistency:
        folder_name += f'_sc_{args.num_samples}'
        args.round = 1
    if args.self_refinement:
        folder_name += f'_sr_{args.round}'
        args.num_samples = 1

    folder_name += f"_seed_{args.seed}"

    # Determine available GPUs
    num_gpus = determine_available_gpus(args)
    if num_gpus < args.num_agents:
        print(
            f"⚠️ Running with {args.num_agents} agents on {num_gpus} GPU(s) - agents will be processed in batches")
    else:
        print(
            f"✅ Running with {args.num_agents} agents on {num_gpus} GPU(s) - all agents will run concurrently")

    # Load models directly if not using server (either vLLM or HF based on cluster2 flag)
    models = None

    # Load prompt templates
    prompt_dict = load_prompt(args)

    # Get model type for folder naming
    model_type = get_model_type(args)

    # Set output directory
    args.output_dir = get_default_output_dir(args)
    folder_name = get_folder_name(args, model_type)

    if not args.use_server:
        models = load_models(args)

    if args.use_server:
        print(
            f"✅ Using vLLM API for inference with {args.num_agents} agents on {args.dataset.upper()} dataset.")
    elif args.cluster1:
        print(
            f"✅ Using cluster1 computing configuration with {args.num_agents} agents on {args.dataset.upper()} dataset.")
    elif args.cluster2:
        print(
            f"✅ Using Hugging Face AutoModel for cluster2 models with {args.num_agents} agents on {args.dataset.upper()} dataset.")
    else:
        print(
            f"✅ Using direct vLLM integration with {args.num_agents} agents on {args.dataset.upper()} dataset.")

    # Load questions from dataset
    questions = load_dataset_questions(args)

    async def main():
        generated_description = []
        async with aiohttp.ClientSession() as session:
            for idx in tqdm(range(len(questions[:args.num_query]))):
                question_data = questions[idx]
                question = question_data["question"]
                answer = question_data["answer"]

                # Get data_topic for CollegeMath if available
                data_topic = question_data.get("data_topic", "").split(
                    ".")[-1] if "data_topic" in question_data else None

                # Initialize agent contexts with appropriate prompts
                agent_contexts = {}
                for i in range(1, args.num_agents + 1):
                    model_key = f"model_{i}"
                    prompt = create_initial_prompts(
                        model_key, model_names[model_key], question, args, prompt_dict, answer)
                    agent_contexts[model_key] = [
                        {"model": model_names[model_key], "content": prompt}]

                print(f"# Question No.{idx+1} starts...")

                final_responses = {}
                final_responses2 = {}
                final_responses3 = {}
                # debate agent
                for debate in range(args.round):
                    current_responses = {
                        key: agent_contexts[key][-1]["content"]
                        for key in agent_contexts.keys()
                    } if debate != 0 else None

                    responses = await generate_responses_concurrent(
                        session, agent_contexts, model_names, args, debate, current_responses, question, models
                    )

                    answers = []
                    for (model_key, response) in zip(agent_contexts.keys(), responses):
                        agent_contexts[model_key].append(
                            {"model": model_names[model_key], "content": response})
                        answers.append(response)
                    # If the round is even, perform coordination
                    if (debate+1) % 2 == 0:
                        # coordinator agent
                        current_responses = {
                            key: agent_contexts[key][-1]["content"]
                            for key in agent_contexts.keys()
                        }

                        # Prepare tasks for concurrent execution
                        coord_tasks = []
                        coord_task_names = []
                        # Always run coordinate_responses_async
                        coord_tasks.append(coordinate_responses_async(
                            session, agent_contexts, model_names, args, current_responses, question, answer
                        ))
                        coord_task_names.append('final_responses')
                        # Run coordinate_responses_async2 if model_1 and model_2 are different
                        if model_names["model_1"] != model_names["model_2"]:
                            coord_tasks.append(coordinate_responses_async2(
                                session, agent_contexts, model_names, args, current_responses, question, answer
                            ))
                            coord_task_names.append('final_responses2')
                        # Run coordinate_responses_async3 if num_agents == 3 and all models are different
                        if args.num_agents == 3 and len(set([model_names["model_1"], model_names["model_2"], model_names["model_3"]])) == 3:
                            coord_tasks.append(coordinate_responses_async3(
                                session, agent_contexts, model_names, args, current_responses, question, answer
                            ))
                            coord_task_names.append('final_responses3')

                        coord_results = await asyncio.gather(*coord_tasks)

                        # Save all rounds responses
                        models_response = {
                            f"{model_names[model_key]}{i+1}": [agent_contexts[model_key][j]["content"]
                                                               for j in range(1, len(agent_contexts[model_key]))]
                            for i, model_key in enumerate(agent_contexts.keys())
                        }
                        # Assign results to the correct dicts
                        coord_idx = 0
                        for name in coord_task_names:
                            chosen_agent = coord_results[coord_idx]
                            agent_index = int(chosen_agent.split("_")[-1])
                            models_response_key = f"{model_names[chosen_agent]}{agent_index}"
                            if name == 'final_responses':
                                final_responses[f"final_response_{debate+1}"] = {
                                    models_response_key: current_responses[chosen_agent]
                                }
                            elif name == 'final_responses2':
                                final_responses2[f"final_response_{debate+1}"] = {
                                    models_response_key: current_responses[chosen_agent]
                                }
                            elif name == 'final_responses3':
                                final_responses3[f"final_response_{debate+1}"] = {
                                    models_response_key: current_responses[chosen_agent]
                                }
                            coord_idx += 1

                result = {
                    "question_id": idx,
                    "question": question,
                    "agent_response": models_response,
                    "final_responses": final_responses,
                    "final_responses2": final_responses2,
                    "final_responses3": final_responses3,
                    "answer": answer
                }

                # Add data_topic for CollegeMath
                if args.dataset == "collegemath" and data_topic:
                    result["data_topic"] = data_topic

                generated_description.append(result)

        return generated_description

    # Run the async main function and get results
    generated_description = asyncio.run(main())

    output_dir = os.path.join(args.output_dir, folder_name)
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, "conversation.json")
    print(f"The result file '{output_file}' is saving...")
    with open(output_file, "w") as f:
        json.dump(generated_description, f, indent=4)

    print("✅ All done!!")
