#!/usr/bin/env python3
"""
Launch SGLang server for OpenHermes dataset generation.

This script launches a SGLang server with the specified model and waits for it to be ready.
"""

import argparse
import sys
import subprocess
from sglang.utils import wait_for_server, print_highlight


def parse_args():
    parser = argparse.ArgumentParser(
        description="Launch SGLang server for OpenHermes dataset generation"
    )
    
    parser.add_argument(
        "--model-path",
        type=str,
        default="Qwen/Qwen3-30BA3B-Instruct-2507",
        help="Path or name of the model to use"
    )
    
    parser.add_argument(
        "--host",
        type=str,
        default="0.0.0.0",
        help="Host address to bind the server"
    )
    
    parser.add_argument(
        "--port",
        type=int,
        default=30000,
        help="Port number for the server"
    )
    
    parser.add_argument(
        "--tp-size",
        type=int,
        default=1,
        help="Tensor parallelism size"
    )
    
    parser.add_argument(
        "--mem-fraction-static",
        type=float,
        default=0.8,
        help="Memory fraction for static allocation"
    )
    
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        help="Data type for model weights"
    )
    
    parser.add_argument(
        "--log-level",
        type=str,
        default="warning",
        choices=["debug", "info", "warning", "error"],
        help="Log level for the server"
    )
    
    return parser.parse_args()


def launch_server_subprocess(cmd_args):
    """Launch server using subprocess, returns (process, port)."""
    try:
        process = subprocess.Popen(cmd_args)
        # Extract port from command args
        port_idx = cmd_args.index("--port") + 1
        port = int(cmd_args[port_idx])
        return process, port
    except Exception as e:
        raise Exception(f"Failed to launch server: {e}")


def main():
    args = parse_args()
    
    print_highlight(f"Launching SGLang server with model: {args.model_path}")
    print(f"Host: {args.host}")
    print(f"Port: {args.port}")
    print(f"Tensor parallelism: {args.tp_size}")
    print(f"Memory fraction: {args.mem_fraction_static}")
    print(f"Data type: {args.dtype}")
    print(f"Log level: {args.log_level}")
    print("-" * 60)
    
    # Build the server command
    cmd_args = [
        "python3", "-m", "sglang.launch_server",
        "--model-path", args.model_path,
        "--host", args.host,
        "--port", str(args.port),
        "--tp-size", str(args.tp_size),
        "--mem-fraction-static", str(args.mem_fraction_static),
        "--dtype", args.dtype,
        "--log-level", args.log_level
    ]
    
    server_process = None
    try:
        # Launch the server
        print_highlight("Starting SGLang server...")
        server_process, actual_port = launch_server_subprocess(cmd_args)
        
        # Wait for the server to be ready
        server_url = f"http://localhost:{actual_port}"
        wait_for_server(server_url)
        
        print_highlight(f"✅ SGLang server is ready!")
        print(f"Server URL: {server_url}")
        print(f"API endpoint: {server_url}/v1")
        print("-" * 60)
        print("You can now run the OpenHermes dataset generation script.")
        print("Press Ctrl+C to stop the server.")
        
        # Keep the server running
        try:
            server_process.wait()
        except KeyboardInterrupt:
            print_highlight("\n🛑 Stopping server...")
            server_process.terminate()
            try:
                server_process.wait(timeout=10)
            except subprocess.TimeoutExpired:
                server_process.kill()
                server_process.wait()
            print("Server stopped.")
            
    except Exception as e:
        print(f"❌ Error launching server: {e}")
        if server_process:
            server_process.terminate()
        sys.exit(1)


if __name__ == "__main__":
    main()
