# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass, field

import torch
from torch import nn

from swift.utils.logger import get_logger
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput

logger = get_logger()


@dataclass
class NEFTuneConfig(SwiftConfig):
    """
    The configuration class for the NEFTune module.

    NEFTune adds slightly noises to embedding outputs.
    See https://arxiv.org/abs/2310.05914

    Args:
        noise_alpha(`float`): The noise alpha value used for the NEFTune, default 5.0
    """

    noise_alpha: float = field(
        default=5.0, metadata={"help": "The noise alpha value used for the NEFTune"}
    )

    def __post_init__(self):
        from .mapping import SwiftTuners

        self.swift_type = SwiftTuners.NEFTUNE


class NEFTune(SwiftAdapter):

    @staticmethod
    def prepare_model(
        model: nn.Module, config: NEFTuneConfig, adapter_name: str
    ) -> SwiftOutput:
        """Prepare a model with `NEFTuneConfig`"""
        for sub_module in model.modules():
            if isinstance(sub_module, torch.nn.Embedding):

                def neftune_hook(module, args, output):
                    if module.training and getattr(module, "nef_activated"):
                        dims = torch.tensor(output.size(-1) * output.size(-2))
                        mag_norm = config.noise_alpha / torch.sqrt(dims)
                        output = output + torch.zeros_like(output).uniform_(
                            -mag_norm, mag_norm
                        )
                    return output

                if hasattr(sub_module, "nef_activated"):
                    raise ValueError("NEFTune does not support a second tuner.")

                sub_module.register_forward_hook(neftune_hook)
                sub_module.nef_activated = True

        def state_dict_callback(state_dict, adapter_name, **kwargs):
            return state_dict

        def mark_trainable_callback(model):
            return

        return SwiftOutput(
            config=config,
            state_dict_callback=state_dict_callback,
            mark_trainable_callback=mark_trainable_callback,
        )

    @staticmethod
    def activate_adapter(
        module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None
    ):
        for sub_module in module.modules():
            if isinstance(sub_module, torch.nn.Embedding):
                sub_module.nef_activated = activate

    @staticmethod
    def freeze_model():
        return False

    @staticmethod
    def has_additional_modules():
        return False
