# SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager
from typing import Any

import torch

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

logger = init_logger(__name__)


class CPUModelRunner(GPUModelRunner):

    def __init__(self, vllm_config: VllmConfig, device: torch.device):
        super().__init__(vllm_config, device)

        assert device == torch.device("cpu")
        assert self.speculative_config is None, "spec decode is not supported."

        self.use_cuda_graph = False
        self.cascade_attn_enabled = False

        self._postprocess_tenosrs()

    def _postprocess_tenosrs(self) -> None:
        # Note: replace device tensors with cpu tensors
        def replace_tensor(obj: Any, cpu_attr_name: str,
                           device_attr_name) -> None:
            cpu_tensor = getattr(obj, cpu_attr_name, None)
            device_tensor = getattr(obj, device_attr_name, None)
            if cpu_tensor is not None and device_tensor is not None:
                assert isinstance(cpu_tensor, torch.Tensor)
                assert isinstance(device_tensor, torch.Tensor)
                setattr(obj, device_attr_name, cpu_tensor)

        for k, v in vars(self).items():
            if k.endswith("_cpu") and isinstance(v, torch.Tensor):
                replace_tensor(self, k, k[:-4])

        for k, v in vars(self.input_batch).items():
            if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
                replace_tensor(self.input_batch, k, k[:-11])

        for k, v in vars(self.input_batch.block_table).items():
            if k.endswith("_cpu") and isinstance(v, torch.Tensor):
                replace_tensor(self.input_batch.block_table, k, k[:-4])

    def load_model(self) -> None:
        logger.info("Starting to load model %s...", self.model_config.model)
        self.model = get_model(vllm_config=self.vllm_config)

        if self.lora_config:
            self.model = self.load_lora_model(self.model, self.model_config,
                                              self.scheduler_config,
                                              self.lora_config, self.device)

    def warming_up_model(self) -> None:
        logger.info("Warming up model for the compilation...")
        # Only generate graph for the generic shape
        self._dummy_run(max(16, self.max_num_reqs))
        logger.info("Warming up done.")

    def _init_device_properties(self) -> None:
        pass

    def _sync_device(self) -> None:
        pass


@contextmanager
def _set_global_compilation_settings():
    import torch._inductor.config

    # Note: The CPPGEMM backend requires freezing parameters.
    freezing_value = torch._inductor.config.freezing
    torch._inductor.config.freezing = True
    # Note: workaround for "ValueError: fast mode: can't pickle cyclic objects
    # including object type dict"
    force_disable_caches = torch._inductor.config.force_disable_caches
    torch._inductor.config.force_disable_caches = True
    yield
    torch._inductor.config.freezing = freezing_value
    torch._inductor.config.force_disable_caches = force_disable_caches
