from typing import NamedTuple, Tuple, Iterable, Callable, Optional, Protocol

import jax.lax
import jax.numpy as jnp
import chex
from functools import partial

from jax._src.scipy.special import logsumexp

from algorithms.fab.utils.jax_util import broadcasted_where
from algorithms.scld.prioritised_buffer_utils import *

# This file is deprecated

def build_prioritised_buffer(  # todo incorporate prio using the target density
        dim: int,
        n_sub_traj: int,
        max_length: int,
        min_length_to_sample: int,
        prioritized: bool = True,
        sample_with_replacement: bool = False
) -> PrioritisedBuffer:
    """
    Create replay buffer for batched sampling and adding of data.

    Args:
        dim: Dimension of x data.
        max_length: Maximum length of the buffer.
        min_length_to_sample: Minimum length of buffer required for sampling.
        sample_with_replacement: Whether to sample with replacement.

    The `max_length` and `min_sample_length` should be sufficiently long to prevent overfitting
    to the replay data. For example, if `min_sample_length` is equal to the
    sampling batch size, then we may overfit to the first batch of data, as we would update
    on it many times during the start of training.
    """
    assert min_length_to_sample <= max_length

    def init(x: chex.Array, log_w: chex.Array) -> PrioritisedBufferState:
        """
        Initialise the buffer state, by filling it above `min_sample_length`.
        """
        chex.assert_rank(x, 3)
        chex.assert_shape(x[0, 0], (dim,))
        n_samples = x.shape[1]
        assert n_samples >= min_length_to_sample, "Buffer requires at least `min_sample_length` samples for init."

        current_index = 0
        is_full = False  # whether the buffer is full
        can_sample = False  # whether the buffer is full enough to begin sampling
        # init data to have -inf log_w to prevent these values being sampled.
        data = Data(x=jnp.zeros((n_sub_traj + 1, max_length, dim)) * float("nan"),
                    log_w=-jnp.ones((n_sub_traj + 1, max_length), ) * float("inf"))
        buffer_state = PrioritisedBufferState(data=data, is_full=is_full, can_sample=can_sample,
                                              current_index=current_index)
        buffer_state = add(x, log_w, buffer_state)
        return buffer_state

    # I'm fairly sure log_w corresponds to logAnnealingSchedule not log_rnds
    def add(x: chex.Array, log_w: chex.Array,
            buffer_state: PrioritisedBufferState) -> PrioritisedBufferState:
        """Update the buffer's state with a new batch of data."""
        x = jax.lax.stop_gradient(x)
        log_w = jax.lax.stop_gradient(log_w) # best to stop gradient too in case?
        chex.assert_rank(x, 3)
        chex.assert_equal_shape((x[0, 0], buffer_state.data.x[0, 0]))
        batch_size = x.shape[1]
        valid_samples = jnp.isfinite(log_w) & jnp.all(jnp.isfinite(x), axis=-1)
        indices = (jnp.arange(batch_size) + buffer_state.current_index) % max_length

        # Remove invalid samples.
        # This may result in duplicating samples from first bit of buffer
        # but this is probably okay if invalid samples are rare enough
        x, log_w = jax.tree_map(
            lambda a, b: broadcasted_where(valid_samples, a, b),
            (x, log_w), (buffer_state.data.x[:, indices], buffer_state.data.log_w[:, indices]))

        # Add valid samples to buffer (possibly overwriting old data).
        x = buffer_state.data.x.at[:, indices].set(x)
        log_w = buffer_state.data.log_w.at[:, indices].set(log_w)

        # Keep track of index, and whether buffer is full.
        new_index = buffer_state.current_index + batch_size
        is_full = jax.lax.select(buffer_state.is_full, buffer_state.is_full,
                                 new_index >= max_length)
        can_sample = jax.lax.select(buffer_state.is_full, buffer_state.can_sample,
                                    new_index >= min_length_to_sample)
        current_index = new_index % max_length

        data = Data(x=x, log_w=log_w)
        state = PrioritisedBufferState(data=data,
                                       current_index=current_index,
                                       is_full=is_full,
                                       can_sample=can_sample)
        return state

    def sample(key: chex.PRNGKey,
               buffer_state: PrioritisedBufferState,
               batch_size: int,
               subtraj_id: int = None) -> Tuple[chex.Array, chex.Array]:
        """
        Sample a batch from the buffer in proportion to the log weights.
        Returns:
            x: Samples: Shape is (num_subtraj + 1, batch_size, problem_dim)
            log_q_old: Value of log_q when log_w was calculated.
            indices: Indices of samples for their location in the buffer state.
                
        """
        assert batch_size <= min_length_to_sample, "Min length to sample must be greater than or equal to " \
                                                   "the batch size."
        # Get indices.
        buffer_size = max_length if buffer_state.is_full else buffer_state.current_index

        def sample_from_probabilities(probabilities, key):
            return jax.random.choice(key, buffer_size, shape=(batch_size,), replace=sample_with_replacement,
                                     p=probabilities)

        # Suppose our buffer is represented as (s,n,d) tensor, s = #subtraj+1, n=#items in buffer, d = problem dim
        # so this gives n possible starting points at each subtraj startpt
        # For each starting time 0<=t<s, the n pts at that time is weighed by 
        # pi_t(pt) and thus we importance sample (by default without replacement) according to these
        # the lines of code below generates the sampled indices
        # where indices[i][j] means buffer[i][indices[i][j]] is the j-th sampled
        # point for the ith subtraj starting point
        probs = jnp.exp(buffer_state.data.log_w[:, :buffer_size])
        sample_fn = jax.vmap(sample_from_probabilities, in_axes=(0, 0))
        indices = sample_fn(probs, jax.random.split(key, probs.shape[0]))

        # retrieves points corresponding to indices
        return buffer_state.data.x[jnp.arange(indices.shape[0])[:, None], indices], indices  

    def sample_n_batches(
            key: chex.PRNGKey,
            buffer_state: PrioritisedBufferState,
            batch_size: int,
            n_batches: int) -> \
            Iterable[Tuple[chex.Array, chex.Array, chex.Array]]:
        """Returns dataset with n-batches on the leading axis."""
        x, indices = sample(key, buffer_state, batch_size * n_batches)
        dataset = jax.tree_map(lambda x: x.reshape((n_batches, batch_size, *x.shape[1:])),
                               (x, indices))
        return dataset

    return PrioritisedBuffer(init=init,
                             add=add,
                             sample=sample,
                             sample_n_batches=sample_n_batches,
                             min_lengtht_to_sample=min_length_to_sample,
                             max_length=max_length)


if __name__ == '__main__':
    dim = 3
    n_sub_traj = 32
    max_length = 1000
    min_length_to_sample = 10

    # x_init = jnp.ones([n_sub_traj + 1, min_length_to_sample + 1, dim])
    # buffer = build_prioritised_buffer(dim, n_sub_traj, max_length, min_length_to_sample)
    # buffer_state = buffer.init(x_init)
    # data, idx = buffer.sample(jax.random.PRNGKey(2), buffer_state, 5)
