              
                                                      
                                                                 

import os
import threading
import gc
import importlib
from datetime import datetime
from typing import List, Dict, Any

import torch
import pynvml

from megatron.core import parallel_state


def clear_memory():
    torch.cuda.synchronize()
    gc.collect()
    torch.cuda.empty_cache()


def to_cuda_if_not_none(x, non_blocking=False):
    if x is None:
        return x
    else:
        return x.cuda(non_blocking=non_blocking)


def to_cpu_if_not_none(x):
    if x is None:
        return x
    else:
        return x.cpu()


def print_rank_0(message):
    """If distributed is initialized, print only on rank 0."""
    if torch.distributed.is_initialized():
        if torch.distributed.get_rank() == 0:
            print(message, flush=True)
    else:
        print(message, flush=True)


def print_with_rank_and_datetime(message, rank=None):
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    my_rank = torch.distributed.get_rank()
    message = f'[RANK {my_rank:<4}][{time_str}] {message}'
    if rank == None or my_rank == rank:
        print(message, flush=True)


def get_any_tensor_element_size(data: Any) -> int:
    if isinstance(data, list):
        res = 0
        for ele in data:
            res += get_any_tensor_element_size(ele)
        return res
    elif isinstance(data, dict):
        res = 0
        for _, v in data.items():
            res += get_any_tensor_element_size(v)
        return res
    elif isinstance(data, torch.Tensor):
        return data.element_size() * data.numel()
    return 0


                                                                 
              
def split_data_cp_rank(val: torch.Tensor, cp_size: int, seq_dim: int, cp_rank: int = None):
    assert cp_size > 1
    assert 0 == val.shape[seq_dim] % (2 * cp_size), f'{val.shape=} {cp_size=}'
    if cp_rank is None:
        cp_rank = parallel_state.get_context_parallel_rank()
    if val is None:
        return val

    val = val.view(
        *val.shape[0:seq_dim],
        2 * cp_size,
        val.shape[seq_dim] // (2 * cp_size),
        *val.shape[(seq_dim + 1):],
    )

    index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device)
    val = val.index_select(seq_dim, index)
    val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2):])

    return val


                 
def split_data_ulysses_cp_rank(val: torch.Tensor, cp_size: int, seq_dim: int, cp_rank: int = None):
    assert cp_size > 1
    assert 0 == val.shape[seq_dim] % cp_size, f'{val.shape=} {cp_size=}'
    if cp_rank is None:
        cp_rank = parallel_state.get_context_parallel_rank()
    val = val.chunk(cp_size, dim=seq_dim)[cp_rank]

    return val


unique_id_lock = threading.Lock()
unique_id = 0


def gen_unique_id() -> str:
    global unique_id, unique_id_lock
    with unique_id_lock:
        id = unique_id
        unique_id += 1
    return f"rank_{torch.distributed.get_rank()}_unique_id_{id}"


def get_nvml_memory_info(gpu_id=0) -> dict:
    res = None
    if torch.cuda.is_available():
        pynvml.nvmlInit()
        handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
        meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
        res = f'total={meminfo.total/1024**3} GB, used={meminfo.used/1024**3} GB, free={meminfo.free/1024**3} GB'
        pynvml.nvmlShutdown()

    return res


def check_rollout_batch(rb: Dict[str, List[Any]]) -> bool:
    if not isinstance(rb, dict):
        return False
    batch_size = None
    for k, v in rb.items():
        if not isinstance(k, str):
            return False
        if not isinstance(v, list):
            return False
        if batch_size is None:
            batch_size = len(v)
        if batch_size != len(v):
            return False
        for e in v[1:]:
            if type(e) != type(v[0]):
                return False
        if torch.is_tensor(v[0]):
            for e in v:
                if not e.is_cpu:
                    return False
    return True


def check_rollout_batches(rbs: List[Dict[str, List[Any]]]) -> bool:
    if not isinstance(rbs, list):
        print_with_rank_and_datetime(f"error1 {type(rbs)}")
        return False
    batch_size = None
    for rb in rbs:
        if not isinstance(rb, dict):
            print_with_rank_and_datetime(f"error2 {type(rb)}")
            return False
        for k, v in rb.items():
            if not isinstance(k, str):
                print_with_rank_and_datetime(f"error3 {k} {type(k)}")
                return False
            if not isinstance(v, list):
                print_with_rank_and_datetime(f"error4 {k} {type(v)}")
                return False
            if batch_size is None:
                batch_size = len(v)
            if batch_size != len(v):
                print_with_rank_and_datetime(f"error5 {batch_size} {len(v)}")
                return False
            for e in v[1:]:
                if type(e) != type(v[0]):
                    print_with_rank_and_datetime(f"error6 {type(e)} {type(v[0])}")
                    return False
            if torch.is_tensor(v[0]):
                for e in v:
                    if not e.is_cpu:
                        print_with_rank_and_datetime(f"error7 {e.is_cpu=}")
                        return False
    return True


def list_for_tensor_tolist(
    data: List[torch.Tensor],
    flatten: bool,
    to_float32: bool = False,
) -> List[Any]:
    res = []
    for datum in data:
        if to_float32:
            datum = datum.float()
        if datum.ndim == 0:
            res.append(datum.item())
        else:
            if flatten:
                res.extend(datum.flatten().tolist())
            else:
                res.append(datum.tolist())

    return res


def print_memory_tracking(log_info: str, verbose: bool = False, rank: int = None):
                       
    rank = torch.distributed.get_rank()
    pid = os.getpid()
    debug_dir = "./debug-tmp"
    outfname = f'{debug_dir}/mem-track-{rank}-{pid}.txt'

    torch.cuda.synchronize()

    message = (
        f"{log_info} allocated {torch.cuda.memory_allocated() / (1024**3)} GB"
        f" reserved {torch.cuda.memory_reserved() / (1024**3)} GB"
        f" in process and {get_nvml_memory_info(torch.cuda.current_device())}"
    )

    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    my_rank = torch.distributed.get_rank()
    message = f'[RANK {my_rank:<4}][{time_str}] {message}\n'

    os.makedirs(debug_dir, exist_ok=True)
    with open(outfname, 'a') as outf:
        outf.write(message)
        outf.flush()


def load_and_call(module_path: str, function_name: str, *args, **kwargs):
    module_name = module_path.split('/')[-1].split('.')[0]

    spec = importlib.util.spec_from_file_location(module_name, module_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)

    assert hasattr(module, function_name), f"Path: {module_path} should have func: {function_name}"
    func = getattr(module, function_name)
    return func(*args, **kwargs)
