

import torch
from megatron.core import parallel_state
from megatron.core.transformer import TransformerConfig

def get_vision_model_config(config: TransformerConfig) -> TransformerConfig:

    if config.num_layers in [28, 36]:
        config.ffn_hidden_size = 3420
    else:
        config.ffn_hidden_size = 3456

    if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
        config.num_layers = 32 * parallel_state.get_virtual_pipeline_model_parallel_world_size()
    else:
        config.num_layers = 32
    config.num_attention_heads = 16
    config.add_bias_linear = True
    config.add_qkv_bias = True
    config.hidden_size = 1280
    config.hidden_dropout = 0.0
    config.attention_dropout = 0.0

    config.kv_channels = config.hidden_size // config.num_attention_heads
    config.num_query_groups = config.num_attention_heads
    config.layernorm_zero_centered_gamma = False
    config.apply_query_key_layer_scaling = False
    config.bias_activation_fusion = False
    config.bias_dropout_fusion = False
    config.attention_softmax_in_fp32 = True

    config.seq_length = 1

    config.tp_comm_overlap = False
    config.sequence_parallel = False
    config.temporal_patch_size = 2
    config.patch_size = 14
    config.in_channels = 3
    config.spatial_merge_size = 2

    config.fullatt_block_indexes = [7, 15, 23, 31]
    config._qwen2_5_vl_window_size = 112
    return config

def get_vision_projection_config(
    config: TransformerConfig, embed_dim: int, spatial_merge_size: int
) -> TransformerConfig:

    config.gated_linear_unit = False
    config.bias_activation_fusion = False
    config.add_bias_linear = True
    config.ffn_hidden_size = embed_dim * (spatial_merge_size**2)
    config.activation_func = torch.nn.functional.gelu
    config.tp_comm_overlap = False
    config.sequence_parallel = False
    return config
