import jax
from jax import numpy as jnp
from jax.nn import softmax
import jax.random as jr
from tqdm import trange
from functools import partial

from vdm.parallel_decode import decode

import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions

def gibbs_corrector(res):
    # Just return the denoising logits
    # Should only be used with the gibbs_update function
    return res["x0_logits"]

def forward_backward_corrector(res):
    return res["rates"] + res["Rt_eval_x"]

def euler_update(key, x, rates):
    D = x.shape[0]
    eps = 1e-8
    # Mask out the self transitions
    rates = rates.at[jnp.arange(D), x].set(0.0)
    sum_rates = jnp.sum(rates, axis=1)
    # transition_logit = jnp.log(-jnp.expm1(-rates)) # Prob = 1 - exp(-rate)
    transition_logit = jnp.log(-jnp.expm1(-sum_rates))[:,None] + jnp.log(rates) - jnp.log(sum_rates + eps)[:,None]
    transition_logit = transition_logit.at[jnp.arange(D), x].set(-sum_rates)
    
    out = jr.categorical(key, transition_logit).astype(jnp.int32)
    return out

def compute_backward(y_with_label, t, apply_fn, params, config, forward_process, rng=None):
    rng = jr.PRNGKey(0) if rng is None else rng

    y_with_label = y_with_label.flatten()
    y = y_with_label[1:-1]

    D = y.shape[0]
    S = config.data.codebook_size + 1
    mask = S-1
    # forward_process = config.forward_process
    min_t = config.training.min_t
    eps = config.training.eps
    qt0 = forward_process.transition(t)
    # R^d_t(*1,*2): (S, S) float array of instantenous transition rates
    # for a single dimension
    Rt = forward_process.rate(t)
    Rt_eval_y = Rt[:, y].T
    
    # Set corresponding values to mask 
    y_with_label = jnp.where((y_with_label == (S-1)), mask, y_with_label)
    y = y_with_label[1:-1]
    x0_logits = apply_fn({"params": params}, y_with_label[None], t, rngs={"permute": rng},
        deterministic=True)
    # Only take the valid parts of the output
    x0_logits = x0_logits[0,1:-1,:S]
    
    # Set mask logits to minus infinity and normalize
    x0_logits = x0_logits.at[:, mask].set(-jnp.inf)
    x0_logits -= jax.scipy.special.logsumexp(x0_logits, axis=-1, 
        keepdims=True)

    # p^{*1}_{0|t}(*2|y): (D, S) float array of marginal likelihoods for each dimension predicted by the model
    p0t_eval_y = softmax(x0_logits, axis=-1)
    
    # q^{*1}_{t|0}(y^d|*2): (D, S) float array of transition probabilities to y
    # However, now each dimension is assumed to be a mask token
    # This is the change that fixes everything!
    qt0_eval_y = qt0[:,mask][None] + eps

    st_eval_y = jnp.einsum("0x,d0->dx", qt0, p0t_eval_y / qt0_eval_y, 
                           precision=jax.lax.Precision.HIGHEST)

    # Since every dimension considers itself as the mask, we set the ratio to 1
    st_eval_y = st_eval_y.at[:, mask].set(1.0)
    backward_score_to_curr = st_eval_y[jnp.arange(D), y] #+ eps
    # On mask dimensions this is dividing by 1, on non-mask it offsets the score function to be centered on y
    st_eval_y /= backward_score_to_curr[:,None]

    # log score is easier to compute
    alpha = qt0[0,0]
    log_score = x0_logits + jnp.log(alpha) - jnp.log(1-alpha)
    log_score = log_score.at[:, mask].set(0)
    log_score = log_score - log_score[jnp.arange(D), y][:, None]

    # (D, S) float array that masks out y[d] for each d index
    y_mask = jnp.ones((D, S))
    y_mask = y_mask.at[jnp.arange(D), y].set(0.0)
    
    results = {
        "score": st_eval_y,
        "log_score": log_score,
        "rates": (st_eval_y * Rt_eval_y) * y_mask,
        "x0_logits": x0_logits,
        "Rt_eval_y": Rt_eval_y,
        "Rt_eval_x": Rt[y],
        "rate_scalar": forward_process._rate_scalar(t)
    }
    return results

def mask_conditonal_gibbs_update(key, x, x0_logits, k=1, mask=1024, temperature=0):
    D = x.shape[0]

    key1, key2 = jr.split(key)

    logits = x0_logits.at[:, mask].set(-jnp.inf)
    # Sample a bunch of new values according to denoising model
    jump_target = jr.categorical(key1, logits).astype(jnp.int32)

    # Margin trick: we use the difference between p(x) and max_{y\neq x} p(y) as confidence
    # this way we prioritize updating the dimensions that can be most improved
    x_logits = logits[jnp.arange(D), x].T
    logits = logits.at[jnp.arange(D), x].set(-jnp.inf)

    # Confidence is logits of x - logits of the best other option
    scores = x_logits - jnp.max(logits, axis=-1)
    # scores = x_logits
    
    # Add temperature annealing
    # This is minus since conventionally we add noise and take max
    scores -= temperature * jr.gumbel(key2, shape=(D,))

    scores = jnp.where(x == mask, jnp.inf, scores)
    # Want to take the k least confident dimensions
    # Trick: sort and then find the kth smallest
    thres = jnp.sort(scores, axis=-1)[k-1]
    out = jnp.where((scores <= thres) & (x != mask), jump_target, x)
    return out

def md4_predictor_update(key, x, x0_logits, unmask_prob, mask=1024):
    """
    Given model predicted denoising probabilities, 
    compute and update transition probabilities from mask schedule
    """
    D = x.shape[0]
    denoising_probs = softmax(x0_logits, axis=-1)
    probs = unmask_prob * denoising_probs
    probs = probs.at[:,mask].set(1 - unmask_prob)

    to_unmask = tfd.Categorical(probs=probs).sample(seed=key)
    out = jnp.where(x == mask, to_unmask, x)
    return out

def remdm_predictor_update(key, x, x0_logits, unmask_prob, sigma_t, 
    mask=1024, softmax_temp=0.8):
    
    D = x.shape[0]
    denoising_probs = softmax(x0_logits / softmax_temp, axis=-1)
    probs = unmask_prob * denoising_probs
    probs = probs.at[:,mask].set(1 - unmask_prob)

    # Identity function but with a small probability of remasking
    remask_probs = jax.nn.one_hot(x, probs.shape[-1]) * (1 - sigma_t)
    remask_probs = remask_probs.at[:,mask].set(sigma_t)

    probs = jnp.where(x[...,None] != mask, remask_probs, probs)

    out = tfd.Categorical(probs=probs).sample(seed=key)
    return out

def backward_process_remdm(apply_fn, params, ts, config, xT, key, forward_process):

    S = config.data.codebook_size + 1
    D = config.data.seq_length
    mask = S - 1
    # k = config.sampler.k
    t = ts[0]
    x = xT

    corrector = config.sampler.corrector
    if corrector == "gibbs":
        corrector_rate = gibbs_corrector
        corrector_update = mask_conditonal_gibbs_update
    elif corrector == "forward_backward" or corrector == "":
        corrector_rate = None
    else:
        raise Exception(f"Only gibbs and forward backward correctors are supported for REMDM")
        return None
    
    def _step(carry, idx):
        x, key = carry
        key, p_key, c_key = jr.split(key, 3)

        t = ts[idx]
        dt = t - ts[idx+1]
        res = compute_backward(x, t, apply_fn, params, config, forward_process)
        
        # Changing update function from euler to MD4 (closed form?)
        # This means that we no longer need rates
        m1 = forward_process.mask_percentage(t)
        m2 = forward_process.mask_percentage(t-dt)
        unmask_prob = (m1 - m2) / m1

        k = config.sampler.k

        if corrector == "gibbs":

            # Apply corrector updates without another forward pass
            rc = corrector_rate(res)
            temperature_coeff = t if config.sampler.anneal_temperature else 1

            c_update = corrector_update(c_key, x[1:-1], rc, k=k, mask=mask,
                temperature=config.sampler.top_k_temperature * temperature_coeff)
            x = x.at[1:-1].set(c_update)

            p_update = md4_predictor_update(p_key, x[1:-1], res["x0_logits"], unmask_prob, mask=mask)
            x = x.at[1:-1].set(p_update)

        else:

            alpha_t = 1 - m1
            alpha_s = 1 - m2

            # Recompute unmask probability because of remasking
            sigma_t_max = jnp.minimum((1 - alpha_s) / alpha_t, 1)
            sigma_t = config.sampler.sigma_scale * sigma_t_max

            unmask_prob = (alpha_s - (1 - sigma_t) * alpha_t) / (1 - alpha_t)

            p_update = remdm_predictor_update(p_key, x[1:-1], res["x0_logits"], unmask_prob, sigma_t, mask=mask)
            x = x.at[1:-1].set(p_update)

        out = { "x": x, }
        
        return (x, key), out

    (x, _), x_hist = jax.lax.scan(_step, (xT, key), jnp.arange(len(ts)-1))
    res = compute_backward(x, t, apply_fn, params, config, forward_process)
    x0_logits = res["x0_logits"]

    if not config.sampler.restricted:
        x0_pred = jnp.argmax(x0_logits, axis=1)
    else:
        # Instead of potentially updating every position, update only the masks
        x0_pred = jnp.where(x[1:-1] == mask, jnp.argmax(x0_logits, axis=1), x[1:-1])

    return x0_pred, x_hist["x"]

def backward_process_corrector(p_apply_fn, p_params, ts, config, xT, key, forward_process,
    c_apply_fn=None, c_params=None):  

    S = config.data.codebook_size + 1
    D = config.data.seq_length
    mask = S - 1
    k = config.sampler.k
    t = ts[0]
    x = xT
    
    c_apply_fn = c_apply_fn or p_apply_fn
    c_params = c_params or p_params

    corrector = config.sampler.corrector
    corrector_update = mask_conditonal_gibbs_update

    if corrector == "gibbs":
        corrector_rate = gibbs_corrector
    elif corrector == "forward_backward":
        corrector_rate = forward_backward_corrector
    else:
        raise Exception(f"Unknown corrector: {corrector}")

    def _c_step(i, carry):
        x, key, t, k = carry
        key, c_key = jr.split(key, 2)
        
        res = compute_backward(x, t, c_apply_fn, c_params, config, forward_process)
        rc = corrector_rate(res)

        if corrector == "gibbs":
            temperature_coeff = t if config.sampler.anneal_temperature else 1
            x_update = corrector_update(c_key, x[1:-1], rc, k=k, mask=mask,
                temperature=config.sampler.top_k_temperature * temperature_coeff)
        else:
            corrector_step_size = config.sampler.corrector_step_size
            x_update = euler_update(c_key, x[1:-1], rc * corrector_step_size)

        x = x.at[1:-1].set(x_update)
        
        return (x, key, t, k)

    def _step(carry, idx):
        x, key = carry
        key, p_key, c_key = jr.split(key, 3)

        t = ts[idx]
        dt = t - ts[idx+1]
        res = compute_backward(x, t, p_apply_fn, p_params, config, forward_process)
        
        # Changing update function from euler to MD4 (closed form?)
        # This means that we no longer need rates
        m1 = forward_process.mask_percentage(t)
        m2 = forward_process.mask_percentage(t-dt)
        unmask_prob = (m1 - m2) / m1
        update = md4_predictor_update(p_key, x[1:-1], res["x0_logits"], unmask_prob, mask=mask)

        x = x.at[1:-1].set(update)

        # Change current time (!!)
        t -= dt

        # Corrector
        if (k <= 0) and corrector == "gibbs":
            # We only apply the corrector if k > 0 to avoid unnecessary computation
            pass
        else:
            x = jax.lax.cond(t <= config.sampler.corrector_entry_time,
                            lambda x: jax.lax.fori_loop(0, config.sampler.num_corrector_steps, 
                                                        _c_step, (x, c_key, t, k))[0],
                            lambda x: x, x)

        out = { "x": x, }
        
        return (x, key), out

    (x, _), x_hist = jax.lax.scan(_step, (xT, key), jnp.arange(len(ts)-1))
    res = compute_backward(x, t, c_apply_fn, c_params, config, forward_process)
    x0_logits = res["x0_logits"]

    if not config.sampler.restricted:
        x0_pred = jnp.argmax(x0_logits, axis=1)
    else:
        # Instead of potentially updating every position, update only the masks
        x0_pred = jnp.where(x[1:-1] == mask, jnp.argmax(x0_logits, axis=1), x[1:-1])

    return x0_pred, x_hist["x"]

def maskgit(apply_fn, params, ts, config, xT, key, forward_process):

    S = config.data.codebook_size + 1
    mask = S-1

    def tokens_to_logits(y): 
        x0_logits = apply_fn({"params": params}, y, t=0, deterministic=True)
        # We keep the label dimensions because they won't be updated anyway
        return x0_logits[...,:S]

    # Add batch dimension
    # inputs = jnp.where((xT == (S-1)), mask, xT)[None]
    inputs = xT[None]
    rng = key

    x_hist = decode(inputs,
           rng,
           tokens_to_logits,
           mask_token_id=mask,
           num_iter=config.sampler.num_steps,
           start_iter=0,
           choice_temperature=config.sampler.maskgit_temperature,
           mask_scheduling_method="cosine")

    x = x_hist[0, -1]

    res = compute_backward(x, 0, apply_fn, params, config, forward_process)
    x0_logits = res["x0_logits"]

    if not config.sampler.restricted:
        x0_pred = jnp.argmax(x0_logits, axis=1)
    else:
        # Instead of potentially updating every position, update only the masks
        x0_pred = jnp.where(x[1:-1] == mask, jnp.argmax(x0_logits, axis=1), x[1:-1])

    return x0_pred, x_hist[0]