#!/usr/bin/env python3
"""
vLLM Server for Qwen 2.5 Models
Hosts the model using vLLM with OpenAI-compatible API on 4 GPUs

Usage examples:
1. Use local model (recommended):
   python vllm_server.py --model /path/to/local/qwen-2.5-1.5b

2. Use HuggingFace model with internet access:
   python vllm_server.py --model Qwen/Qwen2.5-1.5B --online

3. Custom configuration:
   python vllm_server.py --model /path/to/model --gpus 2 --port 8001

Note: By default, the server runs in offline mode to avoid network issues.
If you have the model downloaded locally via HuggingFace cache, it should work automatically.
"""

import subprocess
import sys
import argparse
from pathlib import Path

def start_vllm_server(
    model_name: str = "Qwen/Qwen2.5-1.5B",
    host: str = "0.0.0.0",
    port: int = 8000,
    tensor_parallel_size: int = 4,
    max_model_len: int = 4096,
    gpu_memory_utilization: float = 0.9,
    offline: bool = True
):
    """
    Start vLLM server with OpenAI-compatible API
    
    Args:
        model_name: HuggingFace model name or local path to model
        host: Server host
        port: Server port
        tensor_parallel_size: Number of GPUs to use
        max_model_len: Maximum model length
        gpu_memory_utilization: GPU memory utilization ratio
        offline: Whether to run in offline mode (no HuggingFace downloads)
    """
    
    cmd = [
        "python", "-m", "vllm.entrypoints.openai.api_server",
        "--model", model_name,
        "--host", host,
        "--port", str(port),
        "--tensor-parallel-size", str(tensor_parallel_size),
        "--max-model-len", str(max_model_len),
        "--gpu-memory-utilization", str(gpu_memory_utilization),
        "--trust-remote-code",
        "--dtype", "float16",
        "--disable-log-requests",  # Reduce log verbosity
        "--served-model-name", "base_qwen_model",  # Custom name for API calls
        "--disable-custom-all-reduce",  # Helps with multi-GPU stability
        "--enforce-eager",  # Avoids compilation issues
    ]
    
    # Add offline mode flags if requested
    if offline:
        cmd.extend([
            "--disable-log-stats",  # Reduce network calls
        ])
        # Set environment variable for offline mode
        import os
        os.environ["HF_HUB_OFFLINE"] = "1"
        os.environ["TRANSFORMERS_OFFLINE"] = "1"
    
    print("Starting vLLM server with the following configuration:")
    print(f"Model: {model_name}")
    print(f"Host: {host}")
    print(f"Port: {port}")
    print(f"Tensor Parallel Size: {tensor_parallel_size} GPUs")
    print(f"Max Model Length: {max_model_len}")
    print(f"GPU Memory Utilization: {gpu_memory_utilization}")
    print(f"Command: {' '.join(cmd)}")
    print("\n" + "="*60)
    print("Starting server...")
    print("="*60)
    
    try:
        # Run the server
        subprocess.run(cmd, check=True)
    except KeyboardInterrupt:
        print("\nServer stopped by user")
    except subprocess.CalledProcessError as e:
        print(f"Error starting server: {e}")
        sys.exit(1)

def main():
    parser = argparse.ArgumentParser(description="Start vLLM server for Qwen 2.5 models")
    parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B", help="Model name or local path")
    parser.add_argument("--host", default="0.0.0.0", help="Server host")
    parser.add_argument("--port", type=int, default=8000, help="Server port")
    parser.add_argument("--gpus", type=int, default=4, help="Number of GPUs to use")
    parser.add_argument("--max-len", type=int, default=4096, help="Maximum model length")
    parser.add_argument("--gpu-util", type=float, default=0.9, help="GPU memory utilization")
    parser.add_argument("--online", action="store_true", help="Allow online downloads (default: offline)")
    
    args = parser.parse_args()
    
    # Check if vLLM is installed
    try:
        import vllm
        print(f"vLLM version: {vllm.__version__}")
    except ImportError:
        print("vLLM not found. Please install it with: pip install vllm")
        sys.exit(1)
    
    # Check if model path exists locally
    model_path = Path(args.model)
    if model_path.exists():
        print(f"✓ Found local model at: {model_path.absolute()}")
        model_to_use = str(model_path.absolute())
    else:
        # Check if it might be in HuggingFace cache
        import os
        hf_cache_dir = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
        potential_hf_path = Path(hf_cache_dir) / "hub" / f"models--{args.model.replace('/', '--')}"
        
        if potential_hf_path.exists():
            # Find the actual model directory
            snapshots_dir = potential_hf_path / "snapshots"
            if snapshots_dir.exists():
                snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()]
                if snapshot_dirs:
                    model_to_use = str(snapshot_dirs[0])  # Use the first (likely only) snapshot
                    print(f"✓ Found model in HuggingFace cache: {model_to_use}")
                else:
                    model_to_use = args.model
            else:
                model_to_use = args.model
        else:
            print(f"Using model identifier: {args.model}")
            if not args.online:
                print("WARNING: Model not found locally and offline mode is enabled.")
                print("Possible solutions:")
                print(f"1. Provide full path: --model /path/to/your/model")
                print(f"2. Check HuggingFace cache: {hf_cache_dir}")
                print("3. Enable online mode: --online")
            model_to_use = args.model
    
    start_vllm_server(
        model_name=model_to_use,
        host=args.host,
        port=args.port,
        tensor_parallel_size=args.gpus,
        max_model_len=args.max_len,
        gpu_memory_utilization=args.gpu_util,
        offline=not args.online
    )

if __name__ == "__main__":
    main()