

import importlib
from packaging.version import Version

from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD

from megatron.core.optimizer import OptimizerConfig

from megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native


def get_megatron_optimizer(
        model,
        config: OptimizerConfig,
        no_weight_decay_cond=None,
        scale_lr_cond=None,
        lr_mult=1.0,
        check_for_nan_in_loss_and_grad=False,
        overlap_param_gather=False  # add for verl
):
    # Base optimizer.
    return get_megatron_optimizer_native(config=config,
                                         model_chunks=model,
                                         no_weight_decay_cond=no_weight_decay_cond,
                                         scale_lr_cond=scale_lr_cond,
                                         lr_mult=lr_mult)
