import numpy as np
import torch


class her_sampler:
    def __init__(self, replay_strategy, replay_k, env_params,env, env_name, reward_func=None):
        self.env = env
        self.env_name = env_name
        self.replay_strategy = replay_strategy
        self.replay_k = replay_k
        if self.replay_strategy == 'future':
            self.future_p = 1 - (1. / (1 + replay_k))
        else:
            self.future_p = 0
        self.reward_func = reward_func
        self.max_step = env_params['max_timesteps']
        self.tag = env_params['alg']

    def sample_her_transitions(self, episode_batch, batch_size_in_transitions, tag=10):

        T = episode_batch['actions'].shape[1]
        rollout_batch_size = episode_batch['actions'].shape[0]
        batch_size = batch_size_in_transitions
        # select which rollouts and which timesteps to be used
        episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
        t_samples = np.random.randint(T-1, size=batch_size)
        transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()}
        # her idx
        random_index = np.random.uniform(size=batch_size)
        her_indexes = np.where(random_index < self.future_p)
        her_indexes2 = np.where(random_index >= self.future_p)
        future_offset = random_index * (T - (t_samples))

        future_offset = future_offset.astype(int)
        future_t = (t_samples + 1 + future_offset)[her_indexes]
        future_t2 = (t_samples + 1 + future_offset)
        # replace go with achieved goal
        future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]

        transitions['g'][her_indexes] = future_ag
        transitions['fg'] = episode_batch['ag'][episode_idxs, -1]

        steps = np.array(list(map(self.map102, self.max_step-np.zeros_like(future_t2))))
        stepsher = np.array(list(map(self.map102, np.ones_like(future_t2)*np.max(future_t2))))
        stepsher[her_indexes] = steps[her_indexes]

        state_goal = episode_batch['obs'][episode_idxs, t_samples]
        action = episode_batch['actions'][episode_idxs, t_samples]
        achieved_goal = episode_batch['ag'][episode_idxs, future_t2]

        transitions['sg'] = episode_batch['ag'][episode_idxs, t_samples + 1 + future_offset]
        transitions['fg'] = episode_batch['ag'][episode_idxs, self.max_step]

        if self.env_name.startswith('Fetch') or  self.env_name.startswith('Hand'):
            transitions['r'] = np.expand_dims(self.reward_func(transitions['ag_next'], transitions['g'], None), 1)
        else:
            transitions['r'] = np.expand_dims(-(np.linalg.norm(transitions['ag_next'] - transitions['g'],
                                                               axis=-1)).astype(np.float32),1)

        transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()}

        return transitions, [state_goal, achieved_goal, action],her_indexes, her_indexes2

    def map102(self, x):
        x = bin(x)[2:]
        if len(x)==1:
            x = '00000'+x
        if len(x)==2:
            x = '0000'+x
        if len(x)==3:
            x = '000'+x
        if len(x)==3:
            x = '000'+x
        if len(x)==3:
            x = '000'+x
        if len(x)==4:
            x = '00'+x
        if len(x)==5:
            x = '0'+x

        x = [int(x[0]),int(x[1]),int(x[2]),int(x[3]),int(x[4]),int(x[5])]
        return x