# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass, field

import torch
from torch import nn

from swift.utils.logger import get_logger
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput

logger = get_logger()


@dataclass
class NEFTuneConfig(SwiftConfig):
    """
    The configuration class for the NEFTune module.

    NEFTune adds slightly noises to embedding outputs.
    See https://arxiv.org/abs/2310.05914

    Args:
        noise_alpha(`float`): The noise alpha value used for the NEFTune, default 5.0
    """
    noise_alpha: float = field(default=5.0, metadata={'help': 'The noise alpha value used for the NEFTune'})

    def __post_init__(self):
        from .mapping import SwiftTuners
        self.swift_type = SwiftTuners.NEFTUNE


class NEFTune(SwiftAdapter):

    @staticmethod
    def prepare_model(model: nn.Module, config: NEFTuneConfig, adapter_name: str) -> SwiftOutput:
        """Prepare a model with `NEFTuneConfig`"""
        for sub_module in model.modules():
            if isinstance(sub_module, torch.nn.Embedding):

                def neftune_hook(module, args, output):
                    if module.training and getattr(module, 'nef_activated'):
                        dims = torch.tensor(output.size(-1) * output.size(-2))
                        mag_norm = config.noise_alpha / torch.sqrt(dims)
                        output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
                    return output

                if hasattr(sub_module, 'nef_activated'):
                    raise ValueError('NEFTune does not support a second tuner.')

                sub_module.register_forward_hook(neftune_hook)
                sub_module.nef_activated = True

        def state_dict_callback(state_dict, adapter_name, **kwargs):
            return state_dict

        def mark_trainable_callback(model):
            return

        return SwiftOutput(
            config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)

    @staticmethod
    def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
        for sub_module in module.modules():
            if isinstance(sub_module, torch.nn.Embedding):
                sub_module.nef_activated = activate

    @staticmethod
    def freeze_model():
        return False

    @staticmethod
    def has_additional_modules():
        return False
