import time

from vllm import LLM

from src.utils.generation import is_api_model_auto


def get_vllm_models(
    mutator_model_name: str,
    solver_model_name: str,
    max_model_len: int,
    enforce_eager: bool = False,
    num_gpus: int = 1,
    gpu_memory_utilization: float = None,
):
    """
    Initialize VLLM models for mutator and solver, optimizing GPU memory usage when possible.
    
    Args:
        mutator_model_name: Name of mutator model
        solver_model_name: Name of solver model
        max_model_len: Maximum sequence length
        enforce_eager: Whether to enforce eager execution
        num_gpus: Number of available GPUs
        
    Returns:
        Tuple of (mutator_vllm_model, solver_vllm_model)
    """
    start = time.time()
    mutator_is_api = is_api_model_auto(mutator_model_name)
    solver_is_api = is_api_model_auto(solver_model_name)
    
    # If both are API models, return None for both
    if mutator_is_api and solver_is_api:
        return None, None
        
    # Calculate memory utilization based on number of models needed
    num_models = 0
    if not mutator_is_api:
        num_models += 1
    if not solver_is_api and solver_model_name != mutator_model_name:
        num_models += 1
    
    # Set memory utilization - leave ~10% for overhead
    memory_per_model = 0.9 / max(num_models, 1) if gpu_memory_utilization is None else gpu_memory_utilization

    # If using same open source model for both, initialize single model
    if mutator_model_name == solver_model_name and not mutator_is_api:
        import torch.distributed as dist
        import os
        print(f"[rank {os.environ.get('RANK', 0)}] Initializing vLLM model...")

        shared_model = LLM(
            model=mutator_model_name,
            dtype="float16",
            quantization="bitsandbytes",
            load_format="bitsandbytes",
            enable_lora=True,
            max_model_len=max_model_len,
            gpu_memory_utilization=memory_per_model,
            tensor_parallel_size=num_gpus,
            enforce_eager=enforce_eager
        )
        return shared_model, shared_model
        
    # Otherwise initialize separate models
    mutator_vllm_model = None if mutator_is_api else LLM(
        model=mutator_model_name,
        dtype="float16",
        quantization="bitsandbytes",
        load_format="bitsandbytes", 
        enable_lora=True,
        max_model_len=max_model_len,
        gpu_memory_utilization=memory_per_model,
        tensor_parallel_size=num_gpus,
        enforce_eager=enforce_eager
    )
    
    solver_vllm_model = None if solver_is_api else LLM(
        model=solver_model_name,
        dtype="float16",
        quantization="bitsandbytes",
        load_format="bitsandbytes",
        enable_lora=True,
        max_model_len=max_model_len,
        gpu_memory_utilization=memory_per_model,
        tensor_parallel_size=num_gpus,
        enforce_eager=enforce_eager
    )

    end = time.time()
    print(f"Time taken to initialize vLLM models: {end - start} seconds")
    
    return mutator_vllm_model, solver_vllm_model

