import functools
from typing import Any, Callable, Sequence, Tuple, Dict, Optional

from flax import linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from ..utils import flatten

ModuleDef = Any
model = None


def _sample(state, video, actions, rng):
    variables = {'params': state.params}
    rng, rng2 = jax.random.split(rng)
    pred = model.apply(variables, video, actions, rngs={'sample': rng, 'rng': rng2},
                        method=model.sample)
    return pred

def _decode(x):
    return model.vq_fns['decode'](x[:, None])[:, 0]

def sample(model_sample, state, video, actions, seed=0, return_cond_frames=True, log_output=False, return_real=True):
    global model
    model = model_sample

    rngs = jax.random.PRNGKey(seed)
    rngs = jax.random.split(rngs, jax.local_device_count())
    
    if not model.config.use_actions:
        if actions is None:
            actions = jnp.zeros(video.shape[:3], dtype=jnp.int32)
        else:
            actions = jnp.zeros_like(actions)

    if model.config.mode == 'pixel':
        embeddings = video
    else:
        embeddings, _ = jax.pmap(model.vq_fns['encode'], axis_name='batch')(video)
    samples = jax.pmap(_sample)(state, embeddings, actions, rngs)

    if model.config.mode == 'pixel':
        decode = lambda x: x
    else:
        def decode(samples):
            # samples: NBTHW
            N, B, T = samples.shape[:3]
            samples = jax.device_get(samples)
            samples = np.reshape(samples, (-1, *samples.shape[3:]))

            recons = []
            for i in list(range(0, N * B * T, 64)):
                inp = samples[i:i + 64]
                inp = np.reshape(inp, (N, -1, *inp.shape[1:]))
                recon = jax.pmap(_decode)(inp)
                recon = jax.device_get(recon)
                recon = np.reshape(recon, (-1, *recon.shape[2:]))
                recons.append(recon)
            recons = np.concatenate(recons, axis=0)
            recons = np.reshape(recons, (N, B, T, *recons.shape[1:]))
            recons = np.clip(recons, -1, 1)
            return recons # BTHWC
    samples = decode(samples)

    if video.shape[3] == 16:
        video = decode(video)

    if return_cond_frames:
        samples = np.concatenate((video[:, :, :model.config.open_loop_ctx], samples), axis=2)
    else:
        video = video[:, :, model.config.open_loop_ctx:]
    print('done', samples.shape[0])
    if return_real:
        return samples, video
    else:
        return samples
    

class FitVid(nn.Module):
    """FitVid video predictor."""
    config: Any
    vq_fns: Dict[str, Callable]
    vqvae: Any
    dtype: Optional[Any] = jnp.float32

    def setup(self):
        self.action_embeds = nn.Dense(self.config.action_embed_dim, 
                                      use_bias=False, dtype=self.dtype)
        self.encoder = ModularEncoder(
                encoder_block=functools.partial(EncoderBlock, downsample=False),
                down_block=functools.partial(EncoderBlock, downsample=True),
                stage_sizes=[1, 1, 1, 1],
                num_classes=self.config.g_dim,
                num_filters=self.config.filters)
        self.decoder = ModularDecoder(
                decoder_block=functools.partial(DecoderBlock, upsample=False),
                up_block=functools.partial(DecoderBlock, upsample=True),
                stage_sizes=[1, 1, 1, 1],
                first_block_shape=(2, 2, 1024),
                skip_type='residual',
                num_filters=self.config.filters,
                out_filters=3 if self.config.mode == 'pixel' else self.vqvae.n_codes)
        self.frame_predictor = MultiGaussianLSTM(
                hidden_size=self.config.rnn_size, output_size=self.config.g_dim, num_layers=2)
        self.posterior = MultiGaussianLSTM(
                hidden_size=self.config.rnn_size, output_size=self.config.z_dim, num_layers=1)
        self.prior = MultiGaussianLSTM(
                hidden_size=self.config.rnn_size, output_size=self.config.z_dim, num_layers=1)
                
    @property
    def metrics(self):
        return ['loss', 'recon_loss', 'kl_loss']

    def get_input(self, hidden, action, z):
        inp = [hidden]
        if self.config.use_actions:
            inp += [action]
        inp += [z]
        return jnp.concatenate(inp, axis=1)

    def process_actions(self, actions):
        actions = jax.nn.one_hot(actions, num_classes=self.config.action_dim)
        actions = self.action_embeds(actions)
        return actions

    def sample(self, video, actions):
        actions = self.process_actions(actions)

        batch_size, video_len = video.shape[0], video.shape[1]
        pred_s = self.frame_predictor.init_states(batch_size)
        post_s = self.posterior.init_states(batch_size)
        prior_s = self.prior.init_states(batch_size)

        hidden, skips = jax.vmap(self.encoder, 1, 1)(video)
        # Keep the last available skip only
        skips = {k: skips[k][:, self.config.open_loop_ctx-1] for k in skips.keys()}

        preds, x_pred = [], None
        for i in range(1, video_len):
            h, h_target = hidden[:, i-1], hidden[:, i]
            if i > self.config.open_loop_ctx:
                h = self.encoder(x_pred)[0]

            post_s, _ = self.posterior(h_target, post_s)
            prior_s, (z_t, _, _) = self.prior(h, prior_s)

            inp = self.get_input(h, actions[:, i], z_t)
            pred_s, (_, h_pred, _) = self.frame_predictor(inp, pred_s)
            h_pred = nn.sigmoid(h_pred)
            x_pred = self.decoder(h_pred, skips)
            if i >= self.config.open_loop_ctx:
                preds.append(x_pred)

            if self.config.mode == 'vq':
                x_pred = jnp.argmax(x_pred, axis=-1)
                x_pred = self.vq_fns['lookup'](x_pred)
            else:
                x_pred = nn.sigmoid(x_pred)

        preds = jnp.stack(preds, axis=1)

        if self.config.mode == 'vq':
            preds = jnp.argmax(preds, axis=-1)
        else:
            preds = 2 * nn.sigmoid(preds) - 1
        
        return preds

    def __call__(self, video, actions, deterministic=True, dropout_actions=None): # unsued
        if not self.config.use_actions:
            if actions is None:
                actions = jnp.zeros(video.shape[:2], dtype=jnp.int32)
            else:
                actions = jnp.zeros_like(actions)

        if dropout_actions is None:
            dropout_actions = jax.random.bernoulli(self.make_rng('sample'), p=0.5,
                                                shape=(video.shape[0],)) # B

        if self.config.dropout_actions:
            actions = jnp.where(dropout_actions[:, None], -1, actions)
        
        actions = self.process_actions(actions)
        if self.config.mode == 'vq':
            embeddings, encodings = self.vq_fns['encode'](video)
            targets = encodings[:, 1:]
        else:
            video = video * 0.5 + 0.5
            embeddings = video
            targets = video[:, 1:]
         
        batch_size, video_len = video.shape[0], video.shape[1]
        pred_s = self.frame_predictor.init_states(batch_size)
        post_s = self.posterior.init_states(batch_size)
        prior_s = self.prior.init_states(batch_size)
        kl = functools.partial(kl_divergence, batch_size=batch_size)

        # encode frames
        hidden, skips = jax.vmap(self.encoder, 1, 1)(embeddings)
        # Keep the last available skip only
        skips = {k: skips[k][:, self.config.open_loop_ctx-1] for k in skips.keys()}

        kld, means, logvars = 0.0, [], []
        h_preds = []
        for i in range(1, video_len):
            h, h_target = hidden[:, i-1], hidden[:, i]
            post_s, (z_t, mu, logvar) = self.posterior(h_target, post_s)
            prior_s, (_, prior_mu, prior_logvar) = self.prior(h, prior_s)

            inp = self.get_input(h, actions[:, i], z_t)
            pred_s, (_, h_pred, _) = self.frame_predictor(inp, pred_s)
            h_pred = nn.sigmoid(h_pred)
            h_preds.append(h_pred)
            means.append(mu)
            logvars.append(logvar)
            kld += kl(mu, logvar, prior_mu, prior_logvar)
        kld /= video_len - 1
        h_preds = jnp.stack(h_preds, axis=1)

        if self.config.decode_fraction is not None:
            n_sample = int(self.config.decode_fraction * video.shape[1])
            n_sample = max(1, n_sample)
            idxs = jax.random.randint(self.make_rng('sample'),
                                      [n_sample],
                                      0, video.shape[1] - 1, dtype=jnp.int32)

            if self.config.mode == 'pixel':
                video 
            h_preds = h_preds[:, idxs]
            targets = targets[:, idxs]
        preds = jax.vmap(self.decoder, (1, None), 1)(h_preds, skips)

        if self.config.mode == 'pixel':
            preds = nn.sigmoid(preds)
            mse = l2_loss(preds, targets)
            preds = 2 * preds - 1
            preds = jnp.concatenate([video[:, :1], preds], axis=1)
        else:
            labels = jax.nn.one_hot(targets, self.vqvae.n_codes)
            labels = labels * 0.99 + 0.01 / self.vqvae.n_codes
            preds = jnp.clip(preds, a_max=50.)
            mse = optax.softmax_cross_entropy(preds, labels)
            mse = flatten(mse, 2).sum(-1).mean()
            preds = jnp.argmax(preds, axis=-1)
            preds = jnp.concatenate([encodings[:, :1], preds], axis=1)

        loss = mse + kld * self.config.beta

        # Metrics
        metrics = dict(loss=loss, recon_loss=mse, kl_loss=kld, recon=preds)
        return metrics


def kl_divergence(mean1, logvar1, mean2, logvar2, batch_size):
    kld = 0.5 * (-1.0 + logvar2 - logvar1 + jnp.exp(logvar1 - logvar2)
                             + jnp.square(mean1 - mean2) * jnp.exp(-logvar2))
    return jnp.sum(kld) / batch_size


def l2_loss(model_logits, ground_truth):
    return jnp.mean(jnp.square(model_logits - ground_truth))


class MultiGaussianLSTM(nn.Module):
    """Multi layer lstm with Gaussian output."""
    num_layers: int = 2
    hidden_size: int = 10
    output_size: int = 10
    dtype: int = jnp.float32

    def setup(self):
        self.embed = nn.Dense(self.hidden_size)
        self.mean = nn.Dense(self.output_size)
        self.logvar = nn.Dense(self.output_size)
        self.layers = [nn.recurrent.LSTMCell() for _ in range(self.num_layers)]

    def init_states(self, batch_size):
        init_fn = functools.partial(nn.initializers.zeros, dtype=self.dtype)
        states = [None] * self.num_layers
        for i in range(self.num_layers):
            states[i] = nn.recurrent.LSTMCell.initialize_carry(
                    self.make_rng('rng'),
                    (batch_size,),
                    self.hidden_size,
                    init_fn=init_fn)
        return states

    def reparameterize(self, mu, logvar):
        var = jnp.exp(0.5 * logvar)
        epsilon = jax.random.normal(self.make_rng('rng'), var.shape)
        return mu + var * epsilon

    def __call__(self, x, states):
        x = self.embed(x)
        for i in range(self.num_layers):
            states[i], x = self.layers[i](states[i], x)
        mean = self.mean(x)
        logvar = self.logvar(x)
        z = self.reparameterize(mean, logvar)
        return states, (z, mean, logvar)


class SEBlock(nn.Module):
    """Applies Squeeze-and-Excitation."""
    act: Callable = nn.relu
    axis: Tuple[int, int] = (1, 2)
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x):
        hidden_size = max(x.shape[-1] // 16, 4)
        y = x.mean(axis=self.axis, keepdims=True)
        y = nn.Dense(features=hidden_size, dtype=self.dtype, name='reduce')(y)
        y = self.act(y)
        y = nn.Dense(features=x.shape[-1], dtype=self.dtype, name='expand')(y)
        return nn.sigmoid(y) * x


class EncoderBlock(nn.Module):
    """NVAE ResNet block."""
    filters: int
    conv: ModuleDef
    norm: ModuleDef
    downsample: bool
    act: Callable = nn.swish

    @nn.compact
    def __call__(self, x):
        strides = (2, 2) if self.downsample else (1, 1)

        residual = x
        y = x
        y = self.norm()(y)
        y = self.act(y)
        y = self.conv(self.filters, (3, 3), strides)(y)
        y = self.norm()(y)
        y = self.act(y)
        y = self.conv(self.filters, (3, 3))(y)
        y = SEBlock()(y)

        if residual.shape != y.shape:
            residual = self.conv(self.filters, (1, 1),
                                                     strides, name='conv_proj')(residual)
            residual = self.norm(name='norm_proj')(residual)

        return self.act(residual + y)


class DecoderBlock(nn.Module):
    """NVAE ResNet block."""
    filters: int
    conv: ModuleDef
    norm: ModuleDef
    upsample: bool
    expand: int = 4
    act: Callable = nn.swish

    def upsample_image(self, img, multiplier):
        shape = (img.shape[0],
                         img.shape[1] * multiplier,
                         img.shape[2] * multiplier,
                         img.shape[3])
        return jax.image.resize(img, shape, jax.image.ResizeMethod.NEAREST)

    @nn.compact
    def __call__(self, x):
        if self.upsample:
            x = self.upsample_image(x, multiplier=2)

        residual = x
        y = x
        y = self.norm()(y)
        y = self.conv(self.filters * self.expand, (1, 1))(y)
        y = self.norm()(y)
        y = self.act(y)
        y = self.conv(self.filters * self.expand, (5, 5))(y)
        y = self.norm()(y)
        y = self.act(y)
        y = self.conv(self.filters, (1, 1))(y)
        y = self.norm(scale_init=nn.initializers.zeros)(y)
        y = SEBlock()(y)

        if residual.shape != y.shape:
            residual = self.conv(self.filters, (1, 1), name='conv_proj')(residual)
            residual = self.norm(name='norm_proj')(residual)

        return self.act(residual + y)


class ModularEncoder(nn.Module):
    """Modular Encoder."""
    stage_sizes: Sequence[int]
    encoder_block: Callable
    down_block: Callable
    num_classes: int
    num_filters: Sequence[int]
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x):
        conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
        norm = functools.partial(nn.GroupNorm, num_groups=32)
        skips = {}
        for i, block_size in enumerate(self.stage_sizes):
            for j in range(block_size):
                filters = self.num_filters[i]
                block = self.down_block if i > 0 and j == 0 else self.encoder_block
                x = block(filters=filters, conv=conv, norm=norm)(x)
                skips[(i, j)] = x

        x = jnp.mean(x, axis=(1, 2))
        x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
        x = jnp.asarray(x, self.dtype)
        return x, skips


class ModularDecoder(nn.Module):
    """Modular Decoder."""
    skip_type: None
    stage_sizes: Sequence[int]
    decoder_block: Callable
    up_block: Callable
    first_block_shape: Sequence[int]
    out_filters: int
    num_filters: Sequence[int]
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x, skips):
        conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
        norm = functools.partial(nn.GroupNorm, num_groups=32)
        filters = np.prod(np.array(self.first_block_shape))
        x = nn.Dense(filters, dtype=self.dtype)(x)
        x = jnp.reshape(x, (x.shape[0],) + self.first_block_shape)

        for i, block_size in enumerate(reversed(self.stage_sizes)):
            for j in range(block_size):
                filters = self.num_filters[len(self.stage_sizes) - i - 1]
                block = self.up_block if i > 0 and j == 0 else self.decoder_block
                x = block(filters=filters, conv=conv, norm=norm)(x)

                if self.skip_type == 'residual':
                    x = x + skips[(len(self.stage_sizes) - i - 1, block_size - j - 1)]
                elif self.skip_type == 'concat':
                    x = jnp.concatenate(
                            [x, skips[(len(self.stage_sizes) - i - 1, block_size - j - 1)]],
                            axis=-1)
                elif self.skip_type is not None:
                    raise Exception('Unknown Skip Type.')

        if self.out_filters == 3:
            x = conv(self.out_filters, (3, 3))(x)
        else:
            x = conv(self.out_filters, ( 1, 1))(x)
        x = jnp.asarray(x, self.dtype)
        return x

