import jax.numpy as jnp


def _suffix_lp_norms(x, p):
    x = jnp.asarray(x)
    n = x.shape[-1]
    mask = jnp.triu(jnp.ones((n, n), dtype=x.dtype))
    batch_x = jnp.einsum('...i, ij-> ...ij', x, mask)
    return jnp.linalg.norm(batch_x, ord=p, axis=-1)


import jax.numpy as jnp


def _suffix_lp_norms(x, p):
    x = jnp.asarray(x)
    if p == jnp.inf:
        return jnp.maximum.accumulate(jnp.abs(x)[..., ::-1], axis=-1)[..., ::-1]
    if p == -jnp.inf:
        return jnp.minimum.accumulate(jnp.abs(x)[..., ::-1], axis=-1)[..., ::-1]
    ap = jnp.abs(x) ** p
    rev_cumsum = jnp.cumsum(ap[..., ::-1], axis=-1)[..., ::-1]
    return rev_cumsum ** (1.0 / p)


def cartesian_to_lp_spherical(x, p):
    x = jnp.asarray(x)
    x = x.astype(jnp.result_type(x, jnp.float32))
    n = x.shape[-1]
    tail = _suffix_lp_norms(x, p)
    r = tail[..., 0]
    theta = jnp.arctan2(tail[..., 1:-1], x[..., : n - 2]) if n > 2 else jnp.zeros(x.shape[:-1] + (0,), x.dtype)
    phi = jnp.arctan2(x[..., -1], x[..., -2])
    angles = jnp.concatenate([theta, phi[..., None]], axis=-1)
    angles = jnp.where(r[..., None] > 0, angles, jnp.zeros_like(angles))
    return r, angles


def lp_spherical_to_cartesian(r, angles, p):
    r = jnp.asarray(r).astype(jnp.result_type(r, jnp.float32))
    ang = jnp.asarray(angles).astype(jnp.result_type(angles, jnp.float32))
    n_ang = ang.shape[-1]
    theta = ang[..., :-1] if n_ang > 0 else ang[..., :0]
    phi = ang[..., -1]
    if theta.shape[-1] == 0:
        cs2 = jnp.stack([jnp.cos(phi), jnp.sin(phi)], axis=-1)
        norm_phi = jnp.linalg.norm(cs2, ord=p, axis=-1, keepdims=True)
        cs2_n = jnp.where(norm_phi > 0, cs2 / norm_phi, 0.0)
        x_nm1 = r * cs2_n[..., 0]
        x_n = r * cs2_n[..., 1]
        return jnp.stack([x_nm1, x_n], axis=-1)
    cc = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)
    norm_cc = jnp.linalg.norm(cc, ord=p, axis=-1, keepdims=True)
    ab = jnp.where(norm_cc > 0, cc / norm_cc, 0.0)
    c = ab[..., 0]
    s = ab[..., 1]
    s_cum = jnp.cumprod(s, axis=-1)
    s_prefix = jnp.concatenate([jnp.ones_like(r)[..., None], s_cum[..., :-1]], axis=-1)
    head = r[..., None] * s_prefix * c
    s_last = s_cum[..., -1]
    cs2 = jnp.stack([jnp.cos(phi), jnp.sin(phi)], axis=-1)
    norm_phi = jnp.linalg.norm(cs2, ord=p, axis=-1, keepdims=True)
    cs2_n = jnp.where(norm_phi > 0, cs2 / norm_phi, 0.0)
    x_nm1 = r * s_last * cs2_n[..., 0]
    x_n = r * s_last * cs2_n[..., 1]
    return jnp.concatenate([head, x_nm1[..., None], x_n[..., None]], axis=-1)

if __name__ == '__main__':
    import jax
    import matplotlib.pyplot as plt

    key = jax.random.PRNGKey(0)
    u = jax.random.uniform(key, shape=(5000, 2))
    p = jnp.inf
    def wang(x, alpha):
        return jax.scipy.stats.norm.cdf(jax.scipy.stats.norm.ppf(x.clip(1e-6, 1-1e-6)) + alpha)

    def risk_measure(x):

        return 0.5 * x + wang(x, -1)

    u_out = risk_measure(u)

    plt.scatter(u_out[..., 0], u_out[..., 1])
    plt.show()



