import threading
import numpy as np
import pickle

"""
the replay buffer here is basically from the openai baselines code

"""


class replay_buffer:
    def __init__(self, env_params, buffer_size, sample_func, good=False):
        self.env_params = env_params
        self.T = env_params['max_timesteps']
        self.size = buffer_size // self.T
        # memory management
        self.current_size = 0
        self.n_transitions_stored = 0
        self.sample_func = sample_func
        self.good = good
        # create the buffer to store info
        self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]),
                        'ag': np.empty([self.size, self.T + 1, self.env_params['goal']]),
                        'g': np.empty([self.size, self.T, self.env_params['goal']]),
                        'actions': np.empty([self.size, self.T]),
                        'obs_hash': np.empty([self.size, self.T + 1]),
                        }
        # thread lock
        self.lock = threading.Lock()
        # if self.good:
        #     print('Good buffer')
        #     self.good_traj_ids = None
        #     self.preprocess_buffer()
        #     print(self.good_traj_ids.shape)

    # store the episode
    def store_episode(self, episode_batch):
        mb_obs, mb_ag, mb_g, mb_actions = episode_batch
        batch_size = mb_obs.shape[0]
        with self.lock:
            idxs = self._get_storage_idx(inc=batch_size)
            # store the informations
            self.buffers['obs'][idxs] = mb_obs
            self.buffers['ag'][idxs] = mb_ag
            self.buffers['g'][idxs] = mb_g
            self.buffers['actions'][idxs] = mb_actions
            self.n_transitions_stored += self.T * batch_size

    # sample the data from the replay buffer
    def sample(self, batch_size):
        temp_buffers = {}
        with self.lock:
            for key in self.buffers.keys():
                temp_buffers[key] = self.buffers[key][:self.current_size]
        temp_buffers['obs_next'] = temp_buffers['obs'][:, 1:, :]
        temp_buffers['ag_next'] = temp_buffers['ag'][:, 1:, :]
        temp_buffers['next_obs_hash'] = temp_buffers['obs_hash'][:, 1:]
        # sample transitions
        if self.good:
            transitions = self.sample_func(temp_buffers, batch_size, good_ids=self.good_traj_ids)
        else:
            transitions = self.sample_func(temp_buffers, batch_size)
        return transitions

    def _get_storage_idx(self, inc=None):
        inc = inc or 1
        if self.current_size + inc <= self.size:
            idx = np.arange(self.current_size, self.current_size + inc)
        elif self.current_size < self.size:
            overflow = inc - (self.size - self.current_size)
            idx_a = np.arange(self.current_size, self.size)
            idx_b = np.random.randint(0, self.current_size, overflow)
            idx = np.concatenate([idx_a, idx_b])
        else:
            idx = np.random.randint(0, self.size, inc)
        self.current_size = min(self.size, self.current_size + inc)
        if inc == 1:
            idx = idx[0]
        return idx
    
    def preprocess_buffer(self):
        # for traj_id in range(self.size):
        #     # self.buffers['obs_hash'][traj_id] = np.arange(traj_id * (self.T + 1), (traj_id + 1) * (self.T + 1))
        #     print(self.buffers['ag'][traj_id, -1], self.buffers['ag'][traj_id, 0], np.sum(np.abs(self.buffers['ag'][traj_id, -1] - self.buffers['ag'][traj_id, 0])))
        # for traj_id in range(self.size):
        #     print(self.buffers['ag'][traj_id, -1], self.buffers['ag'][traj_id, 0], np.sum(np.abs(self.buffers['ag'][traj_id, -1] - self.buffers['ag'][traj_id, 0])))
        self.good_traj_ids = np.where(np.sum(np.abs(self.buffers['ag'][:, -1] - self.buffers['ag'][:, 0]), axis=1) > 0.1)[0]

    def save(self, filename):
        with open(filename, 'wb') as f:
            pickle.dump(self.buffers, f)

    def load(self, filename):
        with open(filename, 'rb') as f:
            buffers = pickle.load(f)
            self.buffers['obs'] = np.array(buffers['obs'])
            self.buffers['ag'] = np.array(buffers['ag'])
            self.buffers['g'] = np.array(buffers['g'])
            self.buffers['actions'] = np.array(buffers['actions'])
            self.current_size = self.buffers['obs'].shape[0]
            self.n_transitions_stored = self.current_size * self.T
            self.buffers['obs_hash'] = np.arange(self.current_size * (self.T + 1)).reshape(self.current_size, self.T + 1)
            print('Dataset loaded')
            if self.good:
                print('Good buffer')
                self.preprocess_buffer()
                print(self.good_traj_ids.shape)