import os
import sys
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
project_root = "./futuremind"
sys.path.insert(0, project_root)
logger.info(f"Added to PYTHONPATH: {project_root}")

import re
import time
import math
import json
import random
import logging
import numpy as np
import pandas as pd

import signal
import atexit
import requests
import argparse
import traceback
import threading
import importlib
import subprocess
import multiprocessing as mp
from multiprocessing import Pool, Manager, Queue, Process, JoinableQueue

from queue import Empty
from pathlib import Path
from openai import OpenAI
from collections import defaultdict
from transformers import AutoTokenizer
from typing import Dict, List, Tuple, Any

from futuremind.tool.tools import _default_tool
from futuremind.tool.envs.nous import NousToolEnv
from futuremind import config as default_config
from futuremind.data_file_config import FILE_CONFIGS
from futuremind.src.metric_utils import ( process_validation_metrics,)
from futuremind.eval import compute_format_and_answer_score_using_gpt4omini 


logging.getLogger("openai").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("openai._base_client").setLevel(logging.ERROR)

# Global variables for server management
EXP_NAME = ""
tokenizer = None
vllm_process = None

def cleanup_server():
    """Cleanup function to kill vLLM server on exit"""
    global vllm_process
    if vllm_process and vllm_process.poll() is None:
        print("\nCleaning up vLLM server...")
        try:
            vllm_process.terminate()
            vllm_process.wait(timeout=10)
        except subprocess.TimeoutExpired:
            print("Force killing vLLM server...")
            vllm_process.kill()
            vllm_process.wait()
        print("vLLM server stopped.")


def signal_handler(signum, frame):
    """Handle interrupt signals"""
    print(f"\nReceived signal {signum}")
    cleanup_server()
    sys.exit(0)

# Register cleanup handlers
atexit.register(cleanup_server)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)


class NumpyEncoder(json.JSONEncoder):
    """Custom JSON encoder for numpy arrays"""

    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.integer, np.int64)):
            return int(obj)
        elif isinstance(obj, (np.floating, np.float64)):
            return float(obj)
        elif isinstance(obj, np.bool_):
            return bool(obj)
        return super(NumpyEncoder, self).default(obj)


class First_Vllm_Server:
    def __init__(self, args):
        self.args = args
    

    def check_model_path(self, model_path):
        """Check if model path exists and contains required files"""
        model_path = Path(model_path)
        if not model_path.exists():
            raise FileNotFoundError(f"Model path does not exist: {model_path}")

        # Check for common model files
        required_files = ['config.json']
        missing_files = []
        for file in required_files:
            if not (model_path / file).exists():
                missing_files.append(file)

        if missing_files:
            print(f"Warning: Missing files in model directory: {missing_files}")

        print(f"Model path verified: {model_path}")
        return str(model_path)


    def build_vllm_command(self):
        """Build vLLM server command"""
        args = self.args
        model_path = self.check_model_path(args.model_path)

        cmd = [
            'python', '-m', 'vllm.entrypoints.openai.api_server',
            '--model', model_path,
            '--served-model-name', args.model_name,
            '--host', args.host,
            '--port', str(args.port),
            '--api-key', args.api_key,
            '--tensor-parallel-size', str(args.tensor_parallel_size),
            '--gpu-memory-utilization', str(args.gpu_memory_utilization),
            '--max-model-len', str(args.max_model_len),
            '--dtype', args.dtype,
            '--enable-auto-tool-choice',
            '--trust-remote-code',
            '--tool-call-parser', 'llama3_json',
            '--max-num-seqs', str(args.max_num_seqs),
        ]
        return cmd


    def start_vllm_server(self):
        """Start vLLM server and return process"""
        args = self.args
        global vllm_process

        print("Starting vLLM server deployment...")
        print(f"Model: {args.model_path}")
        print(f"Server: {args.host}:{args.port}")
        print(f"API Key: {args.api_key}")
        print("-" * 50)

        cmd = self.build_vllm_command()
        print(f"Running command: {' '.join(cmd)}")
        print("-" * 50)

        ## Start the server process with suppressed output
        log_dir = "logs/vllm_experiment_logs"
        os.makedirs(log_dir, exist_ok=True)
        log_file = os.path.join(log_dir, f"vllm_experiment_{args.exp_name}.log")

        with open(log_file, "w", encoding="utf-8") as f:
            vllm_process = subprocess.Popen(
                cmd,
                stdout=f,                  
                stderr=subprocess.STDOUT,  
                universal_newlines=True
            )
        print(f"vLLM server started. 日志文件: {log_file}")

        return vllm_process


class Wait_Vllm_Server:
    def __init__(self, args):
        self.args = args
        self.timeout = 600
    
    def health_check(self, elapsed):
        # Step 1: Health check
        args = self.args
        health_response = requests.get(f"http://{args.host}:{args.port}/health", timeout=3)
        if health_response.status_code != 200:
            time.sleep(2)
            # continue
        print(f"✅ Health check passed at {elapsed:.1f}s")

        test_client = OpenAI(
            api_key=args.api_key,
            base_url=f"http://{args.host}:{args.port}/v1"
        )

        # Test models endpoint
        models_response = test_client.models.list()
        available_models = [model.id for model in models_response.data]
        print(f"✅ Models API working. Available: {available_models}")

    def wait_for_server_ready(self, process):
        args = self.args
        """Wait for server to be ready using API health check and simple message test"""
        print(f"Waiting for vLLM server to start at http://{args.host}:{args.port}...")
        print("Note: vLLM output is suppressed to reduce noise")
        
        timeout = self.timeout
        start_time = time.time()
        check_interval = 10  # Check every 10 seconds
        last_check_time = 0

        while time.time() - start_time < timeout:
            elapsed = time.time() - start_time

            # Check if process is still running
            if process.poll() is not None:
                print("❌ vLLM server process has terminated!")
                raise RuntimeError("vLLM server process terminated during startup")

            # Print progress every check_interval seconds
            if elapsed - last_check_time >= check_interval:
                print(f"⏳ Waiting for server... ({elapsed:.1f}s elapsed)")
                last_check_time = elapsed

            try:
                # self.health_check(elapsed)
                # return True
                # Step 1: Health check
                health_response = requests.get(f"http://{args.host}:{args.port}/health", timeout=3)
                if health_response.status_code != 200:
                    time.sleep(2)
                    continue

                print(f"✅ Health check passed at {elapsed:.1f}s")

                test_client = OpenAI(
                    api_key=args.api_key,
                    base_url=f"http://{args.host}:{args.port}/v1"
                )

                # Test models endpoint
                models_response = test_client.models.list()
                available_models = [model.id for model in models_response.data]
                print(f"✅ Models API working. Available: {available_models}")
                return True

            except requests.exceptions.RequestException:
                # Server not ready yet, continue waiting
                time.sleep(2)
                continue

        # If we get here, timeout was reached
        print(f"❌ Timeout waiting for server to be ready after {timeout} seconds")
        if process.poll() is None:
            print("Process is still running but not responding to API calls")
        else:
            print("Process has terminated")
        raise TimeoutError(f"vLLM server failed to start within {timeout} seconds")


class Submit_Task:
    def __init__(self):
        # self.args = args 
        return 

    def get_nested_value(self, data: Dict, key_path: str) -> Any:
        # """Get value from nested dictionary using dot notation"""
        keys = key_path.split('.')
        value = data
        for key in keys:
            value = value[key]
        return value
    
    def load_data(self, file_path: str, prompt_key: str, ground_truth_key: str) -> List[Dict]:
        # """Load data from parquet or jsonl file"""
        file_path = Path(file_path)

        if file_path.suffix == '.parquet':
            df = pd.read_parquet(file_path)
            data = df.to_dict('records')
        elif file_path.suffix == '.jsonl':
            data = []
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    data.append(json.loads(line.strip()))
        else:
            raise ValueError(f"Unsupported file format: {file_path.suffix}")

        # Extract prompts and ground truths
        processed_data = []
        for i, item in enumerate(data):
            try:
                # Handle nested keys like "reward_model.ground_truth"
                prompt = self.get_nested_value(item, prompt_key)
                ground_truth = self.get_nested_value(item, ground_truth_key)

                processed_data.append({
                    "index": i,
                    "prompt": prompt,
                    "ground_truth": ground_truth,
                    "original_data": item
                })
            except KeyError as e:
                print(f"Warning: Missing key {e} in item {i}, skipping...")
                continue

        return processed_data

    def create_all_tasks(self, file_configs: Dict) -> Tuple[List[Dict], Dict[str, int]]:
        """Create all tasks from all files and runs, and return task counts"""
        all_tasks = []
        file_task_counts = {}  # Track expected task counts

        for dataset_name, file_config in file_configs.items():
            file_key = f"{EXP_NAME}_{dataset_name}"
            if not os.path.exists(file_config["path"]):
                print(f"Warning: File {file_config['path']} not found, skipping...")
                continue

            # Load data for this file
            print(f"Loading data from {file_config['path']}...")
            data = self.load_data(
                file_config["path"],
                file_config["prompt_key"],
                file_config["ground_truth_key"]
            )

            # Create tasks for all runs of this file
            for run_id in range(file_config["runs"]):
                file_run_key = f"{file_key}_{run_id}"
                file_task_counts[file_run_key] = len(data)  # Store expected count

                for item in data:
                    task = {
                        "exp_name": file_key,
                        "run_id": run_id,
                        "item": item,
                        "tool_names": file_config["tools"],
                        "task_id": f"{file_key}_{run_id}_{item['index']}"
                    }
                    all_tasks.append(task)

        return all_tasks, file_task_counts


class Construct_Collector_Thread:
    def ___init__(self):
        return 

    def create_queues(self, all_tasks):
        # Create queues
        task_queue = JoinableQueue()
        result_queue = Queue()

        # Add all tasks to queue
        for task in all_tasks:
            task_queue.put(task)
        
        return result_queue

    def save_results(self, results: List[Dict], output_path: str):
        """Save results to file with proper JSON serialization"""
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)

        # Save as JSONL
        with open(output_path.with_suffix('.jsonl'), 'w', encoding='utf-8') as f:
            for result in results:
                json_line = json.dumps(result, ensure_ascii=False, cls=NumpyEncoder)
                f.write(json_line + '\n')

        return output_path

    def save_overall_results_csv(self, overall_results: List[Dict], output_dir: Path):
        """Save overall results to CSV file"""
        import pandas as pd

        if not overall_results:
            return

        df = pd.DataFrame(overall_results)
        csv_path = output_dir / "{}_result.csv".format(EXP_NAME)
        df.to_csv(csv_path, index=False)
        print(f"Overall results saved to: {csv_path}")

    def compute_validation_metrics_original_style(self, results: List[Dict], exp_name: str, run_id: int):
        """Compute validation metrics in original style"""
        # file_results = [r for r in results if r["exp_name"] == exp_name and r["run_id"] == run_id and r["success"]]
        file_results = [r for r in results if r["exp_name"] == exp_name and r["run_id"] == run_id ]

        if not file_results:
            return {}

        # Prepare data for validation metrics computation
        reward_extra_infos_dict = defaultdict(list)
        sample_inputs = []
        sample_outputs = []
        sample_scores = []

        for result in file_results:
            sample_inputs.append(result["prompt"])
            sample_outputs.append(result["solution"])

            scores = result["scores"]
            sample_scores.append(scores["score"])

            # Build reward_extra_infos_dict
            reward_extra_infos_dict["reward"].append(scores["score"])
            reward_extra_infos_dict["acc"].append(scores["acc"])
            reward_extra_infos_dict["llm_acc"].append(scores["llm_acc"])
            reward_extra_infos_dict["format"].append(scores["format"])
            reward_extra_infos_dict["turns"].append(result["iterations"])

        data_sources = ["eval"] * len(file_results)

        # Process validation metrics
        validation_metrics = process_validation_metrics(
            data_sources=data_sources,
            sample_inputs=sample_inputs,
            infos_dict=reward_extra_infos_dict
        )

        # Compute basic statistics
        total = len(file_results)
        successful = len([r for r in file_results if r["success"]])

        basic_stats = {
            "total": total,
            "successful": successful,
            "success_rate": successful / total if total > 0 else 0.0,
        }

        return {
            "basic_stats": basic_stats,
            "validation_metrics": validation_metrics,
            "reward_extra_infos": dict(reward_extra_infos_dict)
        }

    def result_collector_thread(self, result_queue: Queue, file_configs: Dict,
                                output_dir: Path, total_tasks: int,
                                file_task_counts: Dict[str, int]):
        """Thread to collect results and save files when complete"""
        completed_tasks = 0

        # Initialize local dict and list (no Manager needed!)
        all_results = {}  # Regular Python dict
        overall_results = []  # Regular Python list

        # Create raw output subfolder
        raw_output_dir = output_dir / f"{EXP_NAME}_raw_output"
        raw_output_dir.mkdir(parents=True, exist_ok=True)

        while completed_tasks < total_tasks:
            try:
                result = result_queue.get(timeout=10)

                # Store result in appropriate location
                file_run_key = f"{result['exp_name']}_{result['run_id']}"

                # Initialize the nested dict if it doesn't exist
                if file_run_key not in all_results:
                    all_results[file_run_key] = {}

                # Store the result (works perfectly with regular dict!)
                all_results[file_run_key][result['index']] = result
                completed_tasks += 1

                print(f"Progress: {completed_tasks}/{total_tasks} tasks completed")

                # Check if this file-run is complete
                expected_count = file_task_counts.get(file_run_key, 0)
                current_count = len(all_results[file_run_key])
                print(f"{file_run_key} - Complete: {current_count} / {expected_count}")

                if expected_count > 0 and current_count == expected_count:
                    # Sort results by index to maintain original order
                    sorted_results = [all_results[file_run_key][i] for i in sorted(all_results[file_run_key].keys())]

                    # Save results for this file-run in the raw output subfolder
                    exp_name = result['exp_name']
                    run_id = result['run_id']

                    output_name = f"{Path(exp_name).stem}_run{run_id}_results"

                    # Save in the raw output subfolder
                    output_path = raw_output_dir / output_name
                    saved_path = self.save_results(sorted_results, output_path)

                    # Compute and print metrics
                    metrics_result = self.compute_validation_metrics_original_style(sorted_results, exp_name, run_id)
                    basic_stats = metrics_result["basic_stats"]
                    validation_metrics = metrics_result["validation_metrics"]

                    print(f"\n{'=' * 60}")
                    print(f"Completed {exp_name} - Run {run_id}")
                    print(f"Results saved to: {saved_path}")
                    print(f"Total samples: {basic_stats['total']}")
                    print(f"Successful samples: {basic_stats['successful']}")
                    print(f"Success rate: {basic_stats['success_rate']:.1%}")

                    # Print detailed metrics like in the original code
                    overall_result_row = {
                        "exp_name": exp_name,
                        "run_id": run_id,
                        "total_samples": basic_stats['total'],
                        "successful_samples": basic_stats['successful'],
                        "success_rate": basic_stats['success_rate'],
                    }

                    print(f"\nDetailed Metrics:")
                    for var_name, metrics in validation_metrics.get("eval", {}).items():
                        for metric_name, value in metrics.items():
                            print(f"  {var_name}.{metric_name}: {value:.4f}")
                            overall_result_row[f"{var_name}.{metric_name}"] = value

                    # Add to overall results list
                    overall_results.append(overall_result_row)

                    # Save overall results CSV after each completion (incremental save)
                    self.save_overall_results_csv(overall_results, output_dir)

                    print(f"{'=' * 60}")

                    # Clean up memory
                    del all_results[file_run_key]

            except Empty:
                continue
            except Exception as e:
                print(f"Error in result collector: {str(e)}")
                traceback.print_exc()

        print("Result collector thread finished")


class Muti_Process_Infer():
    def __init__(self, args):
        self.args = args

    def fix_truncated_tool_call(self, content):
        if "<tool_call>" in content:
            open_count = content.count("<tool_call>")
            close_count = content.count("</tool_call>")

            if open_count > close_count:
                content += "</tool_call>"

        return content

    def format_messages_with_tokenizer(self, messages: List[Dict], tools=None) -> str:
        """Format messages using tokenizer's chat template"""
        global tokenizer

        # Apply chat template
        formatted_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            tools=tools,
            add_generation_prompt=False
        )

        return formatted_text

    def setup_tools(self, tool_names: List[str], max_tokens: int) -> Tuple[List, NousToolEnv]:
        """Setup tools for inference"""
        tools = []
        for tool_name in tool_names:
            tools.append(_default_tool(tool_name))
        env = NousToolEnv(tools=tools, max_tool_response_length=max_tokens)

        return [tool.tool_description for tool in tools], env

    # Convert scores to float
    def to_float(self, score):
        if isinstance(score, (int, float, bool)):
            return float(score)
        else:
            return float(score[0])

    def process_task(self, task: Dict, client: OpenAI, model_params: Dict) -> Dict:
        """Process a single inference task"""
        exp_name = task['exp_name']
        run_id = task['run_id']
        item = task['item']
        tool_names = task['tool_names']

        try:
            # Setup tools using NousToolEnv
            tools, env = self.setup_tools(tool_names, model_params['max_tokens'])

            # Create initial messages
            messages = [{
                "role": "user",
                "content": item['prompt'][0]['content']
            }]

            max_iterations = 10  # Prevent infinite loops
            iteration = 0

            while iteration < max_iterations:
                iteration += 1

                # Generate response with OpenAI tools format
                response = client.chat.completions.create(
                    model=model_params['model'],
                    messages=messages,
                    tools=tools,
                    tool_choice="auto",
                    temperature=model_params['temperature'],
                    top_p=model_params['top_p'],
                    max_tokens=model_params['max_tokens'],
                    stop=["</tool_call>", "</tool_call>\n", "</tool_call> ", "</tool_call>\n\n"],
                )

                response_message = response.choices[0].message
                response_content = response_message.content

                response_content = self.fix_truncated_tool_call(response_content)

                # Add assistant message to conversation
                messages.append({
                    "role": "assistant",
                    "content": response_content
                })

                tool_responses, tool_successes, has_tool_calls = env.step(response_content, step_inference=True)
                if not has_tool_calls:
                    break

                for tool_response in tool_responses:
                    messages.append({
                        "role": "tool",
                        "content": tool_response
                    })

            try:
                model_input_index = task["item"]["index"]
            except Exception:
                model_input_index = task.get("item", {}).get("index", None)

            model_query = task.get("item", {}).get("prompt", [{}])[0].get("content", None)
            # reward_model.ground_truth 可能不存在，做容错
            try:
                model_ref = task["item"]["original_data"]["reward_model"]["ground_truth"]
            except Exception:
                model_ref = (task.get("item", {})
                                .get("original_data", {})
                                .get("reward_model", {})
                                .get("ground_truth", None))

            # 使用 messages 的副本以锁定当前回答轨迹
            model_answer = messages.copy()
            model_tool = task.get("tool_names", tool_names)  # 回退到当前的 tool_names 如果 task 中没有


            # Format the entire conversation using tokenizer's chat template
            output = self.format_messages_with_tokenizer(messages, tools=tools)

            # Compute scores using qa_em_and_format
            # format_score = compute_format_and_answer_score.compute_score_format(output)
            # em_score = compute_format_and_answer_score.compute_score_em(output, item['ground_truth'])
            # format_answer_score, answer_score = compute_format_and_answer_score.compute_score_format_answer(output, item['ground_truth'])

            format_score = compute_format_and_answer_score_using_gpt4omini.compute_score_format(output)
            em_score = compute_format_and_answer_score_using_gpt4omini.compute_score_em(output, item['ground_truth'])
            format_answer_score, answer_score = compute_format_and_answer_score_using_gpt4omini.compute_score_format_answer(output, item['ground_truth'])

            result = {
                "index": model_input_index,
                "query": model_query,
                "ground_truth": model_ref,
                "answer_messages": model_answer,
                "formatted_output": output,
                "tools_used": model_tool,
                "meta": {
                    "saved_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
                    "iteration_count": iteration
                },
                "acc": {
                    "em_acc": self.to_float(em_score),
                    "llm_acc": self.to_float(answer_score)
                },
            }

            results_dir = "./results/experiment_track_data/"
            os.makedirs(results_dir, exist_ok=True)
            jsonl_filename = f"all_result_{exp_name}.jsonl"        
            jsonl_path = os.path.join(results_dir, jsonl_filename)

            try:
                with open(jsonl_path, "a", encoding="utf-8") as f:
                    line = json.dumps(result, ensure_ascii=False, indent=2, default=str)
                    f.write(line + "\n")
                    f.flush()
                    try:
                        os.fsync(f.fileno())
                    except Exception:
                        pass

                # print("#" * 100)
                # print(f"整个回答轨迹：\n{output}\n")
                print(f"Appended result to JSONL: {jsonl_path}")

            except Exception as e:
                print("Error saving to jsonl:", e)
                raise

            return {
                "exp_name": exp_name,
                "run_id": run_id,
                "index": item['index'],
                "prompt": item['prompt'][0]['content'],
                "ground_truth": item['ground_truth'],
                "solution": output,
                "scores": {
                    "format": self.to_float(format_score),
                    "acc": self.to_float(em_score),
                    "llm_acc": self.to_float(answer_score),
                    "score": self.to_float(format_answer_score)
                },
                "success": True,
                "error": None,
                "iterations": iteration
            }

        except Exception as e:
            print(f"WARNING: {exp_name} - run_id: {run_id} - data_index: {item['index']} failed, reason: {str(e)}")
            return {
                "exp_name": exp_name,
                "run_id": run_id,
                "index": item['index'],
                "prompt": item.get('prompt', [{'content': ''}])[0].get('content', ''),
                "ground_truth": item.get('ground_truth', ''),
                "solution": "",
                "scores": {"format": 0.0, "acc": 0.0, "llm_acc": 0.0, "score": 0.0},
                "success": False,
                "error": str(e),
                "traceback": traceback.format_exc(),
                "iterations": 0
            }

    def init_tokenizer(self, tokenizer_path):
        """Initialize tokenizer for multiprocessing workers"""
        global tokenizer
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)

    def worker_process(self, worker_id: int, model_params: Dict, tokenizer_path: str,
                    task_queue: JoinableQueue, result_queue: Queue):
        """Worker process for processing tasks"""
        # Initialize tokenizer for this worker
        self.init_tokenizer(tokenizer_path)

        # Initialize OpenAI client for this worker
        client = OpenAI(
            api_key=model_params['api_key'],
            base_url=model_params['api_base'],
        )
        print(f"Worker {worker_id}: Started and ready")

        # Process tasks from queue
        while True:
            try:
                # Get task from queue with timeout
                task = task_queue.get(timeout=5)

                if task is None:  # Poison pill to stop worker
                    print(f"Worker {worker_id}: Received stop signal, shutting down")
                    task_queue.task_done()  # Mark the poison pill as done
                    break

                # Process task
                result = self.process_task(task, client, model_params)

                # Put result back
                result_queue.put(result)

                # Mark task as done
                task_queue.task_done()

            except Empty:
                # No task available, continue waiting
                continue
            except Exception as e:
                print(f"Worker {worker_id}: Error in worker process: {str(e)}")
                traceback.print_exc()
                # Still mark task as done even if it failed
                try:
                    task_queue.task_done()
                except:
                    pass
                continue

        print(f"Worker {worker_id}: Worker process finished")


def parse_args():
    parser = argparse.ArgumentParser(description='Auto-deploy vLLM server and run batch inference')
    
    # vLLM Server settings
    parser.add_argument('--api-key', type=str, default='auto-deploy-key', help='API key for vLLM server')
    parser.add_argument('--host', type=str, default='127.0.0.1', help='Host address for vLLM server')
    parser.add_argument('--port', type=int, default=random.randint(10000, 60000), help='Port number for vLLM server')
    parser.add_argument('--model-name', type=str, default='auto-deployed-model', help='Model name for API')
    parser.add_argument('--model-path', type=str, default="./llama3.1/Meta-Llama-3.1-8B-Instruct/", help='Path to local model directory')

    # vLLM Performance settings
    parser.add_argument('--tensor-parallel-size', type=int, default=8, help='Number of GPUs for tensor parallelism')
    parser.add_argument('--max-num-seqs', type=int, default=256, help='Maximum number of sequences for vLLM')
    parser.add_argument('--gpu-memory-utilization', type=float, default=0.9, help='GPU memory utilization ratio')
    parser.add_argument('--max-model-len', type=int, default=32768, help='Maximum model length')
    parser.add_argument('--dtype', type=str, default='auto', choices=['auto', 'half', 'float16', 'bfloat16', 'float32'], help='Data type for model weights')

    # Experiment settings
    parser.add_argument('--exp-name', type=str, default='Online_ToolModel-Llama3.1-70B-For-Llama-3.1-8B-Instruct', help='Experiment name')
    parser.add_argument('--temperature', type=float, default=default_config.TEMPERATURE, help='Temperature for sampling')
    parser.add_argument('--top-p', type=float, default=default_config.TOP_P, help='Top-p for nucleus sampling')
    parser.add_argument('--max-tokens', type=int, default=default_config.MAX_TOKENS, help='Maximum number of tokens to generate')
    parser.add_argument('--repetition-penalty', type=float, default=default_config.REPETITION_PENALTY, help='Repetition penalty for generation')
    
    # Processing settings
    parser.add_argument('--num-processes', type=int, default=1, help='Number of parallel processes for inference')
    parser.add_argument('--output-dir', type=str, default="results", help='Output directory for results')

    # Tokenizer settings
    parser.add_argument('--tokenizer-path', type=str, default=None, help='Path to tokenizer directory (defaults to model-path)')

    # Config file
    parser.add_argument('--config', type=str, default=None, help='Path to custom config file')

    return parser.parse_args()


##########################################################################################
def get_default_config(args):
    config = default_config
    if args.config:
        try:
            spec = importlib.util.spec_from_file_location("custom_config", args.config)
            config = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(config)
            print(f"Loaded custom config from {args.config}")
        except Exception as e:
            print(f"Error loading custom config: {e}")
            print("Falling back to default config")
    
    return config

def start_vllm_server_first(args):
    # Step 1: Start vLLM server
    print("=" * 80)
    print("STEP 1: STARTING VLLM SERVER")
    print("=" * 80)

    vllm_server = First_Vllm_Server(args)
    vllm_process = vllm_server.start_vllm_server()

    return vllm_process

def wait_for_server_ready_second(args, vllm_process):
    # Step 2: Wait for server to be ready
    print("\n" + "=" * 80)
    print("STEP 2: WAITING FOR SERVER TO BE READY")
    print("=" * 80)

    wait_server = Wait_Vllm_Server(args)
    wait_server.wait_for_server_ready(vllm_process)

def prepare_for_inference_third(args):
    # Step 3: Prepare for inference
    print("\n" + "=" * 80)
    print("STEP 3: PREPARING INFERENCE")
    print("=" * 80)

    model_params = {
                    'api_key': args.api_key,
                    'api_base': args.api_base,
                    'model': args.model_name,
                    'temperature': args.temperature,
                    'top_p': args.top_p,
                    'max_tokens': args.max_tokens,
                    'repetition_penalty': args.repetition_penalty
                    }

    # Setup output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    return output_dir, model_params

def create_all_tasks_fourth(FILE_CONFIGS):
    # Step 4: Create all tasks and get task counts
    print("\n" + "=" * 80)
    print("STEP 4: Creating all tasks...")
    print("=" * 80)

    subit_task = Submit_Task()
    all_tasks, file_task_counts = subit_task.create_all_tasks(FILE_CONFIGS)
    total_tasks = len(all_tasks)
    print(f"Created {total_tasks} tasks total")
    print(f"File-run task counts: {file_task_counts}")
    if total_tasks == 0:
        print("No tasks to process!")
        return
    
    return  all_tasks, total_tasks, file_task_counts

def run_inference_fifth(args, output_dir, model_params,  all_tasks, total_tasks, file_task_counts):
    # Step 5: Run inference
    print("\n" + "=" * 80)
    print("STEP 5: RUNNING INFERENCE")
    print("=" * 80)

    # Create queues
    task_queue = JoinableQueue()
    result_queue = Queue()

    # Add all tasks to queue
    for task in all_tasks:
        task_queue.put(task)

    # Start result collector thread
    construct_collector_thread = Construct_Collector_Thread()
    collector_thread = threading.Thread(
        target=construct_collector_thread.result_collector_thread,
        args=(result_queue, FILE_CONFIGS, output_dir, total_tasks, file_task_counts)
    )
    collector_thread.start()

    # Start worker processes
    muti_process_infer = Muti_Process_Infer(args)
    processes = []
    for worker_id in range(args.num_processes):
        p = Process(
            target=muti_process_infer.worker_process,
            args=(
                worker_id, model_params, args.tokenizer_path,
                task_queue, result_queue
            )
        )
        p.start()
        processes.append(p)

    print(f"Started {len(processes)} worker processes")
    
    return task_queue, collector_thread, processes

def print_info(task_queue, processes, collector_thread, total_tasks, output_dir):
    start_time = time.time()
    try:
        # Wait for all tasks to be completed
        print("Waiting for all tasks to complete...")
        task_queue.join()
        print("All tasks completed!")

        # Send stop signals to workers
        for _ in processes:
            task_queue.put(None)  # Poison pill

        # Wait for all workers to finish
        for p in processes:
            p.join()

        # Wait for result collector to finish
        collector_thread.join()

        total_time = time.time() - start_time

        print(f"\n{'=' * 80}")
        print("INFERENCE COMPLETED SUCCESSFULLY")
        print(f"{'=' * 80}")
        print(f"Total processing time: {total_time / 60:.1f} minutes")
        print(f"Total tasks processed: {total_tasks}")
        print(f"Average time per task: {total_time / total_tasks:.2f} seconds")
        print(f"Worker processes used: {len(processes)}")

        print(f"Check the output directory for detailed results: {output_dir}")
        print(f"{'=' * 80}")

    except KeyboardInterrupt:
        print("\nInterrupted by user")
        # Terminate all processes
        for p in processes:
            p.terminate()
            p.join()
        collector_thread.join(timeout=5)

def main():
    args = parse_args()
    args.api_base = f"http://{args.host}:{args.port}/v1"     # vllm deploy port
    if args.tokenizer_path is None:
        args.tokenizer_path = args.model_path

    global EXP_NAME
    EXP_NAME = args.exp_name
    
    # Load custom config if provided
    config = get_default_config(args)

    try:
        # # Start vLLM server
        vllm_process = start_vllm_server_first(args)
        
        # # Wait for server to be ready
        wait_for_server_ready_second(args, vllm_process)

        # Prepare for inference
        output_dir, model_params = prepare_for_inference_third(args)
        
        # Create all tasks and get task counts
        all_tasks, total_tasks, file_task_counts = create_all_tasks_fourth(FILE_CONFIGS)

        # RUNNING INFERENCE
        task_queue, collector_thread, processes = run_inference_fifth(args, output_dir, model_params,  all_tasks, total_tasks, file_task_counts)
        
        # Print_Info
        print_info(task_queue, processes, collector_thread, total_tasks, output_dir)

    except Exception as e:
        print(f"Error during execution: {str(e)}")
        traceback.print_exc()

    finally:
        # Step 5: Cleanup
        print("\n" + "=" * 80)
        print("STEP 5: CLEANING UP")
        print("=" * 80)
        cleanup_server()
        print("Cleanup completed.")

if __name__ == "__main__":
    main()