from typing import Optional, Any, Tuple, Dict, Callable
from tqdm import tqdm
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

from .schedule import get_cosine_schedule, t_to_alpha_sigma_cosine
from .unet import UNet


class LatentFDM(nn.Module):
    config: Any
    ae_fns: Dict[str, Callable]
    ae: Any
    dtype: Optional[Any] = jnp.float32

    @property
    def metrics(self):
        return ['loss']

    def setup(self):
        self.unet = UNet(
            **self.config.unet,
            out_channels=3 if self.config.mode == 'pixel' else self.ae.embedding_dim
        )

    def _x_hat(self, x_t, t, actions, mask, deterministic):
        eps_hat = self._eps(x_t, t, actions, mask, deterministic)
        alpha_t, sigma_t, _ = get_cosine_schedule(unsqueeze4x(t))
        x_hat = (x_t - sigma_t * eps_hat) / alpha_t
        return x_hat

    def _eps(self, x_t, t, actions, mask, deterministic):
        mask = jnp.tile(mask, (1, 1, x_t.shape[2], x_t.shape[3], 1))
        x_t = jnp.concatenate([x_t, mask], axis=-1)
        eps_hat = self.unet(x_t, t, actions, deterministic=deterministic)
        return eps_hat

    def __call__(self, video, actions, deterministic=False, dropout_actions=None):
        # Handle actions
        if not self.config.use_actions:
            if actions is None:
                actions = jnp.zeros(video.shape[:2], dtype=jnp.int32)
            else:
                actions = jnp.zeros_like(actions)

        if dropout_actions is None:
            dropout_actions = jax.random.bernoulli(self.make_rng('sample'), p=0.5,
                                                shape=(video.shape[0],)) # B
        if self.config.dropout_actions:
            actions = jnp.where(dropout_actions[:, None], -1, actions)

        # Conditional mask (0 if cond, 1 if generated)
        n_cond = jax.random.randint(self.make_rng('sample'), shape=(video.shape[0], 1),
                                    minval=0., maxval=self.config.seq_len)
        mask = (jnp.arange(self.config.seq_len)[None] >= n_cond).astype(jnp.float32)
        mask = mask[:, :, None, None, None]

        # Encode
        if self.config.mode == 'vq':
            video = self.ae_fns['encode'](video, self.make_rng('noise')) * self.config.scale_factor
        x_0 = video

        # Diffusion loss
        t = jax.random.uniform(self.make_rng('sample'), shape=(x_0.shape[0],),
                               minval=0, maxval=1, dtype=jnp.float32)
        alpha_t, sigma_t, _ = get_cosine_schedule(unsqueeze4x(t))
        eps = jax.random.normal(self.make_rng('sample'), shape=x_0.shape, dtype=x_0.dtype)
        x_t = (1 - mask) * x_0 + mask * (alpha_t * x_0 + sigma_t * eps)

        eps_hat = self._eps(x_t, t, actions, mask, deterministic=deterministic)

        loss = jnp.square(eps_hat - eps)
        mask = jnp.broadcast_to(mask, loss.shape)
        loss = (loss * mask).sum() / mask.sum()
        return dict(loss=loss)

    
def unsqueeze4x(x):
    return x[..., None, None, None, None]
