# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import json
import os
import sys
import torch
import types
from functools import partial

from schema_core import get_model_schema
from loader_base import MegatronCheckpointLoaderBase


def add_arguments(parser):
    """Add command-line arguments relevant to Megatron model loading."""
    group = parser.add_argument_group(title='Megatron loader')

    group.add_argument('--true-vocab-size', type=int, default=None,
                       help='Original size of vocab; if specified, trims padding from embedding table.')
    group.add_argument('--vocab-file', type=str, default=None,
                       help='Path to a vocab file. If specified, determines vocab size to trim padding.')
    group.add_argument('--megatron-path', type=str, default=None,
                       help='Base directory of Megatron repository')
    group.add_argument('--position-embedding-type',
                       type=str,
                       default='learned_absolute',
                       choices=['learned_absolute', 'rope'],
                       help='Type of position embedding.')
    group.add_argument('--loader-transformer-impl', default='transformer_engine',
                       choices=['local', 'transformer_engine'],
                       help='Which Transformer implementation to use.')


class MegatronCheckpointLoaderLLM(MegatronCheckpointLoaderBase):
    """
    Orchestrates loading a Megatron checkpoint and sending
    model parameters over a given multiprocessing queue.

    Args:
        args: argparse Namespace with Megatron checkpoint configurations.
        queue: A multiprocessing.Queue (or similar) used to send out loaded tensors.
    """

    def build_sys_argv(self):
        """
        Construct a sys.argv list for Megatron's argument parser.
        This centralizes the hack of overwriting sys.argv.
        """

        return [
            *super().build_sys_argv(),
            '--position-embedding-type', self.args.position_embedding_type,
        ]

    def import_model_provider(self):
        """Return the correct model_provider function depending on GPT vs. BERT."""
        if self.args.model_type == 'GPT':
            from model_provider import model_provider
            from gpt_builders import gpt_builder
            self.model_provider = partial(model_provider, gpt_builder)
            return model_provider
        elif self.args.model_type == 'BERT':
            from pretrain_bert import model_provider
            return model_provider
        else:
            raise Exception(f"Unrecognized model type: {self.args.model_type}")


    def send_model_over_queue(self):
        self.send_metadata_over_queue()
        # Model schema.
        schema = get_model_schema(
            self.md.model_type,
            self.margs.transformer_impl,
            self.margs.num_experts,
            self.margs.expert_model_parallel_size,
        )
        self.send_llm_over_queue(schema)
        self.queue.put("done")


def load_checkpoint(queue, args):
    """
    Required top-level function that creates the loader,
    calls its .load(), and handles exceptions by signaling 'exit'.
    """
    loader = MegatronCheckpointLoaderLLM(args, queue)
    try:
        loader.load()
    except Exception as e:
        queue.put("exit")
        raise e
