import numpy as np


class RingBuffer(object):
    def __init__(self, maxlen, reward_freq, shape, dtype='float32', amor=False):
        self.maxlen = maxlen
        self.amor = amor
        self.reward_freq = reward_freq
        self.start = 0
        self.length = 0
        self.interval_start = 0
        self.data = np.zeros((maxlen, reward_freq, ) + shape).astype(dtype)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if idx <= 0 or idx > self.length:
            raise KeyError()
        return self.data[(self.start - idx) % self.length]

    def get_batch(self, idxs):
        return self.data[(self.start - idxs) % self.length]

    def append(self, v, reward_sig):
        if self.amor and reward_sig:
            self.data[self.start, :self.interval_start+1] = 0. * self.data[self.start, :self.interval_start+1] + v
        else:
            self.data[self.start, self.interval_start] = v

        if self.interval_start < self.reward_freq - 1 and not reward_sig:
            self.interval_start += 1
        elif self.interval_start == self.reward_freq - 1 or reward_sig:
            self.start = (self.start + 1) % self.maxlen
            self.length = self.length + 1 if self.length < self.maxlen - 1 else self.length
            self.interval_start = 0
        else:
            raise RuntimeError()

        return v


class ReplayBuffer(object):
    def __init__(self, limit, action_shape, observation_shape, ircr, batch_size, reward_n):
        self.limit = limit
        self.reward_n = reward_n
        self.ircr = ircr
        self.r_max = np.zeros([1]) - 1E4
        self.r_min = np.zeros([1]) + 1E4
        self.coef_matrix = np.array([np.eye(reward_n) for _ in range(batch_size)])
        self.tdxs = np.array(range(batch_size))
        self.obs = RingBuffer(limit, reward_n, shape=observation_shape)
        self.acts = RingBuffer(limit, reward_n, shape=action_shape)
        self.rewards = RingBuffer(limit, reward_n, shape=(1,), amor=ircr)
        self.dones = RingBuffer(limit, reward_n, shape=(1,))
        self.n_obs = RingBuffer(limit, reward_n, shape=observation_shape)
        self.masks = RingBuffer(limit, reward_n, shape=(1,))
        self.r_signals = RingBuffer(limit, reward_n, shape=(1,))

    def sample(self, batch_size):

        batch_idxs = np.random.random_integers(self.nb_entries(), size=batch_size)

        obs_batch = self.obs.get_batch(batch_idxs)
        n_obs_batch = self.n_obs.get_batch(batch_idxs)
        action_batch = self.acts.get_batch(batch_idxs)
        reward_batch = self.rewards.get_batch(batch_idxs)
        done_batch = self.dones.get_batch(batch_idxs)
        mask_batch = self.masks.get_batch(batch_idxs)
        r_sig_batch = self.r_signals.get_batch(batch_idxs)

        if self.ircr:
            reward_batch = (reward_batch - self.r_min) / (self.r_max - self.r_min)

        idxs = np.random.randint(self.reward_n, size=(batch_size,))
        coef = self.coef_matrix[self.tdxs, idxs]

        n_obs_batch = n_obs_batch[self.tdxs, idxs]
        reward_sum_batch = np.sum(reward_batch, axis=1)
        reward_batch = reward_batch[self.tdxs, idxs]
        done_batch = done_batch[self.tdxs, idxs]
        pi_obs_batch = obs_batch[self.tdxs, idxs]
        r_sig_batch = r_sig_batch[self.tdxs, idxs]


        result = {
            'obs': obs_batch,
            'n_obs': n_obs_batch,
            'rew_sum': reward_sum_batch,
            'reward': reward_batch,
            'action': action_batch,
            'pi_obs': pi_obs_batch,
            'done': done_batch,
            'coef': coef.reshape(batch_size, self.reward_n, 1),
            'mask': mask_batch,
            'r_sig': r_sig_batch
        }
        return result

    def can_sample(self, size):
        return len(self.n_obs) >= size

    def append(self, obs, action, reward, n_obs, done, reward_signal):

        self.obs.append(obs, reward_signal)
        self.acts.append(action, reward_signal)
        r = self.rewards.append(reward, reward_signal)
        self.n_obs.append(n_obs, reward_signal)
        self.dones.append(float(done), reward_signal)
        self.masks.append(1., reward_signal)
        self.r_signals.append(float(reward_signal), reward_signal)

        if reward_signal:
            self.r_max = np.maximum(r, self.r_max)
            self.r_min = np.minimum(r, self.r_min)

    def nb_entries(self):
        return len(self.obs)
