from typing import NamedTuple, Tuple, Iterable, Callable, Optional, Protocol

import jax.lax
import jax.numpy as jnp
import numpy as np
import chex
from functools import partial

from jax._src.scipy.special import logsumexp

from algorithms.fab.utils.jax_util import broadcasted_where
from algorithms.scld.prioritised_buffer_utils import *

# I've split this file out in case we do other modifications to buffer ideas
# It is the responsibility of external code to ensure that buffer and input arrays
# are on the same device (for jitting purposes)
# if we put buffer on CPU, opt to use numpy array for buffer as jax arrays make
# copies every time we set

def build_prioritised_subtraj_buffer(  
        dim: int,
        n_sub_traj: int,
        max_length: int,
        min_length_to_sample: int,
        length_of_subtraj: int,
        prioritized: bool = True,
        sample_with_replacement: bool = False,
        temperature: float = 1,
        on_cpu: bool = False
) -> PrioritisedBuffer:
    """
    Create replay buffer for batched sampling and adding of data.

    Args:
        dim: Dimension of x data.
        max_length: Maximum length of the buffer.
        min_length_to_sample: Minimum length of buffer required for sampling.
        sample_with_replacement: Whether to sample with replacement.
        n_sub_traj: number of things

    The `max_length` and `min_sample_length` should be sufficiently long to prevent overfitting
    to the replay data. For example, if `min_sample_length` is equal to the
    sampling batch size, then we may overfit to the first batch of data, as we would update
    on it many times during the start of training.
    """
    assert min_length_to_sample <= max_length

    def init(x: chex.Array, log_w: chex.Array) -> PrioritisedBufferState:
        """
        Initialise the buffer state, by filling it above `min_sample_length`.
        x is array of shape (n_sub_traj, ?, length_of_subtraj, dim)
        """
        chex.assert_rank(x, 4)
        chex.assert_shape(x[0][0], (length_of_subtraj, dim))
        n_samples = x.shape[1]
        assert n_samples >= min_length_to_sample, "Buffer requires at least `min_sample_length` samples for init."

        current_index = 0
        is_full = False  # whether the buffer is full
        can_sample = False  # whether the buffer is full enough to begin sampling
        # init data to have -inf log_w to prevent these values being sampled.

        device = None if not on_cpu else jax.devices("cpu")[0]

        buffer_array = np.zeros((n_sub_traj, max_length, length_of_subtraj, dim)) if on_cpu else jnp.zeros((n_sub_traj, max_length, length_of_subtraj, dim))
        
        # we don't need to set xs to NaNs as the log_ws are -infinity
        data = Data(x=buffer_array, 
                    log_w=-jnp.ones((n_sub_traj, max_length,), device=device) * float("inf")) 


        buffer_state = PrioritisedBufferState(data=data, is_full=is_full, can_sample=can_sample,
                                              current_index=current_index)
        buffer_state = add(x, log_w, buffer_state)
        return buffer_state

    def add(x: chex.Array, log_w: chex.Array,
            buffer_state: PrioritisedBufferState) -> PrioritisedBufferState:
        """Update the buffer's state with a new batch of data."""
            
        chex.assert_rank(x, 4)
        chex.assert_equal_shape((x[0, 0], buffer_state.data.x[0, 0]))
        batch_size = x.shape[1]
        valid_samples = jnp.isfinite(log_w) & jnp.all(jnp.isfinite(x), axis=(2,3))
        indices = (jnp.arange(batch_size) + buffer_state.current_index) % max_length

        # Remove invalid samples.
        # This may result in duplicating samples from first bit of buffer
        # but this is probably okay if invalid samples are rare enough
        x, log_w = jax.tree_map(
            lambda a, b: broadcasted_where(valid_samples, a, b),
            (x, log_w), (buffer_state.data.x[:, indices], buffer_state.data.log_w[:, indices]))

        # Add valid samples to buffer (possibly overwriting old data).
        if not on_cpu:
            x = buffer_state.data.x.at[:, indices].set(x)
        else:
            buffer_state.data.x[:, indices] = x
            x = buffer_state.data.x

        log_w = buffer_state.data.log_w.at[:, indices].set(log_w)
        # Keep track of index, and whether buffer is full.
        new_index = buffer_state.current_index + batch_size
        is_full = jax.lax.select(buffer_state.is_full, buffer_state.is_full,
                                 new_index >= max_length)
        can_sample = jax.lax.select(buffer_state.is_full, buffer_state.can_sample,
                                    new_index >= min_length_to_sample)
        current_index = new_index % max_length

        data = Data(x=x, log_w=log_w)
        state = PrioritisedBufferState(data=data,
                                       current_index=current_index,
                                       is_full=is_full,
                                       can_sample=can_sample)
        return state

    def sample(key: chex.PRNGKey,
               buffer_state: PrioritisedBufferState,
               batch_size: int,
               subtraj_id: int = None) -> Tuple[chex.Array, chex.Array]:
        """
        Sample a batch from the buffer in proportion to the log weights.
        Returns:
            x: Samples: Shape is (num_subtraj_per_rollout, batch_size, subtraj_length, problem_dim)
            indices: Indices of samples for their location in the buffer state.
                
        """
        assert batch_size <= min_length_to_sample, "Min length to sample must be greater than or equal to " \
                                                   "the batch size."
        # Get indices.
        buffer_size = max_length if buffer_state.is_full else buffer_state.current_index

        def logprob_transform(ps):
            # we can apply temperature transforms / de-prioritize the buffer etc
            if not prioritized:
                return jnp.zeros_like(ps)
            return ps/temperature 
        
        def sample_from_probabilities(probabilities, key):
            return jax.random.choice(key, buffer_size, shape=(batch_size,), replace=sample_with_replacement,
                                     p=probabilities)

        # Suppose our buffer is represented as (s,n,d) tensor, s = #subtraj+1, n=#items in buffer, d = problem dim
        # so this gives n possible starting points at each subtraj startpt
        # For each starting time 0<=t<s, the n pts at that time is weighed by 
        # pi_t(pt) and thus we importance sample (by default without replacement) according to these
        # the lines of code below generates the sampled indices
        # where indices[i][j] means buffer[i][indices[i][j]] is the j-th sampled
        # point for the ith subtraj starting point
        sample_fn = jax.vmap(sample_from_probabilities, in_axes=(0, 0))
        if subtraj_id is None:
            probs = jnp.exp(logprob_transform(buffer_state.data.log_w[:, :buffer_size]))
            indices = sample_fn(probs, jax.random.split(key, probs.shape[0]))

            # retrieves points corresponding to indices
            return buffer_state.data.x[jnp.arange(indices.shape[0])[:, None], indices], indices 
        else:
            probs = jnp.exp(logprob_transform(buffer_state.data.log_w[subtraj_id, :buffer_size]))
            indices = sample_fn(probs, jax.random.split(key, probs.shape[0]))

            # retrieves points corresponding to indices
            # returns (1, requested_size, traj_length) buffers
            return jnp.array(buffer_state.data.x[subtraj_id, indices], copy=on_cpu), indices # can also do jnp.array(..., copy=False)

    def adjust_weights(new_logws: chex.Array,
                       indices: chex.Array, 
                       buffer_state: PrioritisedBufferState,
                       subtraj_id: int = None) -> PrioritisedBufferState:
        
        if subtraj_id is None:
            # TODO: why is this code unreachable
            assert(indices.shape == new_logws.shape)
            assert(new_logws.shape[0] == buffer_state.data.log_w.shape[0])
            # update all trajectories as we trained on all of them. new_logws is (S,B)
            new_weights = buffer_state.data.log_w.at[jnp.arange(indices.shape[0])[:, None], 
                                                    indices].set(jax.lax.stop_gradient(new_logws))
        else:
            assert(new_logws.shape[0] == 1 and len(new_logws.shape) == 2)
            # update just the one subtrajectory we trained on. new_logws is (1, B)
            new_weights = buffer_state.data.log_w.at[jnp.array([[subtraj_id]]), 
                                                    indices[subtraj_id]].set(jax.lax.stop_gradient(new_logws))
        
        
        data = Data(x=buffer_state.data.x, log_w=new_weights)
        state = PrioritisedBufferState(data=data,
                                       current_index=buffer_state.current_index,
                                       is_full=buffer_state.is_full,
                                       can_sample=buffer_state.can_sample)
        return state

    # we don't really use this:
    def sample_n_batches(
            key: chex.PRNGKey,
            buffer_state: PrioritisedBufferState,
            batch_size: int,
            n_batches: int) -> \
            Iterable[Tuple[chex.Array, chex.Array, chex.Array]]:
        """Returns dataset with n-batches on the leading axis."""
        x, indices = sample(key, buffer_state, batch_size * n_batches)
        dataset = jax.tree_map(lambda x: x.reshape((n_batches, batch_size, *x.shape[1:])),
                               (x, indices))
        return dataset

    return PrioritisedBuffer(init=init,
                             add=add,
                             sample=sample,
                             sample_n_batches=sample_n_batches,
                             min_lengtht_to_sample=min_length_to_sample,
                             max_length=max_length,
                             upd_weights=adjust_weights)


if __name__ == '__main__':
    dim = 3
    n_sub_traj = 32
    max_length = 1000
    min_length_to_sample = 10

    # x_init = jnp.ones([n_sub_traj + 1, min_length_to_sample + 1, dim])
    # buffer = build_prioritised_buffer(dim, n_sub_traj, max_length, min_length_to_sample)
    # buffer_state = buffer.init(x_init)
    # data, idx = buffer.sample(jax.random.PRNGKey(2), buffer_state, 5)
