from typing import Any, Optional, Dict, Callable
from functools import cached_property, partial
from tqdm import tqdm
import numpy as np
import flax.linen as nn
from flax import jax_utils
import jax
import jax.numpy as jnp
from jax.experimental.maps import xmap
import optax

from . import transformer
from . import sharding
from .utils import f_psum
from ...utils import topk_sample


def _model_step(state, cache, encodings, actions, latent_idx, decode_mode):
    variables = {'params': state.params, **cache}
    logits, cache = model.apply(
        variables, encodings, actions, latent_idx,
        deterministic=True, decode_mode=decode_mode,
        method=model._step,
        mutable=['cache']
    )
    return logits, cache

_model_step_full = partial(_model_step, decode_mode='full')
_model_step_slice = partial(_model_step, decode_mode='slice')

    
def _sample_step(logits, rng):
    new_rng, rng = jax.random.split(rng)
    samples = topk_sample(rng, logits, top_k=None, top_p=None)
    return samples, new_rng


def _decode(x):
    return model.vq_fns['decode'](x[:, None])[:, 0]


def sample(sample_model, state, video, actions, seed=0, return_cond_frames=True,
           extra_frames=0, log_output=False, return_real=True, state_spec=None):
    global model
    model = sample_model
    config = model.config

    use_xmap = state_spec is not None

    if use_xmap:
        num_local_data = max(1, jax.local_device_count() // model.config.num_shards)
    else:
        num_local_data = jax.local_device_count()
    rngs = jax.random.PRNGKey(seed)
    rngs = jax.random.split(rngs, num_local_data)

    assert video.shape[0] == num_local_data
    assert config.open_loop_ctx >= config.latent_size - 1
    
    if not config.use_actions:
        if actions is None:
            actions = jnp.zeros(video.shape[:3], dtype=jnp.int32)
        else:
            actions = jnp.zeros_like(actions)
 
    if video.shape[0] < jax.local_device_count():
        devices = jax.local_devices()[:video.shape[0]]
    else:
        devices = None
    n_tokens_per_frame = np.prod(model.vqvae.latent_shape)
    _, encodings = jax.pmap(model.vq_fns['encode'], devices=devices)(video)
    encodings = encodings.reshape(*encodings.shape[:2], -1)
    encodings = encodings.at[:, :, config.open_loop_ctx * n_tokens_per_frame:].set(0)

    # Initialize encodings
    if use_xmap:
        variable_shapes = jax.eval_shape(
            xmap(partial(model.init, method=model._step, decode_mode='full'),
                 in_axes=(('model', ...,), (...,), (...,), (...,)),
                 out_axes=('model', ...),
                 axis_resources={'model': 'mp'}),
            {k: jax.random.split(jax.random.PRNGKey(0), config.num_shards)
             for k in ['params', *config.rng_keys]},
            encodings[0, :, :config.seq_len * n_tokens_per_frame], 
            actions[0, :, :config.seq_len], 
            0
        )
        cache = {'cache': variable_shapes['cache']}
        cache = jax.tree_util.tree_map(lambda x: np.zeros((num_local_data, *x.shape), dtype=x.dtype), cache)

        p_model_step_full = xmap(
            _model_step_full,
            in_axes=(state_spec, ('data', 'model', ...), ('data', ...), ('data', ...), ('data', ...)),
            out_axes=(('data', ...), ('data', 'model', ...)),
            axis_resources={'data': 'dp', 'model': 'mp'}
        )
        p_model_step_slice = xmap(
            _model_step_slice,
            in_axes=(state_spec, ('data', 'model', ...), ('data', ...), ('data', ...), ('data', ...)),
            out_axes=(('data', ...), ('data', 'model', ...)),
            axis_resources={'data': 'dp', 'model': 'mp'}
        )
        p_sample_step = xmap(
            _sample_step,
            in_axes=(('data', ...), ('data', ...)),
            out_axes=('data', ...),
            axis_resources={'data': 'dp', 'model': 'mp'}
        )
    else:
        variable_shapes = jax.eval_shape(
            partial(model.init, method=model._step, decode_mode='full'),
            rngs={k: jax.random.PRNGKey(0)
                  for k in ['params', *config.rng_keys]},
            x=encodings[0, :, :config.seq_len * n_tokens_per_frame], actions=actions[0, :, :config.seq_len], 
            latent_idx=0
        )
        cache = {'cache': variable_shapes['cache']}
        cache = jax.tree_util.tree_map(lambda x: np.zeros(x.shape, dtype=x.dtype), cache)
        cache = jax_utils.replicate(cache)

        p_model_step_full = jax.pmap(_model_step_full)
        p_model_step_slice = jax.pmap(_model_step_slice)
        p_sample_step = jax.pmap(_sample_step)

    # Sampling
    itr = list(range(config.open_loop_ctx * n_tokens_per_frame, config.eval_seq_len * n_tokens_per_frame))
    if log_output:
        itr = tqdm(itr)
    
    for i in itr:
        frame_id = i // n_tokens_per_frame
        if frame_id < config.seq_len:
            enc, act = encodings[:, :, :config.seq_len * n_tokens_per_frame], actions[:, :, :config.seq_len]
        else:
            enc = encodings[:, :, (frame_id - config.seq_len + 1) * n_tokens_per_frame:(frame_id + 1) * n_tokens_per_frame]
            act = actions[:, :, frame_id - config.seq_len + 1:frame_id + 1]

        if i % n_tokens_per_frame == 0:
            if frame_id < config.seq_len:
                latent_idx = (frame_id - (config.latent_size - 1)) * n_tokens_per_frame
            else:
                latent_idx = (config.seq_len - config.latent_size) * n_tokens_per_frame

            _, cache = p_model_step_full(
                state, cache, enc, act,
                np.full((num_local_data,), latent_idx, dtype=np.int32)
            )

        if frame_id < config.seq_len:
            latent_idx = i
        else:
            latent_idx = (config.seq_len - 1) * n_tokens_per_frame + i % n_tokens_per_frame
        latent_idx = np.full((num_local_data,), latent_idx, dtype=np.int32)
        logits, cache = p_model_step_slice(
            state, cache, enc, act, latent_idx
        )
        s, rngs = p_sample_step(logits, rngs) 

        encodings = encodings.at[:, :, i, None].set(s)
    encodings = encodings.reshape(*encodings.shape[:2], config.eval_seq_len, *model.vqvae.latent_shape)
    
    def decode(samples):
        # samples: NBTHW
        N, B, T = samples.shape[:3]
        samples = jax.device_get(samples)
        samples = np.reshape(samples, (-1, *samples.shape[3:]))

        recons = []
        for i in list(range(0, N * B * T, 64)):
            inp = samples[i:i + 64]
            inp = np.reshape(inp, (N, -1, *inp.shape[1:]))
            recon = jax.pmap(_decode, devices=devices)(inp)
            recon = jax.device_get(recon)
            recon = np.reshape(recon, (-1, *recon.shape[2:]))
            recons.append(recon)
        recons = np.concatenate(recons, axis=0)
        recons = np.reshape(recons, (N, B, T, *recons.shape[1:]))
        recons = np.clip(recons, -1, 1)
        return recons # BTHWC
    samples = decode(encodings)

    if video.shape[3] == 16:
        video = decode(video)
    
    if not return_cond_frames:
        video = video[:, :, model.config.open_loop_ctx:]
        samples = samples[:, :, model.config.open_loop_ctx:]

    if return_real:
        return samples, video
    else:
        return samples


class PerceiverARShard(nn.Module):
    config: Any
    vq_fns: Dict[str, Callable]
    vqvae: Any
    dtype: Optional[Any] = jnp.float32

    @property
    def metrics(self):
        return ['loss']

    def aggregate_metrics(self, metrics):
        metrics = jax.lax.pmean(metrics, 'data')
        return metrics

    def setup(self):
        config = self.config
        self.shape = (config.seq_len, *self.vqvae.latent_shape)
        self.latent_shape = (config.latent_size, *self.vqvae.latent_shape)

        self.action_embeds = nn.Embed(config.action_dim, config.action_embed_dim, dtype=self.dtype)
        self.pos_embeds = transformer.SinusoidalPositionBiases(self.shape, self.dtype)

        # last embedding is sos
        self.token_embeds = nn.Embed(self.vqvae.n_codes + 1, self.vqvae.embedding_dim)

        head_dim = config.transformer['embed_dim'] // config.transformer['num_heads']
        self.rotary_embeds = transformer.RotaryPositionBiases(seq_len=np.prod(self.shape),
                                                         dim=max(32, head_dim // 2))
        self.fc_in = nn.Dense(config.transformer['embed_dim'], dtype=self.dtype)

        # Cross attention
        self.ca_z_norm = transformer.LayerNorm(dtype=self.dtype)
        self.ca_x_norm = transformer.LayerNorm(dtype=self.dtype)
        self.cross_attn = transformer.MultiHeadAttentionShard(
            num_shards=config.num_shards,
            **config.cross_attn,
            max_heads_processed=2,
            dtype=self.dtype
        )

        self.ca_ffw_norm = transformer.LayerNorm(dtype=self.dtype)
        self.ca_ffw = transformer.MlpBlockShard(
            num_shards=config.num_shards,
            intermediate_dim=config.transformer['mlp_dim'],
            intermediate_dropout_rate=config.transformer['dropout'],
            dtype=self.dtype
        )

        # Self attention
        self.net = transformer.TransformerShard(
            num_shards=config.num_shards,
            **config.transformer,
            shape=self.latent_shape,
            pos_embed_type='none',
            out_dim=self.vqvae.n_codes,
            fc_in_mode=None, dtype=self.dtype
        )

    def _step(self, x, actions, latent_idx, deterministic=False, decode_mode=None):
        x_len, z_len = np.prod(self.shape), np.prod(self.latent_shape)

        actions = self.action_embeds(actions)
        actions = actions[:, :, None, None, :]
        actions = jnp.tile(actions, (1, 1, *self.vqvae.latent_shape, 1))
        actions = actions.reshape(actions.shape[0], -1, actions.shape[-1])

        rotary_embeds = self.rotary_embeds()

        # Right shift
        x = jnp.pad(x, ((0, 0), (1, 0)), constant_values=-1)[:, :-1]
        x = self.token_embeds(x)
        x += self.pos_embeds(x)

        # Actions
        x = jnp.concatenate([x, actions], axis=-1)
        x = self.fc_in(x)

        if decode_mode == 'slice':
            z = jax.lax.dynamic_slice_in_dim(x, latent_idx, 1, axis=1)
        else:
            z = jax.lax.dynamic_slice_in_dim(x, latent_idx, z_len, axis=1)

        # Cross attention
        mask = jnp.tril(jnp.ones((x_len, x_len), dtype=bool))
        if decode_mode == 'slice':
            mask = jax.lax.dynamic_slice_in_dim(mask, latent_idx, 1, axis=0) 
            q_rotary_idxs = jnp.arange(1, dtype=jnp.int32) + latent_idx
        else:
            mask = jax.lax.dynamic_slice_in_dim(mask, latent_idx, z_len, axis=0) 
            q_rotary_idxs = jnp.arange(z_len, dtype=jnp.int32) + latent_idx

        inp_z, inp_x = f_psum(z), f_psum(x)
        z = self.cross_attn(self.ca_z_norm(inp_z), self.ca_x_norm(inp_x), 
                            mask=mask, rotary_embeds=rotary_embeds,
                            q_rotary_idxs=q_rotary_idxs,
                            kv_rotary_idxs=None, deterministic=deterministic) + z
        inp = f_psum(z)
        z = self.ca_ffw(self.ca_ffw_norm(inp), deterministic=deterministic) + z

        # Self attention in latents
        mask = jnp.tril(jnp.ones((z_len, z_len), dtype=bool))
        if decode_mode == 'slice':
            n_per_frame = np.prod(self.vqvae.latent_shape)
            decode_step = latent_idx % n_per_frame + (self.config.latent_size - 1) * n_per_frame
        elif decode_mode == 'full':
            decode_step = 0 # doesn't matter what value it is as long as not None
        else:
            decode_step = None
        
        logits = self.net(z, mask=mask, deterministic=deterministic, 
                          decode_step=decode_step)
        return logits


    def __call__(self, video, actions, deterministic=False, dropout_actions=None):
        if not self.config.use_actions:
            if actions is None:
                actions = jnp.zeros(video.shape[:2], dtype=jnp.int32)
            else:
                actions = jnp.zeros_like(actions)
                
        if dropout_actions is None:
            dropout_actions = jax.random.bernoulli(self.make_rng('sample'), p=0.5,
                                                shape=(video.shape[0],)) # B

        if self.config.dropout_actions:
            actions = jnp.where(dropout_actions[:, None], -1, actions)

        _, encodings = self.vq_fns['encode'](video)
        encodings = encodings.reshape(encodings.shape[0], -1)
        x_len, z_len = np.prod(self.shape), np.prod(self.latent_shape)
        latent_idx = jax.random.randint(self.make_rng('sample'), shape=(),
                                        minval=0, maxval=x_len - z_len + 1)
        logits = self._step(encodings, actions, latent_idx, deterministic=deterministic)

        encodings = jax.lax.dynamic_slice_in_dim(encodings, latent_idx, z_len, axis=1)
        labels = jax.nn.one_hot(encodings, num_classes=self.vqvae.n_codes)
        loss = optax.softmax_cross_entropy(logits, labels).mean()
        return dict(loss=loss)

    @cached_property 
    def model_spec(self):
        return sharding.GenericDict({
            'action_embeds': sharding.GenericReplicated(reduce_mode='identity'),
            'token_embeds': sharding.GenericReplicated(reduce_mode='identity'),
            'fc_in': sharding.GenericReplicated(reduce_mode='identity'),
            'ca_z_norm': sharding.GenericReplicated(reduce_mode='sum'),
            'ca_x_norm': sharding.GenericReplicated(reduce_mode='sum'),
            'cross_attn': transformer.MultiHeadAttentionShard.model_spec(),
            'ca_ffw_norm': sharding.GenericReplicated(reduce_mode='sum'),
            'ca_ffw': transformer.MlpBlockShard.model_spec(),
            'net': transformer.TransformerShard.model_spec(**self.config.transformer,
                                                pos_embed_type='none',
                                                out_dim=self.vqvae.n_codes,
                                                fc_in_mode=None)
        })
