import numpy as np

class her_sampler:
    def __init__(self, replay_strategy, replay_k,env_params, env_name, reward_func=None,clipv  = 200):
        self.env_name = env_name
        self.env_params = env_params
        self.replay_strategy = replay_strategy
        self.replay_k = replay_k
        if self.replay_strategy == 'future':
            self.future_p = 1 - (1. / (1 + replay_k))
        else:
            self.future_p = 0
        self.reward_func = reward_func
        self.clipv = clipv

    def sample_her_transitions(self, episode_batch, batch_size_in_transitions):
        T = episode_batch['actions'].shape[1]
        rollout_batch_size = episode_batch['actions'].shape[0]
        batch_size = batch_size_in_transitions
        # select which rollouts and which timesteps to be used
        episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
        t_samples = np.random.randint(T, size=batch_size)
        transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()}


        # her idx
        random_index = np.random.uniform(size=batch_size)
        her_indexes = np.where(random_index < self.future_p)
        her_indexes2 = np.where(random_index >= self.future_p)
        future_offset = np.random.uniform(size=batch_size) * (T - t_samples)

        # clipk = np.random.randint(0,self.clipv,1)[0]
        # future_offset = np.clip(future_offset, 0, clipk)
        future_offset = future_offset.astype(int)
        future_t = (t_samples + 1 + future_offset)[her_indexes]
        # replace go with achieved goal
        future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
        transitions['dgd'] = transitions['g']
        transitions['g'][her_indexes] = future_ag
        transitions['ag'] = episode_batch['ag'][episode_idxs, t_samples]

        transitions['sg'] = episode_batch['ag'][episode_idxs, t_samples + 1 + future_offset]

        transitions['fg'] = episode_batch['ag'][episode_idxs, -1]
        transitions['obst'] = episode_batch['obs'][episode_idxs, :]
        transitions['agt'] = episode_batch['ag'][episode_idxs, :]
        transitions['actionst'] = episode_batch['actions'][episode_idxs, :]
        transitions['actionst'] = np.concatenate((transitions['actionst'], transitions['actionst'][:,-1:,:]),1)

        step = self.env_params['max_timesteps']+1
        t_samples_all = np.arange(0,step,1)
        future_offset_all = np.random.uniform(size=step) * (T - t_samples_all)
        future_offset_all = future_offset_all.astype(int)
        transitions['herg'] = episode_batch['ag'][episode_idxs][:,future_offset_all+t_samples_all,:]
        transitions['trag'] = np.tile(episode_batch['ag'][episode_idxs][:, -1:, :], (1,step,1))
        

        if self.env_name[:3]=='Ant' or self.env_name[:5]=='Fetch' or self.env_name[:4]=='Maze'or self.env_name[:5]=='Point'or self.env_name[:4]=='Hand':
            transitions['r'] = np.expand_dims(self.reward_func(transitions['ag_next'], transitions['g'], None), 1)
        elif self.env_name[:6]=='Sawyer':
            transitions['r'] = np.expand_dims(-(np.linalg.norm(transitions['ag_next'] - transitions['g'],
                                                               axis=-1)).astype(np.float32), 1)
        elif self.env_name == 'MultiGoal':
            r = self.reward_func(transitions['ag_next'], transitions['g'], transitions['actions'])
            transitions['r'] = np.expand_dims(r, 1)

        transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()}
        transitions['T'] = T

        return transitions, her_indexes, her_indexes2, episode_idxs, t_samples, future_offset