from typing import Optional, Any, Tuple, Dict, Callable
import math
from tqdm import tqdm
import optax
import numpy as np
import jax
import flax.linen as nn
import jax.numpy as jnp
from flax.linen.partitioning import param_with_axes

from . import transformer
from . import layers
from ..utils import topk_sample

Array = Any
Dtype = Any


MASK_ID = -1


def schedule(ratio, total_unknown, method='cosine'):
    if method == 'uniform':
        mask_ratio = 1. - ratio
    elif 'pow' in method:
        exponent = float(method.replace('pow', ''))
        mask_ratio = 1. - ratio ** exponent
    elif method == 'cosine':
        mask_ratio = jax.lax.cos(math.pi / 2. * ratio)
    elif method == 'log':
        mask_ratio = -jnp.log2(ratio) / jnp.log2(total_unknown)
    elif method == 'exp':
        mask_ratio = 1 - jnp.exp2(-jnp.log2(total_unknown) * (1 - ratio))
    mask_ratio = jnp.clip(mask_ratio, 1e-6, 1.)
    return mask_ratio

    
def mask_by_random_topk(rng, mask_len, probs, temperature=1.0):
    confidence = jnp.log(probs) + temperature * jax.random.gumbel(rng, probs.shape)
    sorted_confidence = jnp.sort(confidence, axis=-1)
    cut_off = jnp.take_along_axis(sorted_confidence, mask_len, axis=-1)
    masking = (confidence < cut_off)
    return masking


def sample_mask(Z, T, rng):
    N = np.prod(Z)
    idxs = jnp.arange(N, dtype=jnp.int32)
    idxs = jax.random.permutation(rng, idxs)
    chunks = jnp.array_split(idxs, T)

    masks = []
    for t in range(T):
        mask = jax.nn.one_hot(chunks[t], N).sum(axis=0).astype(bool)
        mask = jnp.reshape(mask, Z)
        masks.append(mask)
    return masks


class MaskGit(nn.Module):
    shape: Tuple[int]
    vocab_size: int
    vocab_dim: int
    mask_schedule: str
    tfm_kwargs: Dict[str, Any]
    one_hot_input: bool = False
    dtype: Optional[Any] = jnp.float32

    def setup(self):
        if self.one_hot_input:
            self.token_embed = param_with_axes(
                'token_embed', nn.initializers.normal(stddev=0.02),
                [self.vocab_size, self.vocab_dim], jnp.float32
            )
        else:
            self.token_embed = param_with_axes(
                'token_embed', nn.initializers.normal(stddev=0.02),
                [self.vocab_size + 32, self.vocab_dim], jnp.float32,
                axes=('vocab', 'embed')
            )
        self.net = transformer.Transformer(
            **self.tfm_kwargs,
            shape=self.shape,
            pos_embed_type='broadcast',
            dtype=self.dtype
        )
        self.mlm = MlmLayer(self.vocab_dim, self.dtype)

    def _step(self, x, cond=None, deterministic=False):
        token_embed = jnp.asarray(self.token_embed, self.dtype)
        if self.one_hot_input:
            x = x @ token_embed
        else:
            x = token_embed[(x,)]
        x = self.net(x, cond=dict(cat=cond), deterministic=deterministic)
        logits = self.mlm(x, self.token_embed[:self.vocab_size])
        return logits
        

    def sample(self, n, T_draft, T_revise, M, cond=None):
        if self.one_hot_input:
            sample = jnp.zeros((n, *self.shape, self.vocab_size), dtype=jnp.float32)
        else:
            sample = jnp.full((n, *self.shape), MASK_ID, dtype=jnp.int32)

        def _update(samples, masks):
            for mask in masks:
                if self.one_hot_input:
                    samples = jnp.where(mask[..., None], 0, samples)
                else:
                    samples = jnp.where(mask, MASK_ID, samples)
                logits = self._step(samples, cond=cond, deterministic=True)
                s = topk_sample(self.make_rng('sample'), logits)
                if self.one_hot_input:
                    s = jax.nn.one_hot(s, num_classes=self.vocab_size)
                    samples = jnp.where(mask[..., None], s, samples)
                else:
                    samples = jnp.where(mask, s, samples)
            return samples
        
        # Draft
        masks = sample_mask(self.shape, T_draft, self.make_rng('sample'))
        sample = _update(sample, masks)
        
        # Revise
        for _ in range(M):
            masks = sample_mask(self.shape, T_revise, self.make_rng('sample'))
            sample = _update(sample, masks)
        
        return sample

    def sample2(self, n, T, cond=None, temperature=1.0):
        samples = jnp.full((n, *self.shape), MASK_ID, dtype=jnp.int32)
        for t in range(T):
            samples = self.sample2_step(samples, t, T, cond=cond, temperature=temperature)
        return samples
        

    def sample2_step(self, samples, t, T, cond=None, temperature=1.0):
        total_unknown = jnp.full((samples.shape[0],), np.prod(self.shape), dtype=jnp.int32)
        samples = samples.reshape(samples.shape[0], -1)
        cond = cond.reshape(cond.shape[0], -1, cond.shape[-1])
      
        unknown_map = samples == MASK_ID
        logits = self._step(samples, cond=cond, deterministic=True)
        if self.mask_schedule == 'constant':
            new_tokens = jnp.argmax(logits, axis=-1)
            return new_tokens
        else:
            new_tokens = jax.random.categorical(self.make_rng('sample'), logits, axis=-1) # BL
        samples = jnp.where(unknown_map, new_tokens, samples)

        ratio = 1. * (t + 1) / T
        mask_ratio = schedule(ratio, total_unknown, method=self.mask_schedule)

        probs = jax.nn.softmax(logits, axis=-1)
        selected_probs = jnp.squeeze(
            jnp.take_along_axis(probs, jnp.expand_dims(samples, -1), -1), -1
        )
        selected_probs = jnp.where(unknown_map, selected_probs, jnp.inf)

        mask_len = jnp.expand_dims(jnp.floor(total_unknown * mask_ratio), 1)
        mask_len = jnp.maximum(
            0,
            jnp.minimum(jnp.sum(unknown_map, axis=-1, keepdims=True) - 1, mask_len))
      
        mask_len = mask_len.astype(jnp.int32)
        masking = mask_by_random_topk(self.make_rng('sample'), mask_len, selected_probs,
                                      temperature * (1. - ratio))
        samples = jnp.where(masking, MASK_ID, samples)
        samples = samples.reshape(samples.shape[0], *self.shape)
        return samples

    def __call__(self, x, cond=None, deterministic=False, rng=None):
        # x: B..., cond: B...D
        if self.one_hot_input:
            B, L = x.shape[0], np.prod(x.shape[1:-1])
        else:
            B, L = x.shape[0], np.prod(x.shape[1:])

        ratio = jax.random.uniform(self.make_rng('sample'), shape=(B,), dtype=self.dtype)
        ratio = schedule(ratio, L, method=self.mask_schedule)
        ratio = jnp.maximum(1, jnp.floor(ratio * L))

        sample = jnp.arange(L)[None, :].repeat(B, axis=0)
        sample = jax.random.permutation(self.make_rng('sample'), sample, axis=-1, independent=True)
        mask = sample < ratio[:, None]
        
        if self.one_hot_input:
            mask = mask.reshape(x.shape[:-1])
        else:
            mask = mask.reshape(x.shape)

        if self.one_hot_input:
            masked_x = jnp.where(mask[..., None], 0, x)
        else:
            masked_x = jnp.where(mask, MASK_ID, x)
        logits = self._step(masked_x, cond=cond, deterministic=deterministic) 

        if self.one_hot_input:
            labels = x
        else:
            labels = jax.nn.one_hot(x, num_classes=self.vocab_size)

        return logits, labels, mask

        
class MlmLayer(nn.Module):
    vocab_dim: int
    dtype: Optional[Any] = jnp.float32

    @nn.compact
    def __call__(self, x, embeddings):
        x = layers.DenseGeneral(
            self.vocab_dim,
            dtype=self.dtype,
            kernel_axes=('mlp', 'embed'),
            kernel_init=nn.initializers.normal(stddev=0.02)
        )(x)
        x = nn.gelu(x)
        x = transformer.LayerNorm(dtype=self.dtype)(x)

        output_weights = jnp.transpose(embeddings)
        logits = jnp.matmul(x, output_weights)
        logits = Bias(self.dtype)(logits)
        return logits

        
class Bias(nn.Module):
    dtype: Optional[Any] = jnp.float32

    @nn.compact
    def __call__(self, x):
        bias = param_with_axes('bias', nn.initializers.zeros, 
                               (x.shape[-1],), jnp.float32, axes=('vocab',))
        bias = jnp.asarray(bias, self.dtype) 
        bias = jnp.broadcast_to(bias, x.shape)
        return x + bias
    
