import numpy as np
from collections import deque


class ReplayMemory(object):
    def __init__(self, capacity, seed):
        self.reset(capacity)
        self.capacity = capacity
        self.seed = seed
        self.rng = np.random.default_rng(seed=self.seed)

    def reset(self, capacity):
        if capacity is None:
            capacity = self.capacity
        self.memory = deque(maxlen=capacity)

    def update(self, transition):
        self.memory.appendleft(transition)  # pop from right if full

    def sample(self, batch_size, recent_size=None):
        length = len(self.memory)
        if recent_size is not None:
            length = min(length, recent_size)
        indices = self.rng.integers(low=0, high=length, size=(batch_size, ))
        return [self.memory[i] for i in indices], None  # dummy for nxt

    def __len__(self):
        return len(self.memory)


class ReplayMemoryMeta(ReplayMemory):
    def __init__(self, capacity, fill, seed):
        super().__init__(capacity, seed)
        self.fill = fill

    def reset_online(self):
        self.memory_online = deque()    # unlimited

    def fill_meta(self):
        online_sample_ind = self.rng.choice(len(self.memory_online),
                                            replace=False, 
                                            size=self.fill)
        online_sample = [self.memory_online[ind] for ind in online_sample_ind]
        self.memory.extendleft(online_sample)

    # Overwrites - only store to online memory
    def update(self, transition):
        self.memory_online.appendleft(transition)  # pop from right if full

    # Overwrites
    def sample(self, batch_size, recent_size=None, online_weight=0.5):
        """weighted sampling from online and meta buffer"""
        length_online = len(self.memory_online)
        length_meta = len(self.memory)
        batch_size_meta = min(length_meta, int((1-online_weight)*batch_size))   # in case initially meta has no sample
        batch_size_online = batch_size - batch_size_meta
        indices_online = self.rng.integers(low=0, high=length_online, size=(batch_size_online, ))
        indices_meta = self.rng.integers(low=0, high=length_meta, size=(batch_size_meta, ))
        samples = [self.memory_online[i] for i in indices_online] +  [self.memory[i] for i in indices_meta]
        return samples, None  # dummy for nxt


class ReplayMemoryTraj():
    def __init__(self, capacity, seed, sample_next=False):
        self.reset(capacity)
        self.capacity = capacity
        self.seed = seed
        self.rng = np.random.default_rng(seed=self.seed)
        self.sample_next = sample_next

    def reset(self, capacity=None):
        if capacity is None:
            capacity = self.capacity
        self.memory = deque(maxlen=capacity)
        self.traj_len = deque(maxlen=capacity)

    def update(self, traj):
        self.memory.appendleft(traj)  # pop from right if full
        self.traj_len.appendleft(len(traj))

    ########### For sampling batch of segments from all trajectories ###########
    def set_possible_samples(self,
                             traj_size=50,
                             frame_skip=0,
                             allow_repeat_frame=False,
                             recent_size=None):
        #! if burn-in, not using initial steps, might be an issue with
        #! fixed_init also some trajectories can be too short for traj_size
        if allow_repeat_frame:
            self.offset = 0
        else:
            self.offset = (traj_size - 1) * frame_skip + traj_size
        if recent_size is not None:
            traj_len_all = [
                self.traj_len[ind]
                for ind in range(min(recent_size, len(self.traj_len)))
            ]
        else:
            traj_len_all = self.traj_len
        self.possible_end_inds = []
        for traj_ind, traj_len in enumerate(
                traj_len_all
        ):  # this is fine since recent traj starts from ind=0 (from the left)
            self.possible_end_inds += [
                (traj_ind, transition_ind)
                for transition_ind in range(self.offset, traj_len)
            ]  # allow done at the end

    def sample(self, batch_size, traj_size=50, frame_skip=0):
        traj_cover = (
            traj_size - 1
        ) * frame_skip + traj_size  # min steps needed; if fewer, randomly sample
        inds = self.rng.integers(low=0,
                                 high=len(self.possible_end_inds),
                                 size=(batch_size, ))
        out = []
        out_nxt = []
        for ind in inds:
            traj_ind, transition_ind = self.possible_end_inds[ind]

            # Implicitly allow repeat frame
            if transition_ind < traj_cover:
                if transition_ind == 0:
                    seq = np.zeros((traj_size), dtype='int')
                else:  # randomly sampled from
                    seq_random = np.random.choice(
                        transition_ind, traj_size - 1,
                        replace=True)  # exclude transition_ind
                    seq_random = np.sort(seq_random)  # ascending
                    seq = np.append(seq_random, transition_ind)  # add to end
            else:
                seq = -np.arange(0, traj_size) * (frame_skip +
                                                  1) + transition_ind
                seq = np.flip(seq, 0)
            out += [[self.memory[traj_ind][ind] for ind in seq]]

            # Get next - can be empty if prev is done
            if self.sample_next:
                transition_ind += 1
                if transition_ind < traj_cover:  # cannot be 0 any more
                    seq_random = np.random.choice(transition_ind,
                                                  traj_size - 1,
                                                  replace=True)
                    seq_random = np.sort(seq_random)
                    seq = np.append(seq_random, transition_ind)
                elif transition_ind == self.traj_len[traj_ind]:
                    seq = []
                else:
                    seq = -np.arange(0, traj_size) * (frame_skip +
                                                      1) + transition_ind
                    seq = np.flip(seq, 0)
                out_nxt += [[self.memory[traj_ind][ind] for ind in seq]]
        return out, out_nxt

    def __len__(self):
        return len(self.memory)
