from jax import numpy as jnp
import jax

parallel_scan = jax.lax.associative_scan


def accumulate(carry, args):
    nu, alpha, prev_max = carry
    Qs_t, curr_alph, V_t, c_mx = args
    revert_maxi = jnp.exp(-c_mx + prev_max)
    add_maxi = jnp.exp(curr_alph - c_mx)

    alpha = jnp.einsum("L,L->L", alpha, revert_maxi)
    alpha += add_maxi
    nu = jnp.einsum("LD,L->LD", nu, revert_maxi)
    nu += jnp.einsum("L,D->LD", add_maxi, V_t)
    y = jnp.einsum("L,LD->D", Qs_t / alpha, nu)
    return ((nu, alpha, c_mx), (y, alpha))


def mix_seq1(V, Q, K):
    T, C = V.shape
    L = Q.shape[-1]
    maxi = jax.lax.cummax(K, axis=0)
    init_alpha = jnp.zeros(shape=(L,), dtype=V.dtype)
    init_nu = jnp.zeros(
        (
            L,
            C,
        ),
        dtype=V.dtype,
    )
    Qs = jax.nn.softmax(Q, axis=-1)
    _, y = jax.lax.scan(
        accumulate,
        unroll=10,
        init=(
            init_nu,
            init_alpha,
            K[0],
        ),
        xs=[Qs, K, V, maxi],
    )
    y, alpha = y
    return y, alpha


def mix_seq2(V, Q, K):
    T, C = V.shape
    L = Q.shape[-1]
    maxi = jax.lax.cummax(K, axis=0)
    init_nu = jnp.zeros(
        (
            L,
            C,
        ),
        dtype=V.dtype,
    )
    Qs = jax.nn.softmax(Q, axis=-1)
    revert_maxi = jnp.zeros_like(maxi)
    revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
    revert_maxi = jnp.exp(revert_maxi)  # TBHL
    add_maxi = jnp.exp(K - maxi)

    def accumulate2(carry, args):
        nu = carry
        Qs_t, V_t, revert_maxi, add_maxi, alpha = args
        nu = jnp.einsum("LD,L->LD", nu, revert_maxi)
        nu += jnp.einsum("L,D->LD", add_maxi, V_t)
        y = jnp.einsum("L,LD->D", Qs_t / alpha, nu)
        return nu, y

    def bin_alpha(A, B):
        rmA, amA = A
        rmB, amB = B
        return (rmA * rmB, amA * rmB + amB)

    _, alpha = parallel_scan(bin_alpha, (revert_maxi, add_maxi))
    _, y = jax.lax.scan(
        accumulate2,
        unroll=10,
        init=(init_nu),
        xs=[Qs, V, revert_maxi, add_maxi, alpha],
    )
    return y, alpha


def mix_seq3(V, Q, K):
    T, C = V.shape
    L = Q.shape[-1]
    maxi = jax.lax.cummax(K, axis=0)
    init_nu = jnp.zeros(
        (
            L,
            C,
        ),
        dtype=V.dtype,
    )
    Qs = jax.nn.softmax(Q, axis=-1)
    revert_maxi = jnp.zeros_like(maxi)
    revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
    revert_maxi = jnp.exp(revert_maxi)  # TBHL
    add_maxi = jnp.exp(K - maxi)

    # nu = jnp.einsum("TL,TD->TLD", Qs, V)
    nu = jnp.einsum("TL,TD->TLD", add_maxi, V)

    def bin_V(A, B):
        rmA, amA, nuA = A
        rmB, amB, nuB = B
        nu = nuA * rmB[..., None] + nuB
        # nuB = amB[..., None] * nuB
        alpha = amA * rmB + amB
        return (rmA * rmB, alpha, nu)

    _, alpha, y = parallel_scan(bin_V, (revert_maxi, add_maxi, nu))
    # TL, TLD
    # sum over L
    y = jnp.einsum("TL,TLD->TD", Qs / alpha, y)
    # y = y.sum(axis=1)
    return y, alpha


def mix_seq4(V, Q, K):
    Qs = jax.nn.softmax(Q, axis=-1)


def main():
    Q = jnp.array([[4, 5], [1, 1], [10, 3], [2, 2], [1, 10]]).astype(jnp.float32)
    K = jnp.array([[4, 5], [1, 100000], [10, 3], [1, 200000000], [1, 10]]).astype(
        jnp.float32
    )
    V = jnp.array([[7, 4, 5], [2, 1, 1], [7, 10, 3], [7, 1, 2], [8, 1, 10]]).astype(
        jnp.float32
    )
    y1, alpha1 = mix_seq1(V, Q, K)
    y2, alpha2 = mix_seq2(V, Q, K)
    print("Correct Alpha: ", jnp.all(jnp.isclose(alpha1, alpha2)))

    print("*" * 100)
    print("Correct Y2: ", jnp.all(jnp.isclose(y1, y2)))
    print("-" * 100)
    y3, alpha3 = mix_seq3(V, Q, K)
    print("Correct Alpha3: ", jnp.all(jnp.isclose(alpha1, alpha3)))
    print("Correct Y3: ", jnp.all(jnp.isclose(y1, y3)))

    # print(y1)
    # print(y3)


# def mix_sequence3(self, Q, K, V, Q_drop, V_drop):
#     """Faster version of mix_sequence2 by applying parallel scanns to normalisation as well
#     Still in O(TL + LD)
#     Args:
#         Q: jax.Array(T,B,H,L)
#         K: jax.Array(T,B,H,L)
#         V: jax.Array(T,B,H,D)
#     """
#     T, B, H, C = V.shape
#     L = Q.shape[-1]
#     # calc R^{-s}x_s
#     # V = apply_rotation(sinusoidal_enc=sin_pos, mat=V, neg=True)
#     if isinstance(self.rot_embeds, XPos):
#         # T, B, self.config.nheads, -1 -> BHTD
#         V = self.rot_embeds(
#             V.transpose(1, 2, 0, 3), offset=0, downscale=True
#         ).transpose(2, 0, 1, 3)
#     else:
#         V = self.rot_embeds.apply_vapor(mat=V, neg=True)
#     # V = V_drop(V)
#     Qs = jax.nn.softmax(Q, axis=-1)
#     Qs = Q_drop(Qs)

#     maxi = jax.lax.cummax(K, axis=0)
#     # maxi for stability should be trated as a constant - no grad is faster
#     maxi = jax.lax.stop_gradient(maxi)
#     # revert maxi
#     revert_maxi = jnp.zeros_like(maxi)
#     revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
#     revert_maxi = jnp.exp(revert_maxi)  # TBHL
#     add_maxi = jnp.exp(K - maxi)

#     def bin_alpha(A, B):
#         rmA, amA = A
#         rmB, amB = B
#         return (rmA * rmB, amA * rmB + amB)

#     _, alpha = parallel_scan(bin_alpha, (revert_maxi, add_maxi))
#     Qs = jnp.einsum("TBHL,TBHL->TBHL", Qs, 1 / alpha)
#     # Calculate gamma - sequential operation (not parallel) for O(TL + LD) mem
#     # dominates the practical time
#     init_nu = jnp.zeros((B, H, L, C), dtype=self.dtype)
#     _, y = jax.lax.scan(
#         self.accumulate3,
#         unroll=self.unroll,
#         init=init_nu,
#         xs=[Qs, V, revert_maxi, add_maxi],
#     )
#     # calc R^t \sum_l ...
#     # y = apply_rotation(sinusoidal_enc=sin_pos, mat=y, neg=False)
#     if isinstance(self.rot_embeds, XPos):
#         # TBHD -> BHTD
#         y = y.transpose(1, 2, 0, 3)
#         y = self.rot_embeds(y, offset=0, downscale=False)
#         # BHTD -> BTHD
#         y = y.transpose(0, 2, 1, 3).reshape(B, T, -1)
#         return y, Qs
#     else:
#         # TBHD -> BHTD
#         # y = y.transpose(1, 2, 0, 3)
#         y = self.rot_embeds.apply_vapor(mat=y, neg=False)
#     # TBHD -> BTHD
#     y = y.transpose(1, 0, 2, 3)
#     y = y.reshape(B, T, -1)
#     return y, Qs


# def mix_sequence3_latte(self, Q, K, V, Q_drop, V_drop):
#     """Faster version of mix_sequence2 by applying parallel scanns to normalisation as well
#     Still in O(TL + LD)
#     Args:
#         Q: jax.Array(T,B,H,L)
#         K: jax.Array(T,B,H,L)
#         V: jax.Array(T,B,H,D)
#     """
#     T, B, H, C = V.shape
#     L = Q.shape[-1]
#     # calc R^{-s}x_s
#     # V = V_drop(V)
#     Qs = jax.nn.softmax(Q, axis=-1)
#     Qs = Q_drop(Qs)

#     maxi = jax.lax.cummax(K, axis=0)
#     # maxi for stability should be trated as a constant - no grad is faster
#     maxi = jax.lax.stop_gradient(maxi)
#     # revert maxi
#     revert_maxi = jnp.zeros_like(maxi)
#     revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
#     revert_maxi = jnp.exp(revert_maxi)  # TBHL
#     add_maxi = jnp.exp(K - maxi)

#     def bin_alpha(A, B):
#         rmA, amA = A
#         rmB, amB = B
#         return (rmA * rmB, amA * rmB + amB)

#     _, alpha = parallel_scan(bin_alpha, (revert_maxi, add_maxi))
#     Qs = jnp.einsum("TBHL,TBHL->TBHL", Qs, 1 / alpha)
#     # Calculate gamma - sequential operation (not parallel) for O(TL + LD) mem
#     # dominates the practical time
#     init_nu = jnp.zeros((B, H, L, C), dtype=self.dtype)
#     _, y = jax.lax.scan(
#         self.accumulate3,
#         unroll=self.unroll,
#         init=init_nu,
#         xs=[Qs, V, revert_maxi, add_maxi],
#     )
#     # TBHD -> BTHD
#     y = y.transpose(1, 0, 2, 3)
#     y = y.reshape(B, T, -1)
#     return y, Qs


# def latte_attention3(self, Q, K, V):
#     """Faster version of latte_attention by applying parallel scanns to normalisation as well
#     Still in O(TL + LD)
#     Args:
#         Q: jax.Array(T,B,H,L)
#         K: jax.Array(T,B,H,L)
#         V: jax.Array(T,B,H,D)
#     """
#     T, B, H, C = V.shape
#     L = Q.shape[-1]
#     # calc R^{-s}x_s
#     if isinstance(self.rot_embeds, XPos):
#         # T, B, self.config.nheads, -1 -> BHTD
#         V = self.rot_embeds(
#             V.transpose(1, 2, 0, 3), offset=0, downscale=True
#         ).transpose(2, 0, 1, 3)
#     else:
#         V = self.rot_embeds.apply_vapor(mat=V, neg=True)
#     # V = V_drop(V)
#     Qs = jax.nn.softmax(Q, axis=-1)

#     maxi = jax.lax.cummax(K, axis=0)
#     # maxi for stability should be trated as a constant - no grad is faster
#     maxi = jax.lax.stop_gradient(maxi)
#     # revert maxi
#     revert_maxi = jnp.zeros_like(maxi)
#     revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
#     revert_maxi = jnp.exp(revert_maxi)  # TBHL
#     add_maxi = jnp.exp(K - maxi)

#     def bin_alpha(A, B):
#         rmA, amA = A
#         rmB, amB = B
#         return (rmA * rmB, amA * rmB + amB)

#     _, alpha = parallel_scan(bin_alpha, (revert_maxi, add_maxi))
#     Qs = jnp.einsum("TBHL,TBHL->TBHL", Qs, 1 / alpha)
#     # Calculate gamma - sequential operation (not parallel) for O(TL + LD) mem
#     # dominates the practical time
#     init_nu = jnp.zeros((B, H, L, C), dtype=self.dtype)
#     _, y = jax.lax.scan(
#         accumulate3,
#         unroll=self.unroll,
#         init=init_nu,
#         xs=[Qs, V, revert_maxi, add_maxi],
#     )
#     # calc R^t \sum_l ...
#     if isinstance(self.rot_embeds, XPos):
#         y = y.transpose(1, 2, 0, 3)
#         y = self.rot_embeds(y, offset=0, downscale=False)
#         return y
#     else:
#         # TBHD -> BHTD
#         y = self.rot_embeds.apply_vapor(mat=y, neg=False)
#     # TBHD -> BHTD
#     return y.transpose(1, 2, 0, 3)


if __name__ == "__main__":
    main()
