from typing import Optional, Any, Tuple, Dict, Callable
from functools import cached_property
from tqdm import tqdm
import numpy as np
import jax
import jax.numpy as jnp
from jax.experimental.maps import xmap
import flax.linen as nn

from .schedule import get_cosine_schedule, t_to_alpha_sigma_cosine
from .unet import UNetShard
from .. import sharding

  
def _x_hat(state, x_cond, x_t, t, actions):
    assert x_cond.shape[1] + x_t.shape[1] == model.config.seq_len
    x_t = jnp.concatenate([x_cond, x_t], axis=1)
    mask = jnp.arange(model.config.seq_len) >= x_cond.shape[1]
    mask = mask[None, :, None, None, None]
    mask = jnp.tile(mask, (x_t.shape[0], 1, 1, 1, 1))

    variables = {'params': state.params}
    x_hat = model.apply(
        variables,
        x_t=x_t,
        t=t,
        actions=actions,
        mask=mask,
        deterministic=True,
        method=model._x_hat
    )

    if model.config.mode == 'pixel':
        x_hat = jnp.clip(x_hat, -1, 1)

    x_hat = x_hat[:, x_cond.shape[1]:]
    return x_hat


def _ddim_sample_step(x_t, x_hat, t, rng):
    rng, new_rng = jax.random.split(rng)

    alpha_t, sigma_t = t_to_alpha_sigma_cosine(unsqueeze4x(t))
    epshat = (x_t - alpha_t * x_hat) / sigma_t
    alpha_s, sigma_s = t_to_alpha_sigma_cosine(unsqueeze4x(t - 1. / model.config.num_steps))
    x_t = alpha_s * x_hat + sigma_s * epshat
    return x_t, new_rng


def _ddpm_sample_step(x_t, x_hat, t, rng):
    rng, new_rng = jax.random.split(rng)

    alpha_t, sigma_t = t_to_alpha_sigma_cosine(unsqueeze4x(t))
    logsnr_t = 2 * jnp.log(alpha_t / sigma_t)

    s = t - 1. / model.config.num_steps
    alpha_s, sigma_s = t_to_alpha_sigma_cosine(unsqueeze4x(s))
    logsnr_s = 2 * jnp.log(alpha_s / sigma_s)
    exp_logsnr_diff = jnp.exp(logsnr_t - logsnr_s)

    gamma = 0.1 # model.config.sampler_kwargs['gamma']
    sigma_tilde_sq_s_t = (1 - exp_logsnr_diff) * sigma_s ** 2
    sigma_sq_t_s = (1 - exp_logsnr_diff) * sigma_t ** 2
    noise_scale = jnp.sqrt(sigma_tilde_sq_s_t ** (1 - gamma) * sigma_sq_t_s ** gamma)
    eps = jax.random.normal(rng, shape=x_t.shape)

    x_t = exp_logsnr_diff * (alpha_s / alpha_t) * x_t + \
            (1 - exp_logsnr_diff) * alpha_s * x_hat + \
            noise_scale * eps 
    return x_t, new_rng 


def _decode(x):
    return model.ae_fns['decode'](x[:, None])[:, 0]

    
def sample(sample_model, state, video, actions, seed=0, 
           log_output=False, extra_frames=0, state_spec=None):
    global model
    model = sample_model
    config = model.config

    use_xmap = state_spec is not None

    if use_xmap:
        num_local_data = max(1, jax.local_device_count() // model.config.num_shards)
    else:
        num_local_data = jax.local_device_count()

    rngs = jax.random.PRNGKey(seed)
    init_rng, rngs = jax.random.split(rngs)
    rngs = jax.random.split(rngs, num_local_data)

    assert video.shape[0] == num_local_data

    if not config.use_actions:
        if actions is None:
            actions = jnp.zeros(video.shape[:3], dtype=jnp.int32)
        else:
            actions = jnp.zeros_like(actions)

    if video.shape[0] < jax.local_device_count():
        devices = jax.local_devices()[:video.shape[0]]
    else:
        devices = None
    if config.mode == 'vq':
        embeddings = jax.pmap(model.ae_fns['encode'], devices=devices)(video, rngs) * config.scale_factor
    else:
        embeddings = video
    embeddings = embeddings.at[:, :, config.open_loop_ctx:].set(0)

    num_steps = config.num_steps = 1000
    dt = 1. / num_steps

    def get_shape(t):
        shape = embeddings.shape # NBTHWC
        shape = list(shape)
        shape[2] = t
        return tuple(shape)

    if use_xmap:
        p_x_hat = xmap(
            _x_hat, 
            in_axes=(state_spec, ('data', ...), ('data', ...), ('data', ...), ('data', ...)),
            out_axes=('data', ...),
            axis_resources={'data': 'dp', 'model': 'mp'}
        )
        p_ddpm_sample_step = xmap(
            _ddpm_sample_step,
            in_axes=(('data', ...), ('data', ...), ('data', ...), ('data', ...)),
            out_axes=('data', ...),
            axis_resources={'data': 'dp', 'model': 'mp'} 
        )
        p_ddim_sample_step = xmap(
            _ddim_sample_step,
            in_axes=(('data', ...), ('data', ...), ('data', ...), ('data', ...)),
            out_axes=('data', ...),
            axis_resources={'data': 'dp', 'model': 'mp'} 
        )
    else:
        p_x_hat = jax.pmap(_x_hat)
        p_ddpm_sample_step = jax.pmap(_ddpm_sample_step)
        p_ddim_sample_step = jax.pmap(_ddim_sample_step)

    # Sample rest of initial part if open_loop_ctx < seq_len
    if config.open_loop_ctx < config.seq_len:
        x_cond = embeddings[:, :, :config.open_loop_ctx]
        act = actions[:, :, :config.seq_len]
        x_t = jax.random.normal(init_rng, get_shape(config.seq_len - config.open_loop_ctx), dtype=jnp.float32)
        init_rng, _ = jax.random.split(init_rng)

        itr = np.linspace(1. - dt, dt, num_steps)
        if log_output:
            itr = tqdm(itr)

        for timestep in itr:
            t = np.full(embeddings.shape[:-4], fill_value=timestep, dtype=jnp.float32)
            x_hat = p_x_hat(state, x_cond, x_t, t, act)
            x_t, rngs = p_ddpm_sample_step(x_t, x_hat, t, rngs)
        embeddings = embeddings.at[:, :, config.open_loop_ctx:config.seq_len].set(x_hat)

    # Sample extra frames
    num_steps = config.num_steps = 50
    dt = 1. / num_steps

    assert embeddings.shape[2] >= config.seq_len
    itr = list(range(max(config.open_loop_ctx, config.seq_len), embeddings.shape[2]))
    if log_output:
        itr = tqdm(itr)
    for i in itr:
        x_cond = embeddings[:, :, i - config.seq_len + 1:i]
        act = actions[:, :, i - config.seq_len + 1:i + 1]
        assert x_cond.shape[2] == config.seq_len - 1, x_cond.shape
        x_t = jax.random.normal(init_rng, get_shape(1), dtype=jnp.float32)
        init_rng, _ = jax.random.split(init_rng)

        for timestep in np.linspace(1. - dt, dt, num_steps):
            t = np.full(embeddings.shape[:-4], fill_value=timestep, dtype=jnp.float32)
            x_hat = p_x_hat(state, x_cond, x_t, t, act)
            x_t, rngs = p_ddim_sample_step(x_t, x_hat, t, rngs)
        embeddings = embeddings.at[:, :, i:i + 1].set(x_hat)

    if config.mode == 'vq':
        embeddings /= config.scale_factor
 
    def decode(samples):
        # samples: NBTHWD

        if config.mode == 'pixel':
            return samples
        
        N, B, T = samples.shape[:3]
        if N < jax.local_device_count():
            devices = jax.local_devices()[:N]
        else:
            devices = None

        samples = jax.device_get(samples)
        samples = np.reshape(samples, (-1, *samples.shape[3:]))

        recons = []
        for i in list(range(0, N * B * T, 64)):
            inp = samples[i:i + 64]
            inp = np.reshape(inp, (N, -1, *inp.shape[1:]))
            recon = jax.pmap(_decode, devices=devices)(inp)
            recon = jax.device_get(recon)
            recon = np.reshape(recon, (-1, *recon.shape[2:]))
            recons.append(recon)
        recons = np.concatenate(recons, axis=0)
        recons = np.reshape(recons, (N, B, T, *recons.shape[1:]))
        recons = np.clip(recons, -1, 1)
        return recons # BTHWC
    samples = decode(embeddings)

    if video.shape[3] == 16:
        video = jax.pmap(model.ae_fns['encode'], devices=devices)(video, rngs)
        video = decode(video)

    return samples, video
 

class LatentFDMShard(nn.Module):
    config: Any
    ae_fns: Dict[str, Callable]
    ae: Any
    dtype: Optional[Any] = jnp.float32

    @property
    def metrics(self):
        return ['loss']

    def aggregate_metrics(self, metrics):
        metrics = jax.lax.pmean(metrics, 'data')
        return metrics

    def setup(self):
        self.unet = UNetShard(
            **self.config.unet,
            num_shards=self.config.num_shards,
            out_channels=3 if self.config.mode == 'pixel' else self.ae.embedding_dim
        )

    def _x_hat(self, x_t, t, actions, mask, deterministic):
        eps_hat = self._eps(x_t, t, actions, mask, deterministic)
        alpha_t, sigma_t, _ = get_cosine_schedule(unsqueeze4x(t))
        x_hat = (x_t - sigma_t * eps_hat) / alpha_t
        return x_hat

    def _eps(self, x_t, t, actions, mask, deterministic):
        mask = jnp.tile(mask, (1, 1, x_t.shape[2], x_t.shape[3], 1))
        x_t = jnp.concatenate([x_t, mask], axis=-1)
        eps_hat = self.unet(x_t, t, actions, deterministic=deterministic)
        return eps_hat

    def __call__(self, video, actions, deterministic=False, dropout_actions=None):
        # Handle actions
        if not self.config.use_actions:
            if actions is None:
                actions = jnp.zeros(video.shape[:2], dtype=jnp.int32)
            else:
                actions = jnp.zeros_like(actions)

        if dropout_actions is None:
            dropout_actions = jax.random.bernoulli(self.make_rng('sample'), p=0.5,
                                                shape=(video.shape[0],)) # B
        if self.config.dropout_actions:
            actions = jnp.where(dropout_actions[:, None], -1, actions)

        # Conditional mask (0 if cond, 1 if generated)
        n_cond = jax.random.randint(self.make_rng('sample'), shape=(video.shape[0], 1),
                                    minval=0., maxval=self.config.seq_len)
        mask = (jnp.arange(self.config.seq_len)[None] >= n_cond).astype(jnp.float32)
        mask = mask[:, :, None, None, None]

        # Encode
        if self.config.mode == 'vq':
            video = self.ae_fns['encode'](video, self.make_rng('noise')) * self.config.scale_factor
        x_0 = video

        # Diffusion loss
        t = jax.random.uniform(self.make_rng('sample'), shape=(x_0.shape[0],),
                               minval=0, maxval=1, dtype=jnp.float32)
        alpha_t, sigma_t, _ = get_cosine_schedule(unsqueeze4x(t))
        eps = jax.random.normal(self.make_rng('sample'), shape=x_0.shape, dtype=x_0.dtype)
        x_t = (1 - mask) * x_0 + mask * (alpha_t * x_0 + sigma_t * eps)

        eps_hat = self._eps(x_t, t, actions, mask, deterministic=deterministic)

        loss = jnp.square(eps_hat - eps)
        mask = jnp.broadcast_to(mask, loss.shape)
        loss = (loss * mask).sum() / mask.sum()
        return dict(loss=loss)

    @cached_property
    def model_spec(self):
        if self.config.mode == 'vq':
            image_size = self.ae.latent_shape[0]
        else:
            image_size = self.config.image_size
        return sharding.GenericDict({
            'unet': UNetShard.model_spec(image_size, **self.config.unet, conv_resample=True)
        })

    
def unsqueeze4x(x):
    return x[..., None, None, None, None]
