from typing import NamedTuple, Tuple, Iterable, Callable, Optional, Protocol

import jax.lax
import jax.numpy as jnp
import chex
from functools import partial
import jaxlib
from jax._src.scipy.special import logsumexp

from algorithms.fab.utils.jax_util import broadcasted_where


def sample_without_replacement(key: chex.Array, logits: chex.Array, n: int) -> chex.Array:
    # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
    key1, key2 = jax.random.split(key)
    z = jax.random.gumbel(key=key1, shape=logits.shape)
    # vals, indices = jax.lax.approx_max_k(z + logits, n)
    vals, indices = jax.lax.top_k(z + logits, n)
    indices = jax.random.permutation(key2, indices)
    return indices


class Data(NamedTuple):
    """Log weights and samples generated by annealed importance sampling."""
    x: chex.Array
    log_w: chex.Array


class PrioritisedBufferState(NamedTuple):
    """State of the buffer, storing the data and additional info needed for it's use."""
    data: Data
    is_full: jnp.bool_
    can_sample: jnp.bool_
    current_index: jnp.int32


class InitFn(Protocol):
    def __call__(self, x: chex.Array, log_w: chex.Array, on_cpu: bool) -> PrioritisedBufferState:
        """Initialise the buffer state, by filling it above `min_sample_length`."""


class AddFn(Protocol):
    def __call__(self, x: chex.Array, log_w: chex.Array,
                 buffer_state: PrioritisedBufferState) -> PrioritisedBufferState:
        """Update the buffer's state with a new batch of data."""


class SampleFn(Protocol):
    def __call__(self, key: chex.PRNGKey,
                 buffer_state: PrioritisedBufferState,
                 batch_size: int,
                 subtraj_id: int = None) -> Tuple[chex.Array, chex.Array]:
        """
        Sample a batch from the buffer in proportion to the log weights.
        Returns:
            x: Samples.
            indices: Indices of samples for their location in the buffer state.
        """


class SampleNBatchesFn(Protocol):
    def __call__(self, key: chex.PRNGKey,
                 buffer_state: PrioritisedBufferState,
                 batch_size: int,
                 n_batches: int) -> \
            Iterable[Tuple[chex.Array, chex.Array, chex.Array]]:
        """Returns dataset with n-batches on the leading axis. See `SampleFn`."""


class AdjustFn(Protocol):
    def __call__(self, log_w_new: chex.Array,
                 indices: chex.Array,
                 buffer_state: PrioritisedBufferState,
                 subtraj_id: int = None) \
            -> PrioritisedBufferState:
        """Adjust log weights to match new value of theta, this is typically performed
        over minibatches, rather than over the whole dataset at once."""


class PrioritisedBuffer(NamedTuple):
    init: InitFn
    add: AddFn
    sample: SampleFn
    sample_n_batches: SampleNBatchesFn
    min_lengtht_to_sample: int
    max_length: int
    upd_weights: AdjustFn = None
    


