import numpy as np
from t5x.examples.t5.layers import *
import flax.linen.partitioning as nn_partitioning


class SinusoidalPositionBiases(nn.Module):
  shape: Optional[Tuple[int]] = None
  dtype: Any = jnp.float32

  @nn.compact
  def __call__(self, x):
    embed_dim = x.shape[-1]
    length = np.prod(self.shape or x.shape[1:-1])
    pos_seq = jnp.arange(length, dtype=self.dtype)

    inv_freq = jnp.arange(0.0, embed_dim, 2.0) / embed_dim
    inv_freq = 1. / (10000 ** inv_freq)
    inv_freq = jnp.asarray(inv_freq, self.dtype)

    sinusoid_inp = jnp.outer(pos_seq, inv_freq)
    position_bias = jnp.concatenate([jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp)], axis=-1)
    return position_bias

    
class RotaryPositionBiases(nn.Module):
  seq_len: int
  dim: int
  dtype: Any = jnp.float32

  @nn.compact
  def __call__(self):
    embed_dim = self.dim
    length = self.seq_len

    inv_freq = 1.0 / (10000 ** (jnp.arange(0, embed_dim, 2).astype(jnp.float32) / embed_dim)) 
    pos_seq = jnp.arange(length, dtype=jnp.float32)
    freqs = jnp.einsum('i,j->ij', pos_seq, inv_freq)
    return jnp.concatenate([freqs, freqs], axis=-1)
    
    
class AbsolutePositionBiases(nn.Module):
  dtype: Any = jnp.float32
  embedding_init: Callable[..., Array] = nn.linear.default_embed_init

  @nn.compact
  def __call__(self, x):
    position_bias = nn_partitioning.param_with_axes('abs_embedding', self.embedding_init,
                                    x.shape[1:], jnp.float32, axes=('length', 'embed'))
    return position_bias

    
class BroadcastPositionBiases(nn.Module):
  shape: Optional[Tuple[int]] = None
  dtype: Any = jnp.float32

  @nn.compact
  def __call__(self, x):
    shape = self.shape or x.shape[1:-1]
    n_dim = len(self.shape)
    embed_dim = x.shape[-1]

    chunk_sizes = [embed_dim // n_dim + (i < (embed_dim % n_dim))
                   for i in range(n_dim)]
    assert sum(chunk_sizes) == embed_dim, f'sum({chunk_sizes}) = {sum(chunk_sizes)} != {embed_dim}'

    embs = [
      self.param(f'd_{i}', nn.initializers.normal(stddev=0.02),
                      (shape[i], chunk_sizes[i]), jnp.float32)
      for i in range(n_dim)
    ]

    out = []
    for i in range(n_dim):
      e = embs[i]
      e = jnp.reshape(e, (1,) + (1,) * i + (shape[i],) + (1,) * (n_dim - i - 1) + (-1,))
      e = jnp.broadcast_to(e, (1, *shape, e.shape[-1]))
      out.append(e)
    out = jnp.concatenate(out, axis=-1)
    out = jnp.asarray(out, self.dtype)

    out = jnp.reshape(out, (np.prod(shape), embed_dim))

    return out    

    
class FrameBroadcastPositionBiases(nn.Module):
  shape: Tuple[int]
  dtype: Any = jnp.float32

  @nn.compact
  def __call__(self, x):
    shape = self.shape
    time_shape, frame_shape = shape[:1], shape[1:]

    time_embed = SinusoidalPositionBiases(shape=time_shape, dtype=self.dtype)(x) # TD
    spatial_embed = BroadcastPositionBiases(shape=frame_shape, dtype=self.dtype)(x) # (HW)D
    
    position_embed = time_embed[:, None] + spatial_embed[None] # T(HW)D 
    position_embed = position_embed.reshape(np.prod(shape), x.shape[-1]) # (THW)D
    return position_embed
  
    
class RightShift(nn.Module):
  dtype: Any = jnp.float32

  @nn.compact
  def __call__(self, x):
    sos = nn_partitioning.param_with_axes('sos', nn.initializers.normal(stddev=0.02),
                          (x.shape[-1],), self.dtype, axes=('embed',))
    x_shape = list(x.shape)
    x = jnp.reshape(x, (x_shape[0], -1, x_shape[-1]))
    sos = jnp.tile(sos[None, None], (x_shape[0], 1, 1))
    sos = jnp.asarray(sos, self.dtype)
    x = jnp.concatenate([sos, x[:, :-1]], axis=1)
    x = jnp.reshape(x, x_shape)

    return x
