import math
import torch

from typing import Tuple, Optional


def _chi(df: torch.Tensor) -> torch.Tensor:
    """
    Sample Chi(df) = sqrt(ChiSquare(df)).
    Supports non-integer df > 0 via Gamma sampling:
        ChiSquare(df) ~ Gamma(df/2, rate=1/2)
    """
    if torch.any(df <= 0):
        raise ValueError("All degrees of freedom must be > 0.")
    gamma = torch.distributions.Gamma(concentration=df / 2.0,
                                      rate=torch.full_like(df, 0.5))
    return gamma.sample().sqrt()


def dumitriu_edelman_beta_laguerre_bidiagonal(
    N: int,
    D: int,
    beta: float,
    *,
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float64,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Dumitriu–Edelman bidiagonal model for the beta-Laguerre ensemble.

    Returns vectors (diag, sub) describing the LOWER-bidiagonal matrix B of size N×N:
        B[i, i]   = Chi(beta*(D - i)) / sqrt(beta)          for i=0..N-1
        B[i, i-1] = Chi(beta*(N - i)) / sqrt(beta)          for i=1..N-1

    With these conventions, the (unnormalized) Laguerre/Wishart tridiagonal is:
        L = B @ B.T
    and scaling by 1/D gives the covariance-normalized version aligned with (E.1).
    """
    if D < N:
        raise ValueError(f"Require D >= N for the (D, N) Laguerre/Wishart setup; got D={D}, N={N}.")
    if beta <= 0:
        raise ValueError("beta must be > 0.")

    device = device or torch.device("cpu")
    beta_t = torch.tensor(beta, device=device, dtype=dtype)

    # diag dof: beta*D, beta*(D-1), ..., beta*(D-N+1)
    i = torch.arange(N, device=device, dtype=dtype)
    df_diag = beta_t * (D - i)

    # subdiag dof: beta*(N-1), ..., beta*1
    df_sub = beta_t * torch.arange(N - 1, 0, -1, device=device, dtype=dtype)

    diag = _chi(df_diag) / beta_t.sqrt()
    sub = _chi(df_sub) / beta_t.sqrt()
    return diag, sub


def _haar_orthogonal(
    n: int,
    *,
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    Haar(ish) random orthogonal matrix via QR of a Gaussian matrix, with sign correction.
    """
    A = torch.randn(n, n, device=device, dtype=dtype)
    Q, R = torch.linalg.qr(A)
    d = torch.sign(torch.diag(R))
    d[d == 0] = 1
    return Q * d  # column-wise sign fix


def _haar_stiefel(
    D: int,
    N: int,
    *,
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    Random D×N matrix with orthonormal columns (uniform on Stiefel manifold) via reduced QR.
    """
    A = torch.randn(D, N, device=device, dtype=dtype)
    Q, R = torch.linalg.qr(A, mode="reduced")
    d = torch.sign(torch.diag(R))
    d[d == 0] = 1
    return Q * d  # column-wise sign fix


def sample_htmp_rectangular_matrix(
    D: int,
    N: int,
    kappa: float,
    *,
    randomize_singular_vectors: bool = False,
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float64,
) -> torch.Tensor:
    """
    Generate a D×N rectangular matrix X whose squared singular values approximate
    the HTMP_{gamma, kappa} law (in the large N,D limit), by sampling a beta-Laguerre
    ensemble in the high-temperature regime beta = kappa / N.

    - D >= N (so gamma = N/D in (0,1))
    - beta = kappa / N (high-temperature scaling)
    - Builds Dumitriu–Edelman bidiagonal B (N×N), then scales by 1/sqrt(D)
    - Returns X with X^T X having the beta-Laguerre eigenvalues (normalized by D)

    If randomize_singular_vectors=True, returns a dense matrix with the SAME singular
    values (left/right singular vectors randomized by orthogonal factors).
    """
    if D < N:
        raise ValueError("Require D >= N so that gamma=N/D is in (0,1).")
    if kappa <= 0:
        raise ValueError("kappa must be > 0 (smaller kappa -> heavier-tailed HTMP).")

    device = device or torch.device("cpu")

    beta = float(kappa) / float(N)  # high-temperature beta
    diag, sub = dumitriu_edelman_beta_laguerre_bidiagonal(
        N=N, D=D, beta=beta, device=device, dtype=dtype
    )

    # Scale by 1/sqrt(D) so eigenvalues of X^T X match the (E.1) scaling with exp(-beta*D/2 * lambda)
    inv_sqrt_D = 1.0 / math.sqrt(D)
    diag = diag * inv_sqrt_D
    sub = sub * inv_sqrt_D

    # Build the bidiagonal core B_scaled (N×N)
    B = torch.zeros((N, N), device=device, dtype=dtype)
    idx = torch.arange(N, device=device)
    B[idx, idx] = diag
    B[idx[1:], idx[:-1]] = sub

    if not randomize_singular_vectors:
        # Embed into a rectangular D×N matrix (sparse/structured, but correct singular values)
        X = torch.zeros((D, N), device=device, dtype=dtype)
        X[:N, :] = B
        return X

    # Randomize singular vectors without changing singular values:
    U = _haar_stiefel(D, N, device=device, dtype=dtype)   # D×N, orthonormal columns
    V = _haar_orthogonal(N, device=device, dtype=dtype)   # N×N orthogonal
    return U @ B @ V.T