# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import re
import types
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn

from swift.utils import get_logger
from swift.utils.torch_utils import find_sub_module
from .restuning_components import (
    ResTuner,
    detach_tensors,
    probe_input_pre_hook,
    probe_output_hook,
)
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput

logger = get_logger()


@dataclass
class ResTuningConfig(SwiftConfig):
    """
    The configuration class for the ResTuning module.

    ResTuning is a flexible parameter-efficient and memory-efficient tuning paradigm framework.
    'Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone'
    by Jiang et al.(2023)
    See

    Args:
        dims(`Union[List[int], int]`): The dimensions of the hidden states
        root_modules(`str`): The root module to be replaced, can a regex string
        root_modules_hook(`str`): The hook type of root modules, can be "input" or "output"
        stem_modules(`Union[List[str], str]`): The stem modules to be replaced,
            can a regex string or name list of full match format
        stem_modules_hook(`Union[List[str], str]`): The hook type of stem modules, can be "input" or "output"
        target_modules(`str`): The target module to be replaced, can a regex string
        target_modules_hook(`str`): The hook type of target modules, can be "input" or "output"
        tuner_cfg(`Union[List[Dict], Dict, str]`): The configuration of the tuning module,
            can a string or customized config
        use_upsample(bool): Whether to use auxiliary upsample module
        upsample_out_channels(List[int]): The channels if `use_upsample`
        zero_init_last(bool): Use zero to initialize the last Linear in every sub tuner.

    """

    dims: Optional[Union[List[int], int]] = field(
        default=None, metadata={"help": "The dimensions of the hidden states"}
    )

    root_modules: str = field(
        default=None,
        metadata={
            "help": "The root module to be replaced, can a regex string (use the first matching module) or full match format"
        },
    )

    root_modules_hook: str = field(
        default="input",
        metadata={"help": 'The hook type of root modules, can be "input" or "output"'},
    )

    stem_modules: Optional[Union[List[str], str]] = field(
        default=None,
        metadata={
            "help": "The stem modules to be replaced, can a regex string or name list of full match format"
        },
    )

    stem_modules_hook: str = field(
        default="output",
        metadata={"help": 'The hook type of stem modules, can be "input" or "output"'},
    )

    target_modules: str = field(
        default=None,
        metadata={
            "help": "The target module to be replaced, can a regex string (use the first matching module) or full match format"
        },
    )

    target_modules_hook: str = field(
        default="input",
        metadata={
            "help": 'The hook type of target modules, can be "input" or "output"'
        },
    )

    target_hidden_pos: Union[int, str] = field(
        default=None,
        metadata={"help": "The position of the hidden state for target modules output"},
    )

    tuner_cfg: Optional[Union[List[Dict], Dict, str]] = field(
        default=None,
        metadata={
            "help": "The configuration of the tuning module, can a string or customized config"
        },
    )

    use_upsample: bool = field(
        default=False, metadata={"help": "Whether to use auxiliary upsample module"}
    )

    upsample_out_channels: List[int] = field(
        default=None,
        metadata={
            "help": 'The number of output channels when "use_upsample" is set to "True"'
        },
    )

    zero_init_last: bool = field(
        default=False, metadata={"help": "Zero init last weight"}
    )

    use_bypass: bool = field(default=True, metadata={"help": "Whether to use bypass"})

    def __post_init__(self):
        from .mapping import SwiftTuners

        self.swift_type = SwiftTuners.RESTUNING
        self.target_hidden_pos = (
            0 if self.target_hidden_pos is None else self.target_hidden_pos
        )


class ResTuning(SwiftAdapter):

    @staticmethod
    def prepare_model(
        model: nn.Module, config: ResTuningConfig, adapter_name: str
    ) -> SwiftOutput:
        """Prepare a model with `ResTuningConfig`"""

        def _forward_seq(self, input, *args, **kwargs):
            for idx, module in enumerate(self):
                if idx >= len(self.origin_module_keys):
                    continue
                input = module(input)
            return input

        def _forward_target(self, *args, **kwargs):
            if self.target_modules_hook == "input":
                if isinstance(self.target_hidden_pos, int):
                    args = list(args)
                    _arg = args[self.target_hidden_pos]
                else:
                    _arg = kwargs[self.target_hidden_pos]
                args_main = _forward_restuning(self, _arg)
                if isinstance(self.target_hidden_pos, int):
                    args[self.target_hidden_pos] = args_main
                else:
                    kwargs[self.target_hidden_pos] = args_main
                args_main = getattr(self, f"forward_origin_{adapter_name}")(
                    *args, **kwargs
                )
            else:
                _args_main = getattr(self, f"forward_origin_{adapter_name}")(
                    *args, **kwargs
                )
                _arg = (
                    _args_main[self.target_hidden_pos]
                    if isinstance(_args_main, (tuple, list, dict))
                    else _args_main
                )
                args_main = _forward_restuning(self, _arg)
                if type(_args_main) != type(args_main):
                    _args_main[self.target_hidden_pos] = args_main
                    args_main = _args_main
            return args_main

        def _forward_restuning(self, origin_arg):
            probe_results = []
            root_module_ins = self.root_module_ins_list[0]
            stem_module_ins_list = self.stem_module_ins_list
            top_module = model.get_submodule("")
            if root_module_ins:
                if root_module_ins.root_modules_hook == "input":
                    probe_results.append(root_module_ins.probe_input_data)
                else:
                    probe_results.append(root_module_ins.probe_output_data)
            for i, st_mod in enumerate(stem_module_ins_list):
                if i == 0 and root_module_ins is None:
                    probe_results.append(st_mod.probe_input_data)
                if st_mod.stem_modules_hook == "input":
                    probe_results.append(st_mod.probe_input_data)
                else:
                    probe_results.append(st_mod.probe_output_data)
            args_main = getattr(top_module, f"restuning_{adapter_name}")(
                probe_results, origin_arg
            )
            return args_main

        # 1. Matching the root module
        module_keys = [key for key, _ in model.named_modules()]
        root_module_ins_list = []
        if config.root_modules:
            for module_key in module_keys:
                if re.fullmatch(config.root_modules, module_key):
                    root_module = model.get_submodule(module_key)
                    logger.info(
                        f"Matching root module [{module_key}] of type {type(root_module)}"
                    )
                    if isinstance(root_module, (nn.ModuleList, nn.ModuleDict)):
                        logger.warning(
                            f"Type of {type(root_module)} may not be supported because of its customized forward"
                        )
                    if config.root_modules_hook == "input":
                        root_module.register_forward_pre_hook(probe_input_pre_hook)
                    else:
                        root_module.register_forward_hook(probe_output_hook)
                    root_module.root_modules_hook = config.root_modules_hook
                    root_module_ins_list.append(root_module)
                    break
            if len(root_module_ins_list) == 0:
                logger.error("Cannot match root modules")

        # 2. Matching the stem module
        stem_module_ins_list = []
        stem_module_ins_index = []
        for module_key in module_keys:
            if (
                isinstance(config.stem_modules, str)
                and re.fullmatch(config.stem_modules, module_key)
            ) or (
                isinstance(config.stem_modules, list)
                and module_key in config.stem_modules
            ):
                stem_module = model.get_submodule(module_key)
                if isinstance(config.stem_modules, list):
                    stem_module_ins_index.append(config.stem_modules.index(module_key))
                logger.info(
                    f"Matching stem module [{module_key}] of type {type(stem_module)}"
                )
                if isinstance(stem_module, (nn.ModuleList, nn.ModuleDict)):
                    logger.warning(
                        f"Type of {type(stem_module)} may not be supported because of its customized forward"
                    )
                if len(root_module_ins_list) == 0 and len(stem_module_ins_list) == 0:
                    stem_module.register_forward_pre_hook(probe_input_pre_hook)
                if config.stem_modules_hook == "input":
                    stem_module.register_forward_pre_hook(probe_input_pre_hook)
                else:
                    stem_module.register_forward_hook(probe_output_hook)
                stem_module.stem_modules_hook = config.stem_modules_hook
                stem_module_ins_list.append(stem_module)
        if isinstance(config.stem_modules, list):
            stem_module_ins_list = [
                stem_module_ins_list[stem_module_ins_index.index(i)]
                for i in range(len(stem_module_ins_index))
            ]
        depth = len(stem_module_ins_list)
        if len(stem_module_ins_list) == 0:
            raise Exception("Cannot match source modules")

        # 3. Init restuning module
        if len(stem_module_ins_list) != 0:
            top_module = model.get_submodule("")
            restuning_module = ResTuningBypassModule(
                config.dims,
                depth,
                adapter_name,
                config.use_upsample,
                config.upsample_out_channels,
                config.zero_init_last,
                config.tuner_cfg,
            )
            setattr(top_module, f"restuning_{adapter_name}", restuning_module)

        # 4. Matching the target module
        target_module_ins = None
        for module_key in module_keys:
            if re.fullmatch(config.target_modules, module_key):
                tgt_module = model.get_submodule(module_key)
                logger.info(
                    f"Matching target module [{module_key}] of type {type(tgt_module)}"
                )
                if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)):
                    raise Exception(
                        f"Type of {type(tgt_module)} may not be supported because of its customized forward"
                    )

                tgt_module.target_modules_hook = config.target_modules_hook
                tgt_module.target_hidden_pos = config.target_hidden_pos
                tgt_module.root_module_ins_list = root_module_ins_list
                tgt_module.stem_module_ins_list = stem_module_ins_list
                target_module_ins = tgt_module

                if isinstance(tgt_module, nn.Sequential) and not hasattr(
                    tgt_module, "origin_module_keys"
                ):
                    tgt_module.origin_module_keys = copy.deepcopy(
                        list(tgt_module._modules.keys())
                    )

                    setattr(
                        tgt_module,
                        f"forward_origin_{adapter_name}",
                        types.MethodType(_forward_seq, tgt_module),
                    )
                else:
                    setattr(
                        tgt_module, f"forward_origin_{adapter_name}", tgt_module.forward
                    )
                tgt_module.forward = types.MethodType(_forward_target, tgt_module)
        if target_module_ins is None:
            raise Exception("Cannot match target modules")

        def state_dict_callback(state_dict, adapter_name, **kwargs):
            return {
                key: value
                for key, value in state_dict.items()
                if f"restuning_{adapter_name}" in key
            }

        def mark_trainable_callback(model):
            return

        return SwiftOutput(
            config=config,
            state_dict_callback=state_dict_callback,
            mark_trainable_callback=mark_trainable_callback,
        )

    @staticmethod
    def activate_adapter(
        module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None
    ):
        modules = find_sub_module(module, f"restuning_{adapter_name}")
        for _module in modules:
            _module: ActivationMixin
            _module: nn.Module
            _module.set_activation(adapter_name, activate)
            SwiftAdapter.save_memory(
                _module, adapter_name, _module.module_key, activate, offload
            )


class ResTuningBypassModule(nn.Module, ActivationMixin):
    """The implementation of ResTuningBypass method."""

    def __init__(
        self,
        dims,
        depth,
        adapter_name,
        use_upsample=False,
        upsample_out_channels=None,
        zero_init_last=False,
        tuner_cfg=None,
    ):
        super(ResTuningBypassModule, self).__init__()
        super(nn.Module, self).__init__("")
        self.adapter_name = adapter_name

        self.bypass_blocks = nn.Sequential(
            *[
                ResTunerBypassBlock(
                    dim=dims[i] if isinstance(dims, list) else dims,
                    layer_num=i,
                    depth=depth,
                    use_upsample=use_upsample,
                    upsample_out_channels=(
                        upsample_out_channels[i]
                        if isinstance(upsample_out_channels, list)
                        else upsample_out_channels
                    ),
                    zero_init_last=zero_init_last,
                    tuner_cfg=(
                        tuner_cfg[i] if isinstance(tuner_cfg, list) else tuner_cfg
                    ),
                )
                for i in range(depth)
            ]
        )
        self.mark_all_sub_modules_as_plugin()

    def forward(self, x_list, origin_arg, **kwargs):
        if not self.is_activated(self.adapter_name):
            return origin_arg
        x_bypass = detach_tensors(x_list.pop(0))
        x_bypass = x_bypass[0] if isinstance(x_bypass, (list, tuple)) else x_bypass
        x_list = detach_tensors(x_list)
        x_list = [_x[0] if isinstance(_x, (list, tuple)) else _x for _x in x_list]
        for i, (bp_blk, x_stem) in enumerate(zip(self.bypass_blocks, x_list)):
            target_size = x_list[i + 1].shape[2:] if i < len(x_list) - 1 else None
            x_bypass = bp_blk(x_stem, x_bypass, target_size, **kwargs)
        return x_bypass


class ResTunerBypassBlock(nn.Module):

    def __init__(
        self,
        dim,
        layer_num=-1,
        depth=-1,
        use_upsample=False,
        zero_init_last=False,
        tuner_cfg=None,
        **kwargs,
    ):
        super().__init__()
        self.layer_num = layer_num
        self.depth = depth

        if isinstance(tuner_cfg, str):
            lateral_cfg = tuner_cfg
            vertical_cfg = tuner_cfg
            aux_cfg = "upsample" if use_upsample and layer_num != depth - 1 else None
        elif isinstance(tuner_cfg, dict):
            lateral_cfg = (
                tuner_cfg["lateral_cfg"] if "lateral_cfg" in tuner_cfg else None
            )
            vertical_cfg = (
                tuner_cfg["vertical_cfg"] if "vertical_cfg" in tuner_cfg else None
            )
            aux_cfg = tuner_cfg["aux_cfg"] if "aux_cfg" in tuner_cfg else None

        self.lateral_tuner = ResTuner(
            dim, layer_num, depth, zero_init_last, "lateral", lateral_cfg, **kwargs
        )
        self.vertical_tuner = ResTuner(
            dim, layer_num, depth, zero_init_last, "vertical", vertical_cfg, **kwargs
        )
        if aux_cfg and len(aux_cfg) != 0:
            self.aux_tuner = ResTuner(
                dim, layer_num, depth, zero_init_last, "aux", aux_cfg, **kwargs
            )

    def forward(self, x_stem, x_bypass, target_size=None, **kwargs):
        x_lateral = self.lateral_tuner(x_stem)
        x_vertical = self.vertical_tuner(x_bypass)

        x_bypass_out = x_lateral + x_vertical
        if hasattr(self, "aux_tuner"):
            x_bypass_out = self.aux_tuner(x_bypass_out, target_size)
        return x_bypass_out
