from contextlib import contextmanager
import random
from functools import partial
from typing import List, Iterator, Any, Tuple, Dict, Union
import itertools
import warnings

from einops import rearrange
from torch import Tensor
import torch
import torch.nn.functional as F

from megatron.core import mpu, parallel_state, tensor_parallel
from megatron.core.parallel_state import (
    get_model_parallel_group,
    get_model_parallel_src_rank,
    get_pipeline_model_parallel_world_size,
    get_pipeline_model_parallel_last_rank,
    get_pipeline_model_parallel_group,
)

from gpatch.core.device_type import is_wxacc1
from gpatch.core.tensor_parallel.mappings import (
    all_gather_to_context_parallel_region,
)
from gpatch.core.parallel_state import (
    get_model_and_context_parallel_group,
    get_model_and_context_parallel_src_rank,
    get_model_parallel_with_cp_src_rank,
)
from gpatch.core.utils import clear_memory                
from gpatch.core.swap import (
    copy_megatron_model_to_cpu,
    copy_megatron_model_to_gpu,
)
from gpatch.core.utils import print_memory_tracking


def apply_func_to_dict(func, dictionary):
    return {k: func(v) for k, v in dictionary.items()}


def move_to_device_if_tensor(device, item):
    if torch.is_tensor(item):
        item = item.to(device)
    return item


cuda_dict = partial(apply_func_to_dict, partial(move_to_device_if_tensor, "cuda"))
cpu_dict = partial(apply_func_to_dict, partial(move_to_device_if_tensor, "cpu"))


def average_losses_across_data_parallel_group(losses):
    """Reduce a tensor of losses across all GPUs."""
    averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
    torch.distributed.all_reduce(averaged_losses, group=mpu.get_data_parallel_group())
    averaged_losses = averaged_losses / torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())
    return averaged_losses


def broadcast_2d_tensor(tensor, src, group, dtype=torch.float32):
    """Broadcast any 2d tensor from the src rank to every other rank in the given group.
    All the ranks that send or receive data must call this function."""
    if torch.distributed.get_rank() == src:
        if tensor is None:
            input_info = [1, 0, 0]
        else:
            assert tensor.ndim == 2, f"tensor dims is not 2 but is {tensor.ndim} with shape {tensor.shape}"
            tensor = tensor.cuda().to(dtype)
            input_info = [0, tensor.size(0), tensor.size(1)]
        input_info_tensor = torch.tensor(input_info,
                                         dtype=torch.float32,
                                         device=torch.cuda.current_device())

        torch.distributed.broadcast(input_info_tensor, src, group)
        if tensor is not None:
            torch.distributed.broadcast(tensor, src, group)
    else:
        input_info_tensor = torch.empty(3, dtype=torch.float32, device=torch.cuda.current_device())
        torch.distributed.broadcast(input_info_tensor, src, group)

        is_none = bool(input_info_tensor[0].item())
        dim1 = int(input_info_tensor[1].item())
        dim2 = int(input_info_tensor[2].item())

        if not is_none:
            tensor = torch.empty(dim1, dim2, dtype=dtype, device=torch.cuda.current_device())
            torch.distributed.broadcast(tensor, src, group)
    return tensor


def broadcast_2d_tensor_within_mp(tensor, dtype=torch.float32):
    group = get_model_parallel_group()
    if torch.distributed.get_world_size(group) > 1:
        return broadcast_2d_tensor(tensor, get_model_parallel_src_rank(), group, dtype=dtype)
    else:
        return tensor.to(dtype) if tensor is not None else None


def broadcast_object_within_mp(obj: Any):
    group = get_model_parallel_group()
    if torch.distributed.get_world_size(group) > 1:
        obj_list = [obj]
        torch.distributed.broadcast_object_list(
            obj_list,
            src=get_model_parallel_src_rank(),
            group=group,
        )
        print_memory_tracking(f"Memory tracking: after bcast obj_list", rank=0)
        return obj_list[0]
    else:
        return obj


def broadcast_2d_tensor_within_pp(tensor, dtype=torch.float32):
    if get_pipeline_model_parallel_world_size() > 1:
        return broadcast_2d_tensor(
            tensor,
            get_pipeline_model_parallel_last_rank(),
            get_pipeline_model_parallel_group(),
            dtype=dtype,
        )
    else:
        return tensor.to(dtype) if tensor is not None else None


def broadcast_object_within_pp(obj: Any) -> Any:
    group = get_pipeline_model_parallel_group()

    if torch.distributed.get_world_size(group) > 1:
        obj_list = [obj]
        torch.distributed.broadcast_object_list(
            obj_list,
            src=get_pipeline_model_parallel_last_rank(),
            group=group,
        )
        return obj_list[0]
    else:
        return obj


def broadcast_2d_tensor_within_mp_and_cp(tensor, dtype=torch.float32):
    cp_size = mpu.get_context_parallel_world_size()
    if cp_size == 1:
        return broadcast_2d_tensor_within_mp(tensor, dtype=dtype)

    group = get_model_and_context_parallel_group()

                       
    assert get_model_and_context_parallel_src_rank() == get_model_parallel_with_cp_src_rank(), \
        f"{get_model_and_context_parallel_src_rank()} != {get_model_parallel_with_cp_src_rank()}"

    if torch.distributed.get_world_size(group) > 1:
        return broadcast_2d_tensor(tensor,
                                   get_model_parallel_with_cp_src_rank(),
                                   group,
                                   dtype=dtype)
    else:
        return tensor.to(dtype) if tensor is not None else None


def broadcast_object_within_mp_and_cp(obj: Any):
    cp_size = mpu.get_context_parallel_world_size()
    if cp_size == 1:
        return broadcast_object_within_mp(obj)

    group = get_model_and_context_parallel_group()
                       
    assert get_model_and_context_parallel_src_rank() == get_model_parallel_with_cp_src_rank(), \
        f"{get_model_and_context_parallel_src_rank()} != {get_model_parallel_with_cp_src_rank()}"

    if torch.distributed.get_world_size(group) > 1:
        obj_list = [obj]
        torch.distributed.broadcast_object_list(
            obj_list,
            src=get_model_and_context_parallel_src_rank(),
            group=group,
        )
        return obj_list[0]

    else:
        return obj


def gather_tensor(tensor, dst, group, dtype=torch.float32):
    if tensor is None:
        return None
    tensor = tensor.to(device=torch.cuda.current_device(), dtype=dtype)
    if torch.distributed.get_rank() == dst:
        gather_list = [
            torch.empty_like(tensor) for _ in range(torch.distributed.get_world_size(group))
        ]
    else:
        gather_list = None
    torch.distributed.gather(tensor, gather_list=gather_list, dst=dst, group=group)
    return gather_list


def from_parallel_logits_to_logprobs(vocab_parallel_logits,
                                     target,
                                     inference_only=False,
                                     higher_stability=False):
    """get log probs out of a [B, S//CP, V//TP] tensor
        NOTE: this function shifts the target, which means you must give it the unmodified targets

    Returns a [B, S] tensor
    """
    cp_rank = mpu.get_context_parallel_rank()
    cp_size = mpu.get_context_parallel_world_size()

    s = target.shape[1]
    assert s % cp_size == 0, f'{s=} {cp_size=}'
    local_s = s // cp_size

    target = target.roll(shifts=-1, dims=-1)
                                          
    target = reorder_target_for_cp(target)

    local_target = target[:, cp_rank * local_s:(cp_rank + 1) * local_s]
    local_target = local_target.to(vocab_parallel_logits.device)

    local_target = rearrange(local_target, 'b s -> s b').contiguous()
    vocab_parallel_logits = rearrange(vocab_parallel_logits, 'b s h -> s b h').contiguous()
    curr_log_probs = -1 * tensor_parallel.vocab_parallel_cross_entropy(
        vocab_parallel_logits, local_target)
    curr_log_probs = rearrange(curr_log_probs, 's b -> b s').contiguous()

    if cp_size > 1:
        curr_log_probs = all_gather_to_context_parallel_region(curr_log_probs)
    return curr_log_probs[:, :-1].contiguous()


def get_ltor_masks_and_position_ids(data,
                                    eod_token,
                                    reset_position_ids,
                                    reset_attention_mask,
                                    eod_mask_loss,
                                    compute_attention_mask=True):
    """Build masks and position id for left to right model."""

                                             
    micro_batch_size, seq_length = data.size()

                                        
    if reset_attention_mask:
        att_mask_batch = micro_batch_size
    else:
        att_mask_batch = 1

    attention_mask = None
    if compute_attention_mask:
                                                            
        attention_mask = torch.tril(
            torch.ones(
                (att_mask_batch, seq_length, seq_length), dtype=torch.bool, device="cpu"
            ),
        ).view(att_mask_batch, 1, seq_length, seq_length)

                
    loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
    if eod_mask_loss:
        loss_mask[data == eod_token] = 0.0

                   
    position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
    position_ids = position_ids.unsqueeze(0).repeat(micro_batch_size, 1)
                                                                       
    if reset_position_ids:
        position_ids = position_ids.clone()

    if reset_position_ids or reset_attention_mask:
                                   
        for b in range(micro_batch_size):

                                               
            eod_index = position_ids[b, data[b] == eod_token]
                                                                          
            if reset_position_ids:
                eod_index = eod_index.clone()

                                        
            prev_index = 0
            for j in range(eod_index.size()[0]):
                i = eod_index[j]
                                      
                if reset_attention_mask:
                    attention_mask[b, 0, (i + 1):, :(i + 1)] = False
                                  
                if reset_position_ids:
                    position_ids[b, (i + 1):] -= i + 1 - prev_index
                    prev_index = i + 1

    if compute_attention_mask:
                              
        non_blocking = False if is_wxacc1() else True
        attention_mask = attention_mask.to(data.device, non_blocking=non_blocking)

    return attention_mask, loss_mask, position_ids


def retrieve_model_state_dict_in_cpu(model, cpu_dict=None):
    if cpu_dict is None:
        cpu_dict = {}

    for name, item in model.state_dict().items():
        if isinstance(item, torch.Tensor):
            if name in cpu_dict:
                assert cpu_dict[name].shape == item.shape
                assert cpu_dict[name].dtype == item.dtype
                cpu_dict[name].copy_(item, non_blocking=True)
            else:
                cpu_dict[name] = item.detach().to(device="cpu", non_blocking=True, copy=True)
        else:
            cpu_dict[name] = item

    torch.cuda.synchronize()
    return cpu_dict


@torch.no_grad()
def swap_dict(
    resident_model,
    cpu_weights,
    offload_onto_cpu=True,
    offloaded_weights=None,
):
    if offload_onto_cpu:
        offloaded_weights = retrieve_model_state_dict_in_cpu(resident_model, offloaded_weights)

    resident_model.load_state_dict(cpu_weights)
    if offloaded_weights is None:
        offloaded_weights = {}
    return offloaded_weights


@contextmanager
def cpu_weight_swap(resident_model, cpu_weights):
    cpu_dict = swap_dict(resident_model, cpu_weights)
    try:
        yield
    finally:
        swap_dict(resident_model, cpu_dict, offload_onto_cpu=False)


@contextmanager
def cpu_weight_swap_v2(resident_model, unwrap_model_func, cpu_weights):
    copy_megatron_model_to_cpu(resident_model)
    unwrap_models = unwrap_model_func(resident_model)
    assert len(unwrap_models) == 1
    unwrap_models[0].load_state_dict(cpu_weights)
    try:
        yield
    finally:
        copy_megatron_model_to_gpu(resident_model)


def get_iterator_k_split(batch, num_microbatches: int) -> Iterator:
    if isinstance(batch, dict):
        items = list(batch.items())
        split_batch = []
        for item in items:
            if isinstance(item[1], torch.Tensor):
                assert item[1].shape[
                    0] % num_microbatches == 0, f"{item[0]=} {item[1].shape=} {num_microbatches=}"
                split_batch.append(torch.tensor_split(item[1], num_microbatches, dim=0))
            elif isinstance(item[1], list):
                assert len(
                    item[1]
                ) % num_microbatches == 0, f"{item[0]=} {len(item[1])=} {item[1][0].shape=} {num_microbatches=}"
                mbs = len(item[1]) // num_microbatches
                tmp_mbs = []
                for i in range(num_microbatches):
                                                                    
                    tmp_mbs.append(item[1][i * mbs:(i + 1) * mbs])
                split_batch.append(tuple(tmp_mbs))
            else:
                assert False, f"Not supported type {type(item[1])=}"

        microbatches = [[(items[i][0], split_batch[i][j]) for i in range(len(items))]
                        for j in range(num_microbatches)]
        microbatches = [dict(elem) for elem in microbatches]
    else:
        assert len(batch[0]) % num_microbatches == 0, "Issue with batch size configuration!"
        split_batch = []
        for item in batch:
            if isinstance(item, torch.Tensor):
                assert item.shape[0] % num_microbatches == 0, f"{item.shape=} {num_microbatches=}"
                split_batch.append(torch.tensor_split(item, num_microbatches, dim=0))
            elif isinstance(item, list):
                assert len(item) % num_microbatches == 0, f"{len(item)=} {num_microbatches=}"
                mbs = len(item) // num_microbatches
                tmp_mbs = []
                for i in range(num_microbatches):
                                                                    
                    tmp_mbs.append(item[i * mbs:(i + 1) * mbs])
                split_batch.append(tuple(tmp_mbs))
            else:
                assert item is None
                split_batch.append(item)

        microbatches = [[elem[i] if elem is not None else elem for elem in split_batch]
                        for i in range(num_microbatches)]

    return itertools.chain(microbatches)


def get_iterator_k_split_list(batches: List[Dict[str, Any]], num_microbatches: int) -> Iterator:
    assert len(batches) % num_microbatches == 0
    mbs = len(batches) // num_microbatches
    microbatches = []
    for i in range(num_microbatches):
        microbatches.append(batches[i * mbs:(i + 1) * mbs])
    return itertools.chain(microbatches)


def masked_mean(values: Tensor, mask: Tensor) -> Tensor:
    """
    Masks values with mask, and computes the mean of the values using the masked values.
    """
    return values[mask.bool()].mean()


def masked_mean_list(values: List[Tensor], mask: List[Tensor], dim=None) -> Tensor:
    """
    Masks values with mask, and computes the mean of the values using the masked values.
    """
    res = []
    for v, m in zip(values, mask):
        res.append(v[m.bool()].mean(dim=dim))
    return torch.stack(res)


def masked_std(values, mask, dim=None):
    """
    Masks values with mask, and computes the std of the values using the masked values.
    """
    return values[mask.bool()].std(dim=dim)


def pad_tensors_to_max_global_seq_len(list_of_tensors,
                                      pad_value,
                                      group,
                                      sequence_length_to_pad_to=None):
    """pad a list of tensors to the global sequence length across the specified group
    """
                               
    tensors_padded = torch.nn.utils.rnn.pad_sequence(list_of_tensors,
                                                     batch_first=True,
                                                     padding_value=pad_value)

                                
    max_seq_length = torch.tensor([tensors_padded.size(-1)],
                                  dtype=torch.float32,
                                  device=torch.cuda.current_device())
    torch.distributed.all_reduce(max_seq_length, op=torch.distributed.ReduceOp.MAX, group=group)
    max_seq_length = int(max_seq_length)

    if sequence_length_to_pad_to is not None:
        if max_seq_length > sequence_length_to_pad_to:
            warnings.warn(
                f"{max_seq_length=} is bigger than the provided {sequence_length_to_pad_to=}, overwriting the padding"
                f" to {max_seq_length}")
                                                                       
        max_seq_length = max(sequence_length_to_pad_to, max_seq_length)

    return torch.nn.functional.pad(tensors_padded, (0, max_seq_length - tensors_padded.size(-1)),
                                   value=pad_value)


def normalize_tensor(tensor, mask, group=None):
    """normalizes a tensor using global mean and std
    """
    tensor = tensor.to(device=torch.cuda.current_device())
    mask = mask.to(device=torch.cuda.current_device())

    tensor_global_mean, tensor_global_var = masked_global_mean_var(tensor, mask, group=group)
    tensor = (tensor - tensor_global_mean) * torch.rsqrt(tensor_global_var + 1e-8)
    return tensor


def masked_global_mean_var(values: Tensor, mask: List, group=None) -> Tuple[Tensor, Tensor]:
    """computes the global mean and var when there is a mask

    NOTE: the variance here is uncorrected

    mask and values must have same shape, with mask being {0,1} with 1 being the values we want to keep
    """
    assert values.shape == mask.shape, (values.shape, mask.shape)
    values = values.to(device=torch.cuda.current_device())
    mask = mask.to(device=torch.cuda.current_device())

    values = values * mask

                                                                         
    sum_and_count = torch.tensor([values.sum(), mask.sum()],
                                 dtype=torch.float32,
                                 device=torch.cuda.current_device())
    torch.distributed.all_reduce(sum_and_count, group=group)
    global_sum, global_count = sum_and_count
    global_mean = global_sum / global_count
    variance_summed = ((((values - global_mean)**2) * mask).sum().to(
        device=torch.cuda.current_device(), dtype=torch.float32))

    torch.distributed.all_reduce(variance_summed, group=group)

    return global_mean, variance_summed / global_count


def masked_global_statistics(values: Tensor,
                             mask: Tensor,
                             group=None) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """computes the global mean and var when there is a mask

    NOTE: the variance here is uncorrected

    mask and values must have same shape, with mask being {0,1} with 1 being the values we want to keep
    """
    assert values.shape == mask.shape, (values.shape, mask.shape)
    values = values.to(device=torch.cuda.current_device())
    mask = mask.to(device=torch.cuda.current_device())

    values = values * mask
    max_min = torch.tensor([values.max(), -values.min()],
                           dtype=torch.float32,
                           device=torch.cuda.current_device())
    torch.distributed.all_reduce(max_min, group=group, op=torch.distributed.ReduceOp.MAX)
    max_v, min_v = max_min
    min_v = -min_v

                                                                         
    sum_and_count = torch.tensor([values.sum(), mask.sum()],
                                 dtype=torch.float32,
                                 device=torch.cuda.current_device())
    torch.distributed.all_reduce(sum_and_count, group=group)
    global_sum, global_count = sum_and_count
    global_mean = global_sum / global_count
    variance_summed = ((((values - global_mean)**2) * mask).sum().to(
        device=torch.cuda.current_device(), dtype=torch.float32))

    torch.distributed.all_reduce(variance_summed, group=group)

    return global_mean, variance_summed / global_count, min_v, max_v


def masked_global_statistics_list(values: List[Tensor],
                                  mask: List[Tensor],
                                  group=None) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """computes the global mean and var when there is a mask

    NOTE: the variance here is uncorrected

    mask and values must have same shape, with mask being {0,1} with 1 being the values we want to keep
    """
    assert len(values) == len(mask)
    values = torch.cat([v.view(-1) for v in values])
    mask = torch.cat([m.view(-1) for m in mask])
    return masked_global_statistics(values, mask, group)


def get_last_rank():
    return torch.distributed.get_world_size() - 1


def dist_adam_load_state_bucket_into_device(state_bucket, device):
    """put the state bucket onto a device
    """
    attrs_to_offload = [
        "params_shard", "param_remainders_shard", "exp_avg_shard", "exp_avg_sq_shard"
    ]

    for attr in attrs_to_offload:
        tensor = getattr(state_bucket, attr)
        if tensor is not None:
            setattr(state_bucket, attr, tensor.to(device=device, non_blocking=True))


@contextmanager
def offload_distributed_adam(state_dict):
    """context manager to offload distributed adam states
    """
                       
    for state_bucket in state_dict["state"]["buckets"]:
        dist_adam_load_state_bucket_into_device(state_bucket, device="cpu")

                                                           
    torch.cuda.synchronize()

    try:
        yield

    finally:
                              
        for state_bucket in state_dict["state"]["buckets"]:
            dist_adam_load_state_bucket_into_device(state_bucket,
                                                    device=torch.cuda.current_device())

                                                              
        torch.cuda.synchronize()


def pad_batches_to_multiple_of_within_dp(rollout_batches, multi_of, key, pad_value):
    if isinstance(rollout_batches, dict):
        rollout_batches = [rollout_batches]
    assert isinstance(rollout_batches, list) and isinstance(rollout_batches[0], dict)

    if rollout_batches[0].get(key, None) is None:
        return

    rollout_pad_to = max([x[key].shape[1] for x in rollout_batches])
    rollout_pad_to = (rollout_pad_to + multi_of - 1) // multi_of * multi_of
    rollout_pad_to = torch.tensor([rollout_pad_to],
                                  dtype=torch.float32,
                                  device=torch.cuda.current_device())
    torch.distributed.all_reduce(rollout_pad_to,
                                 op=torch.distributed.ReduceOp.MAX,
                                 group=parallel_state.get_data_parallel_group())
    rollout_pad_to = int(rollout_pad_to.item())

    for rollout_batch in rollout_batches:
        x = rollout_batch[key]
        rollout_batch[key] = torch.nn.functional.pad(x, (0, rollout_pad_to - x.shape[1]),
                                                     value=pad_value)


def pad_batches_to_multiple_of(rollout_batches, multi_of, key, pad_value):
    if isinstance(rollout_batches, dict):
        rollout_batches = [rollout_batches]
    assert isinstance(rollout_batches, list) and isinstance(rollout_batches[0], dict)

    if rollout_batches[0].get(key, None) is None:
        return

    rollout_pad_to = max([x[key].shape[1] for x in rollout_batches])
    rollout_pad_to = (rollout_pad_to + multi_of - 1) // multi_of * multi_of

    for rollout_batch in rollout_batches:
        x = rollout_batch[key]
        rollout_batch[key] = torch.nn.functional.pad(x, (0, rollout_pad_to - x.shape[1]),
                                                     value=pad_value)


def get_tensor_on_this_cp_rank(val, seq_dim, key_name=None):
    if key_name is not None:
        if key_name == "attention_mask":
            assert seq_dim == 2
        else:
            assert seq_dim == 1

    cp_rank = mpu.get_context_parallel_rank()
    cp_size = mpu.get_context_parallel_world_size()
    assert cp_size >= 1
    if cp_size == 1 or val is None:
        return val

    val = val.view(
        *val.shape[0:seq_dim],
        2 * cp_size,
        val.shape[seq_dim] // (2 * cp_size),
        *val.shape[(seq_dim + 1):],
    )
    index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu",
                         pin_memory=True).cuda(non_blocking=True)
    val = val.index_select(seq_dim, index)
    val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2):])
    return val


def reorder_target_for_cp(target, seq_dim=1):
    cp_size = mpu.get_context_parallel_world_size()
    assert cp_size >= 1
    if cp_size == 1:
        return target

    target = target.view(
        *target.shape[0:seq_dim],
        2 * cp_size,
        target.shape[seq_dim] // (2 * cp_size),
        *target.shape[(seq_dim + 1):],
    )
    reordered_indices = []
    for rank in range(cp_size):
        reordered_indices.append(rank)
        reordered_indices.append(2 * cp_size - rank - 1)
    target = target[:, reordered_indices, :]
    target = target.view(*target.shape[0:seq_dim], -1, *target.shape[(seq_dim + 2):])
    return target


def pad_batches_to_multiple_of_within_ep(rollout_batches, multi_of, key, pad_value):
    if isinstance(rollout_batches, dict):
        rollout_batches = [rollout_batches]
    assert isinstance(rollout_batches, list) and isinstance(rollout_batches[0], dict)

    if rollout_batches[0].get(key, None) is None:
        return

    rollout_pad_to = max([x[key].shape[1] for x in rollout_batches])
    rollout_pad_to = (rollout_pad_to + multi_of - 1) // multi_of * multi_of
    rollout_pad_to = torch.tensor([rollout_pad_to],
                                  dtype=torch.int,
                                  device=torch.cuda.current_device())
    torch.distributed.all_reduce(rollout_pad_to,
                                 op=torch.distributed.ReduceOp.MAX,
                                 group=parallel_state.get_expert_model_parallel_group())
    rollout_pad_to = rollout_pad_to.item()

    for rollout_batch in rollout_batches:
        x = rollout_batch[key]
        rollout_batch[key] = torch.nn.functional.pad(x, (0, rollout_pad_to - x.shape[1]),
                                                     value=pad_value)


def get_max_seqlen_within_ep(seqlen: int):
    t_seqlen = torch.tensor([seqlen], dtype=torch.int, device=torch.cuda.current_device())
    torch.distributed.all_reduce(t_seqlen,
                                 op=torch.distributed.ReduceOp.MAX,
                                 group=parallel_state.get_expert_model_parallel_group())
    return t_seqlen.item()


def get_max_seqlen_within_dp(seqlen: int):
    t_seqlen = torch.tensor([seqlen], dtype=torch.int, device=torch.cuda.current_device())
    torch.distributed.all_reduce(t_seqlen,
                                 op=torch.distributed.ReduceOp.MAX,
                                 group=parallel_state.get_data_parallel_group())
    return t_seqlen.item()


def build_key_size_numel_dict(names, weights):
    max_dim = 6
    size_lst = [0 for _ in range(max_dim) for _ in names]

    offset = 0
    for name, weight in zip(names, weights):
        assert weight.ndim < max_dim, 'you should increase max_dim'
        size = weight.size()
        for i, s in enumerate(size):
            size_lst[i + offset] = s
        offset += max_dim

    key_size = {}
    key_numel = {}
    total_numel = 0
    offset = 0
    for name in names:
        i = 0
        size = []
        numel = 1
        while size_lst[offset + i] > 0:
            this_size = size_lst[offset + i]
            size.append(this_size)
            numel *= this_size
            i += 1
        key_size[name] = size
        key_numel[name] = numel
        total_numel += numel
        offset += max_dim
    return key_size, key_numel, total_numel


def flatten_weights(names, weights):
    key_size, key_numel, total_numel = build_key_size_numel_dict(names, weights)
    datatype = None
    device = weights[0].device
    flatten_weight = torch.cat([weight.contiguous().view(-1) for weight in weights],
                               dim=0).to(device=device)

    return flatten_weight, key_size, key_numel, total_numel


def random_pad_list(lst, pad_len):
    assert pad_len >= 0, f'maybe max_seq_len calc wrong {pad_len}'
    if pad_len == 0:
        return lst
    else:
        padding = random.choices(lst, k=pad_len)
        return lst + padding


def expand_rollout_batch(
    rollout_batch: Dict[str, Union[int, List[Any]]],
    allow_nolist=False,
) -> List[Dict[str, Any]]:
    batch_list = []
    for k, vs in rollout_batch.items():
        if not isinstance(vs, list) and allow_nolist:
            continue
        assert isinstance(vs, list)
        if len(batch_list) == 0:
            batch_list = [{} for _ in range(len(vs))]
        assert len(vs) == len(batch_list)
        for i in range(len(vs)):
            batch_list[i][k] = vs[i]
    if allow_nolist:
        for k, vs in rollout_batch.items():
            if not isinstance(vs, list):
                for i in range(len(batch_list)):
                    batch_list[i][k] = vs
    return batch_list


def expand_rollout_batches(
    rollout_batches: List[Dict[str, Union[int, List[Any]]]],
    allow_nolist=False,
) -> List[Dict[str, Any]]:
    ex_rollout_batches = []
    for rollout_batch in rollout_batches:
        ex_rollout_batches.extend(expand_rollout_batch(rollout_batch, allow_nolist))
    return ex_rollout_batches


def get_gbs_batches_seqlen(gbs_batches: List[Dict[str, Any]], pad_to_multi_of: int) -> int:
    max_token_len = max([e['tokens'].shape[-1] for e in gbs_batches])
    max_token_len = ((max_token_len + pad_to_multi_of - 1) // pad_to_multi_of) * pad_to_multi_of
    return max_token_len


def pad_or_truncate_last_dim(t: torch.Tensor, len: int, value):
    if t.shape[-1] < len:
        padded_len = len - t.shape[-1]
        t = torch.nn.functional.pad(t, (0, padded_len), value=value)
    if t.shape[-1] > len:
        t = t[..., :len]
    return t
