from typing import Callable, Optional, Dict, Tuple
from jaxtyping import Array, PyTree

from jax import tree, numpy as jnp, random as jr

from .tfmpe import TFMPE
from ..preprocessing.utils import Labeller, Independence
from ..preprocessing.tokens import Tokens

def truncated_proposal_rejection(
    key: Array,
    model: TFMPE,
    labeller: Labeller,
    independence: Independence,
    f_in: Dict[str, Array],
    n_samples: int,
    epsilon: float,
    y_obs: Dict[str, Array],
    prior_fn: Callable,
    n: int,
    prior_log_prob: Callable[[PyTree], float],
    prob_transform: Optional[Callable[[PyTree, Array], float]] = None,
    n_estimate: int =  10_000, #1_000_000,
    n_batch: Optional[int] = None,
    ) -> PyTree:
    """Sample truncated proposal via sampling importance resampling"""
    if n_batch is None:
        n_batch = n_samples
    estimate_key, key = jr.split(key)

    print('estimating tau')

    print('Sampling')
    samples, log_prob = _batch_sample(
        estimate_key,
        model,
        labeller,
        independence,
        prior_fn,
        n,
        y_obs,
        f_in,
        n_batch,
        n_estimate,
        prob_transform,
    )

    tau = jnp.quantile(log_prob, epsilon)
    m = jnp.zeros((0,))
    print(f"tau = {tau}")

    while jnp.sum(m) < n_samples:
        print(f'progress: {jnp.sum(m)} out of {n_samples}')
        sample_key, key = jr.split(key)
        new_samples, new_log_prob = _batch_sample_prior(
            sample_key,
            model,
            labeller,
            independence,
            prior_fn,
            n,
            y_obs,
            f_in,
            n_batch,
            n_batch,
            prob_transform,
        )
        samples = tree.map(
            lambda x, y: jnp.concatenate([x, y]),
            samples,
            new_samples
        )
        log_prob = jnp.concatenate([
            log_prob,
            new_log_prob
        ])
        m = jnp.concatenate([
            m,
            log_prob > tau
        ])

    samples = tree.map(
        lambda leaf: jnp.compress(m, leaf, axis=0)[:n_samples],
        samples
    )
    print(f"e.g.: {tree.map(lambda leaf: leaf[:2], samples)}")

    return samples

def truncated_proposal_sir(
    key: Array,
    model: TFMPE,
    labeller: Labeller,
    independence: Independence,
    f_in: Dict[str, Array],
    n_samples: int,
    epsilon: float,
    y_obs: Dict[str, Array],
    prior_fn: Callable,
    n: int,
    prior_log_prob: Callable[[PyTree], float],
    prob_transform: Optional[Callable[[PyTree, Array], float]] = None,
    n_estimate: int =  10_000, #1_000_000,
    n_batch: Optional[int] = None,
    ) -> PyTree:
    """Sample truncated proposal via sampling importance resampling"""
    if n_batch is None:
        n_batch = n_samples
    estimate_key, resample_key = jr.split(key)

    print('estimating tau')

    print('Sampling')
    samples, log_prob = _batch_sample(
        estimate_key,
        model,
        labeller,
        independence,
        prior_fn,
        n,
        y_obs,
        f_in,
        n_batch,
        n_estimate,
        prob_transform,
    )

    tau = jnp.quantile(log_prob, epsilon)
    m = log_prob > tau
    print(f"tau = {tau}")
    print(f"acceptance = {jnp.sum(m) / n_estimate}")
    print(f"e.g.: {tree.map(lambda leaf: jnp.compress(m, leaf, axis=0)[:3], samples)}")

    print('Importance')
    w = prior_log_prob(samples) - log_prob
    m = log_prob > tau
    n_w = jnp.sum(m)
    print(f"w_in_m = {jnp.extract(m, w)}")
    print(f"n_w = {n_w}")

    print('resampling')
    indices = jr.categorical(
        resample_key,
        logits=jnp.extract(m, w),
        shape=(n_samples,),
        replace=True
    )

    samples = tree.map(
        lambda leaf: jnp.compress(m, leaf, axis=0)[indices],
        samples
    )
    print(f"e.g.: {tree.map(lambda leaf: leaf[:2], samples)}")

    return samples

def _batch_sample(
    key: Array,
    model: TFMPE,
    labeller: Labeller,
    independence: Independence,
    prior_fn: Callable,
    n: int,
    y_obs: Dict[str, Array],
    f_in: Dict[str, Array],
    n_batch: int,
    n_total: int,
    prob_transform: Optional[Callable[[PyTree, Array], float]] = None,
    ) -> Tuple[PyTree, Array]:
    """Estimate threshold for High Probability Region of approximate posterior"""
    samples, _ = prior_fn(
        key,
        n=n,
        n_samples=1,
    )

    theta_template = tree.map(
        lambda leaf: jnp.zeros(
            (n_batch,) + leaf.shape[1:]
        ),
        samples
    )
    y_expanded = tree.map(
        lambda leaf: jnp.broadcast_to(
            leaf,
            (n_batch,) + leaf.shape[1:]
        ),
        y_obs
    )

    tokens, decoder = Tokens.from_pytree(
        {**y_expanded, **theta_template},
        condition=list(y_expanded.keys()),
        labeller=labeller,
        sample_ndims=1,
        independence=independence,
        functional_inputs=f_in,
        return_decoder=True
    )

    n_sampled: int = 0
    all_tokens = None
    log_prob = jnp.array([])

    while True:
        print(f'progress: {n_sampled} of {n_total}')
        print(f'sampling {n_batch}')

        output_tokens = model.sample_posterior(tokens)
        new_log_prob = model.log_prob_posterior_samples(output_tokens)

        print(f'new_log_prob: {new_log_prob}')
        print(f'new_log_prob shape: {new_log_prob.shape}')

        if prob_transform is not None:
            new_log_prob = prob_transform(decoder(output_tokens), new_log_prob)

        print(f'new_log_prob (post transform): {new_log_prob}')

        log_prob = jnp.concatenate([log_prob, new_log_prob])

        if all_tokens is None:
            all_tokens = output_tokens
        else:
            all_tokens = tree.map(
                lambda x, y: jnp.concatenate([x, y]),
                all_tokens,
                output_tokens
            )

        n_sampled += n_batch

        if n_sampled >= n_total:
            break

    values = decoder(all_tokens)
    values = {
        k: v
        for k, v in values.items()
        if k in theta_template.keys()
    }

    return values, log_prob

def _batch_sample_prior(
    key: Array,
    model: TFMPE,
    labeller: Labeller,
    independence: Independence,
    prior_fn: Callable,
    n: int,
    y_obs: Dict[str, Array],
    f_in: Dict[str, Array],
    n_batch: int,
    n_total: int,
    prob_transform: Optional[Callable[[PyTree, Array], float]] = None,
    ) -> Tuple[PyTree, Array]:
    """Estimate threshold for High Probability Region of approximate posterior"""
    samples, _ = prior_fn(
        key,
        n=n,
        n_samples=1,
    )

    theta_template = tree.map(
        lambda leaf: jnp.zeros(
            (n_batch,) + leaf.shape[1:]
        ),
        samples
    )
    y_expanded = tree.map(
        lambda leaf: jnp.broadcast_to(
            leaf,
            (n_batch,) + leaf.shape[1:]
        ),
        y_obs
    )

    _, decoder = Tokens.from_pytree(
        {**y_expanded, **theta_template},
        condition=list(y_expanded.keys()),
        labeller=labeller,
        sample_ndims=1,
        independence=independence,
        functional_inputs=f_in,
        return_decoder=True
    )

    n_sampled: int = 0
    all_tokens = None
    log_prob = jnp.array([])

    while True:
        print(f'progress: {n_sampled} of {n_total}')
        print(f'sampling {n_batch}')

        prior_samples, _ = prior_fn(
            key,
            n=n,
            n_samples=n_batch,
        )
        output_tokens = Tokens.from_pytree(
            {**y_expanded, **prior_samples},
            condition=list(y_expanded.keys()),
            labeller=labeller,
            sample_ndims=1,
            independence=independence,
            functional_inputs=f_in,
        )
        new_log_prob = model.log_prob_posterior_samples(output_tokens)

        print(f'new_log_prob: {new_log_prob}')
        print(f'new_log_prob shape: {new_log_prob.shape}')

        if prob_transform is not None:
            new_log_prob = prob_transform(decoder(output_tokens), new_log_prob)

        print(f'new_log_prob (post transform): {new_log_prob}')

        log_prob = jnp.concatenate([log_prob, new_log_prob])

        if all_tokens is None:
            all_tokens = output_tokens
        else:
            all_tokens = tree.map(
                lambda x, y: jnp.concatenate([x, y]),
                all_tokens,
                output_tokens
            )

        n_sampled += n_batch

        if n_sampled >= n_total:
            break

    values = decoder(all_tokens)
    values = {
        k: v
        for k, v in values.items()
        if k in theta_template.keys()
    }

    return values, log_prob
