              
                                                      
                                          

from typing import List, Optional, Tuple, Union
from contextlib import nullcontext
import inspect

import torch

import megatron.legacy.model
from megatron.core.models.gpt.gpt_layer_specs import (
    get_gpt_decoder_block_spec,
    get_gpt_layer_local_spec,
    get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.transformer.spec_utils import import_module
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from megatron.core.models.gpt import GPTModel
from megatron.core import mpu
from megatron.training.utils import unwrap_model
from megatron.training import get_tokenizer

from gpatch.core.models.gpt.gpt_reward_model import GptRewardModel
from gpatch.core.models.gpt.gpt_dpo_model import GptDpoModel
from gpatch.core.transformer.transformer_config import GpatchTransformerConfig
from gpatch.core.sampler_v3.infer_engine import InferEngine
from gpatch.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec_lora
from gpatch.core.models.gpt import (
    GptPpoActorModel,
    GptPpoRmCriticClientV3,
    GptPpoSamplerClientV3,
    GptPpoGenRmClientV3,
)


def default_sft_model_provider(pre_process=True,
                               post_process=True
                               ) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
    """Builds the model.

    If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.

    Args:
        pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
        post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.


    Returns:
        Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
    """
    args = get_args()
    use_te = args.transformer_impl == "transformer_engine"

    if args.record_memory_history:
        torch.cuda.memory._record_memory_history(
            True,
                                                                     
            trace_alloc_max_entries=100000,

                                                           
            trace_alloc_record_context=True)

    print_rank_0('building GPT model ...')
                                              
    if args.yaml_cfg is not None:
        config = core_transformer_config_from_yaml(args, "language_model")
    else:
        config = core_transformer_config_from_args(args, GpatchTransformerConfig)

    if args.use_legacy_models:
        assert not args.enable_lora
        model = megatron.legacy.model.GPTModel(
            config,
            num_tokentypes=0,
            parallel_output=True,
            pre_process=pre_process,
            post_process=post_process,
        )
    else:                     
        if args.spec is not None:
            transformer_layer_spec = import_module(args.spec)
        else:
            if args.num_experts:
                assert not args.enable_lora
                                               
                transformer_layer_spec = get_gpt_decoder_block_spec(config,
                                                                    use_transformer_engine=use_te)
            else:
                                               
                if use_te:
                    if args.enable_lora:
                        transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec_lora(
                            args.num_experts,
                            args.moe_grouped_gemm,
                            args.qk_layernorm,
                            args.multi_latent_attention,
                            None,
                            args.moe_use_legacy_grouped_gemm,
                            config.gated_linear_unit,
                        )
                    else:
                        transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
                            args.num_experts,
                            args.moe_grouped_gemm,
                            args.qk_layernorm,
                            args.multi_latent_attention,
                            None,
                            args.moe_use_legacy_grouped_gemm,
                        )
                else:
                    assert not args.enable_lora
                    transformer_layer_spec = get_gpt_layer_local_spec(
                        args.num_experts,
                        args.moe_grouped_gemm,
                        args.qk_layernorm,
                        args.multi_latent_attention,
                        None,
                        args.moe_use_legacy_grouped_gemm,
                    )

        build_model_context = nullcontext
        build_model_context_args = {}
        if args.fp8_param_gather:
            try:
                from transformer_engine.pytorch import fp8_model_init

                build_model_context = fp8_model_init
                build_model_context_args["enabled"] = True

                                                                                   
                if "preserve_high_precision_init_val" in inspect.signature(
                        fp8_model_init).parameters:
                    build_model_context_args["preserve_high_precision_init_val"] = True
            except:
                raise RuntimeError(
                    "--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found."
                )

        with build_model_context(**build_model_context_args):
            model = GPTModel(
                config=config,
                transformer_layer_spec=transformer_layer_spec,
                vocab_size=args.padded_vocab_size,
                max_sequence_length=args.max_position_embeddings,
                pre_process=pre_process,
                post_process=post_process,
                fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
                parallel_output=True,
                share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
                position_embedding_type=args.position_embedding_type,
                rotary_percent=args.rotary_percent,
                rotary_base=args.rotary_base,
                rope_scaling=args.use_rope_scaling,
                rope_scaling_factor=args.rope_scaling_factor,)
    print_rank_0(f"model arch {model}")
                              
    if args.enable_lora:
        assert (not args.freeze_lora
                ) and args.mm_freeze_llm, f"Use LoRA should freeze llm and not freeze LoRA"
        for pname, params in model.named_parameters():
            if not ('.lora_a.' in pname or '.lora_b.' in pname):
                params.requires_grad_(False)

    return model


def default_actor_model_provider(pre_process=True, post_process=True):
    args = get_args()
    use_te = args.transformer_impl == "transformer_engine"
    config = core_transformer_config_from_args(get_args(), GpatchTransformerConfig)
    if hasattr(args, "use_legacy_models"):
        assert not args.use_legacy_models
    else:
        assert args.use_mcore_models
    if args.spec is not None:
        transformer_layer_spec = import_module(args.spec)
    else:
        if args.num_experts:
                                           
            transformer_layer_spec = get_gpt_decoder_block_spec(config,
                                                                use_transformer_engine=use_te)
        else:
                                           
            if use_te:
                transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
                    args.num_experts,
                    args.moe_grouped_gemm,
                    args.qk_layernorm,
                    args.multi_latent_attention,
                    None,
                    args.moe_use_legacy_grouped_gemm,
                )
            else:
                transformer_layer_spec = get_gpt_layer_local_spec(
                    args.num_experts,
                    args.moe_grouped_gemm,
                    args.qk_layernorm,
                    args.multi_latent_attention,
                    None,
                    args.moe_use_legacy_grouped_gemm,
                )

    model = GPTModel(
        config=config,
        transformer_layer_spec=transformer_layer_spec,
        vocab_size=args.padded_vocab_size,
        max_sequence_length=args.max_position_embeddings,
        pre_process=pre_process,
        post_process=post_process,
        fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
        parallel_output=True,
        share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
        position_embedding_type=args.position_embedding_type,
        rotary_percent=args.rotary_percent,
        rotary_base=args.rotary_base,
        rope_scaling=args.use_rope_scaling,
    )
    print_rank_0(f"model arch {model}")
    return model


def default_reward_model_provider(pre_process=True, post_process=True):
    args = get_args()
    use_te = args.transformer_impl == "transformer_engine"
    config = core_transformer_config_from_args(get_args(), GpatchTransformerConfig)

    if hasattr(args, "use_legacy_models"):
        assert not args.use_legacy_models
    else:
        assert args.use_mcore_models
    if args.spec is not None:
        transformer_layer_spec = import_module(args.spec)
    else:
        if args.num_experts:
                                           
            transformer_layer_spec = get_gpt_decoder_block_spec(config,
                                                                use_transformer_engine=use_te)
        else:
                                           
            if use_te:
                transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
                    args.num_experts,
                    args.moe_grouped_gemm,
                    args.qk_layernorm,
                    args.multi_latent_attention,
                    None,
                    args.moe_use_legacy_grouped_gemm,
                )
            else:
                transformer_layer_spec = get_gpt_layer_local_spec(
                    args.num_experts,
                    args.moe_grouped_gemm,
                    args.qk_layernorm,
                    args.multi_latent_attention,
                    None,
                    args.moe_use_legacy_grouped_gemm,
                )

    model = GptRewardModel(
        config=config,
        transformer_layer_spec=transformer_layer_spec,
        vocab_size=args.padded_vocab_size,
        max_sequence_length=args.max_position_embeddings,
        pre_process=pre_process,
        post_process=post_process,
        fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
        parallel_output=True,
        share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
        position_embedding_type=args.position_embedding_type,
        rotary_percent=args.rotary_percent,
        rotary_base=args.rotary_base,
        rope_scaling=args.use_rope_scaling,
                 
        output_sequence=True,
        output_scalar=False,
        use_avg_pool=args.rm_use_avg_pool,
        num_attributes=args.rm_num_attributes,
        merge_attributes=False,                                            
        mask_prompt=args.ppo_rm_mask_prompt,
    )
    print_rank_0(f"model arch {model}")
    return model


def default_sampler_model_provider():
    args = get_args()
    dtype = 'bfloat16'
    tp_size = args.tensor_model_parallel_size
    pp_size = args.pipeline_model_parallel_size
    use_fast = args.px_use_fast_tokenizer
                                                       
    assert tp_size == args.ppo_sampler_tensor_model_parallel_size
    assert pp_size == args.ppo_sampler_pipeline_model_parallel_size

    return InferEngine.from_engine_args(
        args.infer_engine_impl,
        model=args.load,
        dtype=dtype,
        distributed_executor_backend='ray',
        tensor_parallel_size=tp_size,
        pipeline_parallel_size=pp_size,
        gpu_memory_utilization=args.sampler_gpu_memory_utilization,
        enforce_eager=True,
        tp_rank=mpu.get_tensor_model_parallel_rank(),
        pp_rank=mpu.get_pipeline_model_parallel_rank(),
        dp_rank=mpu.get_data_parallel_rank(),
        dist_init_addr=args.sampler_dist_init_addrs[mpu.get_data_parallel_rank()],
        infer_engine_role="sampler",
        use_fast=use_fast,
    )


def default_gen_rm_model_provider():
    args = get_args()
    dtype = 'bfloat16'
    tp_size = args.tensor_model_parallel_size
    pp_size = args.pipeline_model_parallel_size
                                                       
    assert tp_size == args.ppo_gen_rm_tensor_model_parallel_size
    assert pp_size == args.ppo_gen_rm_pipeline_model_parallel_size

    return InferEngine.from_engine_args(
        args.infer_engine_impl,
        model=args.load_ref,
        dtype=dtype,
        distributed_executor_backend='ray',
        tensor_parallel_size=tp_size,
        pipeline_parallel_size=pp_size,
        gpu_memory_utilization=args.gen_rm_gpu_memory_utilization,
        enforce_eager=True,
        tp_rank=mpu.get_tensor_model_parallel_rank(),
        pp_rank=mpu.get_pipeline_model_parallel_rank(),
        dp_rank=mpu.get_data_parallel_rank(),
        dist_init_addr=args.gen_rm_dist_init_addrs[mpu.get_data_parallel_rank()],
        infer_engine_role="gen-rm",
    )


def default_dpo_model_provider(pre_process=True, post_process=True):
    args = get_args()
    assert args.dpo_model_using == 'both' or args.dpo_model_using == 'policy'
    use_te = args.transformer_impl == "transformer_engine"

    if args.record_memory_history:
        torch.cuda.memory._record_memory_history(
            True,
                                                                     
            trace_alloc_max_entries=100000,

                                                           
            trace_alloc_record_context=True)

    print_rank_0('building GPT model ...')
                                              
    if args.yaml_cfg is not None:
        config = core_transformer_config_from_yaml(args, "language_model")
    else:
        config = core_transformer_config_from_args(args, GpatchTransformerConfig)

    model_cnt = 2 if args.dpo_model_using == "both" else 1
    assert config.dpo_policy_ref_model_cnt == model_cnt, \
        f"dpo_policy_ref_model_cnt {config.dpo_policy_ref_model_cnt} mismatch {model_cnt}"

    if args.use_legacy_models:
        model = megatron.legacy.model.GPTModel(
            config,
            num_tokentypes=0,
            parallel_output=True,
            pre_process=pre_process,
            post_process=post_process,
        )
    else:                     
        if args.spec is not None:
            transformer_layer_spec = import_module(args.spec)
        else:
            if args.num_experts:
                                               
                transformer_layer_spec = get_gpt_decoder_block_spec(config,
                                                                    use_transformer_engine=use_te)
            else:
                                               
                if use_te:
                    transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
                        args.num_experts,
                        args.moe_grouped_gemm,
                        args.qk_layernorm,
                        args.multi_latent_attention,
                        None,
                        args.moe_use_legacy_grouped_gemm,
                    )
                else:
                    transformer_layer_spec = get_gpt_layer_local_spec(
                        args.num_experts,
                        args.moe_grouped_gemm,
                        args.qk_layernorm,
                        args.multi_latent_attention,
                        None,
                        args.moe_use_legacy_grouped_gemm,
                    )

        build_model_context = nullcontext
        build_model_context_args = {}
        if args.fp8_param_gather:
            try:
                from transformer_engine.pytorch import fp8_model_init

                build_model_context = fp8_model_init
                build_model_context_args["enabled"] = True

                                                                                   
                if "preserve_high_precision_init_val" in inspect.signature(
                        fp8_model_init).parameters:
                    build_model_context_args["preserve_high_precision_init_val"] = True
            except:
                raise RuntimeError(
                    "--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found."
                )

        with build_model_context(**build_model_context_args):
            model = GptDpoModel(
                config=config,
                transformer_layer_spec=transformer_layer_spec,
                vocab_size=args.padded_vocab_size,
                max_sequence_length=args.max_position_embeddings,
                pre_process=pre_process,
                post_process=post_process,
                fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
                parallel_output=True,
                share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
                position_embedding_type=args.position_embedding_type,
                rotary_percent=args.rotary_percent,
                rotary_base=args.rotary_base,
                rope_scaling=args.use_rope_scaling,
                seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor,
                beta=args.dpo_beta,
                label_smoothing=args.dpo_label_smoothing,
                ftx_gamma=args.dpo_ftx_gamma,
                forward_without_loss=False,
                model_using=args.dpo_model_using,
                pair_gamma=args.dpo_pair_gamma)
    print_rank_0(f"model arch {model}")
    return model


def default_infer_dpo_model_provider(pre_process=True, post_process=True):
    args = get_args()
    use_te = args.transformer_impl == "transformer_engine"
    if args.record_memory_history:
        torch.cuda.memory._record_memory_history(
            True,
                                                                     
            trace_alloc_max_entries=100000,

                                                           
            trace_alloc_record_context=True)

    print_rank_0('building GPT model ...')
                                              
    if args.yaml_cfg is not None:
        config = core_transformer_config_from_yaml(args, "language_model")
    else:
        config = core_transformer_config_from_args(args, GpatchTransformerConfig)

    model_cnt = 2 if args.dpo_model_using == "both" else 1
    assert config.dpo_policy_ref_model_cnt == model_cnt, \
        f"dpo_policy_ref_model_cnt {config.dpo_policy_ref_model_cnt} mismatch {model_cnt}"

    if args.use_legacy_models:
        model = megatron.legacy.model.GPTModel(
            config,
            num_tokentypes=0,
            parallel_output=True,
            pre_process=pre_process,
            post_process=post_process,
        )
    else:                     
        if args.spec is not None:
            transformer_layer_spec = import_module(args.spec)
        else:
            if args.num_experts:
                                               
                transformer_layer_spec = get_gpt_decoder_block_spec(config,
                                                                    use_transformer_engine=use_te)
            else:
                                               
                if use_te:
                    transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
                        args.num_experts,
                        args.moe_grouped_gemm,
                        args.qk_layernorm,
                        args.multi_latent_attention,
                        None,
                        args.moe_use_legacy_grouped_gemm,
                    )
                else:
                    transformer_layer_spec = get_gpt_layer_local_spec(
                        args.num_experts,
                        args.moe_grouped_gemm,
                        args.qk_layernorm,
                        args.multi_latent_attention,
                        None,
                        args.moe_use_legacy_grouped_gemm,
                    )

        build_model_context = nullcontext
        build_model_context_args = {}
        if args.fp8_param_gather:
            try:
                from transformer_engine.pytorch import fp8_model_init

                build_model_context = fp8_model_init
                build_model_context_args["enabled"] = True

                                                                                   
                if "preserve_high_precision_init_val" in inspect.signature(
                        fp8_model_init).parameters:
                    build_model_context_args["preserve_high_precision_init_val"] = True
            except:
                raise RuntimeError(
                    "--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found."
                )

    with build_model_context(**build_model_context_args):
        model = GptDpoModel(
            config=config,
            transformer_layer_spec=transformer_layer_spec,
            vocab_size=args.padded_vocab_size,
            max_sequence_length=args.max_position_embeddings,
            pre_process=pre_process,
            post_process=post_process,
            fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
            parallel_output=True,
            share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
            position_embedding_type=args.position_embedding_type,
            rotary_percent=args.rotary_percent,
            rotary_base=args.rotary_base,
            rope_scaling=args.use_rope_scaling,
            seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor,
            beta=args.dpo_beta,
            label_smoothing=args.dpo_label_smoothing,
            ftx_gamma=args.dpo_ftx_gamma,
            forward_without_loss=True,
            model_using=args.dpo_model_using,
            pair_gamma=args.dpo_pair_gamma,
        )
        model.requires_grad_(False)
    print_rank_0(f"model arch {model}")
    return model


def default_actor_provider(model, ref_model_state):
    args = get_args()

    actor_model = GptPpoActorModel(
        model=model,
        ref_model_state=ref_model_state,
        unwrap_model_func=unwrap_model,
                  
        forward_micro_batch_size=args.ppo_logps_fwd_micro_batch_size,
        ppo_rollout_temperature=args.ppo_rollout_temperature,
                        
        pad_to_multi_of=args.ppo_rollout_pad_to_multiple_of,
        pad_token_id=get_tokenizer()._tokenizer.pad_token_id,
        dynamic_mbs_target_seqlen=args.ppo_train_dynamic_mbs_target_seq,
        dynamic_mbs_limit=args.ppo_train_dynamic_mbs_limit,
        ppo_pack_seq=args.ppo_pack_seq,
    )

    return actor_model


def default_sampler_client_provider():
    args = get_args()
    tokenizer = get_tokenizer()
    cli = GptPpoSamplerClientV3(
        ep_ips=args.ppo_sampler_ips,
        ep_ports=args.ppo_sampler_ports,
        timeout=args.ppo_sampler_client_timeout,
        update_timeout=args.ppo_sampler_client_update_timeout,
        rpc_max_retries=args.grpo_rpc_max_retries,
        unwrap_model_func=unwrap_model,
        infer_engine_impl=args.infer_engine_impl,
        update_weight_max_size_mb=args.update_weight_max_size_mb,
    )
    return cli


def default_rm_critic_client_provider():
    args = get_args()
    tokenizer = get_tokenizer()
    cli = GptPpoRmCriticClientV3(
        pad_token_id=tokenizer._tokenizer.pad_token_id,
        ep_ips=args.ppo_critic_ips,
        ep_ports=args.ppo_critic_ports,
        combine_rm_and_critic_server=args.combine_rm_and_critic_server,
        timeout=args.ppo_rm_critic_client_timeout,
        tokenizer=tokenizer,
        rpc_max_retries=args.grpo_rpc_max_retries,
        num_rm=args.ppo_num_rm,
        ppo_debug_fake_rm_critic=args.ppo_debug_fake_rm_critic,
        ppo_value_truncate_head=args.ppo_value_truncate_head,
    )
    return cli


def default_gen_rm_client_provider():
    args = get_args()
    assert args.use_gen_rm
    cli = GptPpoGenRmClientV3(
        ep_ips=args.ppo_gen_rm_ips,
        ep_ports=args.ppo_gen_rm_ports,
        timeout=args.ppo_gen_rm_client_timeout,
        rpc_max_retries=args.grpo_rpc_max_retries,
        unwrap_model_func=unwrap_model,
    )
    return cli
