# Copyright (c) Alibaba, Inc. and its affiliates.
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from swift.utils.logger import get_logger

logger = get_logger()


class ResTuner(nn.Module):

    def __init__(
        self,
        dim=None,
        layer_num=-1,
        depth=-1,
        zero_init_last=False,
        stage="",
        tuner_cfg={},
        **kwargs
    ):
        super().__init__()
        self.dim = dim
        self.layer_num = layer_num
        self.depth = depth
        self.stage = stage
        self.tuner_cfg = tuner_cfg

        if (isinstance(tuner_cfg, str) and tuner_cfg == "res_adapter") or (
            isinstance(tuner_cfg, dict) and "res_adapter" in tuner_cfg
        ):
            tuner_cfg = (
                tuner_cfg["res_adapter"] if isinstance(tuner_cfg, dict) else tuner_cfg
            )
            self.tuner = ResAdapter(
                dim=dim,
                layer_num=layer_num,
                depth=depth,
                zero_init_last=zero_init_last,
                stage=stage,
                tuner_cfg=tuner_cfg,
                **kwargs
            )
        elif (isinstance(tuner_cfg, str) and tuner_cfg == "res_group_adapter") or (
            isinstance(tuner_cfg, dict) and "res_group_adapter" in tuner_cfg
        ):
            tuner_cfg = (
                tuner_cfg["res_group_adapter"]
                if isinstance(tuner_cfg, dict)
                else tuner_cfg
            )
            self.tuner = ResGroupAdapter(
                dim=dim,
                layer_num=layer_num,
                depth=depth,
                zero_init_last=zero_init_last,
                stage=stage,
                tuner_cfg=tuner_cfg,
                **kwargs
            )
        elif (isinstance(tuner_cfg, str) and tuner_cfg == "upsample") or (
            isinstance(tuner_cfg, dict) and "upsample" in tuner_cfg
        ):
            tuner_cfg = (
                tuner_cfg["upsample"] if isinstance(tuner_cfg, dict) else tuner_cfg
            )
            if "upsample_out_channels" in kwargs:
                out_channels = kwargs["upsample_out_channels"]
                use_conv = True if out_channels else False
            else:
                out_channels = dim
                use_conv = False
            self.tuner = Upsample(
                channels=dim,
                use_conv=use_conv,
                out_channels=out_channels,
                tuner_cfg=tuner_cfg,
                **kwargs
            )
        else:
            self.tuner = Identity()

    def forward(self, x, *args, **kwargs):
        if self.tuner_cfg == "zero" or "zero" in self.tuner_cfg:
            x_out = 0.0
        else:
            x_out = self.tuner(x, *args, **kwargs)
        return x_out


class ResAdapter(nn.Module):

    def __init__(
        self,
        dim,
        layer_num=-1,
        depth=-1,
        zero_init_last=False,
        stage="",
        tuner_cfg=None,
        act_layer=nn.GELU,
        **kwargs
    ):
        super(ResAdapter, self).__init__()
        self.dim = dim
        self.layer_num = layer_num
        self.depth = depth

        self.adapter_length = (
            tuner_cfg["adapter_length"] if "adapter_length" in tuner_cfg else 32
        )
        self.adapter_type = (
            tuner_cfg["adapter_type"] if "adapter_type" in tuner_cfg else None
        )
        self.adapter_weight = (
            tuner_cfg["adapter_weight"] if "adapter_weight" in tuner_cfg else None
        )

        self.adapter_length = (
            self.adapter_length[self.layer_num]
            if isinstance(self.adapter_length, list)
            else self.adapter_length
        )
        assert isinstance(self.adapter_length, int) or (
            isinstance(self.adapter_length, tuple) and len(self.adapter_length) == 3
        )
        if isinstance(self.adapter_length, int):
            self.ln1 = nn.Linear(dim, self.adapter_length)
        else:
            self.ln1 = nn.Linear(self.adapter_length[0], self.adapter_length[1])
        self.activate = act_layer()
        if isinstance(self.adapter_length, int):
            self.ln2 = nn.Linear(self.adapter_length, dim)
        else:
            self.ln2 = nn.Linear(self.adapter_length[1], self.adapter_length[2])
            dim = self.adapter_length[2]

        self._xavier_init_weights(self.ln1)
        if zero_init_last and layer_num == depth - 1:
            self._zero_init_weights(self.ln2)
        else:
            self._xavier_init_weights(self.ln2)

        self.scaling = init_weight_type(dim, self.adapter_weight)
        self._prepared = False

    def _zero_init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.zeros_(m.weight)
            nn.init.zeros_(m.bias)

    def _kaiming_init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
            nn.init.normal_(m.bias)

    def _xavier_init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.normal_(m.bias, std=1e-6)

    def forward(self, x):
        if not self._prepared:
            self.ln1.to(x.device)
            self.activate.to(x.device)
            self.ln2.to(x.device)
            self._prepared = True

        x_dtype = x.dtype
        x = x.to(self.ln1.weight.dtype)
        x_shortcut = x
        if len(x_shortcut.size()) == 4:
            B, C, N1, N2 = x.size()
            x = x.view(x_shortcut.size()[0], x_shortcut.size()[1], -1).permute(0, 2, 1)

        x_adapter = self.ln2(self.activate(self.ln1(x)))

        if self.adapter_weight:
            x_adapter = apply_data_weight(x_adapter, self.scaling, self.adapter_weight)

        if len(x_shortcut.size()) == 4:
            x_adapter = x_adapter.permute(0, 2, 1).view(
                x_shortcut.size()[0],
                x_adapter.size()[-1],
                x_shortcut.size()[2],
                x_shortcut.size()[3],
            )
        x_out = x_shortcut + x_adapter
        return x_out.to(x_dtype)


class ResGroupAdapter(nn.Module):

    def __init__(
        self,
        dim,
        layer_num=-1,
        depth=-1,
        zero_init_last=False,
        stage="",
        tuner_cfg=None,
        act_layer=nn.GELU,
        **kwargs
    ):
        super(ResGroupAdapter, self).__init__()
        self.dim = dim
        self.layer_num = layer_num
        self.depth = depth

        self.adapter_type = (
            tuner_cfg["adapter_type"] if "adapter_type" in tuner_cfg else None
        )
        self.adapter_weight = (
            tuner_cfg["adapter_weight"] if "adapter_weight" in tuner_cfg else None
        )

        self.adapter_dim = tuner_cfg["dim"] if "dim" in tuner_cfg else dim
        self.adapter_head = tuner_cfg["head"] if "head" in tuner_cfg else 4
        self.adapter_scale_factor = (
            tuner_cfg["scale_factor"] if "scale_factor" in tuner_cfg else 2
        )

        assert (
            self.adapter_dim % self.adapter_head == 0
        ), "adapter dim should be divisible by adapter head"
        self.dim_mlp = self.adapter_dim // self.adapter_head

        self.ln1 = nn.Linear(self.dim_mlp, self.dim_mlp * self.adapter_scale_factor)
        self.ln2 = nn.Linear(self.dim_mlp * self.adapter_scale_factor, self.dim_mlp)
        self.activate = act_layer()

        self._kaiming_init_weights(self.ln1)
        if zero_init_last and layer_num == depth - 1:
            self._zero_init_weights(self.ln2)
        else:
            self._kaiming_init_weights(self.ln2)
        self.scaling = init_weight_type(dim, self.adapter_weight)
        self._prepared = False

    def _zero_init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.zeros_(m.weight)
            nn.init.zeros_(m.bias)

    def _kaiming_init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
            nn.init.normal_(m.bias)

    def _xavier_init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.normal_(m.bias, std=1e-6)

    def forward(self, x):
        if not self._prepared:
            self.ln1.to(x.device)
            self.activate.to(x.device)
            self.ln2.to(x.device)
            self._prepared = True

        x_dtype = x.dtype
        x = x.to(self.ln1.weight.dtype)
        x_shortcut = x

        batch, inner_dim, height, width = x.shape

        x_adapter = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)

        x_adapter = rearrange(x_adapter, "b n (c h) -> (b h) n c", h=self.adapter_head)
        x_adapter = self.ln2(self.activate(self.ln1(x_adapter)))
        x_adapter = rearrange(x_adapter, "(b h) n c -> b n (c h)", h=self.adapter_head)

        if self.adapter_weight:
            x_adapter = apply_data_weight(x_adapter, self.scaling, self.adapter_weight)

        x_adapter = (
            x_adapter.reshape(batch, height, width, -1).permute(0, 3, 1, 2).contiguous()
        )
        x_out = x_shortcut + x_adapter

        return x_out.to(x_dtype)


class Identity(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, inputs, *args, **kwargs):
        return inputs


class Upsample(nn.Module):
    """
    An upsampling layer with an optional convolution.
    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 upsampling occurs in the inner-two dimensions.
    """

    def __init__(
        self, channels, use_conv=False, out_channels=None, padding=1, **kwargs
    ):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        if use_conv:
            self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding)
        self.init_weights()

    def init_weights(self):

        def _init_weights(m):
            if isinstance(m, nn.Conv2d):
                nn.init.zeros_(m.weight)
                nn.init.zeros_(m.bias)

        self.apply(_init_weights)

    def forward(self, x, target_size=None, *args, **kwargs):
        assert x.shape[1] == self.channels
        if target_size is None:
            x = F.interpolate(x.float(), scale_factor=2, mode="nearest").type_as(x)
        else:
            x = F.interpolate(x.float(), target_size, mode="nearest").type_as(x)
        if self.use_conv:
            x = self.conv(x)
        return x


def init_weight_type(dim, weight_type):
    if weight_type is None:
        scaling = None
    elif weight_type == "gate":
        scaling = nn.Linear(dim, 1)
    elif weight_type == "scale":
        scaling = nn.Parameter(torch.Tensor(1))
        scaling.data.fill_(1)
    elif weight_type == "scale_kv":
        scaling_k = nn.Parameter(torch.Tensor(1))
        scaling_k.data.fill_(1)
        scaling_v = nn.Parameter(torch.Tensor(1))
        scaling_v.data.fill_(1)
        scaling = (scaling_k, scaling_v)
    elif weight_type == "scale_channel":
        scaling = nn.Parameter(torch.Tensor(dim))
        scaling.data.fill_(1)
    elif weight_type == "scale_kv_channel":
        scaling_k = nn.Parameter(torch.Tensor(dim))
        scaling_k.data.fill_(1)
        scaling_v = nn.Parameter(torch.Tensor(dim))
        scaling_v.data.fill_(1)
        scaling = (scaling_k, scaling_v)
    elif weight_type and weight_type.startswith("scalar"):
        scaling = float(weight_type.split("_")[-1])
    else:
        scaling = None
    return scaling


def apply_data_weight(data, scaling, weight_type):
    if weight_type in ["gate"]:
        scaling = torch.mean(torch.sigmoid(scaling(data)), dim=1).view(-1, 1, 1)
    elif weight_type in ["scale", "scale_channel"] or weight_type.startswith("scalar"):
        scaling = scaling
    else:
        scaling = None
    if scaling is not None:
        data = data * scaling
    return data


def detach_tensors(feats):
    if type(feats) in [list, tuple]:
        feats = [detach_tensors(feat) if feat is not None else None for feat in feats]
    elif isinstance(feats, dict):
        feats = {key: detach_tensors(val) for key, val in feats.items()}
    elif isinstance(feats, torch.Tensor):
        feats = feats.detach()
    else:
        feats = feats.detach()
    return feats


def probe_tensors(module, feats, name):
    feats = detach_tensors(feats)
    setattr(module, name, feats)


def probe_input_pre_hook(self, args):
    input = args[0]
    probe_tensors(self, input, "probe_input_data")
    return args


def probe_output_hook(self, args, result):
    output = result
    probe_tensors(self, output, "probe_output_data")
    return output
