from dataclasses import asdict
import inspect
import copy

import torch

from megatron.core.jit import jit_fuser

from gpatch.core.transformer.transformer_config import GpatchTransformerConfig
from gpatch.core.models.vision.qwen2vl_vit_model import (
    Qwen2VLTransformerConfig,
    Qwen2P5VLTransformerConfig,
)


@jit_fuser
def quick_gelu(x: torch.Tensor) -> torch.Tensor:
    return x * torch.sigmoid(1.702 * x)


def get_qwen2vl_vision_model_config(config: GpatchTransformerConfig,
                                    seq_length) -> Qwen2VLTransformerConfig:
    vision_config = Qwen2VLTransformerConfig(**asdict(config))
                                                                
    vision_config.num_layers = 32         
    vision_config.num_attention_heads = 16             
    vision_config.add_bias_linear = True                                      
    vision_config.add_qkv_bias = True                             
    vision_config.hidden_size = 1280             
    vision_config.hidden_dropout = 0.0
    vision_config.attention_dropout = 0.0
    vision_config.ffn_hidden_size = 1280 * 4                         
    vision_config.gated_linear_unit = False            
    vision_config.activation_func = quick_gelu              
    vision_config.kv_channels = vision_config.hidden_size // vision_config.num_attention_heads
    vision_config.num_query_groups = vision_config.num_attention_heads          
    vision_config.layernorm_zero_centered_gamma = False         
    vision_config.apply_query_key_layer_scaling = False                              
    vision_config.bias_activation_fusion = False                        
    vision_config.bias_dropout_fusion = False                         
    vision_config.attention_softmax_in_fp32 = True            
    vision_config.normalization = 'LayerNorm'                 
    vision_config.pipeline_model_parallel_size = 1
    vision_config.num_layers_in_first_pipeline_stage = None
    vision_config.num_layers_in_last_pipeline_stage = None

    vision_config.packed_freqs = True

    vision_config.seq_length = seq_length
    vision_config.temporal_patch_size = 2
    vision_config.patch_size = 14
    vision_config.in_channels = 3
    vision_config.spatial_merge_size = 2
    return vision_config


def get_qwen2vl_vision_projection_config(config: GpatchTransformerConfig, embed_dim,
                                         spatial_merge_size) -> GpatchTransformerConfig:
    proj_config = GpatchTransformerConfig(**asdict(config))
    proj_config.num_layers_in_first_pipeline_stage = None
    proj_config.num_layers_in_last_pipeline_stage = None
             
                                             
                                               
          
                                                  
                                                                

    proj_config.gated_linear_unit = False
    proj_config.bias_activation_fusion = False
    proj_config.add_bias_linear = True
    proj_config.ffn_hidden_size = embed_dim * (spatial_merge_size**2)
    proj_config.activation_func = torch.nn.functional.gelu
    return proj_config


def get_qwen2p5vl_vision_model_config(config: GpatchTransformerConfig,
                                      seq_length) -> Qwen2P5VLTransformerConfig:
                                                                 
    init_params = inspect.signature(Qwen2P5VLTransformerConfig.__init__).parameters
    valid_keys = init_params.keys()
    config_dict = {
        k: v for k, v in asdict(config).items()
        if k in valid_keys
    }
                                                               
    vision_config = Qwen2P5VLTransformerConfig(**config_dict)
                                                                
    vision_config.num_layers = 32         
    vision_config.num_attention_heads = 16             
    vision_config.add_bias_linear = True                                      
    vision_config.add_qkv_bias = True                             
    vision_config.hidden_size = 1280
    vision_config.hidden_dropout = 0.0
    vision_config.attention_dropout = 0.0
    vision_config.gated_linear_unit = True            
    vision_config.activation_func = torch.nn.functional.silu              
    vision_config.kv_channels = vision_config.hidden_size // vision_config.num_attention_heads
    vision_config.num_query_groups = vision_config.num_attention_heads          
    vision_config.layernorm_zero_centered_gamma = False         
    vision_config.apply_query_key_layer_scaling = False                              
    vision_config.bias_activation_fusion = False                        
    vision_config.bias_dropout_fusion = False                         
    vision_config.attention_softmax_in_fp32 = True            
    vision_config.normalization = 'RMSNorm'                 
    vision_config.pipeline_model_parallel_size = 1
    vision_config.num_layers_in_first_pipeline_stage = None
    vision_config.num_layers_in_last_pipeline_stage = None

    vision_config.packed_freqs = True

    vision_config.seq_length = seq_length
    vision_config.temporal_patch_size = 2
    vision_config.patch_size = 14
    vision_config.in_channels = 3
    vision_config.spatial_merge_size = 2
    vision_config.qwen2_window_size = 112
    vision_config.fullatt_block_indexes = [7, 15, 23, 31]

    if config.hidden_size == 2048:      
        vision_config.ffn_hidden_size = 3420                     
    elif config.hidden_size == 3584:      
        vision_config.ffn_hidden_size = 3420                     
    elif config.hidden_size == 5120:       
        vision_config.ffn_hidden_size = 3456                     
    elif config.hidden_size == 8192:       
        vision_config.ffn_hidden_size = 3456                     
    else:
        raise NotImplementedError()
    return vision_config


def get_qwen2p5vl_vision_projection_config(config: GpatchTransformerConfig, embed_dim,
                                           spatial_merge_size) -> GpatchTransformerConfig:
    init_params = inspect.signature(Qwen2P5VLTransformerConfig.__init__).parameters 
    valid_keys = init_params.keys()
    config_dict = {
            k: v for k, v in asdict(config).items()
            if k in valid_keys
            }
    proj_config = Qwen2P5VLTransformerConfig(**config_dict)
    proj_config.num_layers_in_first_pipeline_stage = None
    proj_config.num_layers_in_last_pipeline_stage = None
             
                                             
                                                   
          
                                                  
                                                                

    proj_config.gated_linear_unit = False
    proj_config.bias_activation_fusion = False
    proj_config.add_bias_linear = True
    proj_config.ffn_hidden_size = embed_dim * (spatial_merge_size**2)
    proj_config.activation_func = torch.nn.functional.gelu

    if config.hidden_size == 2048:      
        proj_config.hidden_size = 2048                   
    elif config.hidden_size == 3584:      
        proj_config.hidden_size = 3584                   
    elif config.hidden_size == 5120:       
        proj_config.hidden_size = 5120                   
    elif config.hidden_size == 8192:       
        proj_config.hidden_size = 8192                   
    else:
        raise NotImplementedError()
    return proj_config
