import asyncio
import math
import os
import types
from typing import List

import torch

try:
    import vllm
    from vllm.engine.arg_utils import AsyncEngineArgs
    from vllm.v1.engine.async_llm import AsyncLLM
except ImportError:
    pass

try:
    import sglang as sgl
    from sglang.srt.managers.io_struct import (
        ReleaseMemoryOccupationReqInput,
        ResumeMemoryOccupationReqInput,
        UpdateWeightFromDiskReqInput,
    )
except ImportError:
    pass

from gpatch.core.utils import print_with_rank_and_datetime

class InferEngine:

    def __init__(self, infer_engine_impl, infer_engine, model_path, infer_engine_role):
        self.infer_engine_impl = infer_engine_impl
        self.infer_engine = infer_engine
        self.model_path = model_path
        self.infer_engine_role = infer_engine_role
        self.model_index = 0

                         
        self.wake_up_tag = {
            "weights": True,
            "kv_cache": True,
        }

        if self.infer_engine_impl == 'sglang':
            self.sglang_version = sgl.__version__
            assert self.sglang_version in ['0.4.6.post5', '0.4.10.post2']
            self.all_supported_tags = ["weights", "kv_cache"]

        print_with_rank_and_datetime(f"Init InferEngine {self.infer_engine_role=}"
                                     f"{self.infer_engine_impl=}")

    def async_generate(self, inp, sampling_params, req_id: str):
                         
                              
              
                                                                       
              
                                                                                        
                              
        assert self.infer_engine != None, "Not initialized InferEngine class"

        if self.infer_engine_impl == "vllm":
            async_gen = self.infer_engine.generate(inp, sampling_params, req_id)
            return async_gen
        else:
            sampling_params_l = sampling_params
            if not isinstance(sampling_params, list):
                sampling_params_l = [sampling_params]
            for i, x in enumerate(sampling_params_l):
                sampling_params_l[i] = x.__dict__
                if 'seed' in sampling_params_l[i]:
                    sampling_params_l[i].pop('seed')                                   
            if len(sampling_params_l) == 1:
                sampling_params_l = sampling_params_l[0]

            input_kwargs = dict(sampling_params=sampling_params_l,
                                input_ids=inp['prompt_token_ids'])
            if "multi_modal_data" in inp:
                input_kwargs["image_data"] = inp["multi_modal_data"]["image"]
            async_output = self.infer_engine.async_generate(**input_kwargs)
            return async_output

    def get_sampling_params(
        self,
        n=1,
        temperature=1.,
        top_k=-1,
        top_p=1.,
        max_tokens=128,
        stop_token_ids=None,
        seed=None,
        repetition_penalty=1.0,
    ):
                   
        if self.infer_engine_impl == 'vllm':
            assert n == 1                      
            return vllm.SamplingParams(
                n=n,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_tokens=max_tokens,
                stop_token_ids=stop_token_ids,
                seed=seed,
                repetition_penalty=repetition_penalty,
            )
        else:
            return types.SimpleNamespace(
                n=n,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_new_tokens=max_tokens,
                stop_token_ids=stop_token_ids,
                seed=seed,
                repetition_penalty=repetition_penalty,
            )

    async def wait_and_get_async_generate_output(self, async_generators):
        if self.infer_engine_impl == "vllm":
            fns = []
            for gi, gen in enumerate(async_generators):

                async def fn(_gen):
                    async for _output in _gen:
                        output = _output
                    return output

                fns.append(fn(gen))
            gen_outputs = await asyncio.gather(*fns)
            return gen_outputs
        else:
            assert self.infer_engine_impl == 'sglang'
            fns = []
            for gi, gen in enumerate(async_generators):
                fns.append(gen)
            async_outputs = await asyncio.gather(*fns)

                                 
            outputs = []
            for async_out in async_outputs:
                if isinstance(async_out, list):       
                    rep_outs = [
                        types.SimpleNamespace(
                            token_ids=rep_out['output_ids'],
                            prompt_len=rep_out['meta_info']['prompt_tokens'],
                        ) for rep_out in async_out
                    ]
                else:
                    rep_outs = [
                        types.SimpleNamespace(
                            token_ids=async_out['output_ids'],
                            prompt_len=async_out['meta_info']['prompt_tokens'],
                        )
                    ]
                output = types.SimpleNamespace(outputs=rep_outs)
                outputs.append(output)
            return outputs

    def get_engine(self):
        assert self.infer_engine != None, "Not initialized InferEngine class"
        return self.infer_engine

    @staticmethod
    def from_engine_args(
        infer_engine_impl,
        model=None,
        dtype='bfloat16',
        distributed_executor_backend='ray',
        tensor_parallel_size=1,
        pipeline_parallel_size=1,
        gpu_memory_utilization=0.67,
        enforce_eager=True,
        tp_rank=None,
        pp_rank=None,
        dp_rank=None,
        dist_init_addr=None,
        num_gpus_per_node=8,
        infer_engine_role=None,
        use_fast=False,
    ):
        assert infer_engine_role in [None, "sampler", "gen-rm"]
        infer_engine = None
        assert model is not None
        assert tp_rank is not None and pp_rank is not None and dp_rank is not None
        model_path = model if isinstance(model, str) else model[0]
        if infer_engine_impl == 'vllm':
            mm_processor_kwargs = {}
            if use_fast:
                mm_processor_kwargs = dict(use_fast=True)
            engine_args = AsyncEngineArgs(
                model=model_path,
                dtype=dtype,
                distributed_executor_backend=distributed_executor_backend,
                tensor_parallel_size=tensor_parallel_size,
                pipeline_parallel_size=pipeline_parallel_size,
                gpu_memory_utilization=gpu_memory_utilization,
                enforce_eager=enforce_eager,
                trust_remote_code=True,
                enable_sleep_mode=True,
                mm_processor_kwargs=mm_processor_kwargs,
            )
            infer_engine = AsyncLLM.from_engine_args(engine_args)
        else:
            assert infer_engine_impl == 'sglang'
            assert dist_init_addr is not None
            nnodes = int(math.ceil(tensor_parallel_size / num_gpus_per_node))
            node_rank = tp_rank // num_gpus_per_node
            base_gpu_id = (dp_rank * tensor_parallel_size) % num_gpus_per_node

            server_args = sgl.ServerArgs(
                model_path=model_path,
                tp_size=tensor_parallel_size,
                dist_init_addr=dist_init_addr,
                nnodes=nnodes,
                node_rank=node_rank,
                base_gpu_id=base_gpu_id,
                mem_fraction_static=gpu_memory_utilization,
                trust_remote_code=True,
                enable_memory_saver=True,
                skip_tokenizer_init=True,
            )
            infer_engine = sgl.Engine(server_args=server_args)

        assert infer_engine != None, f"infer engine not initialized {infer_engine_impl=}"
        return InferEngine(infer_engine_impl, infer_engine, model, infer_engine_role)

    def wake_up(self, *args, **kwargs):
        self.model_index = 0
        if self.infer_engine_impl == 'vllm':
            return self.infer_engine.wake_up(*args, **kwargs)
        else:
            assert self.infer_engine_impl == 'sglang'
            async def fn():
                                                                                                  
                want_to_wake_up_tags: List[str] = (
                    self.all_supported_tags
                    if self.sglang_version == "0.4.6.post5"
                    else kwargs.get("tags", self.all_supported_tags)
                )
                                                     
                should_wake_up_tags: List[str] = [
                    tag for tag in want_to_wake_up_tags if not self.wake_up_tag[tag]
                ]
                if len(should_wake_up_tags) == 0:
                    return
                else:
                    kwargs["tags"] = should_wake_up_tags
                    for tag in should_wake_up_tags:
                        self.wake_up_tag[tag] = True

                if self.sglang_version == '0.4.6.post5':
                    obj = ResumeMemoryOccupationReqInput()
                else:
                    obj = ResumeMemoryOccupationReqInput(*args, **kwargs)

                if self.infer_engine_role == "gen-rm":
                    if isinstance(self.model_path, list):
                        model_path = self.model_path[0]
                    else:
                        model_path = self.model_path
                    await self.infer_engine.tokenizer_manager.resume_memory_occupation(obj, None)
                    obj = UpdateWeightFromDiskReqInput(model_path=model_path)

                                                                                            
                                      
                    return await self.infer_engine.tokenizer_manager.update_weights_from_disk(obj, None)
                else:
                    return await self.infer_engine.tokenizer_manager.resume_memory_occupation(obj, None)

            return fn()

    def update_gen_rm_weight_by_model_idx(self, idx):
        assert self.infer_engine_impl == 'sglang'
        assert self.infer_engine_role == 'gen-rm'
        assert isinstance(self.model_path, list) and len(self.model_path) > idx
        self.model_index = idx
        async def fn():
            obj = UpdateWeightFromDiskReqInput(model_path=self.model_path[idx])
            return await self.infer_engine.tokenizer_manager.update_weights_from_disk(obj, None)

        return fn()

    def sleep(self, *args, **kwargs):
        if self.infer_engine_impl == 'vllm':

            async def fn1():
                                                                         
                await self.infer_engine.reset_prefix_cache()
                return await self.infer_engine.sleep(*args, **kwargs)

            return fn1()
        else:
            assert self.infer_engine_impl == 'sglang'

            async def fn2():
                                                                                                  
                want_to_sleep_tags: List[str] = (
                    self.all_supported_tags
                    if self.sglang_version == "0.4.6.post5"
                    else kwargs.get("tags", self.all_supported_tags)
                )
                                                      
                should_sleep_tags: List[str] = [
                    tag for tag in want_to_sleep_tags if self.wake_up_tag[tag]
                ]
                if len(should_sleep_tags) == 0:
                    return
                else:
                    kwargs["tags"] = should_sleep_tags
                    for tag in should_sleep_tags:
                        self.wake_up_tag[tag] = False

                if self.sglang_version == '0.4.6.post5':
                    obj = ReleaseMemoryOccupationReqInput()
                else:
                    obj = ReleaseMemoryOccupationReqInput(*args, **kwargs)
                return await self.infer_engine.tokenizer_manager.release_memory_occupation(
                    obj, None)

            return fn2()
