import glob

import gym
import numpy as np
import  os
from pointEnv.PointEnv import GymPointEnv

class her_sampler:
    def __init__(self, replay_strategy, replay_k, args, env_name,compute_reward):
        self.env_name = env_name
        self.args = args
        self.replay_strategy = replay_strategy
        self.compute_reward = compute_reward
        self.replay_k = replay_k
        if self.replay_strategy == 'future':
            self.future_p = 1 - (1. / (1 + replay_k))
        else:
            self.future_p = 0

        if self.env_name.startswith('Point'):
            env = GymPointEnv(args.env_name.split('Point')[1], max_episode_steps=25, resize_factor=1)
            # difficulty = 0.75
            difficulty = 0.85
            max_goal_dist = env.max_goal_dist
            env.set_sample_goal_args(
                prob_constraint=1.0,
                min_dist=max(0, max_goal_dist * (difficulty - 0.05)),
                max_dist=max_goal_dist * (difficulty + 0.05))
        else:
            env = gym.make(args.env_name)
        observation = env.reset()
        state_dim = np.shape(observation['observation'])[0]
        goal_dim = np.shape(observation['achieved_goal'])[0]

        # base_path = os.path.join('/media/ly/data16/log/HER3/samples/',
        #                          self.env_name,'HER', str(self.args.seed),'HER_4.pt.npy')
        base_path = glob.glob(os.path.join('/media/ly/data17/log/HER3/samples/',
                                 self.env_name,'HER', '5','*npy'))[-1]
        data = np.load(base_path)
        lastdim = np.shape(data)[-1]
        data = np.reshape(data, (-1, 50, lastdim))

        self.expert_obs = data[:, :,:state_dim]
        self.expert_desgoal = data[:, :,state_dim:state_dim + goal_dim]
        self.expert_action = data[:, :, state_dim + goal_dim:]

    def sample_her_transitions(self, episode_batch, batch_size_in_transitions):

        T = episode_batch['actions'].shape[1]
        rollout_batch_size = episode_batch['actions'].shape[0]
        batch_size = batch_size_in_transitions
        # select which rollouts and which timesteps to be used
        episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
        t_samples = np.random.randint(T, size=batch_size)
        transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()}
        # her idx
        random_index = np.random.uniform(size=batch_size)
        her_indexes = np.where(random_index < self.future_p)
        her_indexes2 = np.where(random_index >= self.future_p)
        future_offset = np.random.uniform(size=batch_size) * (T - (t_samples))
        future_offset = future_offset.astype(int)
        future_t = (t_samples + 1 + future_offset)[her_indexes]
        future_t2 = (t_samples + 1 + future_offset)
        # replace go with achieved goal
        future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
        transitions['g'][her_indexes] = future_ag

        achieved_goal = episode_batch['ag'][episode_idxs, future_t2]
        steps = np.array(list(map(self.map102, future_t2)))
        stephers = np.array(list(map(self.map102, np.ones_like(future_t2)*np.max(future_t2))))
        stephers[her_indexes] = steps[her_indexes]

        exp_idxs = np.random.randint(0, np.shape(self.expert_obs)[0], batch_size)
        expt_samples = np.random.randint(49, size=batch_size)
        expfuture_offset = np.random.uniform(size=batch_size) * (49 - (expt_samples))
        expfuture_offset = expfuture_offset.astype(int)
        expfuture_t = (expt_samples + 1 + expfuture_offset)

        state_goal = self.expert_obs[exp_idxs, expt_samples]
        action = self.expert_action[exp_idxs, expt_samples]
        achieved_goal = self.expert_desgoal[exp_idxs, expfuture_t]

        # to get the params to re-compute reward
        if self.env_name.startswith('Point') or self.env_name.startswith('Fetch'):
            transitions['r'] = np.expand_dims(self.compute_reward(transitions['ag_next'], transitions['g'], None), 1)
        elif self.env_name.startswith('Ant'):
            transitions['r'] = np.expand_dims(-(np.linalg.norm(
                transitions['ag_next']-transitions['g'],axis=-1)).astype(np.float32),1)
        else:
            transitions['r'] = np.expand_dims(-(np.linalg.norm(
                transitions['ag_next']-transitions['g'],axis=-1)).astype(np.float32),1)

        return transitions, [state_goal, achieved_goal, action],steps, stephers, her_indexes, her_indexes2

    def map102(self, x):
        x = bin(x)[2:]
        if len(x)==1:
            x = '00000'+x
        if len(x)==2:
            x = '0000'+x
        if len(x)==3:
            x = '000'+x
        if len(x)==3:
            x = '000'+x
        if len(x)==3:
            x = '000'+x
        if len(x)==4:
            x = '00'+x
        if len(x)==5:
            x = '0'+x

        x = [int(x[0]),int(x[1]),int(x[2]),int(x[3]),int(x[4]),int(x[5])]
        return x
