import math
import warnings
from typing import Optional, List

import torch
import torch.nn.functional as F
from torch import nn

from transformers import Conv1D

from peft.tuners.grouplora.config import GroupLoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.other import transpose


class GroupLoraLayer(BaseTunerLayer):
    adapter_layer_names = ("lora_E", "lora_gateA", "lora_gateB")
    other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout")

    def __init__(self, base_layer: nn.Module) -> None:
        self.base_layer = base_layer
        self.r = {}
        self.lora_alpha = {}
        self.scaling = {}
        self.lora_dropout = nn.ModuleDict({})
        self.lora_E = nn.ParameterDict({})
        self.lora_gateA = nn.ParameterDict({})
        self.lora_gateB = nn.ParameterDict({})

        base_layer = self.get_base_layer()
        if isinstance(base_layer, nn.Linear):
            in_features, out_features = base_layer.in_features, base_layer.out_features
        elif isinstance(base_layer, Conv1D):
            in_features, out_features = (
                base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
            )
        else:
            if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"):
                in_features, out_features = base_layer.in_features, base_layer.out_features
            else:
                in_features, out_features = None, None
        self.in_features = in_features
        self.out_features = out_features
        self.merged_adapters = []

    def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, group_id, num_groups):
        self.r[adapter_name] = r
        self.lora_alpha[adapter_name] = lora_alpha
        self.scaling[adapter_name] = lora_alpha / r
        if lora_dropout > 0.0:
            self.lora_dropout[adapter_name] = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout[adapter_name] = nn.Identity()

        self.lora_gateA[adapter_name] = nn.Parameter(torch.zeros(num_groups))
        self.lora_gateB[adapter_name] = nn.Parameter(torch.zeros(num_groups))

        self.lora_gateA[adapter_name].data[group_id] = 1
        self.lora_gateB[adapter_name].data[group_id] = 1

        self.lora_E[adapter_name] = nn.Parameter(torch.Tensor(r, r))
        nn.init.kaiming_uniform_(self.lora_E[adapter_name], a=math.sqrt(5))


class GroupLoRALinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 0,
    ):
        super().__init__()
        if r > 0:
            self.lora_A = nn.Parameter(torch.Tensor(in_features, r))
            self.lora_B = nn.Parameter(torch.Tensor(r, out_features))
            self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)


class Linear(nn.Module, GroupLoraLayer):
    def __init__(
            self,
            base_layer: nn.Module,
            adapter_name: str,
            config: GroupLoraConfig = None,
            shared_group: nn.ModuleList = None,
            layer_idx: int = 0,
    ):
        super().__init__()
        GroupLoraLayer.__init__(self, base_layer)
        self.fan_in_fan_out = config.fan_in_fan_out

        group_id = layer_idx // config.group_size
        num_groups = math.ceil(config.num_layers / config.group_size)
        self.shared_group = shared_group

        self._active_adapter = adapter_name
        self.config = config
        self.update_layer(
            adapter_name=adapter_name,
            r=config.r,
            lora_alpha=config.alpha,
            lora_dropout=config.dropout,
            group_id=group_id,
            num_groups=num_groups,
        )

    def _get_gated_weights(self, adapter_name):
        lora_gateA = F.softmax(self.lora_gateA[adapter_name], dim=-1)
        lora_gateB = F.softmax(self.lora_gateB[adapter_name], dim=-1)
        combined_A = sum(g * group.lora_A for g, group in zip(lora_gateA, self.shared_group))
        combined_B = sum(g * group.lora_B for g, group in zip(lora_gateB, self.shared_group))

        return combined_A, combined_B

    def get_delta_weight(self, adapter):
        combined_A, combined_B = self._get_gated_weights(adapter)
        weight_E = self.lora_E[adapter].data
        delta = combined_A @ weight_E @ combined_B
        output_tensor = transpose(delta.T, self.fan_in_fan_out) * self.scaling[adapter]
        return output_tensor

    def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
        adapter_names = self.active_adapters

        for active_adapter in adapter_names:
            base_layer = self.get_base_layer()
            if active_adapter in self.lora_E.keys():
                if safe_merge:
                    # Note that safe_merge will be slower than the normal merge
                    # because of the copy operation.
                    orig_weights = base_layer.weight.data.clone()
                    orig_weights += self.get_delta_weight(active_adapter)

                    if not torch.isfinite(orig_weights).all():
                        raise ValueError(
                            f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
                        )

                    base_layer.weight.data = orig_weights
                else:
                    base_layer.weight.data += self.get_delta_weight(active_adapter)
                self.merged_adapters.append(active_adapter)

    def unmerge(self) -> None:
        if not self.merged:
            warnings.warn("Already unmerged. Nothing to do.")
            return
        while len(self.merged_adapters) > 0:
            active_adapter = self.merged_adapters.pop()
            if active_adapter in self.lora_E.keys():
                self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)

    def forward(self, x: torch.Tensor, *args, **kwargs):
        if self.disable_adapters:
            if self.merged:
                self.unmerge()
            result = self.base_layer(x, *args, **kwargs)
        elif self.merged:
            result = self.base_layer(x, *args, **kwargs)
        else:
            result = self.base_layer(x, *args, **kwargs)
            for active_adapter in self.active_adapters:
                if active_adapter not in self.lora_E.keys():
                    continue
                combined_A, combined_B = self._get_gated_weights(active_adapter)
                lora_E = self.lora_E[active_adapter]
                dropout = self.lora_dropout[active_adapter]
                scaling = self.scaling[active_adapter]
                result += dropout(x) @ combined_A @ lora_E @ combined_B * scaling

        return result
