import numpy as np

class her_sampler:
    def __init__(self, replay_strategy, replay_k,env_name, reward_func=None):
        self.env_name = env_name
        self.replay_strategy = replay_strategy
        self.replay_k = replay_k
        if self.replay_strategy == 'future':
            self.future_p = 1 - (1. / (1 + replay_k))
        else:
            self.future_p = 0
        self.reward_func = reward_func

    def sample_her_transitions(self, episode_batch, batch_size_in_transitions):
        T = episode_batch['actions'].shape[1]
        rollout_batch_size = episode_batch['actions'].shape[0]
        batch_size = batch_size_in_transitions
        # select which rollouts and which timesteps to be used
        episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
        t_samples = np.random.randint(T, size=batch_size)
        transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()}
        # her idx
        conda = np.random.uniform(size=batch_size)
        her_indexes = np.where(conda< self.future_p)
        her_indexes2 = np.where(conda>= self.future_p)
        future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
        future_offset = future_offset.astype(int)
        future_t = (t_samples + 1 + future_offset)[her_indexes]
        # replace go with achieved goal
        future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
        # if self.env_name.startswith('HER'):
        transitions['g'][her_indexes] = future_ag

        transitions['sg'] = episode_batch['ag'][episode_idxs, t_samples + 1 + future_offset]
        transitions['fg'] = episode_batch['ag'][episode_idxs, -1]

        if self.env_name[:5]=='Fetch' or self.env_name[:4]=='Maze'or self.env_name[:4]=='Hand':
            transitions['r'] = np.expand_dims(self.reward_func(transitions['ag_next'], transitions['g'], None), 1)
        else:
            transitions['r'] = np.expand_dims(
                -(np.linalg.norm(transitions['ag_next']-transitions['g'],axis=-1)).astype(np.float32),1)

        transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()}

        return transitions