import time
import jax
import jax.numpy as jnp
import numpy as np

src_len = 128
tgt_len = 128
bsz = 32
num_heads = 8
head_dim = 64
embed_dim = 512
ffn_dim = 2048

def causal_mask(step):
  mask = jnp.arange(tgt_len)
  mask = (mask <= step).astype(jnp.float32)
  return jnp.where(mask, 0., -jnp.inf)

def attn_func(q, k, v, mask=None):
  # q: [bsz, num_heads, head_dim]
  # k: [bsz, src_len, num_heads, head_dim]
  # v: [bsz, src_len, num_heads, head_dim]
  # mask: [src_len] or [bsz, src_len]
  attn_weights = jnp.einsum("bhd,bshd->bhs", q, k)
  scale = q.shape[-1] ** -0.5
  attn_weights = attn_weights / scale

  if mask is not None:
    if mask.ndim == 1:
      mask = mask[None, None, :]  # [1, 1, src_len]
    else:
      mask = mask[:, None, :]  # [bsz, 1, src_len]
    attn_weights = attn_weights + mask

  attn_weights = jax.nn.softmax(attn_weights, axis=-1)
  # [bsz, num_heads, head_dim]
  attn_out = jnp.einsum("bhs,bshd->bhd", attn_weights, v)
  return attn_out

def loop_attn(inputs, encoder_out, causal_ws, cross_ws):
  # inputs: [bsz, tgt_len, embed_dim]
  # encoder_out: [bsz, src_len, embed_dim]
  # ws: [num_heads, head_dim, embed_dim]

  causal_wq, causal_wk, causal_wv = causal_ws
  cross_wq, cross_wk, cross_wv = cross_ws
  inputs = jnp.transpose(inputs, [1, 0, 2])  # [tgt_len, bsz, embed_dim]

  # [bsz, src_len, num_heads, head_dim]
  cross_ks = jnp.einsum("hdD,bsD->bshd", cross_wk, encoder_out)
  cross_vs = jnp.einsum("hdD,bsD->bshd", cross_wv, encoder_out)

  # cache keys and values for causal attention
  causal_ks = jnp.zeros((bsz, tgt_len, num_heads, head_dim))
  causal_vs = jnp.zeros((bsz, tgt_len, num_heads, head_dim))

  # loop through the target timesteps.
  def scan_body(state, x):
    step, causal_ks, causal_vs = state
    mask = causal_mask(step)

    # causal attention
    # [bsz, num_heads, head_dim]
    q = jnp.einsum("hdD,bD->bhd", causal_wq, x)
    k = jnp.einsum("hdD,bD->bhd", causal_wk, x)
    v = jnp.einsum("hdD,bD->bhd", causal_wk, x)
    causal_ks = jax.ops.index_update(
        causal_ks, jax.ops.index[:, step, ...], k)
    causal_vs = jax.ops.index_update(
        causal_vs, jax.ops.index[:, step, ...], v)

    # [bsz, num_heads, head_dim]
    attn_out = attn_func(q, k=causal_ks, v=causal_vs, mask=mask)
    attn_out = jnp.reshape(attn_out, [attn_out.shape[0], -1])  # [bsz, embed_dim]

    # cross attention
    # [bsz, num_heads, head_dim]
    q = jnp.einsum("hdD,bD->bhd", cross_wq, x)

    # [bsz, num_heads, head_dim]
    attn_out = attn_func(q, k=cross_ks, v=cross_vs)
    attn_out = jnp.reshape(attn_out, [attn_out.shape[0], -1])  # [bsz, embed_dim]

    step += 1
    return (step, causal_ks, causal_vs), attn_out
  step = 0
  _, attn_out = jax.lax.scan(
        scan_body,
        (step, causal_ks, causal_vs),
        inputs)
  return attn_out

def decode_attn():
  func = jax.jit(loop_attn)
  inputs = jnp.array(np.random.rand(bsz, tgt_len, embed_dim))
  encoder_out = jnp.array(np.random.rand(bsz, src_len, embed_dim))
  w_shape = (num_heads, head_dim, embed_dim)

  causal_wq = jnp.array(np.random.rand(*w_shape))
  causal_wk = jnp.array(np.random.rand(*w_shape))
  causal_wv = jnp.array(np.random.rand(*w_shape))
  cross_wq = jnp.array(np.random.rand(*w_shape))
  cross_wk = jnp.array(np.random.rand(*w_shape))
  cross_wv = jnp.array(np.random.rand(*w_shape))

  total_time = 0.
  for s in range(100):
    start = time.time()

    func(inputs=inputs,
         encoder_out=encoder_out,
         causal_ws=(causal_wq, causal_wk, causal_wv),
         cross_ws=(cross_wq, cross_wk, cross_wv)).block_until_ready()
    end = time.time()
    if s >= 10:
      total_time = total_time + (end - start)
  avg_time = total_time / 90
  return avg_time

def ffn_func(x, w1, w2):
  h = jnp.dot(x, w1)
  h = jax.nn.relu(h)
  out = jnp.dot(h, w2)
  return out

def loop_ffn(inputs, w1, w2):
  # inputs: [bsz, tgt_len, embed_dim]

  inputs = jnp.transpose(inputs, [1, 0, 2])  # [tgt_len, bsz, embed_dim]
  # loop through the target timesteps.
  def scan_body(state, x):
    ffn_out = ffn_func(x, w1, w2)  # [bsz, embed_dim]
    return state, ffn_out
  _, ffn_out = jax.lax.scan(
        scan_body,
        (w1, w2),
        inputs,
        length=tgt_len)
  return ffn_out

def decode_ffn():
  func = jax.jit(loop_ffn)
  inputs = jnp.array(np.random.rand(bsz, tgt_len, embed_dim))
  w1 = jnp.array(np.random.rand(embed_dim, ffn_dim))
  w2 = jnp.array(np.random.rand(ffn_dim, embed_dim))
  total_time = 0.
  for s in range(100):
    start = time.time()

    func(inputs, w1, w2).block_until_ready()
    end = time.time()
    if s >= 10:
      total_time = total_time + (end - start)
  avg_time = total_time / 90
  return avg_time

attn_time = decode_attn()
print(f"Attention: Average time: {attn_time}")

ffn_time = decode_ffn()
print(f"FFN: Average time: {ffn_time}")
