from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import flax.linen.initializers as nn_init
from . import layers


class Transformer(nn.Module):
    embed_dim: int
    num_heads: int
    num_layers: int
    mlp_dim: int
    dropout: float
    attention_dropout: float
    vocab_size: Optional[int] = None
    vocab_dim: Optional[int] = None
    shape: Optional[Tuple[int]] = None
    attention_type: str = 'full'
    pos_embed_type: str = 'absolute'
    out_dim: Optional[int] = None
    use_fc_in: bool = True
    right_shift: bool = False
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, inputs, mask=None, deterministic=False, cond=dict(), decode_step=None):        
        # Supported cond modes:
        # - "cat": should be same length as input, concats before attention
        # - "cat_attn": concatenate to KV self-attention with learned LN and projection
        # - "cross_attn": adds new cross attention block with learned parameters

        if cond is None:
            cond = dict()
        
        x = inputs

        if self.vocab_size is not None and self.vocab_dim is not None:
            assert not isinstance(x, tuple), f'This mode does not support multi-modal input'
            x = layers.Embed(
                num_embeddings=self.vocab_size,
                features=self.vocab_dim,
                attend_dtype=jnp.float32,
                embedding_init=nn.initializers.normal(stddev=1.0),
                dtype=self.dtype
            )(x)

        if 'cat' in cond:
            assert not isinstance(x, tuple), f'This mode does not support multi-modal input'
            x = jnp.concatenate((x, cond['cat']), axis=-1)

        is_tuple = isinstance(x, tuple)
        if not is_tuple:
            x = (x,)
        
        if self.use_fc_in:
            x = [
                    layers.DenseGeneral(
                        features=self.embed_dim, dtype=self.dtype,
                        kernel_axes=('vocab','embed')
                    )(x_i)
                    for x_i in x
                ]
        
        old_shapes = [x_i.shape[1:-1] for x_i in x]
        x = [x_i.reshape(x_i.shape[0], -1, x_i.shape[-1]) for x_i in x]
        x = jnp.concatenate(x, axis=1)

        if self.right_shift:
            if decode_step is None:
                x = layers.RightShift(self.dtype)(x)
            else:
                x_shift = layers.RightShift(self.dtype)(x)
                x = jax.lax.cond(decode_step > 0, lambda: x, lambda: x_shift)

        if self.pos_embed_type == 'absolute':
            position_bias = layers.AbsolutePositionBiases(dtype=self.dtype)(x)
        elif self.pos_embed_type == 'broadcast':
            position_bias = layers.BroadcastPositionBiases(shape=self.shape, 
                                                           dtype=self.dtype)(x)
        elif self.pos_embed_type == 'sinusoidal':
            position_bias = layers.SinusoidalPositionBiases(dtype=self.dtype)(x)
        elif self.pos_embed_type == 'frame_broadcast':
            position_bias = layers.FrameBroadcastPositionBiases(shape=self.shape, 
                                                                dtype=self.dtype)(x, cond['frame_indices'])
        elif self.pos_embed_type == 'none':
            position_bias = None
        else:
            raise Exception(f'Invalid pos_embed_type: {self.pos_embed_type}')

        if position_bias is not None:
            if decode_step is not None and x.shape[1] == 1:
                position_bias = position_bias[decode_step]
            
            x += position_bias
        
        x = nn.Dropout(rate=self.dropout)(x, deterministic=deterministic)
        x = x.astype(self.dtype)

        for _ in range(self.num_layers):
            x = TransformerLayer(
                embed_dim=self.embed_dim,
                num_heads=self.num_heads,
                mlp_dim=self.mlp_dim,
                shape=self.shape,
                dropout=self.dropout,
                attention_dropout=self.attention_dropout,
                attention_type=self.attention_type,
                dtype=self.dtype
            )(x, mask=mask, cond=cond, deterministic=deterministic, decode_step=decode_step)
        
        x = layers.DenseGeneral(
            self.embed_dim,
            dtype=self.dtype,
            kernel_axes=('embed', 'mlp')
        )(x)
        x = nn.gelu(x)
        x = LayerNorm(dtype=self.dtype)(x)
        x = nn.Dropout(rate=self.dropout)(x, deterministic=deterministic)

        if self.out_dim is not None:
            x = layers.DenseGeneral(
                self.out_dim,
                dtype=jnp.float32,
                kernel_axes=('embed', 'vocab')
            )(x)

        split_idxs = np.cumsum([np.prod(os) for os in old_shapes])[:-1]
        x = jnp.split(x, split_idxs, axis=1)
        assert len(x) == len(old_shapes)
        x = [x_i.reshape(x_i.shape[0], *os, x_i.shape[-1])
             for x_i, os in zip(x, old_shapes)]

        if not is_tuple:
            x = x[0]

        return x
    

class TransformerLayer(nn.Module):
    embed_dim: int
    num_heads: int
    mlp_dim: int
    dropout: float
    attention_dropout: float
    attention_type: str = 'full'
    shape: Optional[Tuple] = None
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, inputs, mask=None, cond=dict(), deterministic=False, decode_step=None):
        x = LayerNorm(dtype=self.dtype)(inputs)

        if self.attention_type == 'full':
            x = MultiHeadDotProductAttention(
                num_heads=self.num_heads,
                head_dim=self.embed_dim // self.num_heads,
                dropout_rate=self.attention_dropout,
                dtype=self.dtype
            )(x, x, mask=mask, cond=cond.get('cat_cond', None), 
              deterministic=deterministic, decode_step=decode_step)
        elif self.attention_type == 'spatio_temporal':
            x = MultiHeadSpatioTemporalAttention(
                shape=self.shape,
                num_heads=self.num_heads,
                head_dim=self.embed_dim // self.num_heads,
                dropout_rate=self.attention_dropout,
                dtype=self.dtype
            )(x, x, deterministic=deterministic, decode_step=decode_step)
        else:
            raise Exception(f'Invalid attention_type: {self.attention_type}')
        x = nn.Dropout(rate=self.dropout)(x, deterministic=deterministic)
        x = x + inputs

        y = LayerNorm(dtype=self.dtype)(x)
        y = layers.MlpBlock(
            intermediate_dim=self.mlp_dim,
            activations=(gelu2,),
            intermediate_dropout_rate=self.dropout,
            dtype=self.dtype
        )(y, deterministic=deterministic)
        y = nn.Dropout(rate=self.dropout)(y, deterministic=deterministic)
        y = y + x
        
        return y

        
class MultiHeadDotProductAttention(nn.Module):
    num_heads: int
    head_dim: int
    dtype: Any = jnp.float32
    dropout_rate: float = 0.
    max_heads_processed: Optional[int] = None
    kernel_init: Any = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')

    def _rotate_half(self, x):
        x1, x2 = jnp.split(x, 2, axis=-1)
        return jnp.concatenate([-x2, x1], axis=-1)

    def _apply_rotary_pos_emb(self, t, rotary_embeds, rotary_idxs):
        if rotary_idxs is not None:
            rotary_embeds = rotary_embeds[rotary_idxs]
        rotate_dim = rotary_embeds.shape[-1]
        t, t_pass =  t[..., :rotate_dim], t[..., rotate_dim:]
        t = (t * jnp.cos(rotary_embeds)[:, None]) + (self._rotate_half(t) * jnp.sin(rotary_embeds)[:, None])
        return jnp.concatenate([t, t_pass], axis=-1)

    @nn.compact
    def __call__(self, inputs_q, inputs_kv, mask=None, cond=None, 
                 deterministic=False, decode_step=None, rotary_embeds=None,
                 q_rotary_idxs=None, kv_rotary_idxs=None):
        max_heads_processed = self.max_heads_processed or self.num_heads

        projection = partial(
            layers.DenseGeneral,
            axis=-1,
            features=(self.num_heads, self.head_dim),
            kernel_axes=('embed', 'joined_kv'),
            dtype=self.dtype
        )

        depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)

        query = projection(kernel_init=self.kernel_init, name='query')(inputs_q) / depth_scaling
        key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
        value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv)

        if decode_step is not None:
            cached_key = self.variable('cache', 'cached_key', lambda: key)
            cached_value = self.variable('cache', 'cached_value', lambda: value)

            is_slice = inputs_q.shape[1] == 1
            if is_slice:
                key = cached_key.value.at[:, decode_step].set(key[:, 0])
            else:
                key = cached_key.value.at[:].set(key) 

            if is_slice:
                value = cached_value.value.at[:, decode_step].set(value[:, 0])
            else:
                value = cached_value.value.at[:].set(value)
            
            if mask is not None and is_slice:
                mask = mask[decode_step, None]                

            cached_key.value = key
            cached_value.value = value

        if cond is not None:
            cond = LayerNorm(dtype=self.dtype)(cond) 
            key_cond = projection(kernel_init=self.kernel_init, name='key_cond')(cond)
            key = jnp.concatenate([key, key_cond], axis=1)
            value_cond = projection(kernel_init=self.kernel_init, name='value_cond')(cond)
            value = jnp.concatenate([value, value_cond], axis=1)
        
        query = layers.with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv'))
        key = layers.with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv'))
        value = layers.with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv'))

        if rotary_embeds is not None:
            query = self._apply_rotary_pos_emb(query, rotary_embeds, q_rotary_idxs)
            key = self._apply_rotary_pos_emb(key, rotary_embeds, kv_rotary_idxs)

        if mask is not None:
            attention_bias = jax.lax.select(
                mask > 0,
                jnp.full(mask.shape, 0.).astype(self.dtype),
                jnp.full(mask.shape, -1e10).astype(self.dtype)
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')
        
        n_chunks = self.num_heads // max_heads_processed
        x = []
        for q, k, v in zip(jnp.split(query,  n_chunks, axis=-2),
                           jnp.split(key, n_chunks, axis=-2),
                           jnp.split(value, n_chunks, axis=-2)):
            x_i = layers.dot_product_attention(
                q, k, v,
                bias=attention_bias,
                dropout_rng=dropout_rng,
                dropout_rate=self.dropout_rate,
                deterministic=deterministic,
                dtype=self.dtype
            )
            x.append(x_i)
        x = jnp.concatenate(x, axis=-2)

        out = layers.DenseGeneral(
            features=inputs_q.shape[-1],
            axis=(-2, -1),
            kernel_init=self.kernel_init,
            kernel_axes=('joined_kv', 'embed'),
            dtype=self.dtype,
            name='out'
        )(x)
        return out
            

class MultiHeadSpatioTemporalAttention(nn.Module):
    shape: Tuple[int]
    num_heads: int
    head_dim: int
    dtype: Any = jnp.float32
    dropout_rate: float = 0.
    kernel_init: Any = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')

    def _attention(self, q, k, v, bias, deterministic):
        dropout_rng = None
        if not deterministic and self.dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')

        assert k.shape == v.shape
        out = layers.dot_product_attention(
            q, k, v, bias=bias, dropout_rng=dropout_rng,
            dropout_rate=self.dropout_rate, deterministic=deterministic,
            dtype=self.dtype
        )
        return out

    @nn.compact
    def __call__(self, inputs_q, inputs_kv, deterministic=False, decode_step=None):
        assert len(self.shape) == 3, f'Invalid shape: {self.shape}'
        is_slice = inputs_q.shape[1] == 1

        projection = partial(
            layers.DenseGeneral,
            axis=-1,
            features=(self.num_heads, self.head_dim),
            kernel_axes=('embed', 'joined_kv'),
            dtype=self.dtype
        )

        depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)

        query = projection(kernel_init=self.kernel_init, name='query')(inputs_q) / depth_scaling
        key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
        value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv) 

        if decode_step is not None:
            cached_key = self.variable('cache', 'cached_key', lambda: key)
            cached_value = self.variable('cache', 'cached_value', lambda: value)

            is_slice = inputs_q.shape[1] == 1
            if is_slice:
                key = cached_key.value.at[:, decode_step].set(key[:, 0])
            else:
                key = cached_key.value.at[:].set(key) 

            if is_slice:
                value = cached_value.value.at[:, decode_step].set(value[:, 0])
            else:
                value = cached_value.value.at[:].set(value)
            
            cached_key.value = key
            cached_value.value = value

        query = layers.with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv'))
        key = layers.with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv'))
        value = layers.with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv'))

        n_space, n_time = np.prod(self.shape[1:]), self.shape[0]
        key, value = list(
            map(lambda x: x.reshape(x.shape[0], n_time, n_space, *x.shape[-2:]),
                [key, value])
        )
        if is_slice:
            t_idx = decode_step // n_space
            s_idx = decode_step % n_space
        else:
            query = query.reshape(query.shape[0], n_time, n_space, *query.shape[-2:])

        # Space attention
        space_mask = jnp.tril(jnp.ones((n_space, 2 * n_space), dtype=bool))
        space_mask = space_mask.at[:, n_space:].set(jnp.tril(space_mask[:, n_space:]))
        space_mask = jnp.tile(space_mask[None], (n_time, 1, 1))
        space_mask = space_mask.at[0, :, :n_space].set(False)
        space_bias = jnp.where(space_mask, 0., -1e10)

        key_space = jnp.concatenate([
            jnp.pad(key, ((0, 0), (1, 0), (0, 0), (0, 0), (0, 0)))[:, :-1],
            key
        ], axis=2)
        value_space = jnp.concatenate([
            jnp.pad(value, ((0, 0), (1, 0), (0, 0), (0, 0), (0, 0)))[:, :-1],
            value
        ], axis=2)

        if is_slice:
            space_bias = space_bias[t_idx, s_idx, None]
            x_space = self._attention(query, key_space[:, t_idx], value_space[:, t_idx],
                                      bias=space_bias, deterministic=deterministic)
        else:
            x_space = jax.vmap(
                partial(self._attention, deterministic=deterministic),
                (1, 1, 1, 0), 1
            )(query, key_space, value_space, space_bias)
        
        # Temporal attention
        temporal_mask = jnp.tril(jnp.ones((n_time, n_time), dtype=bool))
        temporal_bias = jnp.where(temporal_mask, 0., -1e10)

        if is_slice:
            temporal_bias = temporal_bias[t_idx, None]
            x_temporal = self._attention(query, key[:, :, s_idx], value[:, :, s_idx],
                                         bias=temporal_bias, deterministic=deterministic)
        else:
            x_temporal = jax.vmap(
                partial(self._attention, bias=temporal_bias, deterministic=deterministic),
                2, 2
            )(query, key, value)

        # Total
        x = x_space + x_temporal
        
        out = layers.DenseGeneral(
            features=inputs_q.shape[-1],
            axis=(-2, -1),
            kernel_init=self.kernel_init,
            kernel_axes=('joined_kv', 'embed'),
            dtype=self.dtype,
            name='out'
        )(x)
        out = out.reshape(out.shape[0], -1, out.shape[-1])
        return out
        


class LayerNorm(nn.Module):
    epsilon: float = 1e-6
    dtype: Optional[Any] = None
    use_bias: bool = True
    use_scale: bool = True
    bias_init: Any = nn_init.zeros
    scale_init: Any = nn_init.ones
    reduction_axes: Any = -1
    feature_axes: Any = -1

    @nn.compact
    def __call__(self, x):
        features = x.shape[-1]
        x = jnp.asarray(x, jnp.float32)
        mean = jnp.mean(x, axis=-1, keepdims=True)
        mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
        var = jnp.maximum(0., mean2 - jax.lax.square(mean))

        y = x - mean
        mul = jax.lax.rsqrt(var + self.epsilon)
        if self.use_scale:
            scale = layers.param_with_axes('scale', self.scale_init,
                                    (features,), jnp.float32, axes=('embed',))
            mul *= scale
        y *= mul

        if self.use_bias:
            bias = layers.param_with_axes('bias', self.bias_init, (features,),
                                   jnp.float32, axes=('embed',))
            y += bias
        
        y = jnp.asarray(y, self.dtype)
        return y


def gelu2(x):
    return nn.sigmoid(1.702 * x) * x
