"""
FastAPI server example for text generation using SGLang Engine and demonstrating client usage.

Starts the server, sends requests to it, and prints responses.

Usage:
python fastapi_engine_inference.py --model-path Qwen/Qwen2.5-0.5B-Instruct --tp_size 1 --host 127.0.0.1 --port 8000
"""

import os
import subprocess
import time
from contextlib import asynccontextmanager

import requests
from fastapi import FastAPI, Request

import sglang as sgl
from sglang.utils import terminate_process

engine = None


# Use FastAPI's lifespan manager to initialize/shutdown the engine
@asynccontextmanager
async def lifespan(app: FastAPI):
    """Manages SGLang engine initialization during server startup."""
    global engine
    # Initialize the SGLang engine when the server starts
    # Adjust model_path and other engine arguments as needed
    print("Loading SGLang engine...")
    engine = sgl.Engine(
        model_path=os.getenv("MODEL_PATH"), tp_size=int(os.getenv("TP_SIZE"))
    )
    print("SGLang engine loaded.")
    yield
    # Clean up engine resources when the server stops (optional, depends on engine needs)
    print("Shutting down SGLang engine...")
    # engine.shutdown() # Or other cleanup if available/necessary
    print("SGLang engine shutdown.")


app = FastAPI(lifespan=lifespan)


@app.post("/generate")
async def generate_text(request: Request):
    """FastAPI endpoint to handle text generation requests."""
    global engine
    if not engine:
        return {"error": "Engine not initialized"}, 503

    try:
        data = await request.json()
        prompt = data.get("prompt")
        max_new_tokens = data.get("max_new_tokens", 128)
        temperature = data.get("temperature", 0.7)

        if not prompt:
            return {"error": "Prompt is required"}, 400

        # Use async_generate for non-blocking generation
        state = await engine.async_generate(
            prompt,
            sampling_params={
                "max_new_tokens": max_new_tokens,
                "temperature": temperature,
            },
            # Add other parameters like stop, top_p etc. as needed
        )

        return {"generated_text": state["text"]}
    except Exception as e:
        return {"error": str(e)}, 500


# Helper function to start the server
def start_server(args, timeout=60):
    """Starts the Uvicorn server as a subprocess and waits for it to be ready."""
    base_url = f"http://{args.host}:{args.port}"
    command = [
        "python",
        "-m",
        "uvicorn",
        "fastapi_engine_inference:app",
        f"--host={args.host}",
        f"--port={args.port}",
    ]

    process = subprocess.Popen(command, stdout=None, stderr=None)

    start_time = time.time()
    with requests.Session() as session:
        while time.time() - start_time < timeout:
            try:
                # Check the /docs endpoint which FastAPI provides by default
                response = session.get(
                    f"{base_url}/docs", timeout=5
                )  # Add a request timeout
                if response.status_code == 200:
                    print(f"Server {base_url} is ready (responded on /docs)")
                    return process
            except requests.ConnectionError:
                # Specific exception for connection refused/DNS error etc.
                pass
            except requests.Timeout:
                # Specific exception for request timeout
                print(f"Health check to {base_url}/docs timed out, retrying...")
                pass
            except requests.RequestException as e:
                # Catch other request exceptions
                print(f"Health check request error: {e}, retrying...")
                pass
            # Use a shorter sleep interval for faster startup detection
            time.sleep(1)

    # If loop finishes, raise the timeout error
    # Attempt to terminate the failed process before raising
    if process:
        print(
            "Server failed to start within timeout, attempting to terminate process..."
        )
        terminate_process(process)  # Use the imported terminate_process
    raise TimeoutError(
        f"Server failed to start at {base_url} within the timeout period."
    )


def send_requests(server_url, prompts, max_new_tokens, temperature):
    """Sends generation requests to the running server for a list of prompts."""
    # Iterate through prompts and send requests
    for i, prompt in enumerate(prompts):
        print(f"\n[{i+1}/{len(prompts)}] Sending prompt: '{prompt}'")
        payload = {
            "prompt": prompt,
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
        }

        try:
            response = requests.post(f"{server_url}/generate", json=payload, timeout=60)

            result = response.json()

            print(f"Prompt: {prompt}\nResponse: {result['generated_text']}")

        except requests.exceptions.Timeout:
            print(f"  Error: Request timed out for prompt '{prompt}'")
        except requests.exceptions.RequestException as e:
            print(f"  Error sending request for prompt '{prompt}': {e}")


if __name__ == "__main__":
    """Main entry point for the script."""

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="127.0.0.1")
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--model-path", type=str, default="Qwen/Qwen2.5-0.5B-Instruct")
    parser.add_argument("--tp_size", type=int, default=1)
    args = parser.parse_args()

    # Pass the model to the child uvicorn process via an env var
    os.environ["MODEL_PATH"] = args.model_path
    os.environ["TP_SIZE"] = str(args.tp_size)

    # Start the server
    process = start_server(args)

    # Define the prompts and sampling parameters
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    max_new_tokens = 64
    temperature = 0.1

    # Define server url
    server_url = f"http://{args.host}:{args.port}"

    # Send requests to the server
    send_requests(server_url, prompts, max_new_tokens, temperature)

    # Terminate the server process
    terminate_process(process)
