"""
Sine LoRA implementation based on the paper "Stable Forgetting: Bounded Parameter-Efficient Unlearning in LLMs"
and the yipingji/Sine-Low-Rank repository.

Mathematical formulation: sin(ωAB^T) where:
- ω (omega) is the frequency parameter
- A and B are low-rank matrices
- sine function provides bounded activation for stable gradient ascent
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union, List, Any
from peft.tuners.lora.layer import LoraLayer
from peft.utils import transpose
from peft import LoraConfig


class SineLoraLinear(nn.Module, LoraLayer):
    """
    Sine LoRA layer that implements sin(ωAB^T) transformation.
    Based on the mathematical formulation from the paper.
    """

    def __init__(
        self,
        base_layer: nn.Module,
        adapter_name: str,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        init_lora_weights: Union[bool, str] = True,
        use_rslora: bool = False,
        use_dora: bool = False,
        freq: int = 1,
        s: float = 1.0,
        **kwargs,
    ) -> None:
        super().__init__()
        LoraLayer.__init__(self, base_layer, **kwargs)

        # Sine-specific parameters
        self.freq = freq  # ω (omega) frequency parameter
        self.s = s        # scale parameter for normalization

        # Store initial parameters
        self._r = r
        self._lora_alpha = lora_alpha
        self._scaling = lora_alpha / r if r > 0 else 1.0
        self._lora_dropout = lora_dropout

        # Initialize adapters (using ModuleDict for compatibility)
        self.lora_A = nn.ModuleDict({})
        self.lora_B = nn.ModuleDict({})
        self.r = {}
        self.lora_alpha = {}
        self.scaling = {}
        self.lora_dropout = nn.ModuleDict({})

        # Handle different layer types
        self.in_features = getattr(base_layer, "in_features", None)
        self.out_features = getattr(base_layer, "out_features", None)

        if self.in_features is None:
            self.in_features = base_layer.weight.shape[1]
        if self.out_features is None:
            self.out_features = base_layer.weight.shape[0]

        # Set up active adapters
        self.active_adapters = [adapter_name]

        # Initialize the adapter
        self.update_layer(
            adapter_name=adapter_name,
            r=r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            init_lora_weights=init_lora_weights,
            use_rslora=use_rslora,
            use_dora=use_dora,
        )

    def update_layer(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float,
                     init_lora_weights: Union[bool, str], use_rslora: bool, use_dora: bool):
        """Initialize the LoRA adapter matrices A and B."""
        if r <= 0:
            raise ValueError(f"r must be positive, got {r}")

        self.r[adapter_name] = r
        self.lora_alpha[adapter_name] = lora_alpha
        self.scaling[adapter_name] = lora_alpha / r
        self.lora_dropout.update(nn.ModuleDict({adapter_name: nn.Dropout(p=lora_dropout)}))

        # Initialize A and B matrices
        self.lora_A.update(nn.ModuleDict({adapter_name: nn.Linear(self.in_features, r, bias=False)}))
        self.lora_B.update(nn.ModuleDict({adapter_name: nn.Linear(r, self.out_features, bias=False)}))

        if init_lora_weights == "loftq":
            self.loftq_init(adapter_name)
        elif init_lora_weights:
            self.reset_lora_parameters(adapter_name, init_lora_weights)

    def reset_lora_parameters(self, adapter_name: str, init_lora_weights: Union[bool, str]):
        """Initialize LoRA parameters following standard initialization."""
        if init_lora_weights is False:
            return

        if adapter_name in self.lora_A:
            if init_lora_weights is True:
                # Standard initialization: A ~ N(0, σ²), B = 0
                nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
            elif init_lora_weights.lower() == "gaussian":
                nn.init.normal_(self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name])
            else:
                raise ValueError(f"Unknown init_lora_weights: {init_lora_weights}")

        if adapter_name in self.lora_B:
            # B matrix initialized to zero
            nn.init.zeros_(self.lora_B[adapter_name].weight)

    def sine_transformation(self, A_weight: torch.Tensor, B_weight: torch.Tensor) -> torch.Tensor:
        """
        Apply the sine transformation: sin(ωAB^T)

        Args:
            A_weight: Matrix A of shape (in_features, r)
            B_weight: Matrix B of shape (r, out_features)

        Returns:
            Transformed weight matrix of shape (in_features, out_features)
        """
        # Compute AB^T: (in_features, r) @ (out_features, r)^T = (in_features, out_features)
        AB_T = A_weight @ B_weight.T

        # Apply sine transformation: sin(ω * AB^T)
        sine_AB_T = torch.sin(self.freq * AB_T)

        return sine_AB_T

    def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
        """Forward pass implementing sine LoRA transformation."""
        previous_dtype = x.dtype

        # Base layer forward pass
        result = self.base_layer(x, *args, **kwargs)

        # Apply sine LoRA adapters
        for active_adapter in self.active_adapters:
            if active_adapter not in self.lora_A:
                continue

            lora_A = self.lora_A[active_adapter]
            lora_B = self.lora_B[active_adapter]
            dropout = self.lora_dropout[active_adapter]
            scaling = self.scaling[active_adapter]

            # Convert input to appropriate dtype
            x = x.to(lora_A.weight.dtype)

            # Apply dropout
            x_dropped = dropout(x)

            # Get A and B weight matrices
            A_weight = lora_A.weight.T  # Shape: (in_features, r)
            B_weight = lora_B.weight    # Shape: (r, out_features)

            # Apply sine transformation: sin(ωAB^T)
            sine_weight = self.sine_transformation(A_weight, B_weight)

            # Apply the sine-transformed weight: x @ sin(ωAB^T)
            sine_output = x_dropped @ sine_weight

            # Scale and normalize: / s * scaling
            sine_output = (sine_output / self.s) * scaling

            # Add to result
            result = result + sine_output

        result = result.to(previous_dtype)
        return result

    def __repr__(self) -> str:
        rep = f"SineLoraLinear(in_features={self.in_features}, out_features={self.out_features}, "
        rep += f"r={self.r}, freq={self.freq}, s={self.s})"
        return rep


def dispatch_sine_lora(
    target: nn.Module,
    adapter_name: str,
    lora_config: LoraConfig,
    **kwargs: Any,
) -> Optional[nn.Module]:
    """
    Dispatch function to create appropriate sine LoRA layer.
    """
    new_module = None

    if isinstance(target, nn.Linear):
        new_module = SineLoraLinear(target, adapter_name, **kwargs)

    return new_module