from typing import Any, Dict, Callable, Optional
from functools import partial
from flax import linen as nn
from jax import numpy as jnp
import jax
import distrax
import optax
import numpy as np
from ..utils import flatten


leaky_relu = partial(nn.leaky_relu, negative_slope=.2)  # TF default
model = None

def _sample(state, video, actions, rng):
    variables = {'params': state.params}
    pred = model.apply(variables, video, actions, rngs={'sample': rng},
                        method=model.open_loop_unroll)
    return pred


def _decode(x):
    return model.vq_fns['decode'](x[:, None])[:, 0]


def sample(model_sample, state, video, actions, seed=0, return_cond_frames=True, log_output=False, return_real=True):
    global model
    model = model_sample
    model.config.dropout_actions = False

    rngs = jax.random.PRNGKey(seed)
    rngs = jax.random.split(rngs, jax.local_device_count())
    
    if not model.config.use_actions:
        if actions is None:
            actions = jnp.zeros(video.shape[:3], dtype=jnp.int32)
        else:
            actions = jnp.zeros_like(actions)

    if model.config.mode == 'pixel':
        embeddings = video
    else:
        embeddings, _ = jax.pmap(model.vq_fns['encode'], axis_name='batch')(video)
    samples = jax.pmap(_sample)(state, embeddings, actions, rngs)

    if model.config.mode == 'pixel':
        decode = lambda x: x
    else:
        def decode(samples):
            # samples: NBTHW
            N, B, T = samples.shape[:3]
            samples = jax.device_get(samples)
            samples = np.reshape(samples, (-1, *samples.shape[3:]))

            recons = []
            for i in list(range(0, N * B * T, 64)):
                inp = samples[i:i + 64]
                inp = np.reshape(inp, (N, -1, *inp.shape[1:]))
                recon = jax.pmap(_decode)(inp)
                recon = jax.device_get(recon)
                recon = np.reshape(recon, (-1, *recon.shape[2:]))
                recons.append(recon)
            recons = np.concatenate(recons, axis=0)
            recons = np.reshape(recons, (N, B, T, *recons.shape[1:]))
            recons = np.clip(recons, -1, 1)
            return recons # BTHWC
    samples = decode(samples)

    if video.shape[3] == 16:
        video = decode(video)

    if return_cond_frames:
        samples = np.concatenate((video[:, :, :model.config.open_loop_ctx], samples), axis=2)
    else:
        video = video[:, :, model.config.open_loop_ctx:]
    print('done', samples.shape[0])

    if return_real:
        return samples, video
    else:
        return samples


class CWVAE(nn.Module):
    config: Any
    vq_fns: Dict[str, Callable]
    vqvae: Any
    dtype: Optional[Any] = jnp.float32

    @property
    def metrics(self):
        return ['loss', 'recon_loss', 'kl_loss']

    def setup(self):
        self.action_embeds = nn.Dense(self.config.action_embed_dim, use_bias=False, dtype=self.dtype)
        self.encoder = Encoder(self.config)
        self.model = HierRSSM(self.config)
        out_filters = self.vqvae.n_codes if self.config.mode == 'vq' else 3

        if self.config.decoder_type == 'simple':
            self.decoder = Decoder(self.config, out_filters)
        elif self.config.decoder_type == 'resnet':
            self.decoder = ResNetDecoder(self.config, out_filters)

    def decode(self, predictions):
        bottom_layer_output = predictions[0]['output']
        return self.decoder(bottom_layer_output)

    def process_actions(self, actions):
        actions = jax.nn.one_hot(actions, num_classes=self.config.action_dim)
        actions = self.action_embeds(actions)
        return actions

    def __call__(self, video, actions, deterministic=True, dropout_actions=None): # unused
        obs = video
        if not self.config.use_actions:
            if actions is None:
                actions = jnp.zeros(video.shape[:2], dtype=jnp.int32)
            else:
                actions = jnp.zeros_like(actions)

        if dropout_actions is None:
            dropout_actions = jax.random.bernoulli(self.make_rng('sample'), p=0.5,
                                                shape=(video.shape[0],)) # B

        if self.config.dropout_actions:
            actions = jnp.where(dropout_actions[:, None], -1, actions)
        
        actions = self.process_actions(actions)
        if self.config.mode == 'vq':
            obs, encodings = self.vq_fns['encode'](video)
        
        _, priors, posteriors = self.model(self.encoder(obs), actions)

        dec_inp = posteriors[0]['output']
        if self.config.decode_fraction is not None:
            n_sample = int(self.config.decode_fraction * video.shape[1])
            n_sample = max(1, n_sample)
            idxs = jax.random.randint(self.make_rng('sample'),
                                      [n_sample],
                                      0, video.shape[1], dtype=jnp.int32)
            
            if self.config.mode == 'pixel':
                obs = obs[:, idxs]
            else:
                encodings = encodings[:, idxs]
            dec_inp = dec_inp[:, idxs]

        output = self.decoder(dec_inp)

        if self.config.mode == 'pixel':
            output = nn.tanh(output)
            recon = output
            output = distrax.Independent(distrax.Normal(output, self.config.dec_stddev), 3)
            nll_term = -jnp.mean(output.log_prob(obs), 0)
        else:
            recon = jnp.argmax(output, axis=-1)
            labels = jax.nn.one_hot(encodings, self.vqvae.n_codes)
            labels = labels * 0.99 + 0.01 / self.vqvae.n_codes
            output = jnp.clip(output, a_max=50.)
            nll_term = optax.softmax_cross_entropy(output, labels)
            nll_term = flatten(nll_term, 2).sum(-1).mean()
            
        priors = [distrax.MultivariateNormalDiag(d['mean'], d['stddev'])
                  for d in priors]
        posteriors = [distrax.MultivariateNormalDiag(d['mean'], d['stddev'])
                      for d in posteriors]
        kls = [jnp.mean(posterior.kl_divergence(prior))
               for prior, posterior in zip(priors, posteriors)]
        kl_term = sum(kls)
        kl_term = jnp.clip(kl_term, a_min=5.)
        metrics = dict(loss=nll_term + kl_term,
                       kl_loss=kl_term, recon_loss=nll_term, recon=recon)
        return metrics

    def open_loop_unroll(self, obs, actions):
        actions = self.process_actions(actions)
        pred = self.decode(self.model.open_loop_unroll(self.encoder(obs), actions))
        if self.config.mode == 'pixel':
            pred = nn.tanh(pred)
        else:
            pred = jnp.argmax(pred, axis=-1)
        return pred

class HierRSSM(nn.Module):
    c: Any

    @nn.compact
    def __call__(self, inputs, actions, use_observations=None, initial_state=None):
        """
        Used to unroll a list of recurrent cells.

        Arguments:
            inputs : list of encoded observations
                Number of timesteps at every level in 'inputs' is the number of steps to be unrolled.
            use_observations : None or list[bool]
            initial_state : list of cell states
        """
        if use_observations is None:
            use_observations = self.c.levels * [True]
        if initial_state is None:
            initial_state = self.c.levels * [None]

        cells = [RSSMCell(self.c) for _ in range(self.c.levels)]

        priors = []
        posteriors = []
        last_states = []
        is_top_level = True
        for level, (cell, use_obs, obs_inputs, initial) in reversed(list(
                enumerate(
                    zip(cells, use_observations, inputs, initial_state)))):

            print(f"Input shape in CWVAE level {level}: {obs_inputs.shape}")

            if is_top_level:
                # Feeding in zeros as context to the top level:
                context = jnp.zeros(obs_inputs.shape[:2] + (cell.state_size["output"],))
                is_top_level = False
            else:
                # Tiling context from previous layer in time by tmp_abs_factor:
                context = jnp.expand_dims(context, axis=2)
                context = jnp.tile(context, [1, 1, self.c.tmp_abs_factor]
                                   + (len(context.shape) - 3) * [1])
                s = context.shape
                context = context.reshape((s[0], s[1] * s[2]) + s[3:])
                # Pruning timesteps to match inputs:
                context = context[:, :obs_inputs.shape[1]]

            # Unroll of RNN cell.
            initial = cell.zero_state(obs_inputs.shape[0]
                                      ) if initial is None else initial
            if level == 0:
                inps = (obs_inputs, context, actions)
            else:
                inps = (obs_inputs, context)

            last_state = initial
            prs, pos = [], []
            for t in range(inps[0].shape[1]):
                last_state, (pr, po) = cell(last_state, [inp[:, t] for inp in inps], use_obs=use_obs)
                prs.append(pr)
                pos.append(po)
            prior = {k: jnp.stack([o[k] for o in prs], axis=1) for k in prs[0].keys()}
            posterior = {k: jnp.stack([o[k] for o in pos], axis=1) for k in pos[0].keys()}
            context = posterior["output"]

            last_states.insert(0, last_state)
            priors.insert(0, prior)
            posteriors.insert(0, posterior)
        return last_states, priors, posteriors

    def open_loop_unroll(self, inputs, actions):
        assert self.c.open_loop_ctx % (
                self.c.tmp_abs_factor ** (self.c.levels - 1)) == 0, \
            f"Incompatible open-loop context length {self.c.open_loop_ctx} and " \
            f"temporal abstraction factor {self.c.tmp_abs_factor} for levels {self.c.levels}."
        ctx_lens = [self.c.open_loop_ctx // self.c.tmp_abs_factor ** level
                    for level in range(self.c.levels)]
        pre_inputs, post_inputs = zip(*[
            (input[:, :ctx_len], jnp.zeros_like(input[:, ctx_len:]))
            for input, ctx_len in zip(inputs, ctx_lens)])

        last_states, _, _ = self(
            pre_inputs, actions[:, :self.c.open_loop_ctx], use_observations=None)
        _, predictions, _ = self(
            post_inputs, actions[:, self.c.open_loop_ctx:], 
            use_observations=self.c.levels * [False],
            initial_state=last_states)
        return predictions


class Encoder(nn.Module):
    """
    Multi-level Video Encoder.
    1. Extracts hierarchical features from a sequence of observations.
    2. Encodes observations using Conv layers, uses them directly for the bottom-most level.
    3. Uses dense features for each level of the hierarchy above the bottom-most level.
    """
    c: Any

    @nn.compact
    def __call__(self, obs):
        """
        Arguments:
            obs : Tensor
                Un-flattened observations (videos) of shape (batch size, timesteps, height, width, channels)
        """
        # Merge batch and time dimensions.
        x = obs.reshape((-1,) + obs.shape[2:])

        for kernel, filters in zip(self.c.enc_cnn_kernels, self.c.enc_cnn_filters):
            x = nn.Conv(filters, [kernel, kernel], strides=[2, 2])(x)
            x = leaky_relu(x)
        x = x.reshape(obs.shape[:2] + (-1,))
        layers = [x]
        print(f"Input shape at level 0: {x.shape}")

        feat_size = x.shape[-1]

        for level in range(1, self.c.levels):
            for _ in range(self.c.enc_dense_layers - 1):
                x = nn.relu(nn.Dense(self.c.enc_dense_embed_size)(x))
            if self.c.enc_dense_layers > 0:
                x = nn.Dense(feat_size)(x)
            layer = x
            timesteps_to_merge = self.c.tmp_abs_factor ** level
            # Padding the time dimension.
            timesteps_to_pad = -layer.shape[1] % timesteps_to_merge
            layer = jnp.pad(layer, ((0, 0), (0, timesteps_to_pad), (0, 0)))
            # Reshaping and merging in time.
            layer = layer.reshape((layer.shape[0], -1, timesteps_to_merge,
                                   layer.shape[2]))
            layer = jnp.sum(layer, axis=2)
            layers.append(layer)
            print(f"Input shape at level {level}: {layer.shape}")

        return layers


class ResNetDecoder(nn.Module):
    c: Any
    out_filters: int
    
    @nn.compact
    def __call__(self, features):
        stages = int(np.log2(16)) - 2
        assert len(self.c.dec_depths) == stages, f'{len(self.c.dec_depths)} != {stages}'

        x = jnp.reshape(features, (-1, *features.shape[2:]))
        x = nn.Dense(16 * self.c.dec_depths[0])(x)
        x = jnp.reshape(x, (-1, 4, 4, self.c.dec_depths[0]))
        for i in range(stages):
            for j in range(self.c.dec_blocks):
                x = self._block(x, self.c.dec_depths[i])
            x = jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]),
                                 jax.image.ResizeMethod.NEAREST)
        x = self._block(x, self.c.dec_depths[-1])
        x = nn.Dense(self.out_filters)(x) 
        x = jnp.reshape(x, (*features.shape[:2], *x.shape[1:]))
        return x

    def _block(self, x, depth):
        skip = x
        if skip.shape[-1] != depth:
            skip = nn.Conv(depth, [1, 1], use_bias=False)(skip)
        x = nn.elu(nn.GroupNorm()(x))
        x = nn.Conv(depth, [3, 3])(x)
        x = nn.elu(nn.GroupNorm()(x))
        x = nn.Conv(depth, [3, 3])(x)
        return skip + 0.1 * x


class Decoder(nn.Module):
    """ States to Images Decoder."""
    c: Any
    out_filters: int

    @nn.compact
    def __call__(self, bottom_layer_output):
        """
        Arguments:
            bottom_layer_output : Tensor
                State tensor of shape (batch_size, timesteps, feature_dim)

        Returns:
            Output video of shape (batch_size, timesteps, 64, 64, out_channels)
        """
        x = nn.Dense(self.c.dec_init_filters)(bottom_layer_output)
        # Merge batch and time dimensions, expand two (spatial) dims.
        x = jnp.reshape(x, (-1, 1, 1, x.shape[-1]))  # (BxT, 1, 1, 1024)

        ConvT = partial(nn.ConvTranspose, strides=(2, 2))
        for i, (kernel, filters) in enumerate(zip(self.c.dec_cnn_kernels, self.c.dec_cnn_filters)):
            x = ConvT(filters, [kernel, kernel])(x)
            if i < len(self.c.dec_cnn_kernels) - 1:
                x = leaky_relu(x)
        x = nn.Dense(self.out_filters)(x)
        return x.reshape(bottom_layer_output.shape[:2] + x.shape[1:])  # (B, T, 64, 64, C)

        
class RSSMPrior(nn.Module):
    c: Any

    @nn.compact
    def __call__(self, prev_state, context, actions=None):
        inputs = jnp.concatenate([prev_state["sample"], context], -1)
        if actions is not None:
            inputs = jnp.concatenate([inputs, actions], -1)
        hl = nn.relu(nn.Dense(self.c.cell_embed_size)(inputs))
        det_state, det_out = nn.GRUCell()(prev_state["det_state"], hl)
        hl = nn.relu(nn.Dense(self.c.cell_embed_size)(det_out))
        mean = nn.Dense(self.c.cell_stoch_size)(hl)
        stddev = nn.softplus(
            nn.Dense(self.c.cell_stoch_size)(hl + .54)) + self.c.cell_min_stddev
        dist = distrax.MultivariateNormalDiag(mean, stddev)
        sample = dist.sample(seed=self.make_rng('sample'))
        return dict(mean=mean, stddev=stddev, sample=sample,
                    det_out=det_out, det_state=det_state,
                    output=jnp.concatenate([sample, det_out], -1))


class RSSMPosterior(nn.Module):
    c: Any

    @nn.compact
    def __call__(self, prior, obs_inputs):
        inputs = jnp.concatenate([prior["det_out"], obs_inputs], -1)
        hl = nn.relu(nn.Dense(self.c.cell_embed_size)(inputs))
        hl = nn.relu(nn.Dense(self.c.cell_embed_size)(hl))
        mean = nn.Dense(self.c.cell_stoch_size)(hl)
        stddev = nn.softplus(
            nn.Dense(self.c.cell_stoch_size)(hl + .54)) + self.c.cell_min_stddev
        dist = distrax.MultivariateNormalDiag(mean, stddev)
        sample = dist.sample(seed=self.make_rng('sample'))
        return dict(mean=mean, stddev=stddev, sample=sample,
                    det_out=prior["det_out"], det_state=prior["det_state"],
                    output=jnp.concatenate([sample, prior["det_out"]], -1))


class RSSMCell(nn.Module):
    c: Any

    @property
    def state_size(self):
        return dict(
            mean=self.c.cell_stoch_size, stddev=self.c.cell_stoch_size,
            sample=self.c.cell_stoch_size, det_out=self.c.cell_deter_size,
            det_state=self.c.cell_deter_size,
            output=self.c.cell_stoch_size + self.c.cell_deter_size)

    def zero_state(self, batch_size, dtype=jnp.float32):
        return {k: jnp.zeros((batch_size, v), dtype=dtype)
                for k, v in self.state_size.items()}

    @nn.compact
    def __call__(self, state, inputs, use_obs):
        has_action = len(inputs) == 3
        if has_action:
            obs_input, context, actions = inputs
        else:
            obs_input, context = inputs
            actions = None
        prior = RSSMPrior(self.c)(state, context, actions)
        posterior = RSSMPosterior(self.c)(prior,
                                          obs_input) if use_obs else prior
        return posterior, (prior, posterior)
