import numpy as np
import torch
import os


class ReplayBuffer(object):
    """Buffer to store environment transitions."""
    def __init__(self, obs_shape, t_obs_shape, action_shape, capacity, device):
        self.capacity = capacity
        self.device = device

        # the proprioceptive obs is stored as float32, pixels obs as uint8
        obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8

        self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.t_obses = np.empty((capacity, *t_obs_shape), dtype=obs_dtype)
        self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.next_t_obses = np.empty((capacity, *t_obs_shape), dtype=obs_dtype)
        self.prev_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.prev_t_obses = np.empty((capacity, *t_obs_shape), dtype=obs_dtype)
        self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)
        self.timestep = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones_no_max = np.empty((capacity, 1), dtype=np.float32)

        self.idx = 0
        self.last_save = 0
        self.full = False

    def __len__(self):
        return self.capacity if self.full else self.idx

    def add(self, obs, t_obs, action, reward,
            next_obs, next_t_obs, prev_obs, prev_t_obs,
            timestep, done, done_no_max):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.t_obses[self.idx], t_obs)
        np.copyto(self.actions[self.idx], action)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.timestep[self.idx], timestep - 1.)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.next_t_obses[self.idx], next_t_obs)
        np.copyto(self.prev_obses[self.idx], next_obs)
        np.copyto(self.prev_t_obses[self.idx], next_t_obs)
        np.copyto(self.not_dones[self.idx], not done)
        np.copyto(self.not_dones_no_max[self.idx], not done_no_max)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def purge_frac(self, frac=0.5):
        to_keep = int((1. - frac) * self.__len__())
        idxs = np.random.randint(0,
                                 self.__len__(),
                                 size=to_keep)

        self.obses[:to_keep] = self.obses[idxs]
        self.t_obses[:to_keep] = self.t_obses[idxs]
        self.actions[:to_keep] = self.actions[idxs]
        self.rewards[:to_keep] = self.rewards[idxs]
        self.timestep[:to_keep] = self.timestep[idxs]
        self.next_obses[:to_keep] = self.next_obses[idxs]
        self.next_t_obses[:to_keep] = self.next_t_obses[idxs]
        self.prev_obses[:to_keep] = self.next_obses[idxs]
        self.prev_t_obses[:to_keep] = self.next_t_obses[idxs]
        self.not_dones[:to_keep] = self.not_dones[idxs]
        self.not_dones_no_max[:to_keep] = self.not_dones_no_max[idxs]

        self.idx = to_keep
        self.full = False

    def sample(self, batch_size):
        idxs = np.random.randint(0,
                                 self.__len__(),
                                 size=batch_size)

        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        t_obses = torch.as_tensor(self.t_obses[idxs],
                                  device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        timesteps = torch.as_tensor(self.timestep[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs],
                                     device=self.device).float()
        next_t_obses = torch.as_tensor(self.next_t_obses[idxs],
                                       device=self.device).float()
        prev_obses = torch.as_tensor(self.prev_obses[idxs],
                                     device=self.device).float()
        prev_t_obses = torch.as_tensor(self.prev_t_obses[idxs],
                                       device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs],
                                           device=self.device)

        return (obses, t_obses, actions, rewards,
                next_obses, next_t_obses,
                prev_obses, prev_t_obses, timesteps,
                not_dones, not_dones_no_max)

    def save(self, path=None, filename=None):
        if path is None:
            path = os.getcwd()
        save_path = os.path.join(path, filename or 'replay_buffer.npz')
        save_dict = {
            'idx': self.idx,
            'full': self.full,
            'obses': self.obses,
            't_obses': self.t_obses,
            'actions': self.actions,
            'rewards': self.rewards,
            'timesteps': self.timestep,
            'next_obses': self.next_obses,
            'next_t_obses': self.next_t_obses,
            'prev_obses': self.prev_obses,
            'prev_t_obses': self.prev_t_obses,
            'not_dones': self.not_dones,
            'not_dones_no_max': self.not_dones_no_max
        }
        np.savez_compressed(file=save_path, **save_dict)

    def load(self, path=None, filename=None):
        if path is None:
            path = os.getcwd()
        load_path = os.path.join(path, filename or 'replay_buffer.npz')
        load_dict = np.load(load_path)

        self.idx = int(load_dict['idx'])
        self.full = load_dict['full']

        self.obses = load_dict['obses']
        self.t_obses = load_dict['t_obses']
        self.actions = load_dict['actions']
        self.rewards = load_dict['rewards']
        self.timestep = load_dict['timesteps']
        self.next_obses = load_dict['next_obses']
        self.next_t_obses = load_dict['next_t_obses']
        self.prev_obses = load_dict['prev_obses']
        self.prev_t_obses = load_dict['prev_t_obses']
        self.not_dones = load_dict['not_dones']
        self.not_dones_no_max = load_dict['not_dones_no_max']
