import numpy as np
import torch
import os
import pickle
import random

class ReplayMemory:
    def __init__(self, capacity, seed):
        random.seed(seed)
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def push_cf(self, augmented_states, augmented_actions, augmented_rewards, augmented_next_states, augmented_dones):
        for i in range(augmented_states.shape[0]):
            if len(self.buffer) < self.capacity:
                self.buffer.append(None)
            self.buffer[self.position] = (augmented_states[i], augmented_actions[i], augmented_rewards[i], augmented_next_states[i], augmented_dones[i])
            self.position = (self.position + 1) % self.capacity
        # self.buffer[self.position] = (state, action, reward, next_state, done)
        # self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)        
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)

    def save_buffer(self, path):
        print('Saving buffer to {}'.format(path))

        with open(path, 'wb') as f:
            pickle.dump(self.buffer, f)

    def load_buffer(self, save_path):
        print('Loading buffer from {}'.format(save_path))

        with open(save_path, "rb") as f:
            self.buffer = pickle.load(f)
            self.position = len(self.buffer) % self.capacity