import jax
import jax.numpy as jnp
import jax.random as jrnd


def so_matrix(key: jax.Array, n: int) -> jnp.ndarray:
    a = jrnd.normal(key, (n, n))
    q, r = jnp.linalg.qr(a)
    sign_det = jnp.linalg.slogdet(q).sign
    d = jnp.diag(jnp.array([1.0] * (n - 1) + [sign_det]))
    so = q @ d
    return so


def commutation_matrix(m: int, n: int) -> jnp.ndarray:
    """Create a commutation matrix K such that K @ vec(A) = vec(A.T)."""

    row_indices = jnp.array([j + i * n for i in range(m) for j in range(n)])
    col_indices = jnp.array([i + j * m for i in range(m) for j in range(n)])

    K = jnp.zeros((m * n, m * n))
    K = K.at[row_indices, col_indices].set(1.0)

    return K


def circulant_matrix(row: jnp.ndarray) -> jnp.ndarray:
    """Create a circulant matrix from the given row vector"""
    n = row.shape[-1]
    indices = jnp.arange(n)[None, :].repeat(n, axis=0)
    row_indices = jnp.arange(n)[:, None]

    circular_indices = (indices - row_indices) % n
    matrix = row[circular_indices]
    return matrix


def orthogonal_matrix(key: jax.Array, size: int) -> jnp.ndarray:
    """Generate random orthogonal matrices using QR decomposition"""
    mat = jrnd.normal(key, (size, size))
    q, r = jnp.linalg.qr(mat)
    d = jnp.sign(jnp.diagonal(r, axis1=-2, axis2=-1))
    q = q * d[..., None, :]
    return q


def permutation_matrix(key: jax.Array, size: int) -> jnp.ndarray:
    """Generate random permutation matrices"""
    perm = jrnd.permutation(key, size)
    mat = jnp.eye(size)[perm]
    return mat


def signed_permutation_matrix(
    key: jax.Array,
    size: int,
) -> jnp.ndarray:
    """Generate random signed permutation matrices"""
    key1, key2 = jrnd.split(key)
    mat = permutation_matrix(key1, size)
    signs = jrnd.choice(key2, jnp.array([-1, 1]), (size, 1))
    return mat * signs
