# BSD 2-CLAUSE LICENSE
# Copyright 2024 LinkedIn Corporation
# All Rights Reserved.
# Redistribution and use in source and binary forms, with or
# without modification, are permitted provided that the following
# conditions are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from transformers import PretrainedConfig, PreTrainedModel

from src.common.logging import get_project_logger
from src.modeling.liger_kernels.cross_entropy import LigerCrossEntropyLoss
from src.modeling.liger_kernels.geglu import LigerGEGLUMLP
from src.modeling.liger_kernels.rope import liger_rotary_pos_emb

logger = get_project_logger()


def apply_liger_kernel_to_gemma2(
    rope: bool = True,
    cross_entropy: bool = True,
    geglu: bool = True,
    model: PreTrainedModel = None,
) -> None:
    logger.info('Loading Liger-Kernels for Gemma2...')

    from transformers.models.gemma2 import modeling_gemma2

    if rope:
        modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
    if cross_entropy:
        modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
    if geglu:
        modeling_gemma2.Gemma2MLP = LigerGEGLUMLP

    if model is not None:
        config: PretrainedConfig = model.config

        if hasattr(model, 'model'):
            base_model = model.model
        else:
            base_model = model

        torch_dtype = config.torch_dtype

        for decoder_layer in base_model.layers:
            if geglu:
                decoder_layer.mlp = LigerGEGLUMLP(config).to(torch_dtype)

    logger.info('Liger-Kernels have been successfully applied!')
