              
                                                      
                                                                                         

import importlib
import sys
import types

from packaging import version

from megatron.core import __version__

from mpatch.core.transformer.moe.moe_utils import topk_softmax_with_capacity
from mpatch.core.transformer.moe.router import aux_loss_load_balancing
from mpatch.core.models.common.embeddings.language_model_embedding import language_model_embedding_forward
from mpatch.core.models.multimodal.llava_model import llava_model_init
from mpatch.core.models.vision.clip_vit_model import clip_vit_forward, clip_vit_init
from mpatch.core.models.vision.multimodal_projector import multimodal_projector_init
from mpatch.core.pipeline_parallel.schedules import get_tensor_shapes
from mpatch.training.checkpointing import save_checkpoint
from mpatch.training.arguments import parse_args_wrapper
from mpatch.core.tensor_parallel.random import checkpoint
from mpatch.training.initialize import initialize_megatron_wrapper, _compile_dependencies


def get_func_name(func):
    if isinstance(func, str):
        return func
    return '.'.join((func.__module__, func.__qualname__))


def dummy_function_wrapper(func_name):

    def dummy_function(*args, **kwargs):
        raise RuntimeError('function {} no exist'.format(func_name))

    return dummy_function


class Patch:
    """
    如果新增代码在原代码中间：
    1. 替换的话实现一个新的函数，调用register_patch进行注册替换，dummy=False
    2. 新增的话同样实现新的函数，调用register_patch进行注册新增，dummy=True
    如果新增代码在原代码首尾：
    实现一个以wrapper为结尾的函数，入参为函数fn，调用register_patch进行注册替换，dummy=False
    """

    def __init__(self, orig_func_name, new_func, create_dummy):
        split_name = orig_func_name.rsplit('.', 1)
        if len(split_name) == 1:
            self.orig_module_name, self.orig_func_name = orig_func_name, None
        else:
            self.orig_module_name, self.orig_func_name = split_name
        self.orig_module = None
        self.orig_func = None

        self.patch_func = None
        self.wrappers = []
        if new_func is None:
            new_func = dummy_function_wrapper(orig_func_name)
        self.set_patch_func(new_func)
        self.is_applied = False
        self.create_dummy = create_dummy

    @property
    def orig_func_id(self):
        return id(self.orig_func)

    @property
    def patch_func_id(self):
        return id(self.patch_func)

    def set_patch_func(self, new_func, force_patch=False):
        if hasattr(new_func, '__name__') and new_func.__name__.endswith(('wrapper', 'decorator')):
            self.wrappers.append(new_func)
        else:
            if self.patch_func and not force_patch:
                raise RuntimeError('the patch of {} exist !'.format(self.orig_func_name))
            self.patch_func = new_func
        self.is_applied = False

    def apply_patch(self):
        if self.is_applied:
            return

        self.orig_module, self.orig_func = Patch.parse_path(self.orig_module_name,
                                                            self.orig_func_name, self.create_dummy)

        final_patch_func = self.orig_func
        if self.patch_func is not None:
            final_patch_func = self.patch_func

        for wrapper in self.wrappers:
            final_patch_func = wrapper(final_patch_func)

        if self.orig_func_name is not None:
            setattr(self.orig_module, self.orig_func_name, final_patch_func)
        for key, value in sys.modules.copy().items():

            if self.orig_func_name is not None and hasattr(value, self.orig_func_name) \
                    and id(getattr(value, self.orig_func_name)) == self.orig_func_id:
                setattr(value, self.orig_func_name, final_patch_func)
        self.is_applied = True

    @staticmethod
    def parse_path(module_path, function_name, create_dummy):
        from importlib.machinery import ModuleSpec
        modules = module_path.split('.')
        for i in range(1, len(modules) + 1):
            parent = '.'.join(modules[:i - 1])
            path = '.'.join(modules[:i])
            try:
                importlib.import_module(path)
            except ModuleNotFoundError as e:
                if not parent or not hasattr(importlib.import_module(parent), modules[i - 1]):
                    if not create_dummy:
                        raise ModuleNotFoundError(e) from e
                    sys.modules[path] = types.ModuleType(path)
                    sys.modules[path].__file__ = 'gpatch.dummy_module.py'
                    sys.modules[path].__spec__ = ModuleSpec(path, None)
                    if parent:
                        setattr(importlib.import_module(parent), modules[i - 1], sys.modules[path])
                else:
                    module = getattr(importlib.import_module(parent), modules[i - 1])
                    if hasattr(module, function_name):
                        return module, getattr(module, function_name)
                    elif create_dummy:
                        return module, dummy_function_wrapper(function_name)
                    else:
                        raise RuntimeError('no exist {} of {}'.format(function_name, module))

        if function_name is not None and not hasattr(sys.modules[module_path], function_name):
            setattr(sys.modules[module_path], function_name, None)
        return sys.modules[module_path], getattr(
            sys.modules[module_path], function_name) if function_name is not None else None


class GCorePatchMgr:
    patches_info = {}

    @staticmethod
    def register_patch(orig_func_name, new_func=None, force_patch=False, create_dummy=False):
        if orig_func_name not in GCorePatchMgr.patches_info:
            GCorePatchMgr.patches_info[orig_func_name] = Patch(orig_func_name, new_func,
                                                               create_dummy)
        else:
            GCorePatchMgr.patches_info.get(orig_func_name).set_patch_func(new_func, force_patch)

    @staticmethod
    def apply_patches():
        for patch in GCorePatchMgr.patches_info.values():
            patch.apply_patch()


def patch_mcore_func(mgr: GCorePatchMgr):
    '''
        Replace megatron functions or add megatron functions
    '''

    mgr.register_patch("megatron.core.transformer.moe.router.TopKRouter.aux_loss_load_balancing",
                       aux_loss_load_balancing)
    mgr.register_patch("megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity",
                       topk_softmax_with_capacity)

    mgr.register_patch(
        "megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward",
        language_model_embedding_forward)
    mgr.register_patch("megatron.core.models.multimodal.llava_model.LLaVAModel.__init__",
                       llava_model_init)                                       
    mgr.register_patch("megatron.core.models.vision.clip_vit_model.CLIPViTModel.forward",
                       clip_vit_forward)
    mgr.register_patch("megatron.core.models.vision.clip_vit_model.CLIPViTModel.__init__",
                       clip_vit_init)
    mgr.register_patch(
        "megatron.core.models.vision.multimodal_projector.MultimodalProjector.__init__",
        multimodal_projector_init)
    mgr.register_patch("megatron.core.pipeline_parallel.schedules.get_tensor_shapes",
                       get_tensor_shapes)

    mgr.register_patch("megatron.training.checkpointing.save_checkpoint", save_checkpoint)

    mgr.register_patch("megatron.training.arguments.parse_args", parse_args_wrapper)
    mgr.register_patch("megatron.core.tensor_parallel.random.checkpoint", checkpoint)
    mgr.register_patch("megatron.training.initialize._compile_dependencies",
                       _compile_dependencies)
    mgr.register_patch("megatron.training.initialize.initialize_megatron",
                       initialize_megatron_wrapper)

    print(f"MCore version {version.parse(__version__)} {version.parse(__version__)<version.parse('0.13.0')}")
    if version.parse(__version__) < version.parse('0.13.0'):
        from mpatch.core.parallel_state import initialize_model_parallel
        from mpatch.training.checkpointing import load_checkpoint_lt_0_13
        from mpatch.core.distributed.param_and_grad_buffer import param_and_grad_buffer_init_lt_0_13
        from mpatch.training.training import train_step_lt_0_13

        mgr.register_patch("megatron.training.checkpointing.load_checkpoint", load_checkpoint_lt_0_13)
        mgr.register_patch("megatron.core.parallel_state.initialize_model_parallel",
                           initialize_model_parallel)
        mgr.register_patch(
            "megatron.core.distributed.param_and_grad_buffer._ParamAndGradBuffer.__init__",
            param_and_grad_buffer_init_lt_0_13)
        mgr.register_patch("megatron.training.training.train_step", train_step_lt_0_13)

        if version.parse(__version__) < version.parse('0.12.0'):
            from mpatch.core.transformer.multi_latent_attention import (
                multi_latent_attention_forward_lt_0_12,
                multi_latent_attention_init_lt_0_12
            )
            from mpatch.core.distributed.finalize_model_grads import _allreduce_word_embedding_grads, _allreduce_layernorm_grads
            from mpatch.training.wandb_utils import on_save_checkpoint_success
            mgr.register_patch(
                "megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads",
                _allreduce_word_embedding_grads)
            mgr.register_patch(
                "megatron.core.distributed.finalize_model_grads._allreduce_layernorm_grads",
                _allreduce_layernorm_grads)
            mgr.register_patch("megatron.training.wandb_utils.on_save_checkpoint_success", on_save_checkpoint_success)
            mgr.register_patch(
                "megatron.core.transformer.multi_latent_attention.MultiLatentAttention.forward",
                multi_latent_attention_forward_lt_0_12)
            mgr.register_patch(
                "megatron.core.transformer.multi_latent_attention.MultiLatentAttention.__init__",
                multi_latent_attention_init_lt_0_12)
    else:
        from mpatch.core.parallel_state import initialize_model_parallel_wrapper
        from mpatch.training.checkpointing import load_checkpoint
        from mpatch.training.training import train_step

        mgr.register_patch("megatron.training.checkpointing.load_checkpoint", load_checkpoint)
        mgr.register_patch("megatron.core.parallel_state.initialize_model_parallel",
                           initialize_model_parallel_wrapper)
        mgr.register_patch("megatron.training.training.train_step", train_step)


def init_gpatch_for_mcore():
    """Called at very beginning of the programme"""
    print(f"Megatron version: {__version__}")
    mgr = GCorePatchMgr()
    patch_mcore_func(mgr)
    mgr.apply_patches()
    print("patch_mcore.py: patching mcore ends.")
