import numpy as np


class her_sampler:
    def __init__(self, replay_strategy, replay_k, reward_func=None):
        self.replay_strategy = replay_strategy
        self.replay_k = replay_k
        if self.replay_strategy == 'future':
            self.future_p = 1 - (1. / (1 + replay_k))
        elif self.replay_strategy == 'none':
            self.future_p = 0
        else:
            raise NotImplementedError()
        self.reward_func = reward_func

    def sample_her_transitions(self, episode_batch, batch_size_in_transitions, good_ids=None, good_weight=2):
        T = episode_batch['actions'].shape[1]
        rollout_batch_size = episode_batch['actions'].shape[0]
        batch_size = batch_size_in_transitions
        # select which rollouts and which timesteps to be used
        if good_ids is None:
            episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
        else:
            weights = np.ones(rollout_batch_size)
            weights[good_ids] = good_weight
            weights = weights / np.sum(weights)
            episode_idxs = np.random.choice(np.arange(rollout_batch_size), batch_size, p=weights)
        t_samples = np.random.randint(T, size=batch_size)
        transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()}
        # her idx
        # print(self.future_p)
        # print(np.random.uniform(size=batch_size))
        her_indexes = np.where(np.random.uniform(size=batch_size) < self.future_p)
        # print(her_indexes)
        # import pdb
        # pdb.set_trace()
        future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
        future_offset = future_offset.astype(int)
        future_t = (t_samples + 1 + future_offset)[her_indexes]
        # replace go with achieved goal
        future_obs = episode_batch['ag'][episode_idxs[her_indexes], future_t]
        # print(future_obs.shape)
        # print(transitions['g'][her_indexes].shape)
        transitions['g'][her_indexes] = future_obs
        # to get the params to re-compute reward
        transitions['reward'] = np.expand_dims(self.reward_func(transitions['ag'], transitions['g']), 1)
        transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()}

        return transitions
