from megatron.core import parallel_state
from megatron.core.enums import ModelType


def get_tensor_shapes(
    *,
    rank: int,
    model_type: ModelType,
    seq_length: int,
    micro_batch_size: int,
    decoder_seq_length: int,
    config,
    encoder_decoder_xattn: bool,
):
    """
    Determine right tensor sizes (based on position of rank with respect to split rank) and
    model size.
    Send two tensors if model decoder requires the encoder's output (via cross-attention) and
    rank is in decoder stage.
    First tensor is decoder. Second tensor is encoder.
    If model has an encoder & decoder and rank is at the boundary, send one tensor.
    Otherwise, send one tensor.
    """
    tensor_shapes = []
    seq_length = seq_length // parallel_state.get_context_parallel_world_size()
    if model_type == ModelType.encoder_and_decoder:
        decoder_seq_length = decoder_seq_length // parallel_state.get_context_parallel_world_size()
    if config.sequence_parallel:
        seq_length = seq_length // parallel_state.get_tensor_model_parallel_world_size()
        if model_type == ModelType.encoder_and_decoder:
            decoder_seq_length = (
                decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size()
            )
    if model_type == ModelType.encoder_and_decoder:
        if parallel_state.is_inside_encoder(rank) and not parallel_state.is_inside_decoder(rank):
            tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
        elif encoder_decoder_xattn:
            tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
            tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
        else:
            tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
    else:                                              
        tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))

    if hasattr(config, 'dpo') and config.dpo:
        dpo_dim = config.dpo_policy_ref_model_cnt + config.dpo_reward_models_cnt
        assert dpo_dim > 0
        for i in range(len(tensor_shapes)):
            tensor_shapes[i] = (dpo_dim, ) + tensor_shapes[i]
    return tensor_shapes
