import random
import numpy as np
import threading

import os

class ReplayMemory:
    def __init__(self, capacity, seed):
        # random.seed(seed)
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)

    def save_buffer(self, env_name, suffix="", save_path=None):
        if not os.path.exists('checkpoints/'):
            os.makedirs('checkpoints/')

        if save_path is None:
            save_path = "checkpoints/sac_buffer_{}_{}".format(env_name, suffix)
        print('Saving buffer to {}'.format(save_path))

        with open(save_path, 'wb') as f:
            pickle.dump(self.buffer, f)

    def load_buffer(self, save_path):
        print('Loading buffer from {}'.format(save_path))

        with open(save_path, "rb") as f:
            self.buffer = pickle.load(f)
            self.position = len(self.buffer) % self.capacity
class replay_buffer:
    def __init__(self, env_params, buffer_size, sample_func):
        # env_params = {'obs':obs_dim,
        # 'goal':obs_dim,
        # 'actions':a_dim,
        # 'max_timesteps':max_steps}
        self.env_params = env_params
        self.T = 10#env_params['max_timesteps']#最长episode长度

        self.size = buffer_size // self.T #最大存储epi条数
        # memory management
        self.current_size = 0
        self.n_transitions_stored = 0
        self.sample_func = sample_func
        # create the buffer to store info
        self.buffers = {
                        'obs': np.empty([self.size, self.T, self.env_params['obs']]),
                        'actions': np.empty([self.size, self.T, 1]),#self.env_params['action'] continuse
                        'r_ex':np.empty([self.size,self.T,1]),
                        'obs_next':np.empty([self.size, self.T, self.env_params['obs']]),
                        'done':np.empty([self.size,self.T,1]),
                        'ag': np.empty([self.size, self.T, self.env_params['goal']]),
                        'g': np.empty([self.size, self.T, self.env_params['goal']]),
                        'ag_next':np.empty([self.size, self.T, self.env_params['goal']])
                        }
        # self.buffers = {
        #                 'obs': [ None for i in range(self.size)],
        #                 'actions':[None for i in range(self.size)],
        #                 'r_ex':[None for i in range(self.size)],
        #                 'obs_next':[None for i in range(self.size)],
        #                 'done':[None for i in range(self.size)],
        #                 'ag': [None for i in range(self.size)],
        #                 'g':[None for i in range(self.size)],
        #                 'ag_next':[None for i in range(self.size)]
        #                 }
        # thread lock
        self.lock = threading.Lock()
    
    # store the episode
    def push(self, episode_batch):
        mb_obs, mb_ag, mb_g, mb_actions, mb_r_ex, mb_done, mb_nex_obs,mb_nex_ag = episode_batch
        #slice into batch*T

        
        nums= mb_obs.shape[0]//self.T

        if mb_obs.shape[0]%self.T!=0:
            drop_num=mb_obs.shape[0]%self.T
        else:
            drop_num=0
        
        if np.random.uniform()>0.5:#forward drop
            s=slice(drop_num,mb_obs.shape[0])
            # mb_obs, mb_ag, mb_g, mb_actions, mb_r_ex, mb_done, mb_nex_obs=mb_obs[drop_num:,:], mb_ag[drop_num:,:], mb_g[drop_num:,:], \
            #     mb_actions[drop_num:,:], mb_r_ex[drop_num:,:], mb_done[drop_num:,:], mb_nex_obs[drop_num:,:]
        else:#backward drop
            s=slice(0,mb_obs.shape[0]-drop_num)
            # mb_obs, mb_ag, mb_g, mb_actions, mb_r_ex, mb_done, mb_nex_obs=mb_obs[:-drop_num,:], mb_ag[:-drop_num,:], mb_g[:-drop_num,:], \
            #     mb_actions[:-drop_num,:], mb_r_ex[:-drop_num,:], mb_done[:-drop_num,:], mb_nex_obs[:-drop_num,:]


        batch_size = nums#mb_obs.shape[0]
        with self.lock:
            idxs = self._get_storage_idx(inc=batch_size)
            # store the informations
            self.buffers['obs'][idxs] = mb_obs[s].reshape((nums,self.T,-1))
            self.buffers['ag'][idxs] = mb_ag[s].reshape((nums,self.T,-1))
            
            
            self.buffers['actions'][idxs] = mb_actions[s].reshape((nums,self.T,-1))#.reshape((mb_obs.shape[0],-1))

            self.buffers['obs_next'][idxs] = mb_nex_obs[s].reshape((nums,self.T,-1))
            self.buffers['ag_next'][idxs] = mb_nex_ag[s].reshape((nums,self.T,-1))

            self.buffers['done'] [idxs]=mb_done[s].reshape((nums,self.T,-1))#.reshape((mb_obs.shape[0],-1))
            self.buffers['r_ex'] [idxs]=mb_r_ex[s].reshape((nums,self.T,-1))#.reshape((mb_obs.shape[0],-1))

            self.buffers['g'][idxs] = mb_g[s].reshape((nums,self.T,-1))

            self.n_transitions_stored = self.T * batch_size
    
    # sample the data from the replay buffer
    def sample(self, batch_size):
        temp_buffers = {}
        with self.lock:
            for key in self.buffers.keys():
                temp_buffers[key] = self.buffers[key][:self.current_size]
        # temp_buffers['obs_next'] = temp_buffers['obs'][:, 1:, :]
        # temp_buffers['ag_next'] = temp_buffers['ag'][:, 1:, :]
        # sample transitions
        transitions = self.sample_func(temp_buffers, batch_size)
        return transitions

    def _get_storage_idx(self, inc=None):
        inc = inc or 1
        if self.current_size+inc <= self.size:
            idx = np.arange(self.current_size, self.current_size+inc)
        elif self.current_size < self.size:
            overflow = inc - (self.size - self.current_size)
            idx_a = np.arange(self.current_size, self.size)
            idx_b = np.random.randint(0, self.current_size, overflow)
            idx = np.concatenate([idx_a, idx_b])
        else:
            idx = np.random.randint(0, self.size, inc)
        self.current_size = min(self.size, self.current_size+inc)
        if inc == 1:
            idx = idx[0]
        return idx


from multiprocessing import Process, Manager

class ReplayMemory_MP:
    def __init__(self, capacity, seed):
        random.seed(seed)
        self.capacity = capacity
        self.buffer = Manager().list()
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)

    def save_buffer(self, env_name, suffix="", save_path=None):
        if not os.path.exists('checkpoints/'):
            os.makedirs('checkpoints/')

        if save_path is None:
            save_path = "checkpoints/sac_buffer_{}_{}".format(env_name, suffix)
        print('Saving buffer to {}'.format(save_path))

        with open(save_path, 'wb') as f:
            pickle.dump(self.buffer, f)

    def load_buffer(self, save_path):
        print('Loading buffer from {}'.format(save_path))

        with open(save_path, "rb") as f:
            self.buffer = pickle.load(f)
            self.position = len(self.buffer) % self.capacity
