import numpy as np

class her_sampler:
    def __init__(self, replay_strategy, replay_k,env_params, env_name, reward_func=None):
        self.env_name = env_name
        self.env_params = env_params
        self.replay_strategy = replay_strategy
        self.replay_k = replay_k
        if self.replay_strategy == 'future':
            self.future_p = 1 - (1. / (1 + replay_k))
        else:
            self.future_p = 0
        self.reward_func = reward_func

    def sample_her_transitions(self, episode_batch, batch_size_in_transitions):
        T = episode_batch['actions'].shape[1]
        rollout_batch_size = episode_batch['actions'].shape[0]
        batch_size = batch_size_in_transitions
        # select which rollouts and which timesteps to be used
        episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
        t_samples = np.random.randint(T, size=batch_size)
        transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()}

        # her idx
        random_index = np.random.uniform(size=batch_size)
        her_indexes = np.where(random_index < self.future_p)
        her_indexes2 = np.where(random_index >= self.future_p)

        target_index = np.minimum(T, t_samples + 200)
        future_offset = np.random.uniform(size=batch_size) * (target_index - t_samples)
        future_offset = future_offset.astype(int)
        future_t = (t_samples + 1 + future_offset)[her_indexes]
        # replace go with achieved goal
        future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
        transitions['g'][her_indexes] = future_ag

        transitions['sg'] = episode_batch['ag'][episode_idxs, t_samples + 1 + future_offset]

        transitions['fg'] = episode_batch['ag'][episode_idxs, -1]
        transitions['obst'] = episode_batch['obs'][episode_idxs, :]
        transitions['actionst'] = episode_batch['actions'][episode_idxs, :]
        transitions['actionst'] = np.concatenate((transitions['actionst'], transitions['actionst'][:,-1:,:]),1)

        if self.env_name[:5] == 'Fetch':
            step = 51
        if self.env_name[:3] == 'Ant':
            step = self.env_params['max_timesteps']+1
        else:
            step = 31

        t_samples_all = np.arange(0,step,1)
        future_offset_all = np.random.uniform(size=step) * (T - t_samples_all)
        future_offset_all = future_offset_all.astype(int)
        transitions['herg'] = episode_batch['ag'][episode_idxs][:,future_offset_all+t_samples_all,:]
        transitions['trag'] = np.tile(episode_batch['ag'][episode_idxs][:, -1:, :], (1,step,1))

        if self.env_name[:6]=='Sawyer' or self.env_name[:3]=='Ant1':
            transitions['r'] = np.expand_dims(-(np.linalg.norm(transitions['ag_next'] - transitions['g'], axis=-1)).
                                              astype(np.float32), 1)
        else:
            transitions['r'] = np.expand_dims(self.reward_func(transitions['ag_next'], transitions['g'], None), 1)
        transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()}

        return transitions, her_indexes, her_indexes2, episode_idxs, t_samples, future_offset