import numpy as np
import tensorflow as tf

class ReplayBuffer(object):
    """Basic replay buffer."""
    def __init__(self, buffer_size, epi_len=1000, initial_data={}):
        self.buffer_size = buffer_size
        self.epi_len = epi_len
        if initial_data == {}:
            self.N = -1
        else:
            self._initial_setup(initial_data)

        init_indice = np.arange(0, self.buffer_size, self.epi_len)
        init_indice_next = init_indice + 1

        self.total_init = tf.concat([init_indice, init_indice_next], axis=0)

        self.non_init_indice = np.delete(np.array(range(self.buffer_size)), self.total_init)

    def _initial_setup(self, initial_data={}):
        self.obs = initial_data['obs'].astype('float32')
        self.nobs = initial_data['nobs'].astype('float32')
        self.act = initial_data['act'].astype('float32')
        self.rew = initial_data['rew'].astype('float32')
        self.don = initial_data['don']
        # self.N = self.buffer_size
        self.N = initial_data['n']

    def add(self, other_data):
        """Add collected data from Sampler."""
        if self.N == -1:
            self._initial_setup(other_data)
        else:
            self.N += other_data['n']
            offset_index = int(np.amax(np.array([self.N - self.buffer_size, 0])))
            self.N -= offset_index

            self.obs = np.concatenate((self.obs[offset_index:],
                                       other_data['obs'].astype('float32')), axis=0)
            self.nobs = np.concatenate((self.nobs[offset_index:],
                                        other_data['nobs'].astype('float32')), axis=0)
            self.act = np.concatenate((self.act[offset_index:],
                                       other_data['act'].astype('float32')), axis=0)
            self.rew = np.concatenate((self.rew[offset_index:],
                                       other_data['rew'].astype('float32')), axis=0)
            self.don = np.concatenate((self.don[offset_index:],
                                       other_data['don']), axis=0)

    def gather_indices(self, indices):
        out_dict = {}
        out_dict['obs'], out_dict['act'], out_dict['nobs'] = (self.obs[indices],
                                                              self.act[indices],
                                                              self.nobs[indices])
        out_dict['rew'], out_dict['don'] = self.rew[indices], self.don[indices]

        out_dict['n'] = self.N

        return out_dict

    def get_random_batch(self, batch_size):
        """Get random batch of data."""
        indices = np.random.randint(self.N, size=batch_size)
        return self.gather_indices(indices)

    def get_balance_batch_with_step(self, batch_size, init_num):
        """Get random batch of data."""

        init_indice = np.arange(0, self.buffer_size, self.epi_len)
        non_init_indice = np.delete(np.array(range(self.buffer_size)), init_indice)
        batch_merge = []
        batch_merge_step = []

        if init_indice.shape[0] <= init_num:
            normal_indices = np.random.choice(non_init_indice, size=(batch_size - init_indice.shape[0]),
                                              replace=False)
            indices = tf.concat([normal_indices, init_indice], axis=0)
            batch_merge.append(self.gather_indices(indices))
            timesteps = [j % self.epi_len for j in indices]
            timesteps_percent =[j / self.epi_len for j in timesteps]
            timesteps_percent = [(j + 1)/2 for j in timesteps_percent]
            batch_merge_step.append(timesteps_percent)

        else:
            init_indices = tf.random.shuffle(init_indice)[:np.int(init_num)]
            normal_indices = np.random.choice(non_init_indice, size=(batch_size - init_indices.shape[0]),
                                              replace=False)
            indices = tf.concat([normal_indices, init_indices], axis=0)
            batch_merge.append(self.gather_indices(indices))
            timesteps = [j % self.epi_len for j in indices]
            timesteps_percent =[j / self.epi_len for j in timesteps]
            timesteps_percent = [(j + 1)/2 for j in timesteps_percent]
            batch_merge_step.append(timesteps_percent)

        return batch_merge, batch_merge_step

class VisualReplayBuffer(ReplayBuffer):
    """Replay buffer with added support for visual observations."""
    def __init__(self, buffer_size, epi_len, initial_data={}):
        super(VisualReplayBuffer, self).__init__(buffer_size, epi_len, initial_data)

    def _initial_setup(self, initial_data={}):
        super(VisualReplayBuffer, self)._initial_setup(initial_data)
        self.ims = initial_data['ims'].astype(np.uint8)

    def add(self, other_data):
        if self.N == -1:
            self._initial_setup(other_data)
        else:
            self.N += other_data['n']
            offset_index = int(np.amax(np.array([self.N - self.buffer_size, 0])))
            self.N -= offset_index

            self.obs = np.concatenate((self.obs[offset_index:],
                                       other_data['obs'].astype('float32')), axis=0)
            self.nobs = np.concatenate((self.nobs[offset_index:],
                                        other_data['nobs'].astype('float32')), axis=0)
            self.act = np.concatenate((self.act[offset_index:],
                                       other_data['act'].astype('float32')), axis=0)
            self.rew = np.concatenate((self.rew[offset_index:],
                                       other_data['rew'].astype('float32')), axis=0)
            self.don = np.concatenate((self.don[offset_index:],
                                       other_data['don']), axis=0)
            self.ims = np.concatenate((self.ims[offset_index:],
                                       other_data['ims'].astype(np.uint8)), axis=0)

    def gather_indices(self, indices):
        """Get random batch of data."""
        out_dict = super(VisualReplayBuffer, self).gather_indices(indices)
        out_dict['ims'] = ((self.ims[indices].astype('float32') + 0.5) / 256)

        return out_dict


class LearnerAgentReplayBuffer(VisualReplayBuffer):
    def __init__(self, gail, buffer_size, epi_len, initial_data={}):
        super(LearnerAgentReplayBuffer, self).__init__(buffer_size, epi_len, initial_data)
        self._enc = gail._encoder
        self._label_net = gail._label_net
        self._label_net_frame = gail._label_net_frame


    def gather_indices_rl(self, indices):
        out_dict = {}
        out_dict['obs'], out_dict['act'], out_dict['nobs'] = (self.obs_RL[indices],
                                                              self.act_RL[indices],
                                                              self.nobs_RL[indices])
        out_dict['rew'], out_dict['don'] = self.rew_RL[indices], self.don_RL[indices]

        out_dict['n'] = self.N_RL

        return out_dict

    def get_random_batch_rl(self, batch_size):
        """Get random batch of data."""
        indices = np.random.randint((self.N_RL), size=batch_size)
        return self.gather_indices_rl(indices)

    def update_reward(self):
        self.ims_RL = self.ims
        self.obs_RL = self.obs
        self.nobs_RL = self.nobs
        self.act_RL = self.act
        self.don_RL = self.don
        self.N_RL = self.N
        rew_buff=[]
        split = 250

        for i in range(0, len(self.ims), split):
            img_chunk = ((self.ims[i:i + split]).astype('float32') + 0.5) / 256

            tl_feature = self._enc(img_chunk)

            tl_recon_label = self._label_net(tl_feature)
            tl_recon_label_frame = self._label_net_frame(tl_feature)

            source_label = -tf.math.log(1-(tl_recon_label_frame * tl_recon_label)+ 1e-12)

            rew_buff.append(tf.squeeze(source_label))

        self.rew_RL = np.concatenate(rew_buff,axis=0)

    def get_random_batch(self, batch_size, re_eval_rw=True):

        if re_eval_rw:
            out = self.get_random_batch_rl(batch_size)
        else:
            out = self.get_random_batch(batch_size)

        return out

class DemonstrationsReplayBuffer(object):
    """Replay buffer efficiently storing priorly collected visual observations."""
    def __init__(self, initial_data, epi_len):
        self.ims = initial_data['ims'][:, 0, :, :, :].astype(np.uint8)
        self.N = self.ims.shape[0]
        self.ids = initial_data['ids']
        self.act = initial_data['act']
        self.done = initial_data['don']
        self.past_frames = initial_data['ims'].shape[1]
        self.idx_shifts = np.expand_dims(np.arange(self.past_frames), axis=0)

        self.pad_image = (np.zeros_like(self.ims[0]).astype('float32') + 0.5) / 256
        _, self.first_indices = np.unique(self.ids, return_index=True)

        self.first_ims = initial_data['ims'][self.first_indices, 1, :, :, :].astype(np.uint8)
        self.epi_len = epi_len
        retrieval_indices = np.arange(self.first_indices.shape[0])
        self.padded_retrieval_list = np.zeros([self.N]).astype('int') + 1000000
        self.padded_retrieval_list[self.first_indices] = retrieval_indices

        self.timestep = np.zeros((self.N, 1))

        for i in range(self.N):
            if i in self.first_indices:
                self.inter_count = 1
            self.timestep[i] = self.inter_count

            self.inter_count += 1
        init_indice = np.arange(0, self.N, self.epi_len)
        init_indice_next = init_indice + 1

        self.total_init = tf.concat([init_indice, init_indice_next], axis=0)

        self.non_init_indice = np.delete(np.array(range(self.N)), self.total_init)


    def gather_indices(self, indices):
        all_indices = np.expand_dims(indices, axis=-1) - self.idx_shifts
        images = (self.ims[all_indices].astype('float32') + 0.5) / 256

        start_indices_mask = np.isin(all_indices[:, :-1], self.first_indices)

        step = self.timestep[indices]
        start_indices_x, start_indices_y = np.where(start_indices_mask)
        first_indices_y = start_indices_y + 1

        trajectories_start_indices = all_indices[start_indices_x, start_indices_y]
        images[start_indices_x, first_indices_y] = (self.first_ims[
                                                        self.padded_retrieval_list[trajectories_start_indices]].astype(
            'float32') + 0.5) / 256

        pad_indices_mask = start_indices_mask[:, :-1]
        for i in range(self.past_frames - 3):
            pad_indices_mask[:, i + 1] = np.logical_or(pad_indices_mask[:, i + 1],
                                                       pad_indices_mask[:, i])
        pad_indices_x, pad_indices_y = np.where(pad_indices_mask)
        pad_indices_y += 2
        images[pad_indices_x, pad_indices_y] = self.pad_image

        return {'ims': images, 'act': self.act[all_indices], 'init': start_indices_mask, 'step': step}

    def get_random_batch(self, batch_size):
        """Get random batch of data."""
        indices = np.random.randint(self.N, size=batch_size)
        return self.gather_indices(indices)

    def get_balance_batch_with_step(self, batch_size, init_num):
        """Get random batch of data."""

        non_init_indice = np.delete(np.array(range(self.N)), self.first_indices)
        batch_merge = []
        batch_merge_step = []

        if self.first_indices.shape[0] <= init_num:
            normal_indices = np.random.choice(non_init_indice, size=(batch_size - self.first_indices.shape[0]),
                                              replace=False)
            indices = tf.concat([normal_indices, self.first_indices], axis=0)
            batch_merge.append(self.gather_indices(indices))
            timesteps = [j % self.epi_len for j in indices]
            timesteps_percent =[j / self.epi_len for j in timesteps]
            timesteps_percent = [(j + 1)/2 for j in timesteps_percent]
            batch_merge_step.append(timesteps_percent)

        else:
            init_indices = tf.random.shuffle(self.first_indices)[:np.int(init_num)]
            normal_indices = np.random.choice(non_init_indice, size=(batch_size - init_indices.shape[0]),
                                              replace=False)
            indices = tf.concat([normal_indices, init_indices], axis=0)
            batch_merge.append(self.gather_indices(indices))
            timesteps = [j % self.epi_len for j in indices]
            timesteps_percent =[j / self.epi_len for j in timesteps]
            timesteps_percent = [(j + 1)/2 for j in timesteps_percent]
            batch_merge_step.append(timesteps_percent)

        return batch_merge, batch_merge_step
