import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class LoRALinear(nn.Module):
    more_target_lens = None
    more_use_masked_z = True
    more_batch_size = None

    def __init__(
        self,
        base_layer: nn.Linear,
        rank: int,
        lora_alpha: float,
        lora_dropout: float,
        more_nonlinear: str = "none",
        more_layer_norm: bool = False,
        more_layer_norm_eps: float = 1e-5,
    ):
        super().__init__()
        if not isinstance(base_layer, nn.Linear):
            raise TypeError(f"LoRALinear expects nn.Linear, got {type(base_layer)}")

        self.base_layer = base_layer
        self.rank = rank
        self.lora_alpha = lora_alpha
        self.scaling = lora_alpha / rank
        self.lora_dropout = nn.Dropout(lora_dropout) if lora_dropout > 0.0 else nn.Identity()
        self.more_nonlinear = (more_nonlinear or "none").lower()
        self.more_activation = self._resolve_activation(self.more_nonlinear)
        self.more_ln = None
        if more_layer_norm:
            self.more_ln = nn.LayerNorm(rank, elementwise_affine=False, eps=more_layer_norm_eps)
            self.more_ln.to(dtype=base_layer.weight.dtype, device=base_layer.weight.device)

        in_features = base_layer.in_features
        out_features = base_layer.out_features
        self.lora_A = nn.Linear(in_features, rank, bias=False)
        self.lora_B = nn.Linear(rank, out_features, bias=False)

        self.lora_A.to(dtype=base_layer.weight.dtype, device=base_layer.weight.device)
        self.lora_B.to(dtype=base_layer.weight.dtype, device=base_layer.weight.device)

        with torch.no_grad():
            if self.lora_A.weight.dtype in (torch.float16, torch.bfloat16):
                tmp = torch.empty_like(self.lora_A.weight, dtype=torch.float32)
                nn.init.orthogonal_(tmp)
                self.lora_A.weight.copy_(tmp.to(dtype=self.lora_A.weight.dtype))
            else:
                nn.init.orthogonal_(self.lora_A.weight)
        nn.init.zeros_(self.lora_B.weight)

        for p in self.base_layer.parameters():
            p.requires_grad = False

        self.register_buffer(
            "more_P",
            torch.eye(rank, device=base_layer.weight.device, dtype=torch.float32),
        )
        self.register_buffer(
            "more_last_z",
            torch.zeros(rank, device=base_layer.weight.device, dtype=torch.float32),
        )

        self.more_group = None
        self.more_name = None

    @staticmethod
    def _resolve_activation(name: str):
        if name in ("none", "", None):
            return None
        if name == "gelu":
            return F.gelu
        if name == "relu":
            return F.relu
        if name == "tanh":
            return torch.tanh
        if name in ("silu", "swish"):
            return F.silu
        raise ValueError(f"Unsupported LoRA nonlinearity: {name}")

    def forward(self, x):
        base_out = self.base_layer(x)
        ax = self.lora_A(self.lora_dropout(x))
        if self.more_activation is not None:
            ax = self.more_activation(ax)
        if self.more_ln is not None:
            ax = self.more_ln(ax)
        bx = self.lora_B(ax)

        with torch.no_grad():
            z = ax
            use_masked = (
                LoRALinear.more_use_masked_z
                and self.more_group != "vision_proj"
                and isinstance(LoRALinear.more_target_lens, (list, tuple))
            )
            if use_masked:
                z_view = None
                batch_size = None
                if z.dim() >= 3:
                    z_view = z
                    batch_size = z.size(0)
                    seq_len = z.size(1)
                elif z.dim() == 2:
                    # OPT decoder flattens sequence into 2D before MLP.
                    bsz = getattr(LoRALinear, "more_batch_size", None)
                    if bsz is not None and bsz > 0 and z.size(0) % bsz == 0:
                        seq_len = z.size(0) // bsz
                        z_view = z.view(bsz, seq_len, z.size(-1))
                        batch_size = bsz

                if z_view is not None and len(LoRALinear.more_target_lens) == batch_size:
                    mask = torch.zeros(
                        batch_size, seq_len, dtype=torch.bool, device=z.device
                    )
                    for i, tlen in enumerate(LoRALinear.more_target_lens):
                        if tlen is None or tlen <= 0:
                            continue
                        start = max(0, seq_len - int(tlen))
                        mask[i, start:seq_len] = True
                    if mask.any():
                        masked = z_view[mask].view(-1, z_view.size(-1))
                        z = masked.mean(dim=0)
                    else:
                        z = z_view.mean(dim=(0, 1))
                else:
                    while z.dim() > 1:
                        z = z.mean(dim=0)
            else:
                while z.dim() > 1:
                    z = z.mean(dim=0)
            self.more_last_z.copy_(z.detach().to(self.more_last_z.dtype))

        return base_out + self.scaling * bx

    def more_update(self, eta: float, rls_lambda: float):
        gA = self.lora_A.weight.grad
        gB = self.lora_B.weight.grad
        if gA is None and gB is None:
            return

        z = self.more_last_z
        if z is None:
            return

        P = self.more_P
        z = z.to(dtype=torch.float32)
        P = P.to(dtype=torch.float32)

        denom = rls_lambda + torch.dot(z, P @ z)
        if denom.abs().item() < 1e-12:
            return

        Pz = P @ z
        P = P - torch.outer(Pz, Pz) / denom
        self.more_P.copy_(P)

        with torch.no_grad():
            if gA is not None:
                delta_A = (P @ gA.float()) * eta
                self.lora_A.weight.add_(-delta_A.to(self.lora_A.weight.dtype))
            if gB is not None:
                delta_B = (gB.float() @ P) * eta
                self.lora_B.weight.add_(-delta_B.to(self.lora_B.weight.dtype))
