import numpy as np
import torch
from env import postprocess_observation, preprocess_observation_


class ExperienceReplay():
  def __init__(self, size, symbolic_env, observation_size, action_size, bit_depth, device):
    self.device = device
    self.symbolic_env = symbolic_env
    self.size = size
    self.observations = np.empty((size, observation_size) if symbolic_env else (size, 3, 64, 64), dtype=np.float32 if symbolic_env else np.uint8)
    self.actions = np.empty((size, action_size), dtype=np.float32)
    self.rewards = np.empty((size, ), dtype=np.float32) 
    self.nonterminals = np.empty((size, 1), dtype=np.float32)
    self.idx = 0
    self.full = False  # Tracks if memory has been filled/all slots are valid
    self.steps, self.episodes = 0, 0  # Tracks how much experience has been used in total
    self.bit_depth = bit_depth

  def append(self, observation, action, reward, done):
    if self.symbolic_env:
      self.observations[self.idx] = observation.numpy()
    else:
      self.observations[self.idx] = postprocess_observation(observation.numpy(), self.bit_depth)  # Decentre and discretise visual observations (to save memory)
    self.actions[self.idx] = action.numpy()
    self.rewards[self.idx] = reward
    self.nonterminals[self.idx] = not done
    self.idx = (self.idx + 1) % self.size
    self.full = self.full or self.idx == 0
    self.steps, self.episodes = self.steps + 1, self.episodes + (1 if done else 0)

  # Returns an index for a valid single sequence chunk uniformly sampled from the memory
  def _sample_idx(self, L):
    valid_idx = False
    while not valid_idx:
      idx = np.random.randint(0, self.size if self.full else self.idx - L)
      idxs = np.arange(idx, idx + L) % self.size
      valid_idx = not self.idx in idxs[1:]  # Make sure data does not cross the memory index
    return idxs

  def _retrieve_batch(self, idxs, n, L):
    vec_idxs = idxs.transpose().reshape(-1)  # Unroll indices
    observations = torch.as_tensor(self.observations[vec_idxs].astype(np.float32))
    if not self.symbolic_env:
      preprocess_observation_(observations, self.bit_depth)  # Undo discretisation for visual observations
    return observations.reshape(L, n, *observations.shape[1:]), self.actions[vec_idxs].reshape(L, n, -1), self.rewards[vec_idxs].reshape(L, n), self.nonterminals[vec_idxs].reshape(L, n, 1)

  # Returns a batch of sequence chunks uniformly sampled from the memory
  def sample(self, n, L):
    batch = self._retrieve_batch(np.asarray([self._sample_idx(L) for _ in range(n)]), n, L)
    return [torch.as_tensor(item).to(device=self.device) for item in batch]

class SeqExperienceReplay():
  def __init__(self, size, symbolic_env, observation_size, action_size, bit_depth, device, episode_length, burnin_size=0, on_policy_episodes=1, discount=0.99):
    self.device = device
    self.symbolic_env = symbolic_env
    self.size = size  # number of episodes
    self.episode_length = episode_length
    self.burnin_size = burnin_size
    self.observations = np.empty((size, episode_length + burnin_size, observation_size) if symbolic_env else (size, episode_length+burnin_size, 3, 64, 64), dtype=np.float32 if symbolic_env else np.uint8)
    self.actions = np.empty((size, episode_length + burnin_size, action_size), dtype=np.float32)
    self.rewards = np.empty((size, episode_length + burnin_size), dtype=np.float32)
    self.nonterminals = np.empty((size, episode_length + burnin_size, 1), dtype=np.float32)
    self.returns = np.empty((size, episode_length + burnin_size), dtype=np.float32)
    self.idx = 0
    self.steps, self.episodes = 0, 0  # Tracks how much experience has been used in total
    self.bit_depth = bit_depth
    self.on_policy_episodes = on_policy_episodes
    self.discount = discount

  def append(self, observation, action, reward, done):
    if self.idx == 0:  # beginning of episode
      if self.symbolic_env:
        self.observations[self.episodes, :self.burnin_size+1] = observation.numpy()
      else:
        self.observations[self.episodes, :self.burnin_size+1] = postprocess_observation(observation.numpy(), self.bit_depth)  # Decentre and discretise visual observations (to save memory)
      self.actions[self.episodes, :self.burnin_size] = np.zeros_like(action.numpy())
      self.actions[self.episodes, self.burnin_size] = action.numpy()
      self.rewards[self.episodes, :self.burnin_size] = 0.0
      self.rewards[self.episodes, self.burnin_size] = reward
      self.nonterminals[self.episodes, :self.burnin_size] = not done
      self.nonterminals[self.episodes, self.burnin_size] = not done
    else:
      if self.symbolic_env:
        self.observations[self.episodes, self.burnin_size + self.idx] = observation.numpy()
      else:
        self.observations[self.episodes, self.burnin_size + self.idx] = postprocess_observation(observation.numpy(), self.bit_depth)  # Decentre and discretise visual observations (to save memory)
      self.actions[self.episodes, self.burnin_size + self.idx] = action.numpy()
      self.rewards[self.episodes, self.burnin_size + self.idx] = reward
      self.nonterminals[self.episodes, self.burnin_size + self.idx] = not done
    if done:
      # calculate returns
      for t in range(self.idx + 1):
        value = 0
        for tau in range(self.idx + 1 - t):
          value += self.rewards[self.episodes, self.burnin_size + t + tau] * np.power(self.discount, tau)
        if t == 0:
          self.returns[self.episodes, :self.burnin_size+1] = value
        else:
          self.returns[self.episodes, self.burnin_size + t] = value
      self.idx = 0
      self.episodes = self.episodes + 1
    else:
      self.idx = self.idx + 1
    self.steps = self.steps + 1

  def append_from_expert(self, npepisode, npepisode_length):
    self.observations[self.episodes, :self.burnin_size] = npepisode['obs'][0]
    self.actions[self.episodes, :self.burnin_size] = np.zeros_like(npepisode['action'][0])
    self.rewards[self.episodes, :self.burnin_size] = np.zeros_like(npepisode['reward'][0])
    self.nonterminals[self.episodes, :self.burnin_size] = np.zeros_like(npepisode['nonterminals'][0])
    self.observations[self.episodes, self.burnin_size:self.burnin_size+npepisode_length] = npepisode['obs'][0:]
    self.actions[self.episodes, self.burnin_size:self.burnin_size+npepisode_length] = npepisode['action'][0:]
    self.rewards[self.episodes, self.burnin_size:self.burnin_size+npepisode_length] = npepisode['reward'][0:]
    self.nonterminals[self.episodes, self.burnin_size:self.burnin_size+npepisode_length] = npepisode['nonterminals'][0:]

    self.steps, self.episodes = (self.episodes+1)*npepisode_length, self.episodes + 1

  # Returns an index for a valid single sequence chunk uniformly sampled from the memory
  def _sample_idx(self, L, start):
    while True:
      episode_number = np.random.randint(self.on_policy_episodes*start, self.episodes)
      episode_length = self.actions[episode_number].shape[0]
      if episode_length > L:
        break
    idx = np.random.randint(0, episode_length - L)
    idxs = np.arange(idx, idx + L) + episode_number*(self.episode_length + self.burnin_size)
    return idxs

  def _retrieve_batch(self, idxs, n, L):
    vec_idxs = idxs.transpose().reshape(-1)  # Unroll indices, transpose for n x L -> L x n
    size_after = self.size * (self.episode_length + self.burnin_size)
    observations = torch.as_tensor(self.observations.reshape(size_after, *self.observations.shape[2:])[vec_idxs].astype(np.float32))
    if not self.symbolic_env:
      preprocess_observation_(observations, self.bit_depth)  # Undo discretisation for visual observations
    return observations.reshape(L+self.burnin_size, n, *observations.shape[1:]), self.actions.reshape(size_after, *self.actions.shape[2:])[vec_idxs].reshape(L+self.burnin_size, n, -1), \
           self.rewards.reshape(size_after, *self.rewards.shape[2:])[vec_idxs].reshape(L+self.burnin_size, n), self.nonterminals.reshape(size_after, *self.nonterminals.shape[2:])[vec_idxs].reshape(L+self.burnin_size, n, 1), \
           self.returns.reshape(size_after, *self.returns.shape[2:])[vec_idxs].reshape(L + self.burnin_size, n)

  # Returns a batch of sequence chunks uniformly sampled from the memory
  def sample(self, n, L):
    batch = self._retrieve_batch(np.asarray([self._sample_idx(L+self.burnin_size, 0) for _ in range(n)]), n, L)
    return [torch.as_tensor(item).to(device=self.device) for item in batch]

  def sample_recent(self, n, L, episode):
    batch = self._retrieve_batch(np.asarray([self._sample_idx(L+self.burnin_size, episode-1) for _ in range(n)]), n, L)
    return [torch.as_tensor(item).to(device=self.device) for item in batch]
