from vllm.core.scheduler import Scheduler
from packaging import version
from loguru import logger

class LLMActor:
    def __init__(self, *args, **kwargs):
        import vllm

        self.__version__ = vllm.__version__
        assert self.__version__ >= "0.4.1", "OpenRLHF only supports vLLM >= 0.4.1"

        self.use_gpu_executor = kwargs["tensor_parallel_size"] == 1

        if version.parse(vllm.__version__) >= version.parse("0.7"):
            kwargs["worker_extension_cls"] = "thinker_task.exp_engine.accelerators.inference.vllm_worker_wrap.WorkerWrapNew"
            kwargs["enable_sleep_mode"] = True
            if self.use_gpu_executor:
                kwargs["enforce_eager"] = False
        
        elif self.use_gpu_executor:
            # See https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py
            from .vllm_worker_wrap import OffloadableVLLMWorker
            vllm.worker.worker.Worker = OffloadableVLLMWorker
            kwargs["enforce_eager"] = True
        else:
            # RayGPUExecutor
            # See the patch https://github.com/vllm-project/vllm/commit/479d69fad0538f04cb22bf13e76ff91cfeb8a4e5
            kwargs["worker_use_ray"] = True
            kwargs["enforce_eager"] = True

            if version.parse(vllm.__version__) > version.parse("0.6.4.post1"):
                # https://github.com/vllm-project/vllm/pull/10555
                kwargs[
                    "worker_cls"
                ] = "thinker_task.exp_engine.accelerators.inference.vllm_worker_wrap.OffloadableVLLMWorker"
            else:
                RayWorkerWrapperPath = vllm.executor.ray_utils

                class RayWorkerWrapper(RayWorkerWrapperPath.RayWorkerWrapper):
                    def __init__(self, *args, **kwargs) -> None:
                        kwargs[
                            "worker_module_name"
                        ] = "thinker_task.exp_engine.accelerators.inference.vllm_worker_wrap"
                        kwargs["worker_class_name"] = "OffloadableVLLMWorker"
                        super().__init__(*args, **kwargs)

                RayWorkerWrapperPath.RayWorkerWrapper = RayWorkerWrapper
        
        logger.info(f"LLMActor args: {args}, kwargs: {kwargs}")
        self.llm = vllm.LLM(*args, **kwargs)
        self.awake = True

        if version.parse(self.__version__) < version.parse("0.7"):
            self.scheduler_config = self.llm.llm_engine.scheduler_config
            self.model_config = self.llm.llm_engine.model_config
            self.cache_config = self.llm.llm_engine.cache_config
            self.lora_config = self.llm.llm_engine.lora_config
            self.parallel_config = self.llm.llm_engine.parallel_config

    def generate(self, *args, **kwargs):
        if version.parse(self.__version__) >= version.parse("0.7"): self.backload_to_gpu()
        return self.llm.generate(*args, **kwargs)

    def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend):
        logger.info(f"init_process_group: master_address={master_address}, master_port={master_port}, rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}, use_gpu_executor={self.use_gpu_executor}")
        if version.parse(self.__version__) < version.parse("0.7"):            
            if self.use_gpu_executor:
                ret = self.llm.llm_engine.model_executor.driver_worker.init_process_group(
                    master_address, master_port, rank_offset, world_size, group_name, backend
                )
            else:            
                ret = self.llm.llm_engine.model_executor._run_workers(
                    "init_process_group", master_address, master_port, rank_offset, world_size, group_name, backend
                )            
        else:
            ret = self.llm.collective_rpc(
                "init_process_group",
                args=(master_address, master_port, rank_offset, world_size, group_name, backend),
            )

        logger.info(f"Finish init_process_group: master_address={master_address}, master_port={master_port}, rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}")        
        return ret

    def get_ip_and_port(self):
        if self.use_gpu_executor:
            return self.llm.llm_engine.model_executor.driver_worker.get_ip_and_port()
        else:
            return self.llm.llm_engine.model_executor._run_workers("get_ip_and_port")

    def offload_to_cpu(self):
        if version.parse(self.__version__) < version.parse("0.7"):
            if self.use_gpu_executor:
                return self.llm.llm_engine.model_executor.driver_worker.offload_cpu()
            else:
                return self.llm.llm_engine.model_executor._run_workers("offload_cpu")
        else:
            self.llm.sleep(level=1)
            self.awake = False

    def backload_to_gpu(self):
        if version.parse(self.__version__) < version.parse("0.7"):
            if self.use_gpu_executor:
                self.llm.llm_engine.model_executor.driver_worker.load_gpu()
            else:
                self.llm.llm_engine.model_executor._run_workers("load_gpu")
            # rebuild scheduler
            self.llm.llm_engine.scheduler = [
                Scheduler(
                    self.scheduler_config,
                    self.cache_config,
                    self.lora_config,
                    self.parallel_config.pipeline_parallel_size,
                    self.async_callbacks[v_id] if self.model_config.use_async_output_proc else None,
                )
                for v_id in range(self.parallel_config.pipeline_parallel_size)
            ]
        else:            
            if not self.awake:
                self.llm.wake_up()
            self.awake = True

    def update_weight(self, name, dtype, shape, empty_cache=False):
        self.stop_remote_worker_execution_loop()
        if version.parse(self.__version__) < version.parse("0.7"):            
            if self.use_gpu_executor:
                return self.llm.llm_engine.model_executor.driver_worker.update_weight(name, dtype, shape, empty_cache)
            else:
                return self.llm.llm_engine.model_executor._run_workers("update_weight", name, dtype, shape, empty_cache)
        else:
            self.backload_to_gpu()
            return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache))

    def update_weight_internal_with_cuda_ipc(self, name, dtype, shape, cudaipc_handler, empty_cache=False):
        if version.parse(self.__version__) < version.parse("0.7"):
            if self.use_gpu_executor:
                return self.llm.llm_engine.model_executor.driver_worker.update_weight_internal_with_cuda_ipc(
                    name, dtype, shape, cudaipc_handler, empty_cache
                )
            else:
                return self.llm.llm_engine.model_executor._run_workers(
                    "update_weight_internal_with_cuda_ipc", name, dtype, shape, cudaipc_handler, empty_cache
                )
        else:
            return self.llm.collective_rpc("update_weight_internal_with_cuda_ipc", args=(name, dtype, shape, cudaipc_handler, empty_cache))

    def stop_remote_worker_execution_loop(self):
        # Fix error for using 2 communication group
        # https://github.com/vllm-project/vllm/commit/eb6d3c264d0cd8e44dec16bca7947fbe96415ce9#diff-e1ad69e38e033accddfa5480ec808c4740eb39244d1ef51cc3407e20dde8cfd4
        if self.__version__ > "0.4.2":
            self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop()

    def get_gpu_memory(self):
        """获取当前Actor使用的GPU内存"""
        import torch

        torch.cuda.empty_cache()
        return torch.cuda.memory_allocated() / 1024**2  # 转换为MB

    def get_weight_statistics(self):
        """Compute lightweight statistics for model weights"""
        stats = {}
        model_runner = self.llm.llm_engine.model_executor.driver_worker.model_runner
        for name, param in model_runner.model.named_parameters():
            # 计算关键统计信息
            tensor_stats = {
                "mean": param.mean().item(),
                "std": param.std().item(),
                "norm": param.norm().item(),
                "shape": tuple(param.shape),
                # 可选：计算一些极值
                "max": param.max().item(),
                "min": param.min().item(),
            }
            stats[name] = tensor_stats
        return stats
    
    def get_version(self):
        return self.__version__
