import numpy as np
import pdb


class ReplayBuffer:
    def __init__(
            self,
            buffer_size,
            obs_shape,
            obs_dtype,
            action_dim,
            action_dtype,
    ):
        self.max_size = buffer_size
        self.obs_shape = obs_shape
        self.obs_dtype = obs_dtype
        self.action_dim = action_dim
        self.action_dtype = action_dtype

        self.ptr = 0
        self.size = 0

        self.observations = np.zeros((self.max_size,) + self.obs_shape, dtype=obs_dtype)
        self.next_observations = np.zeros((self.max_size,) + self.obs_shape, dtype=obs_dtype)
        self.actions = np.zeros((self.max_size, self.action_dim), dtype=action_dtype)
        self.rewards = np.zeros((self.max_size, 1), dtype=np.float32)
        self.terminals = np.zeros((self.max_size, 1), dtype=np.float32)
        self.episode_idx = np.zeros((self.max_size, 1), dtype=np.float32)

    def add(self, obs, next_obs, action, reward, terminal):
        # Copy to avoid modification by reference
        self.observations[self.ptr] = np.array(obs).copy()
        self.next_observations[self.ptr] = np.array(next_obs).copy()
        self.actions[self.ptr] = np.array(action).copy()
        self.rewards[self.ptr] = np.array(reward).copy()
        self.terminals[self.ptr] = np.array(terminal).copy()
        self.episode_idx[self.ptr] = np.zeros_like(np.array(terminal))

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def load_dataset(self, dataset):
        observations = np.array(dataset["observations"], dtype=self.obs_dtype)
        next_observations = np.array(dataset["next_observations"], dtype=self.obs_dtype)
        actions = np.array(dataset["actions"], dtype=self.action_dtype)
        rewards = np.array(dataset["rewards"]).reshape(-1, 1)
        terminals = np.array(dataset["terminals"], dtype=np.float32).reshape(-1, 1)
        episode_idx = np.array(dataset["episode_idx"], dtype=np.float32).reshape(-1, 1)

        self.observations = observations
        self.next_observations = next_observations
        self.actions = actions
        self.rewards = rewards
        self.terminals = terminals
        self.episode_idx = episode_idx

        if len(self.actions.shape) == 1:
            self.actions = np.expand_dims(self.actions, axis=-1)

        self.ptr = len(observations)
        self.size = len(observations)

    def add_batch(self, obs, next_obs, actions, rewards, terminals):
        batch_size = len(obs)
        if self.ptr + batch_size > self.max_size:
            begin = self.ptr
            end = self.max_size
            first_add_size = end - begin
            self.observations[begin:end] = np.array(obs[:first_add_size]).copy()
            self.next_observations[begin:end] = np.array(next_obs[:first_add_size]).copy()
            self.actions[begin:end] = np.array(actions[:first_add_size]).copy()
            self.rewards[begin:end] = np.array(rewards[:first_add_size]).copy()
            self.terminals[begin:end] = np.array(terminals[:first_add_size]).copy()
            self.episode_idx[begin:end] = np.zeros_like(np.array(terminals[:first_add_size]))

            begin = 0
            end = batch_size - first_add_size
            self.observations[begin:end] = np.array(obs[first_add_size:]).copy()
            self.next_observations[begin:end] = np.array(next_obs[first_add_size:]).copy()
            self.actions[begin:end] = np.array(actions[first_add_size:]).copy()
            self.rewards[begin:end] = np.array(rewards[first_add_size:]).copy()
            self.terminals[begin:end] = np.array(terminals[first_add_size:]).copy()
            self.episode_idx[begin:end] = np.zeros_like(np.array(terminals[first_add_size:]))

            self.ptr = end
            self.size = min(self.size + batch_size, self.max_size)

        else:
            begin = self.ptr
            end = self.ptr + batch_size
            self.observations[begin:end] = np.array(obs).copy()
            self.next_observations[begin:end] = np.array(next_obs).copy()
            self.actions[begin:end] = np.array(actions).copy()
            self.rewards[begin:end] = np.array(rewards).copy()
            self.terminals[begin:end] = np.array(terminals).copy()
            self.episode_idx[begin:end] = np.zeros_like(np.array(terminals))

            self.ptr = end
            self.size = min(self.size + batch_size, self.max_size)

    def sample(self, batch_size):
        batch_indices = np.random.randint(0, self.size, size=batch_size)
        return {
            "observations": self.observations[batch_indices].copy(),
            "actions": self.actions[batch_indices].copy(),
            "next_observations": self.next_observations[batch_indices].copy(),
            "terminals": self.terminals[batch_indices].copy(),
            "rewards": self.rewards[batch_indices].copy(),
            "episode_idx": self.episode_idx[:self.size].copy()
        }

    def sample_all(self):
        return {
            "observations": self.observations[:self.size].copy(),
            "actions": self.actions[:self.size].copy(),
            "next_observations": self.next_observations[:self.size].copy(),
            "terminals": self.terminals[:self.size].copy(),
            "rewards": self.rewards[:self.size].copy(),
            "episode_idx": self.episode_idx[:self.size].copy()
        }

    @property
    def get_size(self):
        return self.size


class ReplayBuffer2:
    def __init__(
            self,
            buffer_size,
            obs_shape,
            obs_dtype,
            action_dim,
            action_dtype,
    ):
        self.max_size = buffer_size
        self.obs_shape = obs_shape
        self.obs_dtype = obs_dtype
        self.action_dim = action_dim
        self.action_dtype = action_dtype

        self.ptr = 0
        self.size = 0

        self.observations = np.zeros((self.max_size,) + self.obs_shape, dtype=obs_dtype)
        self.next_observations = np.zeros((self.max_size,) + self.obs_shape, dtype=obs_dtype)
        self.actions = np.zeros((self.max_size, self.action_dim), dtype=action_dtype)
        self.rewards = np.zeros((self.max_size, 1), dtype=np.float32)
        self.terminals = np.zeros((self.max_size, 1), dtype=np.float32)
        self.episode_idx = np.zeros((self.max_size, 1), dtype=np.float32)

        self.dataset = None

    def add(self, obs, next_obs, action, reward, terminal):
        # Copy to avoid modification by reference
        self.observations[self.ptr] = np.array(obs).copy()
        self.next_observations[self.ptr] = np.array(next_obs).copy()
        self.actions[self.ptr] = np.array(action).copy()
        self.rewards[self.ptr] = np.array(reward).copy()
        self.terminals[self.ptr] = np.array(terminal).copy()
        self.episode_idx[self.ptr] = np.zeros_like(np.array(terminal))

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def load_dataset(self, dataset):
        self.dataset = dataset

    def add_batch(self, obs, next_obs, actions, rewards, terminals):
        batch_size = len(obs)
        if self.ptr + batch_size > self.max_size:
            begin = self.ptr
            end = self.max_size
            first_add_size = end - begin
            self.observations[begin:end] = np.array(obs[:first_add_size]).copy()
            self.next_observations[begin:end] = np.array(next_obs[:first_add_size]).copy()
            self.actions[begin:end] = np.array(actions[:first_add_size]).copy()
            self.rewards[begin:end] = np.array(rewards[:first_add_size]).copy()
            self.terminals[begin:end] = np.array(terminals[:first_add_size]).copy()
            self.episode_idx[begin:end] = np.zeros_like(np.array(terminals[:first_add_size]))

            begin = 0
            end = batch_size - first_add_size
            self.observations[begin:end] = np.array(obs[first_add_size:]).copy()
            self.next_observations[begin:end] = np.array(next_obs[first_add_size:]).copy()
            self.actions[begin:end] = np.array(actions[first_add_size:]).copy()
            self.rewards[begin:end] = np.array(rewards[first_add_size:]).copy()
            self.terminals[begin:end] = np.array(terminals[first_add_size:]).copy()
            self.episode_idx[begin:end] = np.zeros_like(np.array(terminals[first_add_size:]))

            self.ptr = end
            self.size = min(self.size + batch_size, self.max_size)

        else:
            begin = self.ptr
            end = self.ptr + batch_size
            self.observations[begin:end] = np.array(obs).copy()
            self.next_observations[begin:end] = np.array(next_obs).copy()
            self.actions[begin:end] = np.array(actions).copy()
            self.rewards[begin:end] = np.array(rewards).copy()
            self.terminals[begin:end] = np.array(terminals).copy()
            self.episode_idx[begin:end] = np.zeros_like(np.array(terminals))

            self.ptr = end
            self.size = min(self.size + batch_size, self.max_size)

    def sample(self, batch_size):
        batch_indices = np.random.randint(0, self.size, size=batch_size)
        pdb.set_trace()
        return {
            "observations": self.observations[batch_indices].copy(),
            "actions": self.actions[batch_indices].copy(),
            "next_observations": self.next_observations[batch_indices].copy(),
            "terminals": self.terminals[batch_indices].copy(),
            "rewards": self.rewards[batch_indices].copy(),
            "episode_idx": self.episode_idx[:self.size].copy()
        }

    def sample_all(self):
        return self.dataset

    @property
    def get_size(self):
        return self.size
