from typing import Any, Optional, Dict, Callable
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax

from . import transformer


class PerceiverAR(nn.Module):
    config: Any
    vq_fns: Dict[str, Callable]
    vqvae: Any
    dtype: Optional[Any] = jnp.float32

    @property
    def metrics(self):
        return ['loss']

    def setup(self):
        config = self.config
        self.shape = (config.seq_len, *self.vqvae.latent_shape)
        self.latent_shape = (config.latent_size, *self.vqvae.latent_shape)

        self.action_embeds = nn.Embed(config.action_dim, config.action_embed_dim, dtype=self.dtype)
        self.pos_embeds = transformer.SinusoidalPositionBiases(self.shape, self.dtype)

        # last embedding is sos
        self.token_embeds = nn.Embed(self.vqvae.n_codes + 1, self.vqvae.embedding_dim)

        head_dim = config.transformer['embed_dim'] // config.transformer['num_heads']
        self.rotary_embeds = transformer.RotaryPositionBiases(seq_len=np.prod(self.shape),
                                                         dim=max(32, head_dim // 2))
        self.fc_in = nn.Dense(config.transformer['embed_dim'], dtype=self.dtype)

        # Cross attention
        self.ca_z_norm = nn.LayerNorm(dtype=self.dtype)
        self.ca_x_norm = nn.LayerNorm(dtype=self.dtype)
        self.cross_attn = transformer.MultiHeadAttention(
            **config.cross_attn,
            max_heads_processed=1,
            dtype=self.dtype
        )

        self.ca_ffw_norm = nn.LayerNorm(dtype=self.dtype)
        self.ca_ffw = transformer.MlpBlock(
            intermediate_dim=config.transformer['mlp_dim'],
            intermediate_dropout_rate=config.transformer['dropout'],
            dtype=self.dtype
        )

        # Self attention
        self.net = transformer.Transformer(
            **config.transformer,
            shape=self.latent_shape,
            pos_embed_type='none',
            out_dim=self.vqvae.n_codes,
            fc_in_mode=None, dtype=self.dtype
        )

    def _step(self, x, actions, latent_idx, deterministic=False, decode_mode=None):
        x_len, z_len = np.prod(self.shape), np.prod(self.latent_shape)

        actions = self.action_embeds(actions)
        actions = actions[:, :, None, None, :]
        actions = jnp.tile(actions, (1, 1, *self.vqvae.latent_shape, 1))
        actions = actions.reshape(actions.shape[0], -1, actions.shape[-1])

        rotary_embeds = self.rotary_embeds()

        # Right shift
        x = jnp.pad(x, ((0, 0), (1, 0)), constant_values=-1)[:, :-1]
        x = self.token_embeds(x)
        x += self.pos_embeds(x)

        # Actions
        x = jnp.concatenate([x, actions], axis=-1)
        x = self.fc_in(x)

        if decode_mode == 'slice':
            z = jax.lax.dynamic_slice_in_dim(x, latent_idx, 1, axis=1)
        else:
            z = jax.lax.dynamic_slice_in_dim(x, latent_idx, z_len, axis=1)

        # Cross attention
        mask = jnp.tril(jnp.ones((x_len, x_len), dtype=bool))
        if decode_mode == 'slice':
            mask = jax.lax.dynamic_slice_in_dim(mask, latent_idx, 1, axis=0) 
            q_rotary_idxs = jnp.arange(1, dtype=jnp.int32) + latent_idx
        else:
            mask = jax.lax.dynamic_slice_in_dim(mask, latent_idx, z_len, axis=0) 
            q_rotary_idxs = jnp.arange(z_len, dtype=jnp.int32) + latent_idx

        z = self.cross_attn(self.ca_z_norm(z), self.ca_x_norm(x), 
                            mask=mask, rotary_embeds=rotary_embeds,
                            q_rotary_idxs=q_rotary_idxs,
                            kv_rotary_idxs=None, deterministic=deterministic) + z
        z = self.ca_ffw(self.ca_ffw_norm(z), deterministic=deterministic) + z

        # Self attention in latents
        mask = jnp.tril(jnp.ones((z_len, z_len), dtype=bool))
        if decode_mode == 'slice':
            n_per_frame = np.prod(self.vqvae.latent_shape)
            decode_step = latent_idx % n_per_frame + (self.config.latent_size - 1) * n_per_frame
        elif decode_mode == 'full':
            decode_step = 0 # doesn't matter what value it is as long as not None
        else:
            decode_step = None
        
        logits = self.net(z, mask=mask, deterministic=deterministic, 
                          decode_step=decode_step)
        return logits


    def __call__(self, video, actions, deterministic=False, dropout_actions=None):
        if not self.config.use_actions:
            if actions is None:
                actions = jnp.zeros(video.shape[:2], dtype=jnp.int32)
            else:
                actions = jnp.zeros_like(actions)
                
        if dropout_actions is None:
            dropout_actions = jax.random.bernoulli(self.make_rng('sample'), p=0.5,
                                                shape=(video.shape[0],)) # B

        if self.config.dropout_actions:
            actions = jnp.where(dropout_actions[:, None], -1, actions)

        _, encodings = self.vq_fns['encode'](video)
        encodings = encodings.reshape(encodings.shape[0], -1)
        x_len, z_len = np.prod(self.shape), np.prod(self.latent_shape)
        latent_idx = jax.random.randint(self.make_rng('sample'), shape=(),
                                        minval=0, maxval=x_len - z_len + 1)
        logits = self._step(encodings, actions, latent_idx, deterministic=deterministic)

        encodings = jax.lax.dynamic_slice_in_dim(encodings, latent_idx, z_len, axis=1)
        labels = jax.nn.one_hot(encodings, num_classes=self.vqvae.n_codes)
        loss = optax.softmax_cross_entropy(logits, labels).mean()
        return dict(loss=loss)
