import jax.numpy as jnp
import jax

def compute_transforms(K, Q, eps=1e-6):
    """find symmetric A and Ainv such that A @ K @ A = Ainv @ Q @ Ainv"""
    rootQ = jnp.real(jax.scipy.linalg.sqrtm(Q + eps * jnp.eye(Q.shape[0])))
    irootQ = jnp.linalg.inv(rootQ)
    Tsq = rootQ @ K @ rootQ
    T = jnp.real(jax.scipy.linalg.sqrtm(Tsq + eps * jnp.eye(Tsq.shape[0])))
    rootT = jnp.real(jax.scipy.linalg.sqrtm(T + eps * jnp.eye(T.shape[0])))
    irootT = jnp.linalg.inv(rootT)
    return irootT @ rootQ, rootT @ irootQ
