import numpy as np
from utils import log_trajectory_statistics


class ReplayBuffer(object):
    def __init__(self, buffer_size, initial_data={}):
        self.buffer_size = buffer_size
        if initial_data == {}:
            self.N = -1
        else:
            self._initial_setup(initial_data)

    def _initial_setup(self, initial_data={}):
        self.obs = initial_data['obs'].astype('float32')
        self.nobs = initial_data['nobs'].astype('float32')
        self.act = initial_data['act'].astype('float32')
        self.rew = initial_data['rew'].astype('float32')
        self.don = initial_data['don']
        self.N = initial_data['n']

    def add(self, other_data):
        if self.N == -1:
            self._initial_setup(other_data)
        else:
            self.N += other_data['n']
            offset_index = int(np.amax(np.array([self.N - self.buffer_size, 0])))
            self.N -= offset_index
            self.obs = np.concatenate((self.obs[offset_index:],
                                       other_data['obs'].astype('float32')), axis=0)
            self.nobs = np.concatenate((self.nobs[offset_index:],
                                        other_data['nobs'].astype('float32')), axis=0)
            self.act = np.concatenate((self.act[offset_index:],
                                       other_data['act'].astype('float32')), axis=0)
            self.rew = np.concatenate((self.rew[offset_index:],
                                       other_data['rew'].astype('float32')), axis=0)
            self.don = np.concatenate((self.don[offset_index:],
                                       other_data['don']), axis=0)

    def gather_indices(self, indices):
        out_dict = {}
        out_dict['obs'], out_dict['act'], out_dict['nobs'] = (self.obs[indices],
                                                              self.act[indices],
                                                              self.nobs[indices])
        out_dict['rew'], out_dict['don'] = self.rew[indices], self.don[indices]
        return out_dict

    def get_random_batch(self, batch_size):
        indices = np.random.randint(self.N, size=batch_size)
        return self.gather_indices(indices)

    def get_stats_previous_timesteps(self, num_timesteps):
        if num_timesteps > self.N:
            print('Not enough samples in the buffer')
            return
        latest_obs = self.obs[-num_timesteps - 1:]
        latest_rew = self.rew[-num_timesteps - 1:]
        latest_don = self.don[-num_timesteps - 1:]
        trajectory_startpoints = np.where(latest_don)[0] + 1
        number_startpoints = trajectory_startpoints.shape[0]
        trajectory_rewards = []
        for i in range(number_startpoints - 1):
            trajectory_rewards.append(np.sum(
                latest_rew[trajectory_startpoints[i]:trajectory_startpoints[i + 1]]))
        if len(trajectory_rewards) == 0:
            print('Latest number of completed trajectories - {}'.format(
                len(trajectory_rewards)))
        else:
            log_trajectory_statistics(trajectory_rewards)


class VisualReplayBuffer(ReplayBuffer):
    def __init__(self, buffer_size, initial_data={}):
        super(VisualReplayBuffer, self).__init__(buffer_size, initial_data)

    def _initial_setup(self, initial_data={}):
        super(VisualReplayBuffer, self)._initial_setup(initial_data)
        self.ims = initial_data['ims'].astype(np.uint8)

    def add(self, other_data):
        if self.N == -1:
            self._initial_setup(other_data)
        else:
            self.N += other_data['n']
            offset_index = int(np.amax(np.array([self.N - self.buffer_size, 0])))
            self.N -= offset_index
            self.obs = np.concatenate((self.obs[offset_index:],
                                       other_data['obs'].astype('float32')), axis=0)
            self.nobs = np.concatenate((self.nobs[offset_index:],
                                        other_data['nobs'].astype('float32')), axis=0)
            self.act = np.concatenate((self.act[offset_index:],
                                       other_data['act'].astype('float32')), axis=0)
            self.rew = np.concatenate((self.rew[offset_index:],
                                       other_data['rew'].astype('float32')), axis=0)
            self.don = np.concatenate((self.don[offset_index:],
                                       other_data['don']), axis=0)
            self.ims = np.concatenate((self.ims[offset_index:],
                                       other_data['ims'].astype(np.uint8)), axis=0)

    def gather_indices(self, indices):
        out_dict = super(VisualReplayBuffer, self).gather_indices(indices)
        out_dict['ims'] = ((self.ims[indices].astype('float32') + 0.5) / 256)
        return out_dict


class LatentReplayBuffer(VisualReplayBuffer):
    def __init__(self, buffer_size, model, initial_data={}):
        super(LatentReplayBuffer, self).__init__(buffer_size, initial_data)
        self._latent_model = model

    def _initial_setup(self, initial_data={}):
        super(LatentReplayBuffer, self)._initial_setup(initial_data)
        self.lat = self._latent_model.get_latent_space(initial_data['lat'])
        latent_shape = self.lat.shape
        self.latent_mask = np.ones([1, 1] + list(latent_shape[2:])).astype('float32')

    def add(self, other_data):
        if self.N == -1:
            self._initial_setup(other_data)
        else:
            self.N += other_data['n']
            offset_index = int(np.amax(np.array([self.N - self.buffer_size, 0])))
            self.N -= offset_index
            self.obs = np.concatenate((self.obs[offset_index:],
                                       other_data['obs'].astype('float32')), axis=0)
            self.nobs = np.concatenate((self.nobs[offset_index:],
                                        other_data['nobs'].astype('float32')), axis=0)
            self.act = np.concatenate((self.act[offset_index:],
                                       other_data['act'].astype('float32')), axis=0)
            self.rew = np.concatenate((self.rew[offset_index:],
                                       other_data['rew'].astype('float32')), axis=0)
            self.don = np.concatenate((self.don[offset_index:],
                                       other_data['don']), axis=0)
            self.ims = np.concatenate((self.ims[offset_index:],
                                       other_data['ims'].astype(np.uint8)), axis=0)
            self.lat = np.concatenate((self.lat[offset_index:],
                                       self._latent_model.get_latent_space(other_data['lat'])), axis=0)

    def gather_indices(self, indices):
        out_dict = super(LatentReplayBuffer, self).gather_indices(indices)
        out_dict['lat'] = self.lat[indices] * self.latent_mask
        return out_dict


class DACReplayBuffer(ReplayBuffer):
    def __init__(self, discriminator, buffer_size, initial_data={}):
        super(DACReplayBuffer, self).__init__(buffer_size, initial_data)
        self._disc = discriminator

    def get_random_batch(self, batch_size, re_eval_rw=True):
        out = super(DACReplayBuffer, self).get_random_batch(
            batch_size)
        if re_eval_rw:
            out['rew'] = self._disc.get_reward(out['obs'], out['act'])
        return out


class VisualDACReplayBuffer(VisualReplayBuffer):
    def __init__(self, discriminator, buffer_size, initial_data={}):
        super(VisualDACReplayBuffer, self).__init__(buffer_size, initial_data)
        self._disc = discriminator

    def get_random_batch(self, batch_size, re_eval_rw=True):
        out = super(VisualDACReplayBuffer, self).get_random_batch(
            batch_size)
        if re_eval_rw:
            out['rew'] = self._disc.get_reward(out['ims'])
        return out


class KLPreprocessingDACReplayBuffer(VisualReplayBuffer):
    def __init__(self, gail, buffer_size, reward_noise=True,
                 initial_data={}):
        super(KLPreprocessingDACReplayBuffer, self).__init__(buffer_size, initial_data)
        self._pre = gail._pre
        self._disc = gail._disc
        self._rn = reward_noise

    def get_random_batch(self, batch_size, re_eval_rw=True):
        out = super(KLPreprocessingDACReplayBuffer, self).get_random_batch(
            batch_size)
        if re_eval_rw:
            if self._rn:
                out['pre'] = self._pre(out['ims'])
            else:
                out['pre'], _ = self._pre.get_distribution_info(out['ims'])
            out['rew'] = self._disc.get_reward(out['pre'])
        return out


class WeightedPreReplayBuffer(VisualReplayBuffer):
    def __init__(self, gail, buffer_size, reward_noise=True,
                 expert_proportion=0.5, expert_percentile=0.9,
                 buffer_baseline=False, initial_data={}):
        super(WeightedPreReplayBuffer, self).__init__(buffer_size, initial_data)
        self._pre = gail._pre
        self._disc = gail._disc
        self._rn = reward_noise
        self._e_prop = expert_proportion
        self._e_perc = expert_percentile
        self._buffer_baseline = buffer_baseline

    def get_random_batch(self, batch_size, re_eval_rw=True, weighted=True):
        if not weighted:
            out = super(WeightedPreReplayBuffer, self).get_random_batch(
                batch_size)
        else:
            out = self.get_random_weighted_batch(batch_size)
        if re_eval_rw:
            if self._rn:
                out['pre'] = self._pre(out['ims'])
            else:
                out['pre'], _ = self._pre.get_distribution_info(out['ims'])
            out['rew'] = self._disc.get_reward(out['pre'])
            if self._buffer_baseline:
                out['rew'] = out['rew'] - self.baseline_rew
        return out

    def get_random_weighted_batch(self, batch_size):
        n_expert_samples = int(int(batch_size) * self._e_prop)
        expert_samples_indices = np.random.choice(
            self.expert_indices, n_expert_samples)
        non_expert_samples_indices = np.random.choice(
            self.non_expert_indices, batch_size - n_expert_samples)
        indices = np.concatenate((expert_samples_indices, non_expert_samples_indices))
        out = self.gather_indices(indices)
        return out

    def update_buffer(self, ):
        if self._rn:
            pre = self._pre(((self.ims.astype('float32') + 0.5) / 256))
        else:
            pre, _ = self._pre.get_distribution_info(((self.ims.astype('float32') + 0.5) / 256))
        self.likelihood = np.array(np.reshape(self._disc(pre), [-1]))
        sorted_indices = np.argsort(self.likelihood)
        split_index = int(self.N * self._e_perc)
        if self._buffer_baseline:
            self.baseline_rew = self._disc.get_reward(np.expand_dims(pre[sorted_indices[split_index]],
                                                                     axis=0))
        self.expert_indices = sorted_indices[split_index:]
        self.non_expert_indices = sorted_indices[:split_index]
        if self.expert_indices.shape[0] == 0:
            self.expert_indices = self.non_expert_indices


class LatentDACReplayBuffer(LatentReplayBuffer):
    def __init__(self, discriminator, buffer_size, model, initial_data={}):
        super(LatentDACReplayBuffer, self).__init__(buffer_size, model, initial_data)
        self._disc = discriminator

    def get_random_batch(self, batch_size, re_eval_rw=True):
        out = super(LatentDACReplayBuffer, self).get_random_batch(
            batch_size)
        if re_eval_rw:
            out['rew'] = self._disc.get_reward(out['lat'])
        return out


class ExpertBuffer(object):
    def __init__(self, initial_data={}):
        self.obs = initial_data['obs'].astype('float32')
        self.nobs = initial_data['nobs'].astype('float32')
        self.act = initial_data['act'].astype('float32')
        self.don = initial_data['don']
        self.N = self.obs.shape[0]

    def gather_indices(self, indices):
        out_dict = {}
        out_dict['obs'], out_dict['act'], out_dict['nobs'] = (self.obs[indices],
                                                              self.act[indices],
                                                              self.nobs[indices])
        return out_dict

    def get_random_batch(self, batch_size):
        indices = np.random.randint(self.N, size=batch_size)
        return self.gather_indices(indices)

    def dac_conversion(self, env):
        self.obs = env.wrap_non_absorbing_states(self.obs).astype('float32')
        self.nobs = env.wrap_non_absorbing_states(self.nobs).astype('float32')
        indices, = np.where(self.don)
        print(indices)
        self.nobs[indices] = env.absorbing_state.astype('float32')
        self.don[indices] = False
        indices += 1
        state_to_insert = np.expand_dims(env.absorbing_state, axis=0).astype('float32')
        act_to_insert = np.zeros((1, env.action_space.shape[0])).astype('float32')
        self.obs = np.insert(self.obs, indices, state_to_insert, axis=0)
        self.nobs = np.insert(self.nobs, indices, state_to_insert, axis=0)
        self.act = np.insert(self.act, indices, act_to_insert, axis=0)
        self.don = np.insert(self.don, indices, False)
        self.N += indices.shape[0]


class VisualExpertBuffer(object):
    def __init__(self, initial_data):
        self.ims = initial_data['ims'].astype(np.uint8)
        self.N = self.ims.shape[0]

    def gather_indices(self, indices):
        return {'ims': (self.ims[indices].astype('float32') + 0.5) / 256}

    def get_random_batch(self, batch_size):
        indices = np.random.randint(self.N, size=batch_size)
        return self.gather_indices(indices)


class OldVisualExpertBuffer(object):
    def __init__(self, initial_data):
        self.ims = (initial_data['ims'].astype('float32') + 0.5) / 256
        self.N = self.ims.shape[0]

    def gather_indices(self, indices):
        return {'ims': self.ims[indices]}

    def get_random_batch(self, batch_size):
        indices = np.random.randint(self.N, size=batch_size)
        return self.gather_indices(indices)


class VisualExpertBufferLight(object):
    def __init__(self, initial_data):
        self.ims = initial_data['ims'][:, 0, :, :, :].astype(np.uint8)
        self.N = self.ims.shape[0]
        self.ids = initial_data['ids']
        self.past_frames = initial_data['ims'].shape[1]
        self.idx_shifts = np.expand_dims(np.arange(self.past_frames), axis=0)

        self.pad_image = (np.zeros_like(self.ims[0]).astype('float32') + 0.5) / 256
        _, self.first_indices = np.unique(self.ids, return_index=True)

        self.first_ims = initial_data['ims'][self.first_indices, 1, :, :, :].astype(np.uint8)

        retrieval_indices = np.arange(self.first_indices.shape[0])
        self.padded_retrieval_list = np.zeros([self.N]).astype('int') + 1000000
        self.padded_retrieval_list[self.first_indices] = retrieval_indices

    def gather_indices(self, indices):
        all_indices = np.expand_dims(indices, axis=-1) - self.idx_shifts
        images = (self.ims[all_indices].astype('float32') + 0.5) / 256
        start_indices_mask = np.isin(all_indices[:, :-1], self.first_indices)
        start_indices_x, start_indices_y = np.where(start_indices_mask)
        first_indices_y = start_indices_y + 1

        trajectories_start_indices = all_indices[start_indices_x, start_indices_y]
        images[start_indices_x, first_indices_y] = (self.first_ims[
                                                        self.padded_retrieval_list[trajectories_start_indices]].astype(
            'float32') + 0.5) / 256

        pad_indices_mask = start_indices_mask[:, :-1]
        for i in range(self.past_frames - 3):
            pad_indices_mask[:, i + 1] = np.logical_or(pad_indices_mask[:, i + 1],
                                                       pad_indices_mask[:, i])
        pad_indices_x, pad_indices_y = np.where(pad_indices_mask)
        pad_indices_y += 2
        images[pad_indices_x, pad_indices_y] = self.pad_image

        return {'ims': images}

    def get_random_batch(self, batch_size):
        indices = np.random.randint(self.N, size=batch_size)
        return self.gather_indices(indices)


class LatentExpertBuffer(object):
    def __init__(self, initial_data, model):
        self._latent_model = model
        self.ims = initial_data['ims'].astype(np.unit8)
        self.lat = model.get_latent_space(initial_data['ims'])
        latent_shape = self.lat.shape
        self.latent_mask = np.ones([1, 1] + list(latent_shape[2:])).astype('float32')
        self.N = self.ims.shape[0]

    def gather_indices(self, indices):
        return {'ims': (self.ims[indices].astype('float32') + 0.5) / 256,
                'lat': self.lat[indices] * self.latent_mask}

    def get_random_batch(self, batch_size):
        indices = np.random.randint(self.N, size=batch_size)
        return self.gather_indices(indices)
