#!/usr/bin/env python3
"""
Flexible vLLM inference script with automatic model configuration and incremental saving
Updated to support Qwen3 and Qwen2.5-VL models with environment-based configuration
Now includes incremental saving and resume capability to prevent loss of progress
Enhanced with cleaner logging and better progress reporting
"""

import json
import argparse
import os
import time
import logging
import sys
from vllm import LLM, SamplingParams

# Configure logging to reduce vLLM verbosity
logging.basicConfig(
    level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s"
)

# Reduce vLLM and transformers logging
logging.getLogger("vllm").setLevel(logging.WARNING)
logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)


def print_status(message, file=sys.stderr):
    """Print status messages to stderr for monitoring"""
    print(f"[INFERENCE] {message}", file=file, flush=True)


def get_model_config(model_name: str) -> dict:
    """
    Get optimal configuration based on model name and environment variables.
    """
    # Get configuration from environment variables with fallbacks
    gpu_memory_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.85"))
    max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096"))
    tensor_parallel_size = int(os.getenv("VLLM_TENSOR_PARALLEL_SIZE", "1"))

    config = {
        "tensor_parallel_size": tensor_parallel_size,
        "gpu_memory_utilization": gpu_memory_util,
        "max_model_len": max_model_len,
        "enable_reasoning": True,
        "supports_thinking": True,
    }

    print_status(f"Model configuration for {model_name}:")
    print_status(f"  GPU memory utilization: {config['gpu_memory_utilization']}")
    print_status(f"  Max model length: {config['max_model_len']}")
    print_status(f"  Tensor parallel size: {config['tensor_parallel_size']}")

    return config


def load_prompts(json_path: str) -> list[dict]:
    """
    Load prompts from JSON file.
    Expects a list of prompt entries, each containing "id", "system_prompt", "text", and "metadata" keys.
    """
    try:
        with open(json_path, "r", encoding="utf-8") as f:
            data = json.load(f)
    except (FileNotFoundError, json.JSONDecodeError) as e:
        print_status(f"Error loading prompts from {json_path}: {e}")
        raise

    if not isinstance(data, list):
        raise ValueError("Invalid JSON structure. Expected a list of prompt objects.")

    prompts = []
    for idx, item in enumerate(data):
        if not isinstance(item, dict):
            raise ValueError(f"Invalid prompt at index {idx}: expected a JSON object.")

        # Validate required fields
        required_fields = ["id", "text"]
        for field in required_fields:
            if field not in item:
                raise ValueError(
                    f"Missing required field '{field}' in prompt at index {idx}"
                )

        # Extract fields
        prompt_id = item["id"]
        system_prompt = item.get("system_prompt", "")
        user_text = item.get("text", "")
        metadata = item.get("metadata", {})

        if not isinstance(system_prompt, str) or not isinstance(user_text, str):
            raise ValueError(
                f"Invalid prompt at index {idx}: 'system_prompt' and 'text' must be strings."
            )

        # Combine system prompt and user text
        combined = f"{system_prompt}\n\n{user_text}" if system_prompt else user_text

        prompts.append({"id": prompt_id, "text": combined, "metadata": metadata})

    print_status(f"Successfully loaded {len(prompts)} prompts")
    return prompts


def load_existing_results(output_path: str) -> list[dict]:
    """Load existing results if the file exists."""
    if os.path.exists(output_path):
        try:
            with open(output_path, "r", encoding="utf-8") as f:
                results = json.load(f)
                print_status(
                    f"Loaded {len(results)} existing results from {output_path}"
                )
                return results
        except (json.JSONDecodeError, IOError) as e:
            print_status(f"Could not load existing results from {output_path}: {e}")
            return []
    return []


def save_results_incrementally(results: list[dict], output_path: str) -> None:
    """Save results to file incrementally using atomic write."""
    # Write to temporary file first, then move to avoid corruption
    temp_path = output_path + ".tmp"
    try:
        with open(temp_path, "w", encoding="utf-8") as f:
            json.dump(results, f, indent=2)
        os.rename(temp_path, output_path)
    except (IOError, OSError) as e:
        print_status(f"Failed to save results: {e}")
        # Clean up temp file if it exists
        if os.path.exists(temp_path):
            os.remove(temp_path)
        raise


def run_inference_incremental(
    model_name: str,
    prompts: list[dict],
    output_path: str,
    batch_size: int = 100,
    resume: bool = True,
) -> list[dict]:
    """
    Run inference with incremental saving and resume capability.
    """

    # Load existing results if resuming
    existing_results = []
    processed_ids = set()

    if resume:
        existing_results = load_existing_results(output_path)
        processed_ids = {result["id"] for result in existing_results}
        if existing_results:
            print_status(f"Resume mode: Found {len(existing_results)} existing results")

    # Filter out already processed prompts
    remaining_prompts = [p for p in prompts if p["id"] not in processed_ids]

    if not remaining_prompts:
        print_status("All prompts already processed!")
        return existing_results

    print_status(
        f"Processing {len(remaining_prompts)} remaining prompts (of {len(prompts)} total)"
    )
    print_status(f"Batch size: {batch_size}")

    # Get model configuration
    config = get_model_config(model_name)

    # Initialize model once
    print_status(f"Loading model: {model_name}")
    print_status("This may take several minutes for large models...")

    llm_kwargs = {
        "model": model_name,
        "tensor_parallel_size": config["tensor_parallel_size"],
        "gpu_memory_utilization": config["gpu_memory_utilization"],
        "max_model_len": config["max_model_len"],
        "enable_prefix_caching": True,
        "disable_log_stats": True,  # Reduce logging
        "swap_space": 4,
        "cpu_offload_gb": 0,
        "enforce_eager": False,
        "trust_remote_code": True,
    }

    try:
        llm = LLM(**llm_kwargs)
        print_status("Model loaded successfully!")
    except Exception as e:
        print_status(f"Failed to load model: {e}")
        raise

    # Get sampling parameters from environment
    temperature = float(os.getenv("TEMPERATURE", "0.3"))
    top_p = float(os.getenv("TOP_P", "0.9"))
    max_tokens = int(os.getenv("MAX_LENGTH", "4096"))

    sampling_params = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        repetition_penalty=1.1,
        stop=[
            "<thinking>\n\n<thinking>",
            "<answer>\n\n<answer>",
            "</thinking>\n\n</thinking>",
            "</answer>\n\n</answer>",
        ],
        stop_token_ids=None,
        skip_special_tokens=True,
    )

    print_status(
        f"Sampling params: temp={temperature}, top_p={top_p}, max_tokens={max_tokens}"
    )

    # Process in batches
    all_results = existing_results.copy()
    total_batches = (len(remaining_prompts) + batch_size - 1) // batch_size

    print_status(f"Starting batch processing: {total_batches} batches")

    for i in range(0, len(remaining_prompts), batch_size):
        batch = remaining_prompts[i : i + batch_size]
        batch_num = i // batch_size + 1

        print_status(
            f"Batch {batch_num}/{total_batches}: Processing {len(batch)} prompts..."
        )
        start_time = time.time()

        # Extract prompts for this batch
        batch_texts = [p["text"] for p in batch]

        # Run inference on this batch
        try:
            # Temporarily suppress some logging during generation
            old_level = logging.getLogger("vllm").level
            logging.getLogger("vllm").setLevel(logging.ERROR)

            batch_outputs = llm.generate(batch_texts, sampling_params)

            # Restore logging level
            logging.getLogger("vllm").setLevel(old_level)

        except Exception as e:
            print_status(f"Error during inference for batch {batch_num}: {e}")
            # Save what we have so far before re-raising
            if all_results != existing_results:
                save_results_incrementally(all_results, output_path)
            raise

        # Process results for this batch
        batch_results = []
        for prompt, output in zip(batch, batch_outputs):
            result = {
                "id": prompt["id"],
                "prompt": prompt["text"],
                "completion": output.outputs[0].text,
                "metadata": prompt["metadata"],
                "model": model_name,
            }
            batch_results.append(result)
            all_results.append(result)

        # Save incrementally after each batch
        try:
            save_results_incrementally(all_results, output_path)
            batch_time = time.time() - start_time

            # Calculate current statistics
            completed_prompts = len(all_results)
            total_prompts = len(prompts)
            completion_rate = (completed_prompts / total_prompts) * 100

            print_status(f"Batch {batch_num} complete in {batch_time:.1f}s")
            print_status(
                f"Progress: {completed_prompts}/{total_prompts} prompts ({completion_rate:.1f}%)"
            )

            # Estimate remaining time
            if batch_num > 1:  # Need at least 2 batches for estimation
                avg_batch_time = batch_time  # Could be more sophisticated
                remaining_batches = total_batches - batch_num
                eta_minutes = (remaining_batches * avg_batch_time) / 60
                print_status(f"Estimated time remaining: {eta_minutes:.1f} minutes")

        except Exception as e:
            print_status(f"Failed to save results for batch {batch_num}: {e}")
            raise

    final_count = len(all_results)
    total_count = len(prompts)
    print_status(f"Inference complete! Processed {len(remaining_prompts)} new prompts")
    print_status(f"Final results: {final_count}/{total_count} prompts completed")
    print_status(f"Results saved to: {output_path}")

    return all_results


def run_inference(model_name: str, prompts: list[dict]) -> list[str]:
    """Original run_inference function for backward compatibility."""

    config = get_model_config(model_name)

    print_status(f"Loading {model_name}...")

    llm_kwargs = {
        "model": model_name,
        "tensor_parallel_size": config["tensor_parallel_size"],
        "gpu_memory_utilization": config["gpu_memory_utilization"],
        "max_model_len": config["max_model_len"],
        "enable_prefix_caching": True,
        "disable_log_stats": True,
        "swap_space": 4,
        "cpu_offload_gb": 0,
        "enforce_eager": False,
        "trust_remote_code": True,
    }

    llm = LLM(**llm_kwargs)

    temperature = float(os.getenv("TEMPERATURE", "0.3"))
    top_p = float(os.getenv("TOP_P", "0.9"))
    max_tokens = int(os.getenv("MAX_LENGTH", "4096"))

    sampling_params = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        repetition_penalty=1.1,
        stop=[
            "<thinking>\n\n<thinking>",
            "<answer>\n\n<answer>",
            "</thinking>\n\n</thinking>",
            "</answer>\n\n</answer>",
        ],
        stop_token_ids=None,
        skip_special_tokens=True,
    )

    print_status(f"Running batch inference on {len(prompts)} prompts...")
    print_status(
        f"Sampling: temp={temperature}, top_p={top_p}, max_tokens={max_tokens}"
    )

    prompt_texts = [prompt["text"] for prompt in prompts]

    # Suppress verbose logging during generation
    old_level = logging.getLogger("vllm").level
    logging.getLogger("vllm").setLevel(logging.ERROR)

    outputs = llm.generate(prompt_texts, sampling_params)

    # Restore logging level
    logging.getLogger("vllm").setLevel(old_level)

    results = [output.outputs[0].text for output in outputs]

    return results


def main():
    parser = argparse.ArgumentParser(
        description="Flexible vLLM inference script with automatic model configuration and incremental saving"
    )
    parser.add_argument("--input", required=True, help="Path to input JSON file")
    parser.add_argument(
        "--model",
        default=os.getenv("MODEL_REPO_ID", "Qwen/Qwen3-8B"),
        help="Model to use for inference (default from MODEL_REPO_ID env var)",
    )
    parser.add_argument("--output", help="Output JSON file (optional)")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=100,
        help="Number of prompts to process in each batch (default: 100)",
    )
    parser.add_argument(
        "--no_resume",
        action="store_true",
        help="Start from scratch instead of resuming from existing results",
    )
    parser.add_argument(
        "--legacy_mode",
        action="store_true",
        help="Use original batch processing (all at once, no incremental saving)",
    )

    args = parser.parse_args()

    print_status(f"Starting vLLM inference with model: {args.model}")
    print_status(f"Input file: {args.input}")
    if args.output:
        print_status(f"Output file: {args.output}")

    # Load prompts
    try:
        prompts = load_prompts(args.input)
    except Exception as e:
        print_status(f"Failed to load prompts: {e}")
        sys.exit(1)

    if args.legacy_mode:
        print_status("Running in legacy mode (no incremental saving)")
        try:
            results = run_inference(args.model, prompts)
        except Exception as e:
            print_status(f"Inference failed: {e}")
            sys.exit(1)

        # Prepare output with preserved metadata
        output_data = []
        for prompt, result in zip(prompts, results):
            output_data.append(
                {
                    "id": prompt["id"],
                    "prompt": prompt["text"],
                    "completion": result,
                    "metadata": prompt["metadata"],
                    "model": args.model,
                }
            )
    else:
        if not args.output:
            print_status("Error: Output path is required for incremental saving mode")
            sys.exit(1)

        print_status("Running with incremental saving enabled")
        try:
            output_data = run_inference_incremental(
                model_name=args.model,
                prompts=prompts,
                output_path=args.output,
                batch_size=args.batch_size,
                resume=not args.no_resume,
            )
        except Exception as e:
            print_status(f"Inference failed: {e}")
            sys.exit(1)

    # Save or print results (for legacy mode or final confirmation)
    if args.output and args.legacy_mode:
        try:
            with open(args.output, "w", encoding="utf-8") as f:
                json.dump(output_data, f, indent=2)
            print_status(f"Results saved to {args.output}")
        except Exception as e:
            print_status(f"Failed to save results: {e}")
            sys.exit(1)
    elif not args.output:
        print(json.dumps(output_data, indent=2))

    print_status("Inference completed successfully")


if __name__ == "__main__":
    main()
