import sys
import os
import pynvml
from typing import List
from dataclasses import dataclass

import torch

from sglang.srt.managers.io_struct import (
    UpdateWeightsFromTensorReqInput,
    InitWeightsUpdateGroupReqInput,
)


def sglang_hack():
    for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
        if k in os.environ:
            del os.environ[k]
    os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"

    import nest_asyncio
    nest_asyncio.apply()

    from sglang.srt.patch_torch import monkey_patch_torch_reductions
    monkey_patch_torch_reductions()


@dataclass
class GcoreUpdateWeightsFromTensorReqInput(UpdateWeightsFromTensorReqInput):
    head_device_uuid: str = None
    flatten_weight_shape: torch.Size = None
    flatten_weight_dtype: type = None
    key_size: List = None
    key_numel: List = None
    name_list: List = None


def get_device_uuid(device_idx):
    uuid = None
    if torch.cuda.is_available():
        pynvml.nvmlInit()
        device_count = torch.cuda.device_count()
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
        uuid = pynvml.nvmlDeviceGetUUID(handle)
        pynvml.nvmlShutdown()
    else:
        assert False, f"currently not reach"
    return uuid
