import numpy as np
import torch


class her_sampler:
    def __init__(self, replay_strategy, replay_k,env, env_name, reward_func=None):
        self.env = env
        self.env_name = env_name
        self.replay_strategy = replay_strategy
        self.replay_k = replay_k
        if self.replay_strategy == 'future':
            self.future_p = 1 - (1. / (1 + replay_k))
        else:
            self.future_p = 0
        self.reward_func = reward_func

    def sample_her_transitions(self, episode_batch, batch_size_in_transitions, tag=10):

        T = episode_batch['actions'].shape[1]
        rollout_batch_size = episode_batch['actions'].shape[0]
        batch_size = batch_size_in_transitions
        # select which rollouts and which timesteps to be used
        episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
        t_samples = np.random.randint(T-1, size=batch_size)
        transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()}
        # her idx
        random_index = np.random.uniform(size=batch_size)
        her_indexes = np.where(random_index < self.future_p)
        her_indexes2 = np.where(random_index >= self.future_p)
        future_offset = np.random.uniform(size=batch_size) * (T - (t_samples))
        future_offset = future_offset.astype(int)
        future_t = (t_samples + 1 + future_offset)[her_indexes]
        future_t2 = (t_samples + 1 + future_offset)
        # replace go with achieved goal
        future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]

        transitions['g'][her_indexes] = future_ag

        steps = np.array(list(map(self.map102, 50-future_t2)))

        state_goal = episode_batch['obs'][episode_idxs, t_samples]
        nextstate = episode_batch['obs_next'][episode_idxs, t_samples]
        current_goal = episode_batch['ag_next'][episode_idxs, t_samples]
        action = episode_batch['actions'][episode_idxs, t_samples]
        next_action = episode_batch['actions'][episode_idxs, t_samples+1]
        achieved_goal = episode_batch['ag'][episode_idxs, future_t2]

        if tag != 1:
            if self.env_name.startswith('Hand') or self.env_name.startswith('Fetch'):
                expreward = np.expand_dims(self.reward_func(achieved_goal, current_goal, None), 1)
            else:
                expreward = np.expand_dims(-(np.linalg.norm(achieved_goal - current_goal, axis=-1)), -1)

        if self.env_name.startswith('Hand') or self.env_name.startswith('Fetch'):
            transitions['r'] = np.expand_dims(self.reward_func(transitions['ag_next'], transitions['g'], None), 1)
        else:
            transitions['r'] = np.expand_dims(-(np.linalg.norm(
                transitions['ag_next']-transitions['g'],axis=-1)).astype(np.float32),1)

        transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()}

        if tag!=1:
            return transitions, [state_goal, achieved_goal, action, next_action],steps, her_indexes, nextstate, expreward,
        else:
            return transitions, [state_goal, achieved_goal, action], 1, 1, her_indexes, 1
        #

    def map102(self, x):
        x = bin(x)[2:]
        if len(x)==1:
            x = '00000'+x
        if len(x)==2:
            x = '0000'+x
        if len(x)==3:
            x = '000'+x
        if len(x)==3:
            x = '000'+x
        if len(x)==3:
            x = '000'+x
        if len(x)==4:
            x = '00'+x
        if len(x)==5:
            x = '0'+x

        x = [int(x[0]),int(x[1]),int(x[2]),int(x[3]),int(x[4]),int(x[5])]
        return x