import re
import string

import argparse
import json
import importlib
import os
import pandas as pd
import multiprocessing as mp
from multiprocessing import Pool, Manager, Queue, Process, JoinableQueue
from openai import OpenAI
from typing import Dict, List, Tuple, Any
import time
from pathlib import Path
import traceback
import sys
import numpy as np
from transformers import AutoTokenizer
import re
from collections import defaultdict
import threading
from queue import Empty
import subprocess
import requests
import signal
import atexit
import random

import os
import sys

current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)

project_root = "./futuremind"
sys.path.insert(0, project_root)

from futuremind.tool.tools import _default_tool
from futuremind.tool.envs.nous import NousToolEnv
from futuremind import config as default_config
from futuremind.data_file_config import FILE_CONFIGS
from futuremind.src.metric_utils import ( process_validation_metrics,)
from futuremind.eval import compute_format_and_answer_score_using_gpt4omini 


import logging
logging.getLogger("openai").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("openai._base_client").setLevel(logging.ERROR)


# Keep all the scoring functions unchanged
def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def em_check(prediction, golden_answers):
    if isinstance(golden_answers, str):
        golden_answers = [golden_answers]
    normalized_prediction = normalize_answer(prediction)
    score = 0.0
    for golden_answer in golden_answers:
        golden_answer = normalize_answer(golden_answer)
        if golden_answer == normalized_prediction:
            score = 1.0
            break
    return score


def subem_check(prediction, golden_answers):
    if isinstance(golden_answers, str):
        golden_answers = [golden_answers]
    normalized_prediction = normalize_answer(prediction)
    score = 0.0
    for golden_answer in golden_answers:
        golden_answer = normalize_answer(golden_answer)
        if golden_answer in normalized_prediction:
            score = 1.0
            break
    return score


def extract_solution(solution_str):
    """Extract the answer from the solution string."""
    answer_pattern = r'<answer>(.*?)</answer>'
    match = re.search(answer_pattern, solution_str, re.DOTALL)

    if match:
        return match.group(1).strip()
    return None

def compute_score_format(solution_str):
    """The scoring function for format reward.

    Args:
        solution_str: the solution text

    """
    if solution_str is None:
        return 0.0

    try:
        # Perfect format match for the new structure
        # First <|im_start|>assistant should have <think> and possibly <tool_call>
        # Then <|im_start|>tool with <tool_response> (can repeat with assistant/tool pairs)
        # Final <|im_start|>assistant with the answer and <|im_end|>

        # Check for basic structure with <|im_start|>assistant and <|im_end|> tags
        assistant_blocks = re.findall(r'<\|start_header_id\|>assistant<\|end_header_id\|>(.*?)<\|eot_id\|>',
                                      solution_str, re.DOTALL)

        format_reward = 0.0

        # If no blocks found, return 0
        if not assistant_blocks or len(assistant_blocks) == 0:
            return 0.0

        # Perfect format requires at least one assistant block and matching tool blocks if tool calls exist
        # Check first assistant block contains <think> tags
        for i, assistant_block in enumerate(assistant_blocks[:-1]):
            if assistant_block.count('<think>') == 1 and assistant_block.count(
                    '</think>') == 1 and assistant_block.count('<tool_call>') == 1 and assistant_block.count(
                '</tool_call>') == 1:
                think_match = re.search(r'^<think>(.*?)</think>(\s*)<tool_call>(.*?)</tool_call>$', assistant_block,
                                        re.DOTALL)
                # soft_think_match = re.search(r'<think>(.*?)</think>(.*?)<tool_call>(.*?)</tool_call>', assistant_block, re.DOTALL)
                if think_match:
                    # format_reward += 0.2 * (0.8 ** i)
                    format_reward += 0.5

        # Check the last assistant block contains <answer> tags
        last_assistant_block = assistant_blocks[-1]
        think_answer_match = re.search(r'^<think>(.*?)</think>(.*?)<answer>(.*?)</answer>$', last_assistant_block,
                                       re.DOTALL)
        if think_answer_match:
            format_reward += 0.5
    except Exception as e:
        print(f"[DEBUG] Error in compute_score_format: {e}")
        return 0.0

    return format_reward


def compute_score_answer(solution_str, ground_truth):
    """The scoring function for exact match (EM) with format reward.

    Args:
        solution_str: the solution text
        ground_truth: the ground truth

    Returns:
        float: Total reward score (format reward + answer reward)
    """
    if solution_str is None:
        return 0.0

    try:
        # Extract answer from <answer> tags
        assistant_blocks = re.findall(r'<\|start_header_id\|>assistant<\|end_header_id\|>(.*?)<\|eot_id\|>',
                                      solution_str, re.DOTALL)
        if not assistant_blocks or len(assistant_blocks) == 0:
            return 0.0
        solution_str = assistant_blocks[-1]
        answer = extract_solution(solution_str)

        answer_reward = 0.0

        if answer is not None:
            # Check for exact match within <answer>
            # if em_check(answer, ground_truth):
            #     answer_reward = 1.0
            # # Check for substring match within <answer>
            # elif subem_check(answer, ground_truth):
            #     answer_reward = 0.5
            if subem_check(answer, ground_truth):
                answer_reward = 1.0

        # If no match found within <answer>, check entire solution for substring match
        if answer_reward == 0.0:
            if subem_check(solution_str, ground_truth):
                answer_reward = 0.2
    except Exception as e:
        print(f"[DEBUG] Error in compute_score_answer: {e}")
        return 0.0

    return answer_reward


def compute_score_llm(solution_str, ground_truth, question=None):
    """Compute LLM-based evaluation score using GPT-4-mini.

    Args:
        solution_str: the solution text
        ground_truth: the ground truth answer
        question: the original question (optional, for better evaluation)

    Returns:
        float: LLM evaluation score (0.0 to 1.0)
    """
    if solution_str is None or ground_truth is None:
        return 0.0

    try:
        # Extract answer from <answer> tags
        assistant_blocks = re.findall(r'<\|start_header_id\|>assistant<\|end_header_id\|>(.*?)<\|eot_id\|>',
                                      solution_str, re.DOTALL)
        if not assistant_blocks or len(assistant_blocks) == 0:
            # If no assistant blocks found, use the entire solution
            answer = extract_solution(solution_str)
            if answer is None:
                answer = solution_str
        else:
            # Use the last assistant block
            last_block = assistant_blocks[-1]
            answer = extract_solution(last_block)
            if answer is None:
                answer = last_block

        # Use the LLM evaluation function
        if question is not None:
            question = extract_question(question)
            llm_score = compute_format_and_answer_score_using_gpt4omini.compute_score_answer(
                question, answer, ground_truth
            )
        else:
            # Fallback: use empty question if not provided
            llm_score = compute_format_and_answer_score_using_gpt4omini.compute_score_answer(
                "", answer, ground_truth
            )

        return float(llm_score)

    except Exception as e:
        print(f"[DEBUG] Error in compute_score_llm: {e}")
        return 0.0


def compute_score_format_answer(solution_str, ground_truth):
    """The scoring function for format reward.

    Args:
        solution_str: the solution text

    """
    if solution_str is None or ground_truth is None:
        return 0.0

    try:
        format_reward = compute_score_format(solution_str)
        answer_reward = compute_score_answer(solution_str, ground_truth)

        format_reward = min(format_reward, 1.0)
        if format_reward >= 0.5:
            return -1.0 + format_reward + answer_reward
        else:
            return -1.0 + format_reward
    except Exception as e:
        print(f"[DEBUG] Error in compute_score_format_answer: {e}")
        return -1.0


def compute_score_format_answer_llm(solution_str, ground_truth, question=None):
    """Combined scoring function with format reward and LLM evaluation.

    Args:
        solution_str: the solution text
        ground_truth: the ground truth answer
        question: the original question (optional)

    Returns:
        float: Combined format and LLM evaluation score
    """
    if solution_str is None or ground_truth is None:
        return 0.0

    try:
        format_reward = compute_score_format(solution_str)
        llm_reward = compute_score_llm(solution_str, ground_truth, question)

        format_reward = min(format_reward, 1.0)
        if format_reward >= 0.5:
            return -1.0 + format_reward + llm_reward
        else:
            return -1.0 + format_reward
    except Exception as e:
        print(f"[DEBUG] Error in compute_score_format_answer_llm: {e}")
        return -1.0


def compute_score_em(solution_str, ground_truth):
    """The scoring function for exact match (EM).

    Args:
        solution_str: the solution text
        ground_truth: the ground truth

    """
    if solution_str is None or ground_truth is None:
        return 0.0

    try:
        assistant_blocks = re.findall(r'<\|start_header_id\|>assistant<\|end_header_id\|>(.*?)<\|eot_id\|>',
                                      solution_str, re.DOTALL)
        if not assistant_blocks or len(assistant_blocks) == 0:
            return 0.0
        solution_str = assistant_blocks[-1]
        answer = extract_solution(solution_str)
        if answer is None:
            return 0.0
        return float(subem_check(answer, ground_truth))
    except Exception as e:
        print(f"[DEBUG] Error in compute_score_em: {e}")
        return 0.0


# Manual tool calling system
def build_tools_system_message(tool_names: List[str]) -> str:
    """Build system message with tool information manually"""
    code_path = "./futuremind"
    sys.path.append(code_path)


    system_msg = """You are a helpful assistant. You have access to the following tools:

"""

    for tool_name in tool_names:
        tool = _default_tool(tool_name)
        system_msg += f""" {tool.name}
Description: {tool.description}
Parameters: {json.dumps(tool.parameters, indent=2)}

"""

    system_msg += """When you need to use a tool, format your tool call as follows:
<tool_call>{"name": "tool_name", "parameters": {"param1": "value1", "param2": "value2"}}</tool_call>

Important:
1. Use tools when necessary to answer questions accurately
2. Format tool calls exactly as shown above
3. Wait for tool results before continuing
4. Provide your reasoning in the <think></think> tags
5. Provide your final answer in the <answer></answer> tag

"""

    return system_msg


def create_manual_messages(user_prompt: str, tool_names: List[str]) -> List[Dict]:
    """Create messages manually without using chat templates"""
    system_message = build_tools_system_message(tool_names)

    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_prompt}
    ]

    return messages


def manual_format_conversation(messages: List[Dict]) -> str:
    """Manually format conversation using Llama format"""
    formatted = ""

    for message in messages:
        role = message["role"]
        content = message["content"]

        if role == "system":
            formatted += f"<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>"
        elif role == "user":
            formatted += f"<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>"
        elif role == "assistant":
            formatted += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>"

    return formatted


def extract_tool_calls_manual(text: str) -> List[Dict]:
    """Extract tool calls from generated text manually"""
    tool_calls = []

    # Pattern to find tool calls
    pattern = r'<tool_call>\s*(.*?)\s*</tool_call>'
    matches = re.findall(pattern, text, re.DOTALL)

    for match in matches:
        try:
            # Try to parse JSON
            tool_call = json.loads(match.strip())
            tool_calls.append(tool_call)
        except json.JSONDecodeError as e:
            print(f"Failed to parse tool call JSON: {match}")
            print(f"Error: {e}")
            # Try to fix common JSON issues
            try:
                # Remove extra whitespace and newlines
                cleaned = re.sub(r'\s+', ' ', match.strip())
                tool_call = json.loads(cleaned)
                tool_calls.append(tool_call)
            except json.JSONDecodeError:
                print(f"Could not parse tool call after cleaning: {cleaned}")
                continue

    return tool_calls


def execute_tool_calls_manual(tool_calls: List[Dict], tool_names: List[str]) -> List[Dict]:
    """Execute tool calls manually using NousToolEnv"""
    code_path = "./futuremind"
    sys.path.append(code_path)

    # Setup tools
    agent_tools = []
    for tool_name in tool_names:
        agent_tools.append(_default_tool(tool_name))

    env = NousToolEnv(tools=agent_tools, max_tool_response_length=4000)

    results = []

    for tool_call in tool_calls:
        try:
            # Format tool call for NousToolEnv
            tool_call_json = json.dumps(tool_call)
            tool_call_text = f"<tool_call>\n{tool_call_json}\n</tool_call>"

            # Execute using NousToolEnv
            tool_responses, tool_successes, _ = env.step(
                tool_call_text,
                step_inference=False,
                arguments_key="parameters"
            )

            if tool_responses and len(tool_responses) > 0:
                result_content = tool_responses[0] if isinstance(tool_responses, list) else tool_responses
                success = tool_successes[0] if tool_successes else True
            else:
                result_content = "No response from tool"
                success = False

            results.append({
                "result": result_content,
                "success": success,
                "tool_call": tool_call
            })

        except Exception as e:
            print(f"Tool execution error: {e}")
            results.append({
                "result": f"Tool execution failed: {str(e)}",
                "success": False,
                "tool_call": tool_call
            })

    return results


def format_tool_response_manual(tool_results: List[Dict]) -> str:
    """Format tool responses manually"""
    response = ""

    for result in tool_results:
        tool_result = result["result"]
        # Limit response length
        if len(tool_result) > 4000:
            tool_result = tool_result[:4000] + "..."

        response += f"<tool_response>\n{tool_result}\n</tool_response>\n"

    return response.strip()

EXP_NAME = ""

# Global variables for server management
vllm_process = None
tokenizer = None


def cleanup_server():
    """Cleanup function to kill vLLM server on exit"""
    global vllm_process
    if vllm_process and vllm_process.poll() is None:
        print("\nCleaning up vLLM server...")
        try:
            vllm_process.terminate()
            vllm_process.wait(timeout=10)
        except subprocess.TimeoutExpired:
            print("Force killing vLLM server...")
            vllm_process.kill()
            vllm_process.wait()
        print("vLLM server stopped.")


def signal_handler(signum, frame):
    """Handle interrupt signals"""
    print(f"\nReceived signal {signum}")
    cleanup_server()
    sys.exit(0)


# Register cleanup handlers
atexit.register(cleanup_server)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)


def init_tokenizer(tokenizer_path):
    """Initialize tokenizer for multiprocessing workers"""
    global tokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)


def parse_args():
    parser = argparse.ArgumentParser(description='Manual tool calling with vLLM server')

    # vLLM Server settings
    parser.add_argument('--model-path', type=str, required=True,
                        help='Path to local model directory')
    parser.add_argument('--model-name', type=str, default='auto-deployed-model',
                        help='Model name for API')
    parser.add_argument('--host', type=str, default='127.0.0.1',
                        help='Host address for vLLM server')
    parser.add_argument('--port', type=int, default=random.randint(10000, 50000),
                        help='Port number for vLLM server')
    parser.add_argument('--api-key', type=str, default='auto-deploy-key',
                        help='API key for vLLM server')

    # vLLM Performance settings
    parser.add_argument('--tensor-parallel-size', type=int, default=8,
                        help='Number of GPUs for tensor parallelism')
    parser.add_argument('--gpu-memory-utilization', type=float, default=0.9,
                        help='GPU memory utilization ratio')
    parser.add_argument('--max-model-len', type=int, default=32768,
                        help='Maximum model length')
    parser.add_argument('--dtype', type=str, default='auto',
                        choices=['auto', 'half', 'float16', 'bfloat16', 'float32'],
                        help='Data type for model weights')
    parser.add_argument('--max-num-seqs', type=int, default=256,
                        help='Maximum number of sequences for vLLM')

    # Experiment settings
    parser.add_argument('--exp-name', type=str, required=True,
                        help='Experiment name')

    # Model parameters for inference
    parser.add_argument('--temperature', type=float, default=default_config.TEMPERATURE,
                        help='Temperature for sampling')
    parser.add_argument('--top-p', type=float, default=default_config.TOP_P,
                        help='Top-p for nucleus sampling')
    parser.add_argument('--max-tokens', type=int, default=default_config.MAX_TOKENS,
                        help='Maximum number of tokens to generate')
    parser.add_argument('--repetition-penalty', type=float, default=default_config.REPETITION_PENALTY,
                        help='Repetition penalty for generation')

    # Processing settings
    parser.add_argument('--num-processes', type=int, default=64,
                        help='Number of parallel processes for inference')
    parser.add_argument('--output-dir', type=str, default="results",
                        help='Output directory for results')

    # Tokenizer settings
    parser.add_argument('--tokenizer-path', type=str, default=None,
                        help='Path to tokenizer directory (defaults to model-path)')

    # Config file
    parser.add_argument('--config', type=str, default=None,
                        help='Path to custom config file')

    # LLM evaluation settings
    parser.add_argument('--enable-llm-eval', action='store_true', default=False,
                        help='Enable LLM-based evaluation using GPT-4-mini')

    return parser.parse_args()


def check_model_path(model_path):
    """Check if model path exists and contains required files"""
    model_path = Path(model_path)
    if not model_path.exists():
        raise FileNotFoundError(f"Model path does not exist: {model_path}")

    # Check for common model files
    required_files = ['config.json']
    missing_files = []
    for file in required_files:
        if not (model_path / file).exists():
            missing_files.append(file)

    if missing_files:
        print(f"Warning: Missing files in model directory: {missing_files}")

    print(f"Model path verified: {model_path}")
    return str(model_path)


def build_vllm_command(args):
    """Build simplified vLLM server command without tool-specific parameters"""
    model_path = check_model_path(args.model_path)

    cmd = [
        'python', '-m', 'vllm.entrypoints.openai.api_server',
        '--model', model_path,
        '--served-model-name', args.model_name,
        '--host', args.host,
        '--port', str(args.port),
        '--api-key', args.api_key,
        '--tensor-parallel-size', str(args.tensor_parallel_size),
        '--gpu-memory-utilization', str(args.gpu_memory_utilization),
        '--max-model-len', str(args.max_model_len),
        '--dtype', args.dtype,
        '--trust-remote-code',
        '--max-num-seqs', str(args.max_num_seqs),
    ]

    return cmd


def start_vllm_server(args):
    """Start vLLM server and return process"""
    global vllm_process

    print("Starting simplified vLLM server deployment...")
    print(f"Model: {args.model_path}")
    print(f"Server: {args.host}:{args.port}")
    print(f"API Key: {args.api_key}")
    print("-" * 50)

    cmd = build_vllm_command(args)
    print(f"Running command: {' '.join(cmd)}")
    print("-" * 50)

    # Start the server process with suppressed output
    vllm_process = subprocess.Popen(
        cmd,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
        universal_newlines=True
    )

    return vllm_process


def extract_question(input_str):
    pattern = r'Question:\s*(.*?)<<<'
    match = re.search(pattern, input_str)

    if match:
        return match.group(1).strip()
    else:
        return input_str


def wait_for_server_ready(args, process, host, port, timeout=600):
    """Wait for server to be ready using API health check"""
    print(f"Waiting for vLLM server to start at http://{host}:{port}...")
    print("Note: vLLM output is suppressed to reduce noise")

    start_time = time.time()
    check_interval = 10
    last_check_time = 0

    while time.time() - start_time < timeout:
        elapsed = time.time() - start_time

        if elapsed - last_check_time >= check_interval:
            print(f"⏳ Waiting for server... ({elapsed:.1f}s elapsed)")
            last_check_time = elapsed

        try:
            # Health check
            health_response = requests.get(f"http://{host}:{port}/health", timeout=3)
            if health_response.status_code != 200:
                time.sleep(2)
                continue

            print(f"✅ Health check passed at {elapsed:.1f}s")

            test_client = OpenAI(
                api_key=args.api_key,
                base_url=f"http://{host}:{port}/v1"
            )

            # Test models endpoint
            models_response = test_client.models.list()
            available_models = [model.id for model in models_response.data]
            print(f"✅ Models API working. Available: {available_models}")
            return True

        except requests.exceptions.RequestException:
            time.sleep(2)
            continue

    print(f"❌ Timeout waiting for server to be ready after {timeout} seconds")
    raise TimeoutError(f"vLLM server failed to start within {timeout} seconds")


def load_data(file_path: str, prompt_key: str, ground_truth_key: str) -> List[Dict]:
    """Load data from parquet or jsonl file"""
    file_path = Path(file_path)

    if file_path.suffix == '.parquet':
        df = pd.read_parquet(file_path)
        data = df.to_dict('records')
    elif file_path.suffix == '.jsonl':
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                data.append(json.loads(line.strip()))
    else:
        raise ValueError(f"Unsupported file format: {file_path.suffix}")

    # Extract prompts and ground truths
    processed_data = []
    for i, item in enumerate(data):
        try:
            prompt = get_nested_value(item, prompt_key)
            ground_truth = get_nested_value(item, ground_truth_key)

            processed_data.append({
                "index": i,
                "prompt": prompt,
                "ground_truth": ground_truth,
                "original_data": item
            })
        except KeyError as e:
            print(f"Warning: Missing key {e} in item {i}, skipping...")
            continue

    return processed_data


def get_nested_value(data: Dict, key_path: str) -> Any:
    """Get value from nested dictionary using dot notation"""
    keys = key_path.split('.')
    value = data
    for key in keys:
        value = value[key]
    return value


def worker_process(worker_id: int, model_params: Dict, tokenizer_path: str,
                   task_queue: JoinableQueue, result_queue: Queue):
    """Worker process for processing tasks with manual tool calling"""
    # Initialize tokenizer for this worker
    init_tokenizer(tokenizer_path)

    # Initialize OpenAI client for this worker
    client = OpenAI(
        api_key=model_params['api_key'],
        base_url=model_params['api_base'],
    )

    print(f"Worker {worker_id}: Started and ready for manual tool calling")

    # Process tasks from queue
    while True:
        try:
            # Get task from queue with timeout
            task = task_queue.get(timeout=5)

            if task is None:  # Poison pill to stop worker
                print(f"Worker {worker_id}: Received stop signal, shutting down")
                task_queue.task_done()
                break

            # Process task
            result = process_task_manual(task, client, model_params)

            # Put result back
            result_queue.put(result)

            # Mark task as done
            task_queue.task_done()

        except Empty:
            continue
        except Exception as e:
            print(f"Worker {worker_id}: Error in worker process: {str(e)}")
            traceback.print_exc()
            try:
                task_queue.task_done()
            except:
                pass
            continue

    print(f"Worker {worker_id}: Worker process finished")


def clean_messages(messages):
    """Clean messages for proper UTF-8 encoding"""
    cleaned_messages = []
    for message in messages:
        if isinstance(message, dict):
            cleaned_message = {}
            for key, value in message.items():
                if isinstance(value, str):
                    cleaned_message[key] = value.encode('utf-8', errors='ignore').decode('utf-8')
                else:
                    cleaned_message[key] = value
            cleaned_messages.append(cleaned_message)
        else:
            cleaned_messages.append(message)
    return cleaned_messages


def process_task_manual(task: Dict, client: OpenAI, model_params: Dict) -> Dict:
    """Process a single inference task using manual tool calling"""
    exp_name = task['exp_name']
    run_id = task['run_id']
    item = task['item']
    tool_names = task['tool_names']
    enable_llm_eval = task.get('enable_llm_eval', False)

    try:
        # Create messages manually (no templates)
        messages = create_manual_messages(item['prompt'][0]['content'], tool_names)

        # Track full conversation history
        max_iterations = 10
        iteration = 0

        while iteration < max_iterations:
            iteration += 1

            # Format conversation manually
            conversation_text = manual_format_conversation(messages)

            try:
                # Generate with stop tokens for tool calls
                response = client.completions.create(
                    model=model_params['model'],
                    prompt=conversation_text,
                    temperature=model_params['temperature'],
                    top_p=model_params['top_p'],
                    max_tokens=model_params['max_tokens'],
                    stop=["</tool_call>", "<|eot_id|>"],  # Stop at tool call end or conversation end
                )

                generated_text = response.choices[0].text
                finish_reason = response.choices[0].finish_reason

            except Exception as e:
                print(f"API call failed: {e}")
                raise e

            # Check if generation was stopped due to tool call
            if "<tool_call>" in response.choices[0].text or any(stop_token in generated_text for stop_token in ["</tool_call>"]):
                # Complete the tool call if it was cut off
                if not generated_text.strip().endswith("</tool_call>"):
                    generated_text += "</tool_call>"

                # Extract tool calls from the complete text
                tool_calls = extract_tool_calls_manual(generated_text)

                if tool_calls:
                    # Add assistant message with tool call
                    assistant_content = generated_text.strip()
                    messages.append({"role": "assistant", "content": assistant_content})

                    # Execute tool calls
                    try:
                        tool_results = execute_tool_calls_manual(tool_calls, tool_names)

                        # Format tool responses for the conversation (clean format)
                        tool_response_parts = []
                        for result in tool_results:
                            tool_result = result["result"]
                            # Limit response length
                            if len(tool_result) > 4000:
                                tool_result = tool_result[:4000] + "..."

                            # Clean the tool result from any inconsistent formatting
                            # Remove any <|im_end|>, <|im_start|> tokens that might be in the result
                            tool_result = tool_result.replace("<|im_end|>", "").replace("<|im_start|>user", "").replace(
                                "<|im_start|>assistant", "")
                            tool_result = tool_result.strip()

                            tool_response_parts.append(f"<tool_response>\n{tool_result}\n</tool_response>")

                        tool_response_text = "\n".join(tool_response_parts)

                        # Add user message with tool responses
                        messages.append({"role": "user", "content": tool_response_text})

                        continue  # Continue the conversation loop

                    except Exception as e:
                        print(f"Tool execution failed: {e}")
                        # Add error response and continue
                        error_response = f"<tool_response>\nError executing tool: {str(e)}\n</tool_response>"
                        messages.append({"role": "user", "content": error_response})
                        continue

                else:
                    # No valid tool calls found, treat as regular response
                    messages.append({"role": "assistant", "content": generated_text})
                    break

            else:
                # Regular completion - no tool call
                messages.append({"role": "assistant", "content": generated_text})
                break

        # Build final output by formatting all messages manually (without generation prompt for final output)
        output = manual_format_conversation(messages)

        # Compute all scores
        format_score = compute_score_format(output)
        em_score = compute_score_em(output, item['ground_truth'])
        format_answer_score = compute_score_format_answer(output, item['ground_truth'])

        # Compute LLM-based evaluation if enabled
        llm_acc_score = 0.0
        if enable_llm_eval:
            try:
                original_question = item['prompt'][0]['content'] if item['prompt'] else ""
                llm_acc_score = compute_score_llm(output, item['ground_truth'], original_question)
            except Exception as e:
                print(f"WARNING: LLM evaluation failed for task {item['index']}: {str(e)}")
                llm_acc_score = 0.0

        def to_float(score):
            if isinstance(score, (int, float, bool)):
                return float(score)
            else:
                return float(score[0])

        return {
            "exp_name": exp_name,
            "run_id": run_id,
            "index": item['index'],
            "prompt": item['prompt'][0]['content'],
            "ground_truth": item['ground_truth'],
            "solution": output,
            "scores": {
                "format": to_float(format_score),
                "acc": to_float(em_score),
                "score": to_float(format_answer_score),
                "llm_acc": to_float(llm_acc_score)
            },
            "success": True,
            "error": None,
            "iterations": iteration,
            "messages": messages,
        }

    except Exception as e:
        print(f"WARNING: {exp_name} - run_id: {run_id} - data_index: {item['index']} failed, reason: {str(e)}")
        traceback.print_exc()
        return {
            "exp_name": exp_name,
            "run_id": run_id,
            "index": item['index'],
            "prompt": item.get('prompt', [{'content': ''}])[0].get('content', ''),
            "ground_truth": item.get('ground_truth', ''),
            "solution": "",
            "scores": {"format": 0.0, "acc": 0.0, "score": 0.0, "llm_acc": 0.0},
            "success": False,
            "error": str(e),
            "traceback": traceback.format_exc(),
            "iterations": 0
        }

def create_all_tasks(file_configs: Dict, enable_llm_eval: bool = False) -> Tuple[List[Dict], Dict[str, int]]:
    """Create all tasks from all files and runs, and return task counts"""
    all_tasks = []
    file_task_counts = {}

    for dataset_name, file_config in file_configs.items():
        file_key = f"{EXP_NAME}_{dataset_name}"
        if not os.path.exists(file_config["path"]):
            print(f"Warning: File {file_config['path']} not found, skipping...")
            continue

        # Load data for this file
        print(f"Loading data from {file_config['path']}...")
        data = load_data(
            file_config["path"],
            file_config["prompt_key"],
            file_config["ground_truth_key"]
        )

        # Create tasks for all runs of this file
        for run_id in range(file_config["runs"]):
            file_run_key = f"{file_key}_{run_id}"
            file_task_counts[file_run_key] = len(data)

            for item in data:
                task = {
                    "exp_name": file_key,
                    "run_id": run_id,
                    "item": item,
                    "tool_names": file_config["tools"],
                    "task_id": f"{file_key}_{run_id}_{item['index']}",
                    "enable_llm_eval": enable_llm_eval
                }
                all_tasks.append(task)

    return all_tasks, file_task_counts


class NumpyEncoder(json.JSONEncoder):
    """Custom JSON encoder for numpy arrays"""

    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.integer, np.int64)):
            return int(obj)
        elif isinstance(obj, (np.floating, np.float64)):
            return float(obj)
        elif isinstance(obj, np.bool_):
            return bool(obj)
        return super(NumpyEncoder, self).default(obj)


def save_results(results: List[Dict], output_path: str):
    """Save results to file with proper JSON serialization"""
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # Save as JSONL
    with open(output_path.with_suffix('.jsonl'), 'w', encoding='utf-8') as f:
        for result in results:
            json_line = json.dumps(result, ensure_ascii=False, cls=NumpyEncoder)
            f.write(json_line + '\n')

    return output_path


def compute_validation_metrics_original_style(results: List[Dict], exp_name: str, run_id: int):
    """Compute validation metrics in original style"""
    file_results = [r for r in results if r["exp_name"] == exp_name and r["run_id"] == run_id and r["success"]]

    if not file_results:
        return {}

    # Prepare data for validation metrics computation
    reward_extra_infos_dict = defaultdict(list)
    sample_inputs = []
    sample_outputs = []
    sample_scores = []

    for result in file_results:
        sample_inputs.append(result["prompt"])
        sample_outputs.append(result["solution"])

        scores = result["scores"]
        sample_scores.append(scores["score"])

        # Build reward_extra_infos_dict
        reward_extra_infos_dict["reward"].append(scores["score"])
        reward_extra_infos_dict["acc"].append(scores["acc"])
        reward_extra_infos_dict["format"].append(scores["format"])
        reward_extra_infos_dict["llm_acc"].append(scores["llm_acc"])
        reward_extra_infos_dict["turns"].append(result["iterations"])

    data_sources = ["eval"] * len(file_results)

    # Process validation metrics
    validation_metrics = process_validation_metrics(
        data_sources=data_sources,
        sample_inputs=sample_inputs,
        infos_dict=reward_extra_infos_dict
    )

    # Compute basic statistics
    total = len(file_results)
    successful = len([r for r in file_results if r["success"]])

    basic_stats = {
        "total": total,
        "successful": successful,
        "success_rate": successful / total if total > 0 else 0.0,
    }

    return {
        "basic_stats": basic_stats,
        "validation_metrics": validation_metrics,
        "reward_extra_infos": dict(reward_extra_infos_dict)
    }


def save_overall_results_csv(overall_results: List[Dict], output_dir: Path):
    """Save overall results to CSV file"""
    if not overall_results:
        return

    df = pd.DataFrame(overall_results)
    csv_path = output_dir / "{}_result.csv".format(EXP_NAME)
    df.to_csv(csv_path, index=False)
    print(f"Overall results saved to: {csv_path}")


def result_collector_thread(result_queue: Queue, file_configs: Dict,
                            output_dir: Path, total_tasks: int,
                            file_task_counts: Dict[str, int]):
    """Thread to collect results and save files when complete"""
    completed_tasks = 0
    all_results = {}
    overall_results = []

    # Create raw output subfolder
    raw_output_dir = output_dir / f"{EXP_NAME}_raw_output"
    raw_output_dir.mkdir(parents=True, exist_ok=True)

    while completed_tasks < total_tasks:
        try:
            result = result_queue.get(timeout=10)

            # Store result in appropriate location
            file_run_key = f"{result['exp_name']}_{result['run_id']}"

            if file_run_key not in all_results:
                all_results[file_run_key] = {}

            all_results[file_run_key][result['index']] = result
            completed_tasks += 1

            print(f"Progress: {completed_tasks}/{total_tasks} tasks completed")

            # Check if this file-run is complete
            expected_count = file_task_counts.get(file_run_key, 0)
            current_count = len(all_results[file_run_key])
            print(f"{file_run_key} - Complete: {current_count} / {expected_count}")

            if expected_count > 0 and current_count == expected_count:
                # Sort results by index to maintain original order
                sorted_results = [all_results[file_run_key][i] for i in sorted(all_results[file_run_key].keys())]

                # Save results for this file-run in the raw output subfolder
                exp_name = result['exp_name']
                run_id = result['run_id']

                output_name = f"{Path(exp_name).stem}_run{run_id}_results"
                output_path = raw_output_dir / output_name
                saved_path = save_results(sorted_results, output_path)

                # Compute and print metrics
                metrics_result = compute_validation_metrics_original_style(sorted_results, exp_name, run_id)
                basic_stats = metrics_result["basic_stats"]
                validation_metrics = metrics_result["validation_metrics"]

                print(f"\n{'=' * 60}")
                print(f"Completed {exp_name} - Run {run_id}")
                print(f"Results saved to: {saved_path}")
                print(f"Total samples: {basic_stats['total']}")
                print(f"Successful samples: {basic_stats['successful']}")
                print(f"Success rate: {basic_stats['success_rate']:.1%}")

                # Print detailed metrics
                overall_result_row = {
                    "exp_name": exp_name,
                    "run_id": run_id,
                    "total_samples": basic_stats['total'],
                    "successful_samples": basic_stats['successful'],
                    "success_rate": basic_stats['success_rate'],
                }

                print(f"\nDetailed Metrics:")
                for var_name, metrics in validation_metrics.get("eval", {}).items():
                    for metric_name, value in metrics.items():
                        print(f"  {var_name}.{metric_name}: {value:.4f}")
                        overall_result_row[f"{var_name}.{metric_name}"] = value

                # Add to overall results list
                overall_results.append(overall_result_row)

                # Save overall results CSV after each completion
                save_overall_results_csv(overall_results, output_dir)

                print(f"{'=' * 60}")

                # Clean up memory
                del all_results[file_run_key]

        except Empty:
            continue
        except Exception as e:
            print(f"Error in result collector: {str(e)}")
            traceback.print_exc()

    print("Result collector thread finished")


def main():
    args = parse_args()

    global EXP_NAME
    EXP_NAME = args.exp_name

    # Set tokenizer path to model path if not specified
    if args.tokenizer_path is None:
        args.tokenizer_path = args.model_path

    api_base = f"http://{args.host}:{args.port}/v1"

    # Load custom config if provided
    config = default_config
    if args.config:
        try:
            spec = importlib.util.spec_from_file_location("custom_config", args.config)
            config = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(config)
            print(f"Loaded custom config from {args.config}")

        except Exception as e:
            print(f"Error loading custom config: {e}")
            print("Falling back to default config")

    # Print evaluation status
    if args.enable_llm_eval:
        print("LLM-based evaluation using GPT-4-mini is ENABLED")
        print("Warning: This will increase processing time significantly")
    else:
        print("LLM-based evaluation is DISABLED (use --enable-llm-eval to enable)")

    try:
        # Step 1: Start vLLM server
        print("=" * 80)
        print("STEP 1: STARTING SIMPLIFIED VLLM SERVER FOR MANUAL TOOL CALLING")
        print("=" * 80)

        vllm_process = start_vllm_server(args)

        # Step 2: Wait for server to be ready
        print("\n" + "=" * 80)
        print("STEP 2: WAITING FOR SERVER TO BE READY")
        print("=" * 80)

        wait_for_server_ready(args, vllm_process, args.host, args.port)

        # Step 3: Prepare for inference
        print("\n" + "=" * 80)
        print("STEP 3: PREPARING MANUAL TOOL CALLING INFERENCE")
        print("=" * 80)

        # Model parameters
        model_params = {
            'api_key': args.api_key,
            'api_base': api_base,
            'model': args.model_name,
            'temperature': args.temperature,
            'top_p': args.top_p,
            'max_tokens': args.max_tokens,
            'repetition_penalty': args.repetition_penalty
        }

        # Setup output directory
        output_dir = Path(args.output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        # Create all tasks and get task counts
        print("Creating all tasks...")
        all_tasks, file_task_counts = create_all_tasks(FILE_CONFIGS, args.enable_llm_eval)
        total_tasks = len(all_tasks)
        print(f"Created {total_tasks} tasks total")
        print(f"File-run task counts: {file_task_counts}")

        if total_tasks == 0:
            print("No tasks to process!")
            return

        # Step 4: Run inference
        print("\n" + "=" * 80)
        print("STEP 4: RUNNING MANUAL TOOL CALLING INFERENCE")
        print("=" * 80)

        # Create queues
        task_queue = JoinableQueue()
        result_queue = Queue()

        # Add all tasks to queue
        for task in all_tasks:
            task_queue.put(task)

        # Start result collector thread
        collector_thread = threading.Thread(
            target=result_collector_thread,
            args=(result_queue, FILE_CONFIGS, output_dir, total_tasks, file_task_counts)
        )
        collector_thread.start()

        # Start worker processes
        processes = []
        for worker_id in range(args.num_processes):
            p = Process(
                target=worker_process,
                args=(
                    worker_id, model_params, args.tokenizer_path,
                    task_queue, result_queue
                )
            )
            p.start()
            processes.append(p)

        print(f"Started {len(processes)} worker processes")

        start_time = time.time()

        try:
            # Wait for all tasks to be completed
            print("Waiting for all tasks to complete...")
            task_queue.join()
            print("All tasks completed!")

            # Send stop signals to workers
            for _ in processes:
                task_queue.put(None)

            # Wait for all workers to finish
            for p in processes:
                p.join()

            # Wait for result collector to finish
            collector_thread.join()

            total_time = time.time() - start_time

            print(f"\n{'=' * 80}")
            print("MANUAL TOOL CALLING INFERENCE COMPLETED SUCCESSFULLY")
            print(f"{'=' * 80}")
            print(f"Total processing time: {total_time / 60:.1f} minutes")
            print(f"Total tasks processed: {total_tasks}")
            print(f"Average time per task: {total_time / total_tasks:.2f} seconds")
            print(f"Worker processes used: {len(processes)}")
            if args.enable_llm_eval:
                print("LLM-based evaluation was enabled and computed")

            print(f"Check the output directory for detailed results: {output_dir}")
            print(f"{'=' * 80}")

        except KeyboardInterrupt:
            print("\nInterrupted by user")
            for p in processes:
                p.terminate()
                p.join()
            collector_thread.join(timeout=5)

    except Exception as e:
        print(f"Error during execution: {str(e)}")
        traceback.print_exc()

    finally:
        # Step 5: Cleanup
        print("\n" + "=" * 80)
        print("STEP 5: CLEANING UP")
        print("=" * 80)
        cleanup_server()
        print("Cleanup completed.")


if __name__ == "__main__":
    main()
