import torch
from megatron.core import parallel_state as mpu
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.training import initialize_megatron, get_args
from megatron.training.global_vars import set_global_variables
from megatron.training.arguments import parse_args

from mbridge import AutoBridge
from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model


def init_distributed(tp=2, pp=1, cp=1, vpp=1, ep=1, etp=None):
    """Initialize distributed environment"""
    torch.distributed.init_process_group("nccl")
    print(f"[Init] global_rank {torch.distributed.get_rank()} local_rank {torch.distributed.get_node_local_rank()}")
                             
    torch.cuda.set_device(torch.distributed.get_node_local_rank())
    if pp <= 1:
        vpp = None
    mpu.initialize_model_parallel(
        tensor_model_parallel_size=tp,
        pipeline_model_parallel_size=pp,
        virtual_pipeline_model_parallel_size=vpp,
        context_parallel_size=cp,
        expert_model_parallel_size=ep,
        expert_tensor_parallel_size=etp,
    )
    model_parallel_cuda_manual_seed(0)


def add_ckpt_args(parser):
    parser.add_argument(
        "--model_path",
        type=str,
        required=True,
        help="Path to the HuggingFace model directory",
    )
    parser.add_argument(
        "--num_layers_in_first_pipeline_stage",
        type=int,
        default=14,
        help="Number of layers in the first pipeline stage",
    )
    parser.add_argument(
        "--num_layers_in_last_pipeline_stage",
        type=int,
        default=15,
        help="Number of layers in the last pipeline stage",
    )
    parser.add_argument("--tp", type=int, default=8, help="Tensor parallel size")
    parser.add_argument("--pp", type=int, default=4, help="Pipeline parallel size")
    parser.add_argument("--cp", type=int, default=1, help="Context parallel size")
    parser.add_argument(
        "--vpp", type=int, default=1, help="Virtual pipeline parallel size"
    )
    parser.add_argument("--ep", type=int, default=1, help="Expert parallel size")
    parser.add_argument(
        "--etp", type=int, default=None, help="Expert tensor parallel size"
    )
    return parser


def main():
                                     

                                           
    args = parse_args(extra_args_provider=add_ckpt_args)

                         
    args.data_parallel_size = 1
    args.micro_batch_size = 1
    args.global_batch_size = 1
    args.use_dist_ckpt = True

    set_global_variables(args=args, build_tokenizer=False)
    init_distributed(tp=args.tp, pp=args.pp, cp=args.cp, vpp=args.vpp, ep=args.ep)

                             
    bridge = AutoBridge.from_pretrained(args.model_path)
    bridge.set_extra_args(
        num_layers_in_first_pipeline_stage=args.num_layers_in_first_pipeline_stage,
        num_layers_in_last_pipeline_stage=args.num_layers_in_last_pipeline_stage,
    )
                                      
    model = bridge.get_model(post_model_creation_callbacks=[], wrap_with_ddp=False)

                                
    for m in model:
        from mbridge.core.util import unwrap_model

        m = unwrap_model(m)
        if hasattr(m, "decoder"):
            for l in m.decoder.layers:
                if (
                    hasattr(l, "mlp")
                    and hasattr(l.mlp, "router")
                    and hasattr(l.mlp.router, "_maintain_float32_expert_bias")
                ):
                                                                             
                    l.mlp.router._maintain_float32_expert_bias()
        if hasattr(m, "mtp"):
            for l in m.mtp.layers:
                if (
                    hasattr(l, "mlp")
                    and hasattr(l.mlp, "router")
                    and hasattr(l.mlp.router, "_maintain_float32_expert_bias")
                ):
                    print(f"maintain router bias dtype for {l.mlp.router}")
                    l.mlp.router._maintain_float32_expert_bias()

    bridge.load_weights(model, args.model_path, memory_efficient=True)
    print(f"[rank {torch.distributed.get_rank()}] Model loaded, proceeding with post-processing ...")

                                                             
                                                                           
                               
                                                             
    save_distributed_checkpoint = True
    if save_distributed_checkpoint:
        from megatron.training.checkpointing import save_checkpoint
        save_checkpoint(
            iteration=1,
            model=model,
            optimizer=None,
            opt_param_scheduler=None,
            num_floating_point_operations_so_far=0
        )

                                                
    for k, v in bridge.export_weights(model):
        if torch.distributed.get_rank() != 0:
            continue
        gt = bridge.safetensor_io.load_one_hf_weight(k).to(v.device)
        if k != "lm_head.weight":
            assert v.shape == gt.shape, f"mismatch of {k} {v.shape=} {gt.shape=}"
            assert v.sum().item() == gt.sum().item(), f"mismatch of {k} {v=} {gt=}"
        else:
            if v.shape[0] == 1:
                print(f"this is a value model, {k} {v.shape=} {gt.shape=}")
        if torch.distributed.get_rank() == 0:
            print(k, "export ok")

                                                            
    torch.distributed.barrier()
    torch.distributed.destroy_process_group()
    return f"rank {rank} done"


if __name__ == "__main__":
    main()