# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING, Optional

import torch
from peft import IA3Config, PeftModel, get_peft_model

from swift.llm import MODEL_ARCH_MAPPING, ModelKeys
from swift.utils import find_all_linears

if TYPE_CHECKING:
    from swift.llm import TrainArguments


class Tuner:

    @staticmethod
    def prepare_model(
        args: "TrainArguments", model: torch.nn.Module
    ) -> torch.nn.Module:
        """Prepare a new model with a tuner

        Args:
            args: The training arguments
            model: The model instance

        Returns:
            The wrapped model
        """
        raise NotImplementedError

    @staticmethod
    def save_pretrained(
        model: torch.nn.Module,
        save_directory: str,
        state_dict: Optional[dict] = None,
        safe_serialization: bool = True,
        **kwargs,
    ) -> None:
        """Save when save_steps reaches

        Args:
            model: The wrapped model by `prepare_model`
            save_directory: The directory to save
            safe_serialization: Use safetensors or not
        """
        raise NotImplementedError

    @staticmethod
    def from_pretrained(
        model: torch.nn.Module, model_id: str, **kwargs
    ) -> torch.nn.Module:
        """Load the ckpt_dir

        Args:
            model: The original model instance.
            model_id: The model id or ckpt_dir to load
        Returns:
            The wrapped model instance
        """
        raise NotImplementedError


class PeftTuner(Tuner):

    @staticmethod
    def save_pretrained(
        model: torch.nn.Module,
        save_directory: str,
        state_dict: Optional[dict] = None,
        safe_serialization: bool = True,
        **kwargs,
    ) -> None:
        model.save_pretrained(
            save_directory, safe_serialization=safe_serialization, **kwargs
        )

    @staticmethod
    def from_pretrained(
        model: torch.nn.Module, model_id: str, **kwargs
    ) -> torch.nn.Module:
        return PeftModel.from_pretrained(model, model_id, **kwargs)


# Here gives a simple example of IA3
class IA3(PeftTuner):

    @staticmethod
    def prepare_model(
        args: "TrainArguments", model: torch.nn.Module
    ) -> torch.nn.Module:
        model_arch: ModelKeys = MODEL_ARCH_MAPPING[model.model_meta.model_arch]
        ia3_config = IA3Config(
            target_modules=find_all_linears(model),
            feedforward_modules=".*" + model_arch.mlp.split("{}.")[1] + ".*",
        )
        return get_peft_model(model, ia3_config)


class DummyTuner(PeftTuner):

    @staticmethod
    def prepare_model(
        args: "TrainArguments", model: torch.nn.Module
    ) -> torch.nn.Module:
        return model


# Add your own tuner here, use --train_type xxx to begin
extra_tuners = {"ia3": IA3, "dummy": DummyTuner}
