import jax
import jax.numpy as jnp
from jaxtyping import PyTree

from egxc.utils.typing import FloatTxT, NpFloatTxT


# TODO: make this a wrapper
def breakpoint_if_nonfinite(x):
    is_finite = jnp.isfinite(x).all()

    def true_fn(x):
        pass

    def false_fn(x):
        jax.debug.breakpoint()

    jax.lax.cond(is_finite, true_fn, false_fn, x)


def print_keys(x: PyTree):
    if isinstance(x, dict):
        print(x.keys())
        for v in x.values():
            print_keys(v)


def print_condition_number(A: FloatTxT | NpFloatTxT, label: str) -> None:
    singular_values = jnp.linalg.svd(A, compute_uv=False)
    s_min = singular_values[-1]
    s_max = singular_values[0]
    cond_A = s_max / jnp.maximum(s_min, jnp.finfo(A.dtype).eps)
    jax.debug.print(
        '{label}: min_sv={s_min:.3e}, max_sv={s_max:.3e}, cond={cond_A:.3e}',
        label=label,
        s_min=s_min,
        s_max=s_max,
        cond_A=cond_A,
    )
