from typing import Optional, Any, Tuple
from functools import partial
import numpy as np
import flax.linen as nn
import flax.linen.initializers as nn_init
import jax
import jax.numpy as jnp

from .utils import f_psum, g_psum, create_g_all_gather
from . import sharding


class TransformerShard(nn.Module):
    num_shards: int
    embed_dim: int
    num_heads: int
    num_layers: int
    mlp_dim: int
    dropout: float
    attention_dropout: float
    vocab_size: Optional[int] = None
    vocab_dim: Optional[int] = None
    shape: Optional[Tuple[int]] = None
    attention_type: str = 'full'
    pos_embed_type: str = 'absolute'
    out_dim: Optional[int] = None
    fc_in_mode: str = 'out'
    right_shift: bool = False
    dtype: Any = jnp.float32
    
    @nn.compact
    def __call__(self, inputs, mask=None, deterministic=False, cond=dict(), decode_step=None):
        # Supported cond modes:
        # - "cat": should be same length as input, concats before attention
        # - "cat_attn": concatenate to KV self-attention with learned LN and projection
        # - "cross_attn": adds new cross attention block with learned parameters
        g_all_gather = create_g_all_gather(axis=2)

        if cond is None:
            cond = dict()
        
        x = inputs

        if self.vocab_size is not None and self.vocab_dim is not None:
            x = nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.vocab_dim,                
                dtype=self.dtype,
                embedding_init=nn.initializers.normal(stddev=1.0)
            )(x)

        if 'cat' in cond:
            assert not isinstance(x, tuple), f'This mode does not support multi-modal input'
            x = jnp.concatenate((x, cond['cat']), axis=-1) 

        old_shape = x.shape[1:-1]
        x = x.reshape(x.shape[0], -1, x.shape[-1])

        if self.fc_in_mode == 'in': # TODO this option probably not correct
            raise NotImplementedError
            x = nn.Dense(self.embed_dim, dtype=self.dtype, use_bias=False,
                         kernel_init=nn.initializers.variance_scaling(
                             1.0 / self.num_shards, 'fan_in', 'normal'   
                         ))(x) # different init to account for partial activations
            x = g_psum(x)
            x = AddBias(dtype=self.dtype)(x)
        elif self.fc_in_mode == 'out':
            x = f_psum(x)
            x = nn.Dense(self.embed_dim // self.num_shards, dtype=self.dtype)(x)
            x = g_all_gather(x)
            x = x.reshape(*x.shape[:-2], -1)
        else:
            assert self.fc_in_mode is None
        
        if self.right_shift:
            if decode_step is None:
                x = RightShift(self.dtype)(x)
            else:
                x_shift = RightShift(self.dtype)(x)
                x = jax.lax.cond(decode_step > 0, lambda: x, lambda: x_shift)
            
        if self.pos_embed_type == 'absolute':
            position_bias = AbsolutePositionBiases(dtype=self.dtype)(x)
        elif self.pos_embed_type == 'broadcast':
            position_bias = BroadcastPositionBiases(shape=self.shape,
                                                    dtype=self.dtype)(x)
        elif self.pos_embed_type == 'sinusoidal':
            position_bias = SinusoidalPositionBiases(dtype=self.dtype)(x)
        elif self.pos_embed_type == 'none':
            position_bias = None
        else:
            raise Exception(f'Invalid pos_embed_type: {self.pos_embed_type}')

        if position_bias is not None:
            if decode_step is not None and x.shape[1] == 1:
                position_bias = position_bias[decode_step]
            x += position_bias

        x = nn.Dropout(rate=self.dropout)(x, deterministic=deterministic)
        x = x.astype(self.dtype)

        for i in range(self.num_layers):
            x = TransformerLayerShard(
                num_shards=self.num_shards,
                embed_dim=self.embed_dim,
                num_heads=self.num_heads,
                mlp_dim=self.mlp_dim,
                dropout=self.dropout,
                attention_dropout=self.attention_dropout,
                attention_type=self.attention_type,
                shape=self.shape,
                dtype=self.dtype,
                name=f'TransformerLayer_{i}'
            )(x, mask=mask, cond=cond, deterministic=deterministic, decode_step=decode_step)
        
        x = f_psum(x)
        x = nn.Dense(
            self.embed_dim // self.num_shards, dtype=self.dtype
        )(x)
        x = nn.gelu(x)
        x = LayerNormShard(dtype=self.dtype, name='LayerNorm_0')(x)
        x = nn.Dropout(rate=self.dropout)(x, deterministic=deterministic)

        if self.out_dim is not None:
            x = nn.Dense(self.out_dim, dtype=jnp.float32, use_bias=False,
                         kernel_init=nn.initializers.variance_scaling(
                             1.0 / self.num_shards, 'fan_in', 'normal'
                        ))(x) 
            x = g_psum(x)
            x = AddBias(dtype=jnp.float32)(x)
        else:
            x = g_all_gather(x)
            x = x.reshape(*x.shape[:-2], -1)
        
        x = x.reshape(x.shape[0], *old_shape, x.shape[-1])
        return x

    @staticmethod
    def model_spec(vocab_size=None, vocab_dim=None, fc_in_mode='out', right_shift=False,
                   pos_embed_type='absolute', num_layers=1, out_dim=None, **kwargs):
        spec = dict()
        if vocab_size is not None and vocab_dim is not None:
            assert fc_in_mode is None or fc_in_mode == 'out'
            spec['Embed_0'] = sharding.GenericReplicated(reduce_mode='identity')
        
        dense_idx, bias_idx = 0, 0
        if fc_in_mode == 'in':
            spec[f'Dense_{dense_idx}'] = sharding.Dense(use_bias=False, axis=0)
            spec[f'AddBias_{bias_idx}'] = sharding.GenericReplicated(reduce_mode='identity')
            dense_idx += 1
            bias_idx += 1
        elif fc_in_mode == 'out':
            spec[f'Dense_{dense_idx}'] = sharding.Dense(use_bias=True, axis=1)
            dense_idx += 1
        
        if right_shift:
            spec['RightShift_0'] = sharding.GenericReplicated(reduce_mode='identity')
        
        if pos_embed_type == 'absolute':
            spec['AbsolutePositionBiases_0'] = sharding.GenericReplicated(reduce_mode='identity')
        elif pos_embed_type == 'broadcast':
            spec['BroadcastPositionBiases_0'] = sharding.GenericReplicated(reduce_mode='identity')
        
        for i in range(num_layers):
            spec[f'TransformerLayer_{i}'] = TransformerLayerShard.model_spec()
        
        spec[f'Dense_{dense_idx}'] = sharding.Dense(use_bias=True, axis=1)
        dense_idx += 1
        spec['LayerNorm_0'] = LayerNormShard.model_spec(use_bias=True, use_scale=True)

        if out_dim is not None:
            spec[f'Dense_{dense_idx}'] = sharding.Dense(use_bias=False, axis=0)
            spec[f'AddBias_{bias_idx}'] = sharding.GenericReplicated(reduce_mode='identity')

        return sharding.GenericDict(spec)
        

class TransformerLayerShard(nn.Module):
    num_shards: int
    embed_dim: int
    num_heads: int
    mlp_dim: int
    dropout: float
    attention_dropout: float
    attention_type: str = 'full'
    shape: Optional[Tuple] = None
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, inputs, mask=None, cond=dict(), deterministic=False, decode_step=None):
        x = f_psum(inputs)
        x = LayerNorm(dtype=self.dtype)(x)

        if self.attention_type == 'full':
            x = MultiHeadAttentionShard(
                num_heads=self.num_heads,
                head_dim=self.embed_dim // self.num_heads,
                num_shards=self.num_shards,
                dropout_rate=self.attention_dropout,
                dtype=self.dtype,
                name='MultiHeadAttention_0'
            )(x, x, mask=mask,
              deterministic=deterministic, decode_step=decode_step)
        else:
            raise Exception(f'Invalid attention_type: {self.attention_type}')
        x = nn.Dropout(rate=self.dropout)(x, deterministic=deterministic)
        x = x + inputs

        y = f_psum(x)
        y = LayerNorm(dtype=self.dtype)(y)
        y = MlpBlockShard(
            num_shards=self.num_shards,
            intermediate_dim=self.mlp_dim,
            intermediate_dropout_rate=self.dropout,
            dtype=self.dtype,
            name='MlpBlock_0'
        )(y, deterministic=deterministic)
        y = nn.Dropout(rate=self.dropout)(y, deterministic=deterministic)
        y = y + x
        
        return y

    @staticmethod
    def model_spec():
        return sharding.GenericDict({
            'LayerNorm_0': sharding.GenericReplicated(reduce_mode='sum'),
            'MultiHeadAttention_0': MultiHeadAttentionShard.model_spec(),
            'LayerNorm_1': sharding.GenericReplicated(reduce_mode='sum'),
            'MlpBlock_0': MlpBlockShard.model_spec()
        })


class MlpBlockShard(nn.Module):
    num_shards: int
    intermediate_dim: int
    kernel_init: Any = nn.initializers.variance_scaling(
      1.0, 'fan_in', 'truncated_normal')
    intermediate_dropout_rate: float = 0.1
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, inputs, deterministic=False):
        assert self.intermediate_dim % self.num_shards == 0
        x = nn.Dense(
            self.intermediate_dim // self.num_shards,
            dtype=self.dtype,
            kernel_init=self.kernel_init,
            name='wi'
        )(inputs)
        x = gelu2(x)
        x = nn.Dropout(rate=self.intermediate_dropout_rate,
                       broadcast_dims=(-2,))(x, deterministic=deterministic)
        x = nn.Dense(
            inputs.shape[-1],
            dtype=self.dtype,
            kernel_init=nn.initializers.variance_scaling(
                1.0 / self.num_shards, 'fan_in', 'truncated_normal'
            ),  # since input to this dense is D / num_shards
            use_bias=False,
            name='wo'
        )(x)

        x = g_psum(x)
        x = AddBias(name='wo_bias')(x)

        return x

    @staticmethod
    def model_spec():
        return sharding.GenericDict({
            'wi': sharding.Dense(use_bias=True, axis=1),
            'wo': sharding.Dense(use_bias=False, axis=0),
            'wo_bias': sharding.GenericReplicated(reduce_mode='identity')
        })

 
class MultiHeadAttentionShard(nn.Module):
    num_heads: int
    head_dim: int
    num_shards: int
    dtype: Any = jnp.float32
    dropout_rate: float = 0.
    max_heads_processed: Optional[int] = None
    kernel_init: Any = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')

    def _rotate_half(self, x):
        x1, x2 = jnp.split(x, 2, axis=-1)
        return jnp.concatenate([-x2, x1], axis=-1)

    def _apply_rotary_pos_emb(self, t, rotary_embeds, rotary_idxs):
        if rotary_idxs is not None:
            rotary_embeds = rotary_embeds[rotary_idxs]
        rotate_dim = rotary_embeds.shape[-1]
        t, t_pass =  t[..., :rotate_dim], t[..., rotate_dim:]
        t = (t * jnp.cos(rotary_embeds)[:, None]) + (self._rotate_half(t) * jnp.sin(rotary_embeds)[:, None])
        return jnp.concatenate([t, t_pass], axis=-1)
    
    @nn.compact
    def __call__(self, inputs_q, inputs_kv, mask=None, deterministic=False,
                 decode_step=None, rotary_embeds=None, q_rotary_idxs=None, kv_rotary_idxs=None):
        assert self.num_heads % self.num_shards == 0
        num_heads_per_shard = self.num_heads // self.num_shards
        max_heads_processed = self.max_heads_processed or num_heads_per_shard

        projection = partial(
            nn.DenseGeneral,
            axis=-1,
            features=(num_heads_per_shard, self.head_dim),
            dtype=self.dtype
        )

        depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
        query = projection(kernel_init=self.kernel_init, name='query')(inputs_q) #/ depth_scaling
        key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
        value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv)

        if rotary_embeds is not None:
            query = self._apply_rotary_pos_emb(query, rotary_embeds, q_rotary_idxs)
            key = self._apply_rotary_pos_emb(key, rotary_embeds, kv_rotary_idxs)

        if decode_step is not None:
            cached_key = self.variable('cache', 'cached_key', lambda: key)
            cached_value = self.variable('cache', 'cached_value', lambda: value)

            is_slice = inputs_q.shape[1] == 1
            if is_slice:
                key = cached_key.value.at[:, decode_step].set(key[:, 0])
            else:
                key = cached_key.value.at[:].set(key) 

            if is_slice:
                value = cached_value.value.at[:, decode_step].set(value[:, 0])
            else:
                value = cached_value.value.at[:].set(value)
            
            if mask is not None and is_slice:
                mask = mask[decode_step, None]                

            cached_key.value = key
            cached_value.value = value
        

        if mask is not None:
            attention_bias = jax.lax.select(
                mask > 0,
                jnp.full(mask.shape, 0.).astype(self.dtype),
                jnp.full(mask.shape, -1e10).astype(self.dtype)
            )
        else:
            attention_bias = None
        
        dropout_rng = None
        if not deterministic and self.dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')

        n_chunks = num_heads_per_shard // max_heads_processed
        x = []
        for q, k, v in zip(jnp.split(query,  n_chunks, axis=-2),
                           jnp.split(key, n_chunks, axis=-2),
                           jnp.split(value, n_chunks, axis=-2)):
            # TODO fix attention scaled twice
            x_i = nn.attention.dot_product_attention(
                q, k, v, bias=attention_bias,
                dropout_rng=dropout_rng, dropout_rate=self.dropout_rate,
                deterministic=deterministic, dtype=self.dtype
            )
            x.append(x_i)
        x = jnp.concatenate(x, axis=-2)

        out = nn.DenseGeneral(
            features=inputs_q.shape[-1],
            axis=(-2, -1),
            kernel_init=nn.initializers.variance_scaling(
                1.0 / self.num_shards, 'fan_in', 'normal' 
            ), # to account for only partial activations in fan_in init
            dtype=self.dtype,
            use_bias=False,
            name='out'
        )(x)
        out = g_psum(out)
        out = AddBias(name='out_bias')(out)

        return out

    @staticmethod
    def model_spec():
        return sharding.GenericDict({
            'query': sharding.DenseGeneral(use_bias=True, shard_mode='out'),
            'key': sharding.DenseGeneral(use_bias=True, shard_mode='out'),
            'value': sharding.DenseGeneral(use_bias=True, shard_mode='out'),
            'out': sharding.DenseGeneral(use_bias=False, shard_mode='in'),
            'out_bias': sharding.GenericReplicated(reduce_mode='identity')
        })

        
class LayerNormShard(nn.Module):
    epsilon: float = 1e-6
    dtype: Optional[Any] = None
    use_bias: bool = True
    use_scale: bool = True
    bias_init: Any = nn_init.zeros
    scale_init: Any = nn_init.ones
    reduction_axes: Any = -1
    feature_axes: Any = -1
    
    @nn.compact     
    def __call__(self, x):
        features = x.shape[-1]
        x = jnp.asarray(x, jnp.float32)
        mean = jnp.mean(x, axis=-1, keepdims=True)
        mean = jax.lax.pmean(mean, 'model')
        mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
        mean2 = jax.lax.pmean(mean2, 'model')
        var = jnp.maximum(0., mean2 - jax.lax.square(mean))

        y = x - mean
        mul = jax.lax.rsqrt(var + self.epsilon)
        if self.use_scale:
            scale = self.param('scale', self.scale_init, (features,), jnp.float32)
            mul *= scale
        y *= mul

        if self.use_bias:
            bias = self.param('bias', self.bias_init, (features,), jnp.float32)
            y += bias
        
        y = jnp.asarray(y, self.dtype)
        return y

    @staticmethod
    def model_spec(use_bias, use_scale):
        return sharding.LayerNormShard(use_bias=use_bias, use_scale=use_scale)

         
class LayerNorm(nn.Module):
    epsilon: float = 1e-6
    dtype: Optional[Any] = None
    use_bias: bool = True
    use_scale: bool = True
    bias_init: Any = nn_init.zeros
    scale_init: Any = nn_init.ones
    reduction_axes: Any = -1
    feature_axes: Any = -1

    @nn.compact
    def __call__(self, x):
        features = x.shape[-1]
        x = jnp.asarray(x, jnp.float32)
        mean = jnp.mean(x, axis=-1, keepdims=True)
        mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
        var = jnp.maximum(0., mean2 - jax.lax.square(mean))

        y = x - mean
        mul = jax.lax.rsqrt(var + self.epsilon)
        if self.use_scale:
            scale = self.param('scale', self.scale_init, (features,), jnp.float32)
            mul *= scale
        y *= mul

        if self.use_bias:
            bias = self.param('bias', self.bias_init, (features,), jnp.float32)
            y += bias
        
        y = jnp.asarray(y, self.dtype)
        return y


class AddBias(nn.Module):
    dtype: Any = jnp.float32
    param_dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x):
        bias = self.param('bias', nn.initializers.zeros, (x.shape[-1],), self.param_dtype)
        x += bias
        return x


class RightShift(nn.Module):
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x):
        sos = self.param('sos', nn.initializers.normal(stddev=0.02),
                         (x.shape[-1],), self.dtype)
        sos = jnp.tile(sos[None, None], (x.shape[0], 1, 1))
        sos = jnp.asarray(sos, self.dtype)
        x = jnp.concatenate([sos, x[:, :-1]], axis=1)
        return x


class RotaryPositionBiases(nn.Module):
    seq_len: int
    dim: int
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self):
        embed_dim = self.dim
        length = self.seq_len

        inv_freq = 1.0 / (10000 ** (jnp.arange(0, embed_dim, 2).astype(jnp.float32) / embed_dim)) 
        pos_seq = jnp.arange(length, dtype=jnp.float32)
        freqs = jnp.einsum('i,j->ij', pos_seq, inv_freq)
        return jnp.concatenate([freqs, freqs], axis=-1)

        
class SinusoidalPositionBiases(nn.Module):
    shape: Optional[Tuple[int]] = None
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x):
        embed_dim = x.shape[-1]
        length = np.prod(self.shape or x.shape[1:-1])
        pos_seq = jnp.arange(length, dtype=self.dtype)

        inv_freq = jnp.arange(0.0, embed_dim, 2.0) / embed_dim
        inv_freq = 1. / (10000 ** inv_freq)
        inv_freq = jnp.asarray(inv_freq, self.dtype)

        sinusoid_inp = jnp.outer(pos_seq, inv_freq)
        position_bias = jnp.concatenate([jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp)], axis=-1)
        return position_bias

        
class AbsolutePositionBiases(nn.Module):
    dtype: Any = jnp.float32
    embedding_init: Any = nn.linear.default_embed_init

    @nn.compact
    def __call__(self, x):
        position_bias = self.param('abs_embedding', self.embedding_init,
                                   x.shape[1:], jnp.float32)
        return position_bias
    

class BroadcastPositionBiases(nn.Module):
    shape: Optional[Tuple[int]] = None
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x):    
        shape = self.shape or x.shape[1:-1]
        n_dim = len(self.shape)
        embed_dim = x.shape[-1]

        chunk_sizes = [embed_dim // n_dim + (i < (embed_dim % n_dim))
                       for i in range(n_dim)]
        assert sum(chunk_sizes) == embed_dim, f'sum({chunk_sizes}) = {sum(chunk_sizes)} != {embed_dim}'

        embs = [
            self.param(f'd_{i}', nn.initializers.normal(stddev=0.02),
                            (shape[i], chunk_sizes[i]), jnp.float32)
            for i in range(n_dim)
        ]

        out = []
        for i in range(n_dim):
            e = embs[i]
            e = jnp.reshape(e, (1,) + (1,) * i + (shape[i],) + (1,) * (n_dim - i - 1) + (-1,))
            e = jnp.broadcast_to(e, (1, *shape, e.shape[-1]))
            out.append(e)
        out = jnp.concatenate(out, axis=-1)
        out = jnp.asarray(out, self.dtype)

        out = jnp.reshape(out, (np.prod(shape), embed_dim))

        return out    


def gelu2(x):
    return nn.sigmoid(1.702 * x) * x
