
import torch
import torch.nn as nn
import torch.nn.functional as F

def qr_orthonormalize(W: torch.Tensor) -> torch.Tensor:
    """
    Given arbitrary matrix W (d x k) with full column rank, returns Q with orthonormal columns.
    We use QR with retraction; sign correction ensures deterministic orientation.
    """
    # reduced QR
    Q, R = torch.linalg.qr(W, mode='reduced')
    # fix sign to avoid Q drifting when R has negative diagonals
    diag = torch.sign(torch.diag(R))
    diag[diag==0] = 1.0
    Q = Q @ torch.diag(diag)
    return Q

def orthonormal_complement(Q: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """
    Compute an orthonormal complement R such that [Q, R] is an orthonormal basis of R^d.
    Q: (d x k) with orthonormal columns (Q^T Q = I)
    Returns R: (d x (d-k)) with orthonormal columns and Q^T R = 0.
    """
    d, k = Q.shape
    I = torch.eye(d, device=Q.device, dtype=Q.dtype)
    P_perp = I - Q @ Q.t()
    # QR on the projection of identity gives a stable complement basis
    # Note: P_perp may be rank-deficient; select the last (d-k) columns for stability.
    Qc, Rc = torch.linalg.qr(P_perp, mode='complete')  # Qc: d x d
    # Select columns corresponding to the nullspace dimension (d - k)
    R_comp = Qc[:, d - (d - k):]  # last (d-k) columns
    # Safety: ensure orthogonality numerically
    return R_comp


class LearnableQ(nn.Module):
    def __init__(self, s_dim: int, d: int, k: int, hidden: int = 128):
        super().__init__()
        self.d, self.k = d, k
        self.mlp = nn.Sequential(
            nn.Linear(s_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, d * k)
        )

    def forward(self, s: torch.Tensor, return_R: bool = False):
        """
        s: (batch, s_dim)
        return: Q: (batch, d, k)
        """
        W = self.mlp(s)                     # (batch, d*k)
        W = W.view(-1, self.d, self.k)      # (batch, d, k)
        Q, _ = torch.linalg.qr(W, mode='reduced')  # (batch, d, k)
        return Q