from typing import Optional, Any, Dict, Callable
from tqdm import tqdm
import optax
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp

from .transformer import Transformer
from .base import ResNetEncoder, ResNetDecoder, Codebook
from .maskgit import MaskGit


Array = Any
Dtype = Any


def _observe(state, encodings, rng):
    rng, new_rng = jax.random.split(rng)
    variables = {'params': state.params, **state.model_state}
    cond, out = model.apply(variables, encodings,
                            method=model.observe,
                            rngs={'sample': rng})
    return cond, out, new_rng

def _imagine(state, z_embeddings, actions, cond, t, rng):
    variables = {'params': state.params, **state.model_state}
    rng, new_rng = jax.random.split(rng)
    z, recon = model.apply(variables, z_embeddings, actions, cond, t,
                            method=model.imagine,
                            rngs={'sample': rng}) 
    return recon, z, new_rng

def _decode(x):
    return model.vq_fns['decode'](x[:, None])[:, 0]


def sample(sample_model, state, video, actions, seed=0,
           return_cond_frames=True, extra_frames=0, log_output=False,
           return_real=True):
    global model
    model = sample_model

    rngs = jax.random.PRNGKey(seed)
    rngs = jax.random.split(rngs, jax.local_device_count())

    assert video.shape[0] == jax.local_device_count()
    assert model.config.n_cond <= model.config.open_loop_ctx

    if not model.config.use_actions:
        if actions is None:
            actions = jnp.zeros(video.shape[:3], dtype=jnp.int32)
        else:
            actions = jnp.zeros_like(actions)
 
    _, encodings = jax.pmap(model.vq_fns['encode'], axis_name='batch')(video)
    cond, zs, rngs  = jax.pmap(_observe, axis_name='batch')(state, encodings, rngs)
    zs = zs.at[:, :, model.config.open_loop_ctx - model.config.n_cond:].set(0)
    zs = zs[:, :, :model.config.seq_len - model.config.n_cond]

    recon = [encodings[:, :, i] for i in range(model.config.open_loop_ctx)]
    dummy_encoding = jnp.zeros_like(recon[0])
    itr = list(range(model.config.open_loop_ctx, model.config.eval_seq_len + extra_frames))
    if log_output:
        itr = tqdm(itr)
    for i in itr:
        if i >= model.config.seq_len:
            encodings = jnp.stack([*recon[-model.config.seq_len + 1:], dummy_encoding], axis=2)
            cond, zs, rngs = jax.pmap(_observe, axis_name='batch')(state, encodings, rngs)
            act = actions[:, :, i - model.config.seq_len + 1:i + 1]
            i = model.config.seq_len - 1
        else:
            act = actions[:, :, :model.config.seq_len]
        
        t = jnp.full((jax.local_device_count(),), i, dtype=jnp.int32)
        r, z, rngs = jax.pmap(_imagine, axis_name='batch')(state, zs,
                                                           act, cond, t, rngs)
        zs = zs.at[:, :, i - model.config.n_cond].set(z)
        recon.append(r)
    encodings = jnp.stack(recon, axis=2)

    def decode(samples):
        # samples: NBTHW
        N, B, T = samples.shape[:3]
        samples = jax.device_get(samples)
        samples = np.reshape(samples, (-1, *samples.shape[3:]))

        recons = []
        for i in list(range(0, N * B * T, 64)):
            inp = samples[i:i + 64]
            inp = np.reshape(inp, (N, -1, *inp.shape[1:]))
            recon = jax.pmap(_decode)(inp)
            recon = jax.device_get(recon)
            recon = np.reshape(recon, (-1, *recon.shape[2:]))
            recons.append(recon)
        recons = np.concatenate(recons, axis=0)
        recons = np.reshape(recons, (N, B, T, *recons.shape[1:]))
        recons = np.clip(recons, -1, 1)
        return recons # BTHWC
    samples = decode(encodings)

    if video.shape[3] == 16:
        video = decode(video)

    if not return_cond_frames:
        video = video[:, :, model.config.open_loop_ctx:]
        samples = samples[:, :, model.config.open_loop_ctx:]
    
    if return_real:
        return samples, video
    else:
        return samples


class TECO(nn.Module):
    config: Any
    vq_fns: Dict[str, Callable]
    vqvae: Any
    dtype: Optional[Any] = jnp.float32

    @property
    def metrics(self):
        metrics = ['loss', 'recon_loss', 'trans_loss', 'codebook_loss', 
                   'commitment_loss', 'perplexity', 'logits_min', 'logits_max', 'logits_mean']
        return metrics

    def setup(self):
        config = self.config

        self.action_embeds = nn.Embed(config.action_dim + 1, config.action_embed_dim, dtype=self.dtype)

        # Posterior
        self.sos_post = self.param('sos_post', nn.initializers.normal(stddev=0.02),
                                   (*self.vqvae.latent_shape, self.vqvae.embedding_dim), jnp.float32)
        self.encoder = nn.Sequential([
            ResNetEncoder(**config.encoder, dtype=self.dtype),
            nn.Dense(config.embedding_dim, dtype=self.dtype)
        ])
        ds = 2 ** (len(config.encoder['depths']) - 1)
        self.z_shape = tuple([d // ds for d in self.vqvae.latent_shape])
        self.codebook = Codebook(**self.config.codebook, embedding_dim=config.embedding_dim, 
                                 dtype=self.dtype)

        # Prior
        z_kernel = [config.z_ds, config.z_ds]
        self.z_tfm_shape = tuple([d // config.z_ds for d in self.z_shape])
        self.z_proj = nn.Conv(config.z_tfm_kwargs['embed_dim'], z_kernel,
                              strides=z_kernel, use_bias=False, padding='VALID', dtype=self.dtype)

        self.sos = self.param('sos', nn.initializers.normal(stddev=0.02),
                              (*self.z_tfm_shape, config.z_tfm_kwargs['embed_dim'],), jnp.float32)
        self.z_tfm = Transformer(
            **config.z_tfm_kwargs, pos_embed_type='sinusoidal',
            shape=(config.seq_len, *self.z_tfm_shape), dtype=self.dtype
        )

        self.z_unproj = nn.ConvTranspose(config.embedding_dim, z_kernel, strides=z_kernel,
                                         padding='VALID', use_bias=False, dtype=self.dtype)
        self.z_git = nn.vmap(
            MaskGit, in_axes=(1, 1, None), out_axes=1,
            variable_axes={'params': None},
            split_rngs={'sample': True, 'dropout': True, 'params': False}
        )(
            shape=self.z_shape, vocab_size=self.codebook.n_codes,
            **config.z_git, dtype=self.dtype
        )

        # Decoder
        out_dim = self.vqvae.n_codes
        self.decoder = ResNetDecoder(**config.decoder, image_size=self.vqvae.latent_shape[0], 
                                     out_dim=out_dim, dtype=self.dtype)

    def _init_mask(self):
        n_per = np.prod(self.z_tfm_shape)
        mask = jnp.tril(jnp.ones((self.config.seq_len, self.config.seq_len), dtype=bool))
        mask = mask.repeat(n_per, axis=0).repeat(n_per, axis=1)
        return mask

    def observe(self, encodings):
        cond, out = self.encode_obs(encodings)
        return cond, out['embeddings']

    def imagine(self, z_embeddings, actions, cond, t):
        t -= self.config.n_cond
        actions = self.process_actions(actions)
        deter = self.predict_dynamics(z_embeddings, actions, cond, deterministic=True)
        deter = deter[:, t]

        sample = self.z_git.sample(z_embeddings.shape[0], self.config.T_draft,
                                   self.config.T_revise, self.config.M,
                                   cond=deter)
 
        z_t = self.codebook(None, encoding_indices=sample)

        recon = jnp.argmax(self.decoder(deter, z_t), axis=-1)
        return z_t, recon

    def encode_obs(self, encodings):
        embeddings = self.vq_fns['lookup'](encodings)
        sos = jnp.tile(self.sos_post[None, None], (embeddings.shape[0], 1, 1, 1, 1))
        sos = jnp.asarray(sos, self.dtype)
        
        embeddings = jnp.concatenate([sos, embeddings], axis=1)
        inp = jnp.concatenate([embeddings[:, :-1], embeddings[:, 1:]], axis=-1)
        out = jax.vmap(self.encoder, 1, 1)(inp)
        
        z = out[:, self.config.n_cond:]
        vq_output = self.codebook(z)
        return out[:, :self.config.n_cond], vq_output 

    def process_actions(self, actions):
        return self.action_embeds(actions)

    def predict_dynamics(self, z_embeddings, actions, cond, deterministic=False):
        inp = jnp.concatenate([cond, z_embeddings], axis=1)
        actions = jnp.tile(actions[:, :, None, None], (1, 1, *inp.shape[2:4], 1)) 
        inp = jnp.concatenate([inp[:, :-1], actions[:, 1:]], axis=-1)
        inp = jax.vmap(self.z_proj, 1, 1)(inp)
        
        sos = jnp.tile(self.sos[None, None], (z_embeddings.shape[0], 1, 1, 1, 1))
        sos = jnp.asarray(sos, self.dtype)

        inp = jnp.concatenate([sos, inp], axis=1)
        deter = self.z_tfm(inp, mask=self._init_mask(), deterministic=deterministic)
        deter = deter[:, self.config.n_cond:]

        deter = jax.vmap(self.z_unproj, 1, 1)(deter)

        return deter
 
    def __call__(self, video, actions, deterministic=False,
                 dropout_actions=None):
        # video: BTCHW, actions: BT
        if not self.config.use_actions:
            if actions is None:
                actions = jnp.zeros(video.shape[:2], dtype=jnp.int32)
            else:
                actions = jnp.zeros_like(actions)
                
        if dropout_actions is None:
            dropout_actions = jax.random.bernoulli(self.make_rng('sample'), p=0.5,
                                                shape=(video.shape[0],)) # B

        if self.config.dropout_actions:
            actions = jnp.where(dropout_actions[:, None], -1, actions)

        actions = self.process_actions(actions)
        _, encodings = self.vq_fns['encode'](video)

        cond, vq_output = self.encode_obs(encodings)
        z_embeddings, z_codes = vq_output['embeddings'], vq_output['encodings']

        deter = self.predict_dynamics(z_embeddings, actions, cond, deterministic=deterministic)

        encodings = encodings[:, self.config.n_cond:]
        labels = jax.nn.one_hot(encodings, num_classes=self.vqvae.n_codes)
        labels = labels * 0.99 + 0.01 / self.vqvae.n_codes

        assert deter.shape[1] == labels.shape[1]
        if self.config.decode_fraction is not None and self.config.decode_fraction < 1.0:
            n_sample = int(self.config.decode_fraction * deter.shape[1])
            n_sample = max(1, n_sample)
            idxs = jax.random.randint(self.make_rng('sample'),
                                      [n_sample],
                                      0, video.shape[1], dtype=jnp.int32)
            deter = deter[:, idxs]
            z_embeddings = z_embeddings[:, idxs]
            z_codes = z_codes[:, idxs]
            labels = labels[:, idxs]

        z_logits, z_labels, z_mask = self.z_git(z_codes, deter, deterministic)
        trans_loss = optax.softmax_cross_entropy(z_logits, z_labels)
        trans_loss = (trans_loss * z_mask).sum() / z_mask.sum()
        trans_loss = trans_loss * np.prod(self.z_shape) * self.config.trans_weight
        
        recon_logits = jax.vmap(self.decoder, 1, 1)(deter, z_embeddings)
        recon_loss = optax.softmax_cross_entropy(recon_logits, labels)
        recon_loss = recon_loss.sum(axis=(-2, -1)).mean()

        loss = recon_loss + trans_loss + vq_output['commitment_loss'] + \
             vq_output['codebook_loss']
 
        out = dict(loss=loss, recon_loss=recon_loss, trans_loss=trans_loss,
                   commitment_loss=vq_output['commitment_loss'],
                   codebook_loss=vq_output['codebook_loss'],
                   perplexity=vq_output['perplexity'],
                   logits_min=recon_logits.min(),
                   logits_max=recon_logits.max(),
                   logits_mean=recon_logits.mean())
        return out
  
