import os
from functools import partial
from typing import Any, Optional, Dict, Callable
from tqdm import tqdm
import numpy as np
from flax import jax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax

from . import transformer
from . import layers


def _model_step(state, cache, encodings, actions, latent_idx, decode_mode):
    variables = {'params': state.params, **cache}
    logits, cache = model.apply(
        variables, encodings, actions, latent_idx,
        deterministic=True, decode_mode=decode_mode,
        method=model._step,
        mutable=['cache']
    )
    return logits, cache

_model_step_full = partial(_model_step, decode_mode='full')
_model_step_slice = partial(_model_step, decode_mode='slice')

    
def _sample_step(logits, rng):
    new_rng, rng = jax.random.split(rng)
    samples = jax.random.categorical(rng, logits, axis=-1)
    return samples, new_rng


def _decode(x):
    return model.vq_fns['decode'](x[:, None])[:, 0]


def sample(sample_model, state, video, actions, seed=0, return_cond_frames=True,
           extra_frames=0, log_output=False, return_real=True):
    global model
    model = sample_model
    config = model.config

    rngs = jax.random.PRNGKey(seed)
    rngs = jax.random.split(rngs, jax.local_device_count())

    assert video.shape[0] == jax.local_device_count()
    assert config.open_loop_ctx >= config.latent_size - 1
    
    if not config.use_actions:
        if actions is None:
            actions = jnp.zeros(video.shape[:3], dtype=jnp.int32)
        else:
            actions = jnp.zeros_like(actions)
 
    n_tokens_per_frame = np.prod(model.vqvae.latent_shape)
    _, encodings = jax.pmap(model.vq_fns['encode'], axis_name='batch')(video)
    encodings = encodings.reshape(*encodings.shape[:2], -1)
    encodings = encodings.at[:, :, config.open_loop_ctx * n_tokens_per_frame:].set(0)

    # Initialize encodings
    variable_shapes = jax.eval_shape(
        partial(model.init, method=model._step, decode_mode='full'),
        rngs={k: jax.random.PRNGKey(0)
              for k in ['params', *config.rng_keys]},
        x=encodings[0, :, :config.seq_len * n_tokens_per_frame], actions=actions[0, :, :config.seq_len], 
        latent_idx=0
    )
    cache = {'cache': variable_shapes['cache']}
    cache = jax.tree_util.tree_map(lambda x: np.zeros(x.shape, dtype=x.dtype), cache)
    cache = jax_utils.replicate(cache)

    # Sampling
    itr = list(range(config.open_loop_ctx * n_tokens_per_frame, config.eval_seq_len * n_tokens_per_frame))
    if log_output:
        itr = tqdm(itr)
    
    for i in itr:
        frame_id = i // n_tokens_per_frame
        if frame_id < config.seq_len:
            enc, act = encodings[:, :, :config.seq_len * n_tokens_per_frame], actions[:, :, :config.seq_len]
        else:
            enc = encodings[:, :, (frame_id - config.seq_len + 1) * n_tokens_per_frame:(frame_id + 1) * n_tokens_per_frame]
            act = actions[:, :, frame_id - config.seq_len + 1:frame_id + 1]

        if i % n_tokens_per_frame == 0:
            if frame_id < config.seq_len:
                latent_idx = (frame_id - (config.latent_size - 1)) * n_tokens_per_frame
            else:
                latent_idx = (config.seq_len - config.latent_size) * n_tokens_per_frame

            _, cache = jax.pmap(_model_step_full)(
                state, cache, enc, act,
                np.full((jax.local_device_count(),), latent_idx, dtype=np.int32)
            )

        if frame_id < config.seq_len:
            latent_idx = i
        else:
            latent_idx = (config.seq_len - 1) * n_tokens_per_frame + i % n_tokens_per_frame
        latent_idx = np.full((jax.local_device_count(),), latent_idx, dtype=np.int32)
        logits, cache = jax.pmap(_model_step_slice)(
            state, cache, enc, act, latent_idx
        )
        s, rngs = jax.pmap(_sample_step)(logits, rngs) 

        encodings = encodings.at[:, :, i, None].set(s)
    encodings = encodings.reshape(*encodings.shape[:2], config.eval_seq_len, *model.vqvae.latent_shape)
    
    def decode(samples):
        # samples: NBTHW
        N, B, T = samples.shape[:3]
        samples = jax.device_get(samples)
        samples = np.reshape(samples, (-1, *samples.shape[3:]))

        recons = []
        for i in list(range(0, N * B * T, 64)):
            inp = samples[i:i + 64]
            inp = np.reshape(inp, (N, -1, *inp.shape[1:]))
            recon = jax.pmap(_decode)(inp)
            recon = jax.device_get(recon)
            recon = np.reshape(recon, (-1, *recon.shape[2:]))
            recons.append(recon)
        recons = np.concatenate(recons, axis=0)
        recons = np.reshape(recons, (N, B, T, *recons.shape[1:]))
        recons = np.clip(recons, -1, 1)
        return recons # BTHWC
    samples = decode(encodings)

    if video.shape[3] == 16:
        video = decode(video)
    
    if not return_cond_frames:
        video = video[:, :, model.config.open_loop_ctx:]
        samples = samples[:, :, model.config.open_loop_ctx:]

    if return_real:
        return samples, video
    else:
        return samples


class PerceiverAR(nn.Module):
    config: Any
    vq_fns: Dict[str, Callable]
    vqvae: Any
    dtype: Optional[Any] = jnp.float32

    @property
    def metrics(self):
        return ['loss']

    def setup(self):
        config = self.config
        self.shape = (config.seq_len, *self.vqvae.latent_shape)
        self.latent_shape = (config.latent_size, *self.vqvae.latent_shape)

        self.action_embeds = nn.Embed(config.action_dim, config.action_embed_dim, dtype=self.dtype)
        self.pos_embeds = layers.SinusoidalPositionBiases(self.shape, self.dtype)

        # last embedding is sos
        self.token_embeds = nn.Embed(self.vqvae.n_codes + 1, self.vqvae.embedding_dim)

        head_dim = config.transformer['embed_dim'] // config.transformer['num_heads']
        self.rotary_embeds = layers.RotaryPositionBiases(seq_len=np.prod(self.shape),
                                                         dim=max(32, head_dim // 2))
        self.fc_in = nn.Dense(config.transformer['embed_dim'], dtype=self.dtype)

        # Cross attention
        self.ca_z_norm = nn.LayerNorm(dtype=self.dtype)
        self.ca_x_norm = nn.LayerNorm(dtype=self.dtype)
        self.cross_attn = transformer.MultiHeadDotProductAttention(
            **config.cross_attn,
            max_heads_processed=2,
            dtype=self.dtype
        )

        self.ca_ffw_norm = nn.LayerNorm(dtype=self.dtype)
        self.ca_ffw = layers.MlpBlock(
            intermediate_dim=config.transformer['mlp_dim'],
            activations=(transformer.gelu2,),
            intermediate_dropout_rate=config.transformer['dropout'],
            dtype=self.dtype
        )

        # Self attention
        self.net = transformer.Transformer(
            **config.transformer,
            shape=self.latent_shape,
            pos_embed_type='none',
            out_dim=self.vqvae.n_codes,
            use_fc_in=False, dtype=self.dtype
        )

    def _step(self, x, actions, latent_idx, deterministic=False, decode_mode=None):
        x_len, z_len = np.prod(self.shape), np.prod(self.latent_shape)

        actions = self.action_embeds(actions)
        actions = actions[:, :, None, None, :]
        actions = jnp.tile(actions, (1, 1, *self.vqvae.latent_shape, 1))
        actions = actions.reshape(actions.shape[0], -1, actions.shape[-1])

        rotary_embeds = self.rotary_embeds()

        # Right shift
        x = jnp.pad(x, ((0, 0), (1, 0)), constant_values=-1)[:, :-1]
        x = self.token_embeds(x)
        x += self.pos_embeds(x)

        # Actions
        x = jnp.concatenate([x, actions], axis=-1)
        x = self.fc_in(x)

        if decode_mode == 'slice':
            z = jax.lax.dynamic_slice_in_dim(x, latent_idx, 1, axis=1)
        else:
            z = jax.lax.dynamic_slice_in_dim(x, latent_idx, z_len, axis=1)

        # Cross attention
        mask = jnp.tril(jnp.ones((x_len, x_len), dtype=bool))
        if decode_mode == 'slice':
            mask = jax.lax.dynamic_slice_in_dim(mask, latent_idx, 1, axis=0) 
            q_rotary_idxs = jnp.arange(1, dtype=jnp.int32) + latent_idx
        else:
            mask = jax.lax.dynamic_slice_in_dim(mask, latent_idx, z_len, axis=0) 
            q_rotary_idxs = jnp.arange(z_len, dtype=jnp.int32) + latent_idx

        z = self.cross_attn(self.ca_z_norm(z), self.ca_x_norm(x), 
                            mask=mask, rotary_embeds=rotary_embeds,
                            q_rotary_idxs=q_rotary_idxs,
                            kv_rotary_idxs=None, deterministic=deterministic) + z
        z = self.ca_ffw(self.ca_ffw_norm(z), deterministic=deterministic) + z

        # Self attention in latents
        mask = jnp.tril(jnp.ones((z_len, z_len), dtype=bool))
        if decode_mode == 'slice':
            n_per_frame = np.prod(self.vqvae.latent_shape)
            decode_step = latent_idx % n_per_frame + (self.config.latent_size - 1) * n_per_frame
        elif decode_mode == 'full':
            decode_step = 0 # doesn't matter what value it is as long as not None
        else:
            decode_step = None
        
        logits = self.net(z, mask=mask, deterministic=deterministic, 
                          decode_step=decode_step)
        return logits


    def __call__(self, video, actions, deterministic=False, dropout_actions=None):
        if not self.config.use_actions:
            if actions is None:
                actions = jnp.zeros(video.shape[:2], dtype=jnp.int32)
            else:
                actions = jnp.zeros_like(actions)
                
        if dropout_actions is None:
            dropout_actions = jax.random.bernoulli(self.make_rng('sample'), p=0.5,
                                                shape=(video.shape[0],)) # B

        if self.config.dropout_actions:
            actions = jnp.where(dropout_actions[:, None], -1, actions)

        _, encodings = self.vq_fns['encode'](video)
        encodings = encodings.reshape(encodings.shape[0], -1)
        x_len, z_len = np.prod(self.shape), np.prod(self.latent_shape)
        latent_idx = jax.random.randint(self.make_rng('sample'), shape=(),
                                        minval=0, maxval=x_len - z_len + 1)
        logits = self._step(encodings, actions, latent_idx, deterministic=deterministic)

        encodings = jax.lax.dynamic_slice_in_dim(encodings, latent_idx, z_len, axis=1)
        labels = jax.nn.one_hot(encodings, num_classes=self.vqvae.n_codes)
        loss = optax.softmax_cross_entropy(logits, labels).mean()
        return dict(loss=loss)
