import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class SBVILinear(nn.Module):
    """
    Implementation of the Spectral Bridge Variational Inference (SBVI) Layer.

    This layer implements the Factorized Retraction (Sec 3.3) and the
    Adaptive Spectral Friction mechanism (Sec 3.4).
    """

    def __init__(
            self,
            in_features: int,
            out_features: int,
            r_init: int = 64,
            lora_alpha: int = 16,
            lambda_max: float = 100.0,
            gamma: float = 1.0,
            rho: float = 0.99,
            epsilon: float = 1e-6
    ):
        super().__init__()
        self.r = r_init
        self.in_features = in_features
        self.out_features = out_features
        self.scaling = lora_alpha / r_init

        # Hyperparameters for the Spectral Bridge (Table 4 in Appendix)
        self.lambda_max = lambda_max
        self.gamma = gamma
        self.rho = rho
        self.epsilon = epsilon

        # ====================================================================
        # 1. Initialization (Algorithm 1, Line 1)
        # Consistent with Sec 3.3: B is Down-proj (din -> r), A is Up-proj (r -> dout)
        # B ~ N(0, sigma^2), A = 0
        # ====================================================================
        self.lora_B = nn.Parameter(torch.zeros(in_features, r_init))
        self.lora_A = nn.Parameter(torch.zeros(out_features, r_init))

        # Initialize B with Gaussian, A with Zeros
        nn.init.normal_(self.lora_B, mean=0.0, std=1.0)
        nn.init.zeros_(self.lora_A)

        # Freeze the base weight (simulated here, in practice handled by PEFT lib)
        self.base_weight = nn.Parameter(torch.randn(out_features, in_features), requires_grad=False)

        # Spectral Friction Coefficients (lambda_k)
        # Registered as buffer because they are updated via Empirical Bayes rule, not SGD
        self.register_buffer('friction', torch.zeros(r_init))

    def compute_spectral_energy(self):
        """
        Compute E_k = ||a_k||^2 + ||b_k||^2 for each rank k.
        Corresponds to Eq. (14) denominator / Algorithm 1 Line 4.
        """
        # lora_A: [dout, r], lora_B: [din, r]. We sum over the first dimension.
        energy_A = torch.sum(self.lora_A ** 2, dim=0)
        energy_B = torch.sum(self.lora_B ** 2, dim=0)
        return energy_A + energy_B

    def update_friction_stats(self):
        """
        Perform the Empirical Bayes update for spectral friction.
        Corresponds to Eq. (15) and Algorithm 1 Line 5.

        lambda_k^{t+1} <- (1-rho)*lambda_k^{t} + rho * min(lambda_max, gamma*D / (E_k + eps))
        """
        with torch.no_grad():
            E_k = self.compute_spectral_energy()
            D = self.in_features + self.out_features

            # The Empirical Bayes optimal lambda (Eq. 14) with stability clipping
            lambda_optimal = (self.gamma * D) / (E_k + self.epsilon)
            lambda_target = torch.clamp(lambda_optimal, max=self.lambda_max)

            # Temporal Smoothing (EMA)
            self.friction.data = (1 - self.rho) * self.friction.data + \
                                 self.rho * lambda_target.data

    def get_regularization_loss(self):
        """
        Compute the surrogate regularization term for the ELBO.
        Corresponds to the regularization term in Eq. (13) / Algorithm 1 Line 7.

        Reg = sum_k (lambda_k / 2) * (||a_k||^2 + ||b_k||^2)
        """
        E_k = self.compute_spectral_energy()
        # Note: self.friction is treated as a constant (detached) during backprop
        # because it is updated via the Empirical Bayes step, not gradient descent.
        reg_term = torch.sum(0.5 * self.friction.detach() * E_k)
        return reg_term

    def forward(self, x):
        # Standard LoRA forward pass: W = W0 + B @ A.T
        # Note: B is [din, r], A is [dout, r]
        # x @ B @ A.T -> [batch, din] @ [din, r] @ [r, dout] -> [batch, dout]

        base_out = F.linear(x, self.base_weight)
        lora_out = (x @ self.lora_B @ self.lora_A.T) * self.scaling
        return base_out + lora_out


# ====================================================================
# Usage Example (Pseudo-code for Training Loop)
# ====================================================================
if __name__ == "__main__":
    # 1. Setup Model
    model = SBVILinear(in_features=4096, out_features=4096, r_init=64)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

    # Dummy Data
    x = torch.randn(16, 4096)
    target = torch.randn(16, 4096)

    # 2. Training Loop (Algorithm 1)
    for step in range(100):
        optimizer.zero_grad()

        # --- Step A: Update Friction (Algorithm 1, Line 5) ---
        # This aligns the prior geometry with the current signal energy
        model.update_friction_stats()

        # --- Step B: Forward & Loss (Algorithm 1, Line 6-7) ---
        prediction = model(x)
        task_loss = F.mse_loss(prediction, target)

        # Add the Spectral Bridge Regularization
        reg_loss = model.get_regularization_loss()

        total_loss = task_loss + reg_loss

        # --- Step C: Backward & Update (Algorithm 1, Line 8-9) ---
        total_loss.backward()
        optimizer.step()

        if step % 10 == 0:
            # Monitoring: Check effective rank
            # Ranks with high friction will be driven to zero
            active_ranks = (model.friction < model.lambda_max * 0.9).sum()
            print(f"Step {step}: Loss={total_loss.item():.4f}, Active Ranks={active_ranks}")