import jax
import jax.numpy as jnp
from jax import lax
from jax.flatten_util import ravel_pytree

def pcgrad_pytree(grads, *, key=None, eps=1e-12, reduction="mean"):
    if len(grads) == 0:
        raise ValueError("grads must be non-empty")
    T = len(grads)

    g0_vec, unravel = ravel_pytree(grads[0])
    vecs = [g0_vec]
    for i in range(1, T):
        vecs.append(ravel_pytree(grads[i])[0])
    G = jnp.stack(vecs, axis=0)  # (T, D)

    def project_if_conflict(gi, gj):
        dot = jnp.vdot(gi, gj)
        denom = jnp.vdot(gj, gj) + eps
        coeff = dot / denom
        return lax.cond(dot < 0.0, lambda _: gi - coeff * gj, lambda _: gi, operand=None)

    def adjust_one(i, G):
        gi = G[i]
        if key is None:
            order = jnp.arange(T, dtype=jnp.int32)
        else:
            order = jax.random.permutation(jax.random.fold_in(key, i), T)

        def body(t, carry):
            j = order[t]
            gj = G[j]
            return lax.cond(j == i, lambda _: carry, lambda _: project_if_conflict(carry, gj), operand=None)

        return lax.fori_loop(0, T, body, gi)

    adjusted = []
    for i in range(T):
        adjusted.append(adjust_one(i, G))
    A = jnp.stack(adjusted, axis=0)  # (T, D)

    out = A.sum(axis=0)
    if reduction == "mean":
        out = out / jnp.array(T, dtype=out.dtype)
    elif reduction != "sum":
        raise ValueError("reduction must be 'mean' or 'sum'")

    return unravel(out)

def pcgrad_dict(grads_by_key, keys, *, key=None, eps=1e-12, reduction="mean"):
    grads = [grads_by_key[k] for k in keys]
    return pcgrad_pytree(grads, key=key, eps=eps, reduction=reduction)
