"""
Simplified Sine LoRA implementation based on sin(ωAB^T) from the paper.
"""

import math
import torch
import torch.nn as nn
from typing import Any


class SineLoraLinear(nn.Module):
    """
    Simple Sine LoRA layer implementing sin(ωAB^T) transformation.
    """

    def __init__(
        self,
        base_layer: nn.Linear,
        r: int = 4,
        lora_alpha: int = 8,
        lora_dropout: float = 0.0,
        freq: int = 100,
        s: float = 45.25,
    ):
        super().__init__()

        self.base_layer = base_layer
        self.r = r
        self.lora_alpha = lora_alpha
        self.freq = freq  # ω (omega) frequency parameter
        self.s = s        # scale parameter
        self.scaling = lora_alpha / r

        # LoRA matrices A and B
        self.lora_A = nn.Linear(base_layer.in_features, r, bias=False)
        self.lora_B = nn.Linear(r, base_layer.out_features, bias=False)
        self.lora_dropout = nn.Dropout(p=lora_dropout)

        # Initialize weights
        self.reset_parameters()

        # Freeze base layer
        for param in base_layer.parameters():
            param.requires_grad = False

    def reset_parameters(self):
        """Initialize LoRA parameters."""
        # Standard LoRA initialization: A ~ N(0, σ²), B = 0
        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass implementing sin(ωAB^T) transformation.

        Formula: output = base_layer(x) + (x @ sin(ω * A^T @ B)) / s * scaling
        """
        # Base layer output
        result = self.base_layer(x)

        # Apply dropout to input
        x_dropped = self.lora_dropout(x)

        # Get weight matrices
        A_weight = self.lora_A.weight.T  # Shape: (in_features, r)
        B_weight = self.lora_B.weight.T  # Shape: (r, out_features)

        # Compute A^T @ B: (in_features, r) @ (r, out_features) = (in_features, out_features)
        AB = A_weight @ B_weight

        # Apply sine transformation: sin(ω * AB)
        sine_AB = torch.sin(self.freq * AB)

        # Apply sine LoRA: x @ sin(ωAB) / s * scaling
        sine_output = x_dropped @ sine_AB
        sine_output = (sine_output / self.s) * self.scaling

        # Add to base output
        result = result + sine_output

        return result

    def __repr__(self):
        return (f"SineLoraLinear(in_features={self.base_layer.in_features}, "
                f"out_features={self.base_layer.out_features}, r={self.r}, "
                f"freq={self.freq}, s={self.s})")


def apply_sine_lora_to_model(model: nn.Module, config):
    """Apply sine LoRA to target modules in the model."""

    target_modules = config.target_modules
    if isinstance(target_modules, str):
        target_modules = [target_modules]

    modules_replaced = 0

    def replace_modules(module, prefix=""):
        nonlocal modules_replaced

        for name, child in list(module.named_children()):
            full_name = f"{prefix}.{name}" if prefix else name

            # Check if this module should be replaced
            should_replace = False
            if target_modules:
                for target in target_modules:
                    if target in full_name or target == name:
                        should_replace = True
                        break

            if should_replace and isinstance(child, nn.Linear):
                # Create sine LoRA layer
                sine_layer = SineLoraLinear(
                    base_layer=child,
                    r=config.r,
                    lora_alpha=config.lora_alpha,
                    lora_dropout=config.lora_dropout,
                    freq=config.freq,
                    s=config.s,
                )

                # Replace the module
                setattr(module, name, sine_layer)
                modules_replaced += 1
                print(f"Applied SineLoRA to {full_name}")

            # Recursively apply to child modules
            replace_modules(child, full_name)

    replace_modules(model)
    print(f"Total modules replaced with SineLoRA: {modules_replaced}")

    return model