"""Defines utilities such as position embeddings required for Transformers."""
import functools
from flax import struct
import jax
import jax.numpy as jnp


@struct.dataclass
class TransformerConfig:
  """Hyperparameters for the transformer model."""

  vocab_size: int
  context_length: int = 1024
  emb_dim: int = 768
  num_heads: int = 12
  num_layers: int = 12
  dropout_rate: float = 0.1
  sketch_size: int = 32
  attention: str = 'sketched'
  lt_multiply: bool = True
  use_glu: bool = True
  precision: jax.lax.Precision = jax.lax.Precision.DEFAULT
  normalization: float = 1.0
  grain_size: int = 256
  disable_attention: bool = False
  checkpoint_attention: bool = False


class RoPE:
  """This class can be used to apply rotary position embeddings."""

  def __init__(self, context_length, d):
    assert d % 2 == 0
    thetas = 10000 ** (-2 * jnp.arange(d // 2) / d)
    positions = jnp.arange(context_length)
    positions_times_thetas = jnp.outer(positions, thetas)
    self.sin = jnp.sin(positions_times_thetas)
    self.cos = jnp.cos(positions_times_thetas)

  def apply(self, x):
    # Shape of x is [..., context_length, d]
    context_length, d = x.shape[-2:]
    batch_shape = x.shape[:-2]
    x = x.reshape(batch_shape + (context_length, d // 2, 2))
    x0, x1 = x[..., 0], x[..., 1]  # (B, T, d//2)
    x = jnp.concatenate(
        [x0 * self.cos - x1 * self.sin, x1 * self.cos + x0 * self.sin], axis=-1
    )
    return x


class MultiHeadRoPE:
  """This class can be used to apply rotary position embeddings for multi-head attention."""

  def __init__(self, context_length, d):
    thetas = 10000 ** (-2 * jnp.arange(d // 2) / d)
    positions = jnp.arange(context_length)
    positions_times_thetas = jnp.outer(positions, thetas)
    self.sin = jnp.sin(positions_times_thetas)  # context_length x d//2
    self.cos = jnp.cos(positions_times_thetas)  # context_length x d//2
    self.sin = jnp.expand_dims(self.sin, axis=-2)  # context_length x 1 x d//2
    self.cos = jnp.expand_dims(self.cos, axis=-2)  # context_length x 1 x d//2

  def apply(self, x):
    batch_size, context_length, num_heads, d = x.shape
    x0, x1 = jnp.split(x, 2, axis=-1)  
    # Each of shape (batch_size, context_length, num_heads, d // 2)
    x = jnp.concatenate(
        [x0 * self.cos - x1 * self.sin, x1 * self.cos + x0 * self.sin], axis=-1
    )
    return x


def sinusoidal_position_embedding(T, d):
  thetas = 10000 ** (-jnp.arange(0, d // 2) / (d // 2))
  positions = jnp.arange(0, T)
  positions_times_thetas = jnp.outer(positions, thetas)  # T, d//2
  sin = jnp.sin(positions_times_thetas)  # T x d//2
  cos = jnp.cos(positions_times_thetas)  # T x d//2
  pe = jnp.concatenate(
      [sin[..., None], cos[..., None]], axis=-1
  )  # T x d// 2 x 2
  pe = pe.reshape((T, d)) / jnp.sqrt(d)
  return pe


self_tensor = jax.jit(lambda a: jnp.kron(a, a))


def rows_tensored(x):
  return jnp.apply_along_axis(self_tensor, axis=-1, arr=x)


@functools.partial(jax.remat, static_argnums=(3, 4, 5))
def stable_linear_attention(
    A, B, C, grain_size, multi_precision, normalization
    ):
  # tensored_A = rows_tensored(A)
  # tensored_B = rows_tensored(B)
  context, _ = C.shape
  C_concatenated = jnp.concatenate(
      [C, jnp.ones((context, 1), dtype=jnp.float32)], axis=-1
  )
  # result = stable_lt_multiply(
  #     tensored_A, tensored_B, C_concatenated, multi_precision
  # )
  result = stable_tensor_lt_multiply(
      A, B, C_concatenated, grain_size, multi_precision
  )
  numerator = result[:, :-1]
  denominator = result[:, -1]  # (T,)
  # normalization_vector = normalization*jnp.arange(1, A.shape[0]+1, dtype=jnp.float32)
  denominator = normalization + jnp.maximum(denominator, 1e-10)
  return numerator / (denominator[..., None])


batched_stable_linear_attention = jax.vmap(
    stable_linear_attention, in_axes=(0, 0, 0, None, None, None), out_axes=0
)


def stable_tensor_lt_multiply(A,
                              B,
                              C,
                              grain_size,
                              multi_precision=jax.lax.Precision.DEFAULT):
  # Receive untensored n x r matrices A and B and n x d matrix C and compute 
  # lt(tensor(A) @ tensor(B).transpose) @ C
  _, r = A.shape
  _, d = C.shape
  # grain_size = r**2
  A_view = A.reshape(-1, grain_size, r)
  B_view = B.reshape(-1, grain_size, r)
  C_view = C.reshape(-1, grain_size, d)

  B_tensor_transpose_C_products = jnp.einsum('...ti, ...tj, ...td -> ...ijd',
                                             B_view,
                                             B_view,
                                             C_view,
                                             precision=multi_precision)
  B_tensor_transpose_C_products_cumsum = jnp.cumsum(
      B_tensor_transpose_C_products, axis=0)
  result = jnp.einsum('...ti, ...tj, ...ijd -> ...td',
                      A_view[1:],
                      A_view[1:],
                      B_tensor_transpose_C_products_cumsum[:-1],
                      precision=multi_precision)
  print(result.shape)
  result = result.reshape(-1, d)
  print(result.shape)
  result = jnp.pad(result, ((grain_size, 0), (0, 0)))
  print(result.shape)
  correction = jnp.einsum('...ti, ...tj, ...si, ...sj -> ...ts',
                          A_view,
                          A_view,
                          B_view,
                          B_view)
  correction = jnp.tril(correction)
  print(correction.shape)
  correction = jnp.einsum('...ts, ...sd -> ...td',
                          correction,
                          C_view,
                          precision=multi_precision)
  correction = correction.reshape(-1, d)
  return result + correction

