
from collections import OrderedDict
from typing import Literal, Optional

import torch
from torch import Tensor

from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.legacy.model.gpt_model import GPTModel


def gpt_model_init(
    self,
    config: TransformerConfig,
    transformer_layer_spec: ModuleSpec,
    vocab_size: int,
    max_sequence_length: int,
    pre_process: bool = True,
    post_process: bool = True,
    fp16_lm_cross_entropy: bool = False,
    parallel_output: bool = True,
    share_embeddings_and_output_weights: bool = False,
    position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
    rotary_percent: float = 1.0,
    rotary_base: int = 10000,
    rope_scaling: bool = False,
    rope_scaling_factor: float = 8.0,
    scatter_embedding_sequence_parallel: bool = True,
    seq_len_interpolation_factor: Optional[float] = None,
) -> None:
    super(GPTModel, self).__init__(config=config)
    if has_config_logger_enabled(config):
        log_config_to_disk(config, locals(), prefix=type(self).__name__)
    self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
    self.vocab_size = vocab_size
    self.max_sequence_length = max_sequence_length
    self.pre_process = pre_process
    self.post_process = post_process
    self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
    self.parallel_output = parallel_output
    self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
    self.position_embedding_type = position_embedding_type
                                                              
                                    
    self.model_type = ModelType.encoder_or_decoder
                                                            
    self.max_position_embeddings = max_sequence_length
    self.rotary_percent = rotary_percent
    self.rotary_base = rotary_base
    self.rotary_scaling = rope_scaling
    if self.pre_process:
        self.embedding = LanguageModelEmbedding(
            config=self.config,
            vocab_size=self.vocab_size,
            max_sequence_length=self.max_sequence_length,
            position_embedding_type=position_embedding_type,
            scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
        )
    if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
        self.rotary_pos_emb = RotaryEmbedding(
            kv_channels=self.config.kv_channels,
            rotary_percent=rotary_percent,
            rotary_interleaved=self.config.rotary_interleaved,
            seq_len_interpolation_factor=seq_len_interpolation_factor,
            rotary_base=rotary_base,
            rope_scaling=rope_scaling,
            rope_scaling_factor=rope_scaling_factor,
            use_cpu_initialization=self.config.use_cpu_initialization,
        )
                                             
        if hasattr(self.config, "rope_local_base_freq"):
            self.rotary_pos_emb.inv_freq /= rope_scaling_factor
            self.rotary_pos_emb_local = RotaryEmbedding(
                kv_channels=self.config.kv_channels,
                rotary_percent=rotary_percent,
                rotary_interleaved=self.config.rotary_interleaved,
                seq_len_interpolation_factor=seq_len_interpolation_factor,
                rotary_base=self.config.rope_local_base_freq,
                rope_scaling=rope_scaling,
                rope_scaling_factor=rope_scaling_factor,
                use_cpu_initialization=self.config.use_cpu_initialization,
            )

                                                                    
    self.rotary_pos_emb_cache = {}

                  
    self.decoder = TransformerBlock(
        config=self.config,
        spec=transformer_layer_spec,
        pre_process=self.pre_process,
        post_process=self.post_process,
    )
            
    if post_process:
        if self.config.defer_embedding_wgrad_compute:
                                                                                            
                                                                                            
                                                                                              
                                                                                                
                                                                                           
                                                                                          
                                                                                           
                                 
            self.embedding_activation_buffer = []
            self.grad_output_buffer = []
        else:
            self.embedding_activation_buffer = None
            self.grad_output_buffer = None
        self.output_layer = tensor_parallel.ColumnParallelLinear(
            config.hidden_size,
            self.vocab_size,
            config=config,
            init_method=config.init_method,
            bias=False,
            skip_bias_add=False,
            gather_output=not self.parallel_output,
            skip_weight_param_allocation=self.pre_process
            and self.share_embeddings_and_output_weights,
            embedding_activation_buffer=self.embedding_activation_buffer,
            grad_output_buffer=self.grad_output_buffer,
        )
    if self.pre_process or self.post_process:
        self.setup_embeddings_and_output_layer()
    if has_config_logger_enabled(self.config):
        log_config_to_disk(
            self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
        )

def gpt_model_forward(
    self,
    input_ids: Tensor,
    position_ids: Tensor,
    attention_mask: Tensor,
    decoder_input: Tensor = None,
    labels: Tensor = None,
    inference_params: InferenceParams = None,
    packed_seq_params: PackedSeqParams = None,
    extra_block_kwargs: dict = None,
    runtime_gather_output: Optional[bool] = None,
) -> Tensor:
    """Forward function of the GPT Model This function passes the input tensors
    through the embedding layer, and then the decoeder and finally into the post
    processing layer (optional).
    It either returns the Loss values if labels are given  or the final hidden units
    Args:
        runtime_gather_output (bool): Gather output at runtime. Default None means
            `parallel_output` arg in the constructor will be used.
    """
                                                                                           
                                                                                          
                        
    if decoder_input is not None:
        pass
    elif self.pre_process:
        decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
    else:
                                        
                                                                  
        decoder_input = None
                                                                                  
    rotary_pos_emb = None
    rotary_pos_cos = None
    rotary_pos_sin = None
    if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
        if not self.training and self.config.flash_decode and inference_params:
                                                                  
            rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
                inference_params.max_sequence_length,
                self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length),
            )
        else:
            rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
                inference_params, self.decoder, decoder_input, self.config, packed_seq_params
            )
            rotary_pos_emb = self.rotary_pos_emb(
                rotary_seq_len,
                packed_seq=packed_seq_params is not None
                and packed_seq_params.qkv_format == 'thd',
            )
            if hasattr(self.config, "rope_local_base_freq"):
                rotary_pos_emb_local = self.rotary_pos_emb_local(
                    rotary_seq_len,
                    packed_seq=packed_seq_params is not None
                    and packed_seq_params.qkv_format == 'thd',
                )
                rotary_pos_emb = (rotary_pos_emb, rotary_pos_emb_local)
    if (
        (self.config.enable_cuda_graph or self.config.flash_decode)
        and rotary_pos_cos is not None
        and inference_params
    ):
        sequence_len_offset = torch.tensor(
            [inference_params.sequence_len_offset] * inference_params.current_batch_size,
            dtype=torch.int32,
            device=rotary_pos_cos.device,                                          
        )
    else:
        sequence_len_offset = None
                  
    hidden_states = self.decoder(
        hidden_states=decoder_input,
        attention_mask=attention_mask,
        inference_params=inference_params,
        rotary_pos_emb=rotary_pos_emb,
        rotary_pos_cos=rotary_pos_cos,
        rotary_pos_sin=rotary_pos_sin,
        packed_seq_params=packed_seq_params,
        sequence_len_offset=sequence_len_offset,
        **(extra_block_kwargs or {}),
    )
    if not self.post_process:
        return hidden_states
                     
    output_weight = None
    if self.share_embeddings_and_output_weights:
        output_weight = self.shared_embedding_or_output_weight()
    logits, _ = self.output_layer(
        hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
    )
    if has_config_logger_enabled(self.config):
        payload = OrderedDict(
            {
                'input_ids': input_ids,
                'position_ids': position_ids,
                'attention_mask': attention_mask,
                'decoder_input': decoder_input,
                'logits': logits,
            }
        )
        log_config_to_disk(self.config, payload, prefix='input_and_logits')
    if labels is None:
                            
        return logits.transpose(0, 1).contiguous()
    loss = self.compute_language_model_loss(labels, logits)
    return loss