

from dataclasses import dataclass
from typing import Dict, Tuple, Optional, Sequence

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree


def _safe_norm(x, axis=None, eps=1e-12):
    return jnp.sqrt(jnp.maximum(jnp.sum(x * x, axis=axis), eps))


def _normalize(x, eps=1e-12):
    return x / (_safe_norm(x, axis=-1, eps=eps)[..., None])


def _oja_update(U, g, lr):
    U_new = U + lr * jnp.outer(g, g @ U)
    Q, _ = jnp.linalg.qr(U_new, mode="reduced")
    return Q


def _pairwise_overlap(U_all):
    S = jnp.einsum("ipr,jps->ijrs", U_all, U_all)
    fro2 = jnp.einsum("ijrs,ijrs->ij", S, S)
    r = U_all.shape[-1]
    return fro2 / r



def _pairwise_cosine(G, eps=1e-12):
    Gn = _normalize(G, eps=eps)
    return Gn @ Gn.T


def _consensus_subspace(U_all, r=None):
    k, p, r0 = U_all.shape
    r = r0 if r is None else r
    Q = jnp.reshape(jnp.swapaxes(U_all, 0, 1), (p, k * r0))
    U, S, _ = jnp.linalg.svd(Q, full_matrices=False)
    U_star = U[:, :r]
    mu = S * S
    gap = jnp.where(k * r0 > r, mu[r - 1] - mu[r], mu[r - 1])
    return U_star, gap, mu


def _global_alignment(U_all, U_star):
    B = jnp.einsum("pa,kpb->kab", U_star, U_all)
    fro2 = jnp.einsum("kab,kab->k", B, B)
    r = U_star.shape[1]
    return jnp.mean(fro2) / r


def _captured_energy(G, U, eps=1e-12):
    proj = G @ U
    num = jnp.sum(proj * proj, axis=-1)
    den = jnp.sum(G * G, axis=-1) + eps
    return num / den


def grads_dict_to_matrix(grads_dict: Dict[str, object], keys: Sequence[str]) -> jnp.ndarray:
    flat0, _ = ravel_pytree(grads_dict[keys[0]])
    cols = [flat0]
    for k in keys[1:]:
        fk, _ = ravel_pytree(grads_dict[k])
        cols.append(fk)
    return jnp.stack(cols, axis=0)


def init_bases(key: jax.Array, params: object, keys: Sequence[str], r: int) -> jnp.ndarray:
    flat, _ = ravel_pytree(params)
    p = flat.shape[0]
    k = len(keys)
    keys_rng = jax.random.split(key, k)

    def _init_one(kk):
        A = jax.random.normal(kk, (p, r))
        Q, _ = jnp.linalg.qr(A, mode="reduced")
        return Q

    return jax.vmap(_init_one)(keys_rng)


@dataclass
class SubspaceAlignState:
    keys: Tuple[str, ...]
    r: int
    U_all: jnp.ndarray
    step: jnp.ndarray


def init_state(
    key: jax.Array,
    params: object,
    keys: Sequence[str],
    r: int,
) -> SubspaceAlignState:
    keys_t = tuple(keys)
    U_all = init_bases(key, params, keys_t, r)
    return SubspaceAlignState(keys=keys_t, r=r, U_all=U_all, step=jnp.array(0, dtype=jnp.int32))


def update_state(
    st: SubspaceAlignState,
    grads_dict: Dict[str, object],
    lr: float,
    normalize_grads: bool = True,
    eps: float = 1e-12,
) -> Tuple[SubspaceAlignState, jnp.ndarray]:
    G = grads_dict_to_matrix(grads_dict, st.keys)
    G_use = _normalize(G, eps=eps) if normalize_grads else G
    U_new = jax.vmap(_oja_update, in_axes=(0, 0, None))(st.U_all, G_use, lr)
    st2 = SubspaceAlignState(keys=st.keys, r=st.r, U_all=U_new, step=st.step + 1)
    return st2, G


def compute_metrics(
    st: SubspaceAlignState,
    G: jnp.ndarray,
    eps: float = 1e-12,
) -> Dict[str, jnp.ndarray]:
    A = _pairwise_overlap(st.U_all)
    C = _pairwise_cosine(G, eps=eps)

    U_star, gap, mu = _consensus_subspace(st.U_all, r=st.r)
    align = _global_alignment(st.U_all, U_star)

    cap_star = _captured_energy(G, U_star, eps=eps)

    k = A.shape[0]
    mask = ~jnp.eye(k, dtype=bool)
    offA = A[mask]
    offC = C[mask]

    out = {
        "A_pairwise": A,
        "cos_pairwise": C,
        "offA": offA,
        "offC": offC,
        "A_mean_offdiag": jnp.mean(offA),
        "A_min_offdiag": jnp.min(offA),
        "cos_mean_offdiag": jnp.mean(offC),
        "cos_min_offdiag": jnp.min(offC),
        "consensus_gap": gap,
        "align_to_consensus": align,
        "cap_to_consensus_per_task": cap_star,
        "grad_norms": _safe_norm(G, axis=-1, eps=eps),
        "consensus_mu": mu,
    }
    return out


def pack_pairwise_logs(
    keys: Sequence[str],
    A: jnp.ndarray,
    C: jnp.ndarray,
    prefix: str = "subspace",
) -> Dict[str, jnp.ndarray]:
    keys = list(keys)
    out = {}
    T = len(keys)
    for i in range(T):
        for j in range(i + 1, T):
            a, b = keys[i], keys[j]
            out[f"{prefix}/A_{a}_{b}"] = A[i, j]
            out[f"{prefix}/cos_{a}_{b}"] = C[i, j]
    return out
