"""Contains the PolySketchFormer definition."""
import functools

from flax import linen as nn
import jax
import jax.numpy as jnp

import utils


def construct_hadamard_matrix(dim):
  assert (dim & (dim - 1)) == 0
  if dim == 1:
    return jnp.array([[1.0]])
  small = construct_hadamard_matrix(dim // 2)
  top = jnp.concatenate([small, small], axis=-1)
  bottom = jnp.concatenate([small, -small], axis=-1)
  return jnp.concatenate([top, bottom], axis=0)


@jax.custom_vjp
def hadamard_transform(x):
  """Applies hadamard transform to each row of the matrix x."""
  original_shape = x.shape
  if len(original_shape) == 3:
    batch_size, context_length, d = x.shape
    x = x.reshape(batch_size * context_length, d)
  _, d = x.shape
  h_2x2 = jnp.array([[1.0, 1.0], [1.0, -1.0]], dtype=x.dtype)

  def _hadamard_step(x):
    x = x.reshape(-1, 2)
    x = x @ h_2x2
    x = x.reshape(-1, d // 2, 2)
    x = x.transpose(0, 2, 1)
    x = x.reshape(-1, d)
    return x

  i = 0
  cond = lambda arg: (2 ** (arg[0]) < d)
  body = lambda arg: (arg[0] + 1, _hadamard_step(arg[1]))
  _, x = jax.lax.while_loop(cond, body, (i, x))
  return x.reshape(original_shape)


def hadamard_transform_fwd(x):
  return hadamard_transform(x), None


def hadamard_transform_bwd(_, y_bar):
  return (hadamard_transform(y_bar),)


hadamard_transform.defvjp(hadamard_transform_fwd, hadamard_transform_bwd)


@functools.partial(jax.vmap, in_axes=(0, None, None), out_axes=0)
def scale_rotate_select(x, signs, select):
  m = select.shape[0]
  x = x * signs[None, :]
  x = hadamard_transform(x)
  x = x[:, select] / jnp.sqrt(m)
  return x


@functools.partial(jax.vmap, in_axes=(0, 0, None, None, None, None), out_axes=0)
def scale_rotate_tensor_select(x, y, signs1, signs2, select1, select2):
  m = select1.shape[0]
  x = x * signs1[None, :]  # (n, d) * (1, d) --> (n, d) * (n, d) --> (n, d)
  x = hadamard_transform(x)
  y = y * signs2[None, :]
  y = hadamard_transform(y)
  return x[:, select1] * y[:, select2] / jnp.sqrt(m)


def naive_create_SRHT(random_key, d, m):
  key1, key2 = jax.random.split(random_key)
  D = 2 * jax.random.randint(key1, (d,), 0, 2) - 1
  P = jax.random.randint(key2, (m,), 0, d)
  m = D[..., None] * (construct_hadamard_matrix(d)[:, P])
  return m


def naive_apply_SRHT(m, x):
  return x @ m / jnp.sqrt(m.shape[-1])


def create_SRHT(random_key, d, m):
  """This function initializes an SRHT sketch.

  Initialized with (random_key, d, m) where d is the original dimension and m is
  the sketch dimension. The random_key is used to sample both the sign and
  selecting matrices.
  """
  SRHT = {}
  key1, key2 = jax.random.split(random_key)
  SRHT['D'] = 2 * jax.random.randint(key1, (d,), 0, 2) - 1
  SRHT['P'] = jax.random.randint(key2, (m,), 0, d)
  return SRHT


def apply_SRHT(SRHT, x):
  """Input [batch,..., n, d].

  Output [batch,..., n, m]. where n can be arbitrary.
  """
  return scale_rotate_select(x, SRHT['D'], SRHT['P'])


def naive_create_TensorSRHT(random_key, d, m):
  key1, key2, key3, key4 = jax.random.split(random_key, 4)
  D1 = 2 * jax.random.randint(key1, (d,), 0, 2) - 1
  D2 = 2 * jax.random.randint(key2, (d,), 0, 2) - 1
  select1 = jax.random.randint(key3, (m,), 0, d)
  select2 = jax.random.randint(key4, (m,), 0, d)

  m1 = D1[..., None] * (construct_hadamard_matrix(d)[:, select1])
  m2 = D2[..., None] * (construct_hadamard_matrix(d)[:, select2])
  return (m1, m2)


def naive_apply_TensorSRHT(m1, m2, A, B):
  return (A @ m1) * (B @ m2) / jnp.sqrt(m1.shape[-1])


def create_TensorSRHT(random_key, d, m):
  """This function instantiates a TensorSRHT sketch.

  Initiailized with (random_key, d, m, p) where d is the original dimension and
  m is the sketch dimension.

  We instantiate two +1/-1 scaling matrices. Given A and B, we first scale the
  columns of A and B using these sign matrices respectively.
  Then we rotate the rows of the scaled matrices using Hadamard Transform.
  Finally, we select m rows of the matrix formed by tensor products of rows of
  the rotated matrices.
  To implement this efficiently, we select m columns of scaled_rotated_A and m
  columns of scaled_rotated_B and take hadamard products of the matrices formed
  by those columns.
  """
  TensorSRHT = {}
  key1, key2, key3, key4 = jax.random.split(random_key, num=4)
  TensorSRHT['D1'] = 2 * jax.random.randint(key1, (d,), 0, 2) - 1
  TensorSRHT['D2'] = 2 * jax.random.randint(key2, (d,), 0, 2) - 1
  TensorSRHT['select1'] = jax.random.randint(key3, (m,), 0, d)
  TensorSRHT['select2'] = jax.random.randint(key4, (m,), 0, d)
  return TensorSRHT


def apply_TensorSRHT(TensorSRHT, A, B):
  """Input: Two tensors of shape [batch,..., n, d].

  Output: One tensor of shape [batch,..., n, m]. where n can be arbitrary.
  """
  return scale_rotate_tensor_select(
      A,
      B,
      TensorSRHT['D1'],
      TensorSRHT['D2'],
      TensorSRHT['select1'],
      TensorSRHT['select2'],
  )


def naive_create_TensorSketch_degree_2(random_key, d, m):
  TensorSketch = {}
  keys = jax.random.split(random_key, num=3)
  TensorSketch['base_transforms'] = [
      naive_create_SRHT(keys[0], d, m),
      naive_create_SRHT(keys[1], d, m),
  ]
  TensorSketch['upper_transform'] = naive_create_TensorSRHT(keys[2], m, m)
  return TensorSketch


def naive_apply_TensorSketch_degree_2(TensorSketch, x):
  M_0 = naive_apply_SRHT(TensorSketch['base_transforms'][0], x)
  M_1 = naive_apply_SRHT(TensorSketch['base_transforms'][1], x)
  return naive_apply_TensorSRHT(*TensorSketch['upper_transform'], M_0, M_1)


def create_TensorSketch_degree_2(random_key, d, m):
  """This function instantiates a TensorSketch for degree 2 polynomials.

  Requires a random_key used to instantiate internal random data structures. d -
  original dimension m - sketch dimension
  """
  TensorSketch = {}
  keys = jax.random.split(random_key, num=3)
  TensorSketch['base_transforms'] = [
      create_SRHT(keys[0], d, m),
      create_SRHT(keys[1], d, m),
  ]
  TensorSketch['upper_transform'] = create_TensorSRHT(keys[2], m, m)
  return TensorSketch


def apply_TensorSketch_degree_2(TensorSketch, x):
  M_0 = apply_SRHT(TensorSketch['base_transforms'][0], x)
  M_1 = apply_SRHT(TensorSketch['base_transforms'][1], x)
  return apply_TensorSRHT(TensorSketch['upper_transform'], M_0, M_1)


def naive_instantiate_all_tensor_sketches(
    config: utils.TransformerConfig, random_key
):
  """Instantiates all the matrices required to apply tensorsketch at all attention
  heads and all layers.
  """
  # There are n_layer number of layers
  # Each layer has n_head number of attention heads
  # Each attention head has one tensorsketch datastructure
  # Each tensorsketch datastructure has 2 matrices as base transforms
  # and 2 additional matrices which form the upper transform

  # We output base: n_layer x n_head x 2 x head_size x sketch_size <- base transform
  # and upper: n_layer x n_head x 2 x head_size x head_size <- upper transform

  base = []
  upper = []

  for i in range(config.num_layers):
    key_i = jax.random.fold_in(random_key, i)
    base_i = []
    upper_i = []
    head_size = config.emb_dim // config.num_heads
    for j in range(config.num_heads):
      key_ij = jax.random.fold_in(key_i, j)
      TensorSketch = naive_create_TensorSketch_degree_2(
          key_ij, head_size, config.sketch_size
      )
      # print(TensorSketch['base_transforms'][0])
      base_i.append(jnp.concatenate([TensorSketch['base_transforms'][0][None],
                                     TensorSketch['base_transforms'][1][None]],
                                    axis=0))
      upper_i.append(jnp.concatenate([TensorSketch['upper_transform'][0][None],
                                      TensorSketch['upper_transform'][1][None]],
                                     axis=0))
    base.append(jnp.concatenate([x[None] for x in base_i]))
    upper.append(jnp.concatenate([x[None] for x in upper_i]))

  base = jnp.concatenate([x[None] for x in base])
  upper = jnp.concatenate([x[None] for x in upper])
  return {'base': base, 'upper': upper}


def naive_apply_tensorsketch_from_components(sketch, x):
  base_0 = sketch['base'][0]
  base_1 = sketch['base'][1]
  m = base_0.shape[-1]
  M_0 = x @ base_0 / jnp.sqrt(m)
  M_1 = x @ base_1 / jnp.sqrt(m)
  upper_0 = sketch['upper'][0]
  upper_1 = sketch['upper'][1]
  m_dash = upper_0.shape[-1]
  M = (M_0 @ upper_0) * (M_1 @ upper_1) / jnp.sqrt(m_dash)
  return M


def instantiate_all_tensor_sketches(
    config: utils.TransformerConfig, random_key
):
  """This function instantiates all the necessary datastructures for sketching given

  a transformer configuration.
  """
  # There are n_layer number of layers
  # Each layer has n_head number of attention heads
  # Each attention head has one tensorsketch
  # Each tensorsketch has 2 SRHTs and 1 TensorSRHT
  # Each SRHT has a select array and scale array
  # Each TensorSRHT has two select arrays and two scale arrays

  # SRHT select : n_layer x n_head x 2 x sketch_size
  # SRHT scale : n_layer x n_head x 2 x dimension
  # TensorSRHT select : n_layer x n_head x 2 x sketch_size
  # TensorSRHT scale : n_layer x n_head x 2 x dimension
  SRHT_select = []
  SRHT_scale = []
  TensorSRHT_select = []
  TensorSRHT_scale = []

  for i in range(config.num_layers):
    key_i = jax.random.fold_in(random_key, i)
    SRHT_select_layer = []
    SRHT_scale_layer = []
    TensorSRHT_select_layer = []
    TensorSRHT_scale_layer = []
    head_size = config.emb_dim // config.num_heads
    for j in range(config.num_heads):
      key_ij = jax.random.fold_in(key_i, j)
      TensorSketch = create_TensorSketch_degree_2(
          key_ij, head_size, config.sketch_size
      )
      SRHT_select_layer.append(
          jnp.concatenate([
              TensorSketch['base_transforms'][0]['P'][None],
              TensorSketch['base_transforms'][1]['P'][None],
          ])
      )
      SRHT_scale_layer.append(
          jnp.concatenate([
              TensorSketch['base_transforms'][0]['D'][None],
              TensorSketch['base_transforms'][1]['D'][None],
          ])
      )
      TensorSRHT_select_layer.append(
          jnp.concatenate([
              TensorSketch['upper_transform']['select1'][None],
              TensorSketch['upper_transform']['select2'][None],
          ])
      )
      TensorSRHT_scale_layer.append(
          jnp.concatenate([
              TensorSketch['upper_transform']['D1'][None],
              TensorSketch['upper_transform']['D2'][None],
          ])
      )
    SRHT_select_layer = jnp.concatenate([x[None] for x in SRHT_select_layer])
    SRHT_scale_layer = jnp.concatenate([x[None] for x in SRHT_scale_layer])
    TensorSRHT_select_layer = jnp.concatenate(
        [x[None] for x in TensorSRHT_select_layer]
    )
    TensorSRHT_scale_layer = jnp.concatenate(
        [x[None] for x in TensorSRHT_scale_layer]
    )

    SRHT_select.append(SRHT_select_layer)
    SRHT_scale.append(SRHT_scale_layer)
    TensorSRHT_select.append(TensorSRHT_select_layer)
    TensorSRHT_scale.append(TensorSRHT_scale_layer)

  SRHT_select = jnp.concatenate([x[None] for x in SRHT_select])
  SRHT_scale = jnp.concatenate([x[None] for x in SRHT_scale])
  TensorSRHT_select = jnp.concatenate([x[None] for x in TensorSRHT_select])
  TensorSRHT_scale = jnp.concatenate([x[None] for x in TensorSRHT_scale])

  return {
      'SRHT_select': SRHT_select,
      'SRHT_scale': SRHT_scale,
      'TensorSRHT_select': TensorSRHT_select,
      'TensorSRHT_scale': TensorSRHT_scale,
  }


def apply_TensorSketch_degree_2_from_components(sketch, x):
  SRHT_scale = sketch['SRHT_scale']
  SRHT_select = sketch['SRHT_select']
  TensorSRHT_scale = sketch['TensorSRHT_scale']
  TensorSRHT_select = sketch['TensorSRHT_select']
  x0 = scale_rotate_select(x, SRHT_scale[0, :], SRHT_select[0, :])
  x1 = scale_rotate_select(x, SRHT_scale[1, :], SRHT_select[1, :])
  return scale_rotate_tensor_select(
      x0,
      x1,
      TensorSRHT_scale[0, :],
      TensorSRHT_scale[1, :],
      TensorSRHT_select[0, :],
      TensorSRHT_select[1, :],
  )


class SketchedAttention(nn.Module):
  """Attention Head."""
  config: utils.TransformerConfig

  def setup(self):
    self.head_size = self.config.emb_dim // self.config.num_heads
    self.rope = utils.RoPE(self.config.context_length, self.head_size)

  @nn.remat
  @nn.compact
  def __call__(self, x, sketch):
    B, T, _ = x.shape
    projection = nn.Dense(3 * self.head_size, use_bias=False)(
        x
    )  # (B, T, 3*head_size)
    q, k, v = jnp.array_split(
        projection, 3, axis=2
    )  # Each of size (B, T, head_size)

    if self.config.disable_attention:
      return v

    q = self.rope.apply(q)
    k = self.rope.apply(k)

    q = q / jnp.sqrt(self.head_size)
    k = k / jnp.sqrt(self.head_size)

    sketched_q = naive_apply_tensorsketch_from_components(
        sketch, q
    )  # Each of size (B, T, sketch_size)
    sketched_k = naive_apply_tensorsketch_from_components(
        sketch, k
    )  # Each of size (B, T, sketch_size)

    if self.config.lt_multiply:
      output = utils.batched_stable_linear_attention(
          sketched_q,
          sketched_k,
          v,
          self.config.grain_size,
          self.config.precision,
          self.config.normalization)
    else:
      tensored_q = utils.batch_rows_tensored(sketched_q)  # (B, T, sketch_size^2)
      tensored_k = utils.batch_rows_tensored(sketched_k)  # (B, T, sketch_size^2)
      # tensored_q, tensored_k <- (B, T, sketch_size^2)
      attn_matrix = jnp.tril(tensored_q @ tensored_k.transpose(0, 2, 1))  # (B, T, T)
      # attn_matrix <- (B, T, T) and lower triangular
      qkv = attn_matrix @ v  # (B, T, head_size)
      scaling = jnp.einsum('...qd -> ...q', attn_matrix)  # (B, T)
      normalization_vector = self.config.normalization * jnp.arange(1, T+1, dtype=jnp.float32)
      scaling = normalization_vector[..., None] + jnp.maximum(scaling[..., None], 1e-10)  # (B, T, 1)
      output = qkv / scaling  # (B, T, head_size)
    return output


class MultiHeadAttention(nn.Module):
  config: utils.TransformerConfig

  @nn.compact
  def __call__(self, x, sketches, training: bool):
    def body_fn(cell, x, sketch):
      # Takes an input x and a sketch and applies a single attention head
      return cell(x, sketch)

    cell = SketchedAttention(self.config)
    mha_module = nn.vmap(
        body_fn,
        variable_axes={'params': 0},
        split_rngs={'params': True},
        in_axes=(None, 0),
        out_axes=0,
    )  # Share the  same input and split input sketches
    x = mha_module(cell, x, sketches)
    x = jnp.concatenate(x, axis=-1)
    x = nn.Dense(self.config.emb_dim, use_bias=False)(x)
    x = nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=not training)
    return x


class FeedForward(nn.Module):
  config: utils.TransformerConfig

  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(4 * self.config.emb_dim, use_bias=False)(x)
    x = nn.gelu(x)
    x = nn.Dense(self.config.emb_dim, use_bias=False)(x)
    x = nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=not training)
    return x


class GLU(nn.Module):
  config: utils.TransformerConfig

  @nn.compact
  def __call__(self, x, training: bool):
    hidden_dim = (8 * self.config.emb_dim // 3) // 8 * 8
    # The above makes the hidden dim divisible by 8
    # We chose it so as to be close to the number of parameters in a model that
    # uses an expansion factor of 4 when using simple FeedForward layer as above.
    x = nn.Dense(2 * hidden_dim, use_bias=False)(x)
    x1, x2 = jnp.split(x, 2, axis=-1)
    x2 = nn.gelu(x2)
    x = x1 * x2
    x = nn.Dense(self.config.emb_dim, use_bias=False)(x)
    x = nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=not training)
    return x


class Block(nn.Module):
  config: utils.TransformerConfig

  def setup(self):
    self.ln_1 = nn.LayerNorm(epsilon=1e-5)
    self.attn = MultiHeadAttention(self.config)
    self.ln_2 = nn.LayerNorm(epsilon=1e-5)
    self.ff = (
        GLU(self.config) if self.config.use_glu else FeedForward(self.config)
    )

  def __call__(self, x, training, sketch):
    x = x + self.attn(self.ln_1(x), sketch, training=training)
    x = x + self.ff(self.ln_2(x), training=training)
    return x


class Transformer(nn.Module):
  config: utils.TransformerConfig
  tree_of_sketches: dict[str, jnp.ndarray]
  # random_key: jax.random.PRNGKeyArray  # used for generating sketches

  def setup(self):
    self.wte = nn.Embed(self.config.vocab_size, self.config.emb_dim)
    self.pe = utils.sinusoidal_position_embedding(
        self.config.context_length, self.config.emb_dim
    )
    self.ln = nn.LayerNorm(epsilon=1e-5)
    # self.tree_of_sketches = naive_instantiate_all_tensor_sketches(
    #     self.config, self.random_key
    # )
    self.softmax_bias = self.param(
        'softmax_bias', nn.initializers.zeros_init(), (self.config.vocab_size,)
    )

  @nn.compact
  def __call__(self, idx, training: bool):
    token_embd = self.wte(idx)
    x = token_embd + self.pe

    block = Block(self.config)  # Defines the block that is to be applied

    def body_fn_sketched(cell, carry, sketch):
      # Takes the input from previous layer, sketches corresponding to that block
      # and returns the result of that block as carry.
      return cell(carry, training, sketch), None

    sequential_blocks = nn.scan(
        body_fn_sketched,
        variable_axes={'params': 0},
        split_rngs={'params': True, 'dropout': True},
        length=self.config.num_layers,
    )

    x, _ = sequential_blocks(block, x, self.tree_of_sketches)
    x = self.ln(x)  # B, T, C
    # logits = self.lm_head(x)  # B, T, V
    logits = self.wte.attend(x) + self.softmax_bias
    return logits
