# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from dvlab-research/LongLoRA.
import re
from dataclasses import dataclass, field
from typing import List, Tuple, Union

import torch.nn as nn

from swift.tuners.lora import lora_state_dict, mark_lora_as_trainable
from swift.tuners.lora_layers import LoraModel
from .. import LoRA, LoRAConfig, SwiftOutput


class LongLoRAModelType:
    LLAMA = "llama"


@dataclass
class LongLoRAConfig(LoRAConfig):
    """
    The Config for the LongLoRA adapter.
    LongLoRA:[Efficient Fine-tuning of Long-Context Large Language Models](https://arxiv.org/abs/2309.12307)
    This adapter uses S2-attention to shorten the attention window for long context training scenarios.
    Args:
        embedder_and_normalizer: LongLoRA allows the embedder and normalizer to be trainable, this parameter specifies
            the names of the embedders and normalizers.
        model_type: The model type, now support llama only
        group_size_ratio: The group size window ratio of the sequence length.
            Note: The sequence length should be split to smaller sequences by the ratio.
    """

    embedder_and_normalizer: Union[str, List[str], Tuple[str]] = field(
        default=("embed", "norm"),
        metadata={
            "help": "The names of embedder and normalizer, regex format if is a str, else will match with sub sequences"
        },
    )

    model_type: str = field(
        default=None,
        metadata={"help": "The model type, now only support `llama` structure."},
    )

    group_size_ratio: float = field(
        default=0.25, metadata={"help": "The S2 attention group ratio"}
    )

    def __post_init__(self):
        from swift.tuners.mapping import SwiftTuners

        self.swift_type = SwiftTuners.LONGLORA


class LongLoRA(LoRA):

    @staticmethod
    def prepare_model(model: nn.Module, config: LongLoRAConfig, adapter_name: str):
        """Prepare a model with `LongLoRAConfig`"""
        LoraModel(model, config, adapter_name)

        def state_dict_callback(state_dict, adapter_name, **kwargs):
            _state_dict = lora_state_dict(state_dict, adapter_name, config.bias)
            for name, value in state_dict.items():
                if isinstance(config.embedder_and_normalizer, str):
                    target_module_found = re.fullmatch(
                        config.embedder_and_normalizer, name
                    )
                else:
                    target_module_found = any(
                        target_key in name
                        for target_key in config.embedder_and_normalizer
                    )
                if target_module_found and name not in _state_dict:  # noqa
                    _state_dict[name] = value
            return _state_dict

        def mark_trainable_callback(model):
            mark_lora_as_trainable(model, adapter_name, config.bias)
            mark_embedding_normalizer_as_trainable(
                model, config.embedder_and_normalizer
            )

        if config.model_type == LongLoRAModelType.LLAMA:
            from .llama import replace_llama_attn

            replace_llama_attn(model)
            # only support code base from transformers
            model.config.group_size_ratio = config.group_size_ratio

        return SwiftOutput(
            config=config,
            state_dict_callback=state_dict_callback,
            mark_trainable_callback=mark_trainable_callback,
        )


def mark_embedding_normalizer_as_trainable(
    model: nn.Module, extra_parameters: Union[str, List[str], Tuple[str]]
) -> None:
    for name, sub_module in model.named_parameters():
        if isinstance(extra_parameters, str):
            target_module_found = re.fullmatch(extra_parameters, name)
        else:
            target_module_found = any(
                target_key in name for target_key in extra_parameters
            )
        if target_module_found:  # noqa
            sub_module.requires_grad = True
