import collections

import jax
import numpy as np

Batch = collections.namedtuple(
    "Batch",
    ["observations", "actions", "rewards", "discounts", "next_observations"])


class EnsembleBuffer:
    def __init__(self,
                 ensemble_num: int,
                 obs_dim: int,
                 act_dim: int,
                 max_size: int = int(1e6),
                 memory_size: int = 250,
                 memory_skip: int = 4):
        self.ensemble_num = ensemble_num
        self.max_size = max_size
        self.obs_dim = obs_dim
        self.ptr = 0
        self.size = 0

        self.observations = np.zeros((max_size, obs_dim), dtype=np.float32)
        self.actions = np.zeros((max_size, act_dim), dtype=np.float32)
        self.next_observations = np.zeros((max_size, obs_dim), dtype=np.float32)
        self.rewards = np.zeros(max_size, dtype=np.float32)
        self.discounts = np.zeros(max_size, dtype=np.int32)
        self.trajs = np.zeros((max_size), dtype=np.int32)

        # memory idx
        self.memory_idxes = [[] for i in range(ensemble_num)]
        for i in range(memory_size):
            start_idx = i * ensemble_num * memory_skip
            for j in range(ensemble_num):
                self.memory_idxes[j].append(start_idx + j)
        self.memory_idxes = np.array(self.memory_idxes).reshape(-1)

    def add(self,
            observation: np.ndarray,
            action: np.ndarray,
            next_observation: np.ndarray,
            reward: float,
            done: float,
            traj: int = 0):
        self.observations[self.ptr] = observation
        self.actions[self.ptr] = action
        self.next_observations[self.ptr] = next_observation
        self.rewards[self.ptr] = reward
        self.discounts[self.ptr] = 1 - done
        self.trajs[self.ptr] = traj
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size: int) -> Batch:
        batch_shape = (self.ensemble_num, batch_size)
        idx = np.random.randint(0, self.size, size=batch_size*self.ensemble_num)
        batch = Batch(observations=self.observations[idx].reshape(*batch_shape, -1),
                      actions=self.actions[idx].reshape(*batch_shape, -1),
                      rewards=self.rewards[idx].reshape(batch_shape),
                      discounts=self.discounts[idx].reshape(batch_shape),
                      next_observations=self.next_observations[idx].reshape(*batch_shape, -1))
        return batch

    def get_contrastive_pairs(self, batch_size: int, window: int = 10):
        idxes, positive_idxes = [], []
        single_size = self.size // self.ensemble_num
        single_batch_size = batch_size // self.ensemble_num

        for i in range(self.ensemble_num):
            rnd_idx = np.random.randint(window, single_size, size=single_batch_size+30)
            flags = np.random.choice([1, -1], size=single_batch_size+30)
            pos_idx = rnd_idx + flags * np.random.randint(1, window, size=single_batch_size+30)
            cnt = 0
            for j in range(single_batch_size+30):
                idx1, idx2 = rnd_idx[j], pos_idx[j]
                traj1 = self.trajs[idx1 * self.ensemble_num + i] 
                traj2 = self.trajs[idx2 * self.ensemble_num + i]

                # ensure the positve pairs are from the same trajectory
                if traj1 == traj2:
                    idxes.append(idx1 * self.ensemble_num + i)
                    positive_idxes.append(idx2 * self.ensemble_num + i)
                    cnt += 1

                if cnt == single_batch_size:
                    break

            while cnt < single_batch_size:
                idx1 = np.random.randint(window, single_size)
                idx2 = idx1 + np.random.choice([-1, 1]) * np.random.randint(1, window)
                traj1 = self.trajs[idx1 * self.ensemble_num + i] 
                traj2 = self.trajs[idx2 * self.ensemble_num + i]
                if traj1 == traj2: 
                    idxes.append(idx1 * self.ensemble_num + i)
                    positive_idxes.append(idx2 * self.ensemble_num + i)
                    cnt += 1

        idxes = np.array(idxes)
        positive_idxes = np.array(positive_idxes)
        total_idxes = np.concatenate([idxes, positive_idxes], axis=0)
        observations = self.observations[total_idxes]
        return observations

    def get_memory(self):
        if self.size > self.memory_idxes.max():
            lst_idx = self.memory_idxes.max()
            return self.observations[self.memory_idxes - lst_idx - 1 + self.ptr]
        return None

    def get_buffer_iterator(self, queue_size: int = 2, batch_size: int = 256):
        queue = collections.deque()
        def enqueue(n):
            for _ in range(n):
                batch = self.sample(batch_size=batch_size)
                queue.append(jax.device_put(batch))
        enqueue(queue_size)
        while queue:
            yield queue.popleft()
            enqueue(1)

    def get_memory_iterator(self, queue_size: int = 2):
        queue = collections.deque()
        def enqueue(n):
            for _ in range(n):
                memory = self.get_memory()
                queue.append(jax.device_put(memory))
        enqueue(queue_size)
        while queue:
            yield queue.popleft()
            enqueue(1)

    def get_contrast_iterator(self,
                              queue_size: int = 2,
                              batch_size: int = 256,
                              window: int = 10):
        queue = collections.deque()
        def enqueue(n):
            for _ in range(n):
                batch = self.get_contrastive_pairs(batch_size=batch_size,
                                                   window=window)
                queue.append(jax.device_put(batch))
        enqueue(queue_size)
        while queue:
            yield queue.popleft()
            enqueue(1)
