import jax
import jax.numpy as jnp


def mlstm_siging_parallel_bw(
    matDeltaHtilde: jax.Array,
    matQ: jax.Array,
    matK: jax.Array,
    matV: jax.Array,
    vecI: jax.Array,
    vecF: jax.Array,
    vecN: jax.Array,
    eps: float = 1e-6,
    stable_fgate: bool = True,
    normalize: bool = True,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
    B, NH, S, DHQK = matQ.shape
    assert matK.shape == (B, NH, S, DHQK)
    assert vecI.shape == (B, NH, S)
    assert vecF.shape == (B, NH, S)

    vecLogSigF = jax.nn.log_sigmoid(vecF)  # (B, NH, S)

    if stable_fgate:
        matLogSigF_tril = jnp.tril(vecLogSigF[:, :, :, None].repeat(S, axis=-1), k=-1)
        matLogSigF = jnp.cumsum(matLogSigF_tril, axis=-2)
    else:
        vecLogSigF_cumsum = jnp.cumsum(vecLogSigF, axis=-1)
        matLogSigF = vecLogSigF_cumsum[:, :, :, None] - vecLogSigF_cumsum[:, :, None, :]

    ltr = jnp.tril(jnp.ones((S, S), dtype=jnp.bool_))

    matLogSigF_mask = jnp.where(ltr, matLogSigF, -float("inf"))

    vecLogSigI = jax.nn.log_sigmoid(vecI)

    matLogD = matLogSigF_mask + vecLogSigI[:, :, None, :]

    matD = jnp.exp(matLogD)  # (B, NH, S, S)

    # intermediate delta-errors
    if normalize:
        matDeltaC = matDeltaHtilde @ matV.swapaxes(-2, -1) / (vecN[:, :, :, None] + eps)
    else:
        matDeltaC = matDeltaHtilde @ matV.swapaxes(-2, -1)

    matS = (matQ @ matK.swapaxes(-2, -1)) * (DHQK**-0.5)

    matDeltaDtilde = matDeltaC * matD * matS

    vecDeltaIbar = jnp.sum(matDeltaDtilde, axis=-2)

    # output delta-errors / gradients
    matP = matDeltaC * matD

    matDeltaQ = (matP @ matK) * (DHQK**-0.5)
    matDeltaK = (matP.swapaxes(-2, -1) @ matQ) * (DHQK**-0.5)

    matCtilde: jax.Array = matS * matD

    if normalize:
        matDeltaV = matCtilde.swapaxes(-2, -1) @ (
            matDeltaHtilde / (vecN[:, :, :, None] + eps)
        )
    else:
        matDeltaV = matCtilde.swapaxes(-2, -1) @ matDeltaHtilde

    # compute the vecDeltaFbar values with dfbar = rev_cumsum((q*dq - k*dk).sum(-1))
    vecDeltaFbar_acc = jnp.sum((matQ * matDeltaQ - matK * matDeltaK), axis=-1)
    vecDeltaFbar = jnp.flip(
        jnp.cumsum(jnp.flip(vecDeltaFbar_acc, axis=-1), axis=-1), axis=-1
    )
    vecDeltaF = vecDeltaFbar * jax.nn.sigmoid(-vecF)

    vecDeltaI = vecDeltaIbar * jax.nn.sigmoid(-vecI)

    return (
        matDeltaQ,
        matDeltaK,
        matDeltaV,
        vecDeltaI,
        vecDeltaF,
    )
