import torch
import os
from datetime import datetime
import numpy as np
from mpi4py import MPI
from mpi_utils.mpi_utils import sync_networks, sync_grads
from rl_modules.replay_buffer import replay_buffer
from rl_modules.models import actor, critic
from rl_modules.models import critic as criticgoal
from rl_modules.models import expert
from mpi_utils.normalizer import normalizer
from her_modules.her1 import her_sampler
from her_modules.herAnt import her_sampler as HERANT
import torch.nn.functional as F
torch.cuda.set_device(0)
"""
ddpg with HER (MPI-version)

"""

class ddpg_CounterHER:
    
    def __init__(self, args, alg, env, env_name, env_params, testenv=None):

        self.args = args
        self.env = env
        self.alg = alg
        self.env_name = env_name
        self.testenv  = testenv
        self.env_params = env_params
        # create the network
        self.actor_network = actor(env_params)
        self.actorgoal_network = actor(env_params)       
        self.actor_network.share_memory()
        self.critic_network = critic(env_params)
        self.criticgoal_network = critic(env_params)
        # sync the networks across the cpus

        # build up the target network
        self.actor_target_network = actor(env_params)
        self.actorgoal_target_network = actor(env_params)
        self.critic_target_network = critic(env_params)
        self.critic_goaltarget_network = critic(env_params)
        # load the weights into the target networks
        self.actor_target_network.load_state_dict(self.actor_network.state_dict())
        self.actorgoal_target_network.load_state_dict(self.actorgoal_network.state_dict())
        self.critic_target_network.load_state_dict(self.critic_network.state_dict())
        self.critic_goaltarget_network.load_state_dict(self.criticgoal_network.state_dict())
        self.device = torch.device('cuda:%d'%self.args.gpuid)
        # self.device = torch.device('cuda:0')
        # if use gpu
        if self.args.cuda:
            self.actor_network.to(self.device)
            self.actorgoal_target_network.to(self.device)
            self.actorgoal_network.to(self.device)
            self.critic_network.to(self.device)
            self.criticgoal_network.to(self.device)
            self.actor_target_network.to(self.device)
            self.critic_target_network.to(self.device)
            self.critic_goaltarget_network.to(self.device)
        # create the optimizer
        self.actor_optim = torch.optim.Adam(self.actor_network.parameters(), lr=self.args.lr_actor)
        self.actorgoal_optim = torch.optim.Adam(self.actorgoal_network.parameters(), lr=self.args.lr_actor)
        self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr_critic)
        self.criticgoal_optim = torch.optim.Adam(self.criticgoal_network.parameters(), lr=self.args.lr_critic)
        print(self.alg, self.env_name, self.args.n_epochs)
        # her sampler
        if env_name[:3]=='Ant':
            self.her_module = HERANT(self.args.replay_strategy, self.args.replay_k, env_params, self.env_name,
                                     self.env.compute_reward)
        else:
            self.her_module = her_sampler(self.args.replay_strategy, self.args.replay_k, env_params,self.env_name,
                                      self.env.compute_reward)


        # create the replay buffer
        self.buffer = replay_buffer(self.env_params, self.args.buffer_size, self.her_module.sample_her_transitions)
        # create the normalizer
        self.o_norm = normalizer(size=env_params['obs'], default_clip_range=self.args.clip_range)
        self.ag_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range)
        self.ag_next_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range)
        self.g_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range)
        self.dg_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range)
        # create the dict for store the model
        if MPI.COMM_WORLD.Get_rank() == 0:
            if not os.path.exists(self.args.save_dir):
                os.makedirs(self.args.save_dir)
            # path to save the model
            self.model_path = os.path.join(self.args.save_dir, self.args.env_name,
                                           alg +'_'+str(self.args.weight1)
                                           +'_'+str(self.args.weight2),str(self.args.seed))
            self.data_path = os.path.join(self.args.save_dir, 'samples',
                                          self.args.env_name, alg, str(self.args.seed))
            if not os.path.exists(self.model_path):
                os.makedirs(self.model_path)
            if not os.path.exists(self.data_path):
                os.makedirs(self.data_path)

        self.total_step = 0

    def train(self):
        ep_obs, ep_ag, ep_g, ep_actions = [], [], [], []
        # reset the environment
        observation = self.env.reset()
        obs = observation['observation']
        ag = observation['achieved_goal']
        g = observation['desired_goal']
        # start to collect samples
        for t in range(self.env_params['max_timesteps']):
            with torch.no_grad():
                input_tensor = self._preproc_inputs(obs, g)
                pi = self.actor_network(input_tensor)
                action = self._select_actions(pi)
            # feed the actions into the environment
            observation_new, _, _, info = self.env.step(action)
            obs_new = observation_new['observation']
            ag_new = observation_new['achieved_goal']
            # append rollouts
            ep_obs.append(obs.copy())
            ep_ag.append(ag.copy())
            ep_g.append(g.copy())
            ep_actions.append(action.copy())
            # re-assign the observation
            obs = obs_new
            ag = ag_new

    def learn(self, writer, logpath, tag):
        """
        train the network

        """
        # start to collect samples
        for epoch in range(self.args.n_epochs):

            samples = []
            for _ in range(self.args.n_cycles):
                mb_obs, mb_ag, mb_g, mb_actions = [], [], [], []
                for _ in range(self.args.num_rollouts_per_mpi):
                    # reset the rollouts
                    ep_obs, ep_ag, ep_g, ep_actions = [], [], [], []
                    # reset the environment
                    observation = self.env.reset()
                    obs = observation['observation']
                    ag = observation['achieved_goal']
                    g = observation['desired_goal']
                    # start to collect samples
                    for t in range(self.env_params['max_timesteps']):
                        with torch.no_grad():
                            input_tensor = self._preproc_inputs(obs, g)
                            pi = self.actor_network(input_tensor)
                            action = self._select_actions(pi)
                        # feed the actions into the environment
                        observation_new, _, _, info = self.env.step(action)

                        obs_new = observation_new['observation']
                        ag_new = observation_new['achieved_goal']
                        # append rollouts
                        ep_obs.append(obs.copy())
                        ep_ag.append(ag.copy())
                        ep_g.append(g.copy())
                        ep_actions.append(action.copy())
                        # re-assign the observation
                        obs = obs_new
                        ag = ag_new
                    ep_obs.append(obs.copy())
                    ep_ag.append(ag.copy())
                    mb_obs.append(ep_obs)
                    mb_ag.append(ep_ag)
                    mb_g.append(ep_g)
                    mb_actions.append(ep_actions)
                # convert them into arrays
                mb_obs = np.array(mb_obs)
                mb_ag = np.array(mb_ag)
                mb_g = np.array(mb_g)
                mb_actions = np.array(mb_actions)
                # store the episodes
                self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions])
                self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions])
                # samples.append(np.concatenate((mb_obs[:, :50, :], mb_ag[:, :50, :], mb_g[:, :50, :], mb_actions[:, :50, :]), 2))
                for _ in range(self.args.n_batches):
                    # train the network
                    trans = self._update_network(writer, epoch + 1)
                    # samples.append(trans)
                    self.total_step += 1
                # soft update
                self._soft_update_target_network(self.actor_target_network, self.actor_network)
                self._soft_update_target_network(self.critic_target_network, self.critic_network)
                self._soft_update_target_network(self.critic_goaltarget_network, self.criticgoal_network)
                self._soft_update_target_network(self.actorgoal_target_network, self.actorgoal_network)
            # start to do the evaluation
            if epoch%5==0:
                torch.save([self.o_norm.mean, self.o_norm.std, self.g_norm.mean, self.g_norm.std,
                        self.actor_network.state_dict(), self.critic_network.state_dict(),self.actorgoal_network.state_dict(),
                        self.criticgoal_network.state_dict(),self.criticgoal_network.state_dict(),self.criticgoal_network.state_dict()],
                       os.path.join(self.model_path, tag + '_' + str(epoch)+'.pt'))

            # samples = np.array(samples)
            # np.save(os.path.join(self.data_path, tag + '_' + str(epoch)), samples)

            if MPI.COMM_WORLD.Get_rank() == 0:
                if self.env_name.startswith('Fetch') or self.env_name.startswith('Hand') or self.env_name.startswith('fetch'):
                    success_rate = self._eval_agent()
                    print('[{}] epoch is: {}, eval success rate is: {:.3f}'.format(datetime.now(), epoch, success_rate))
                    writer.add_scalar('reward_eval', success_rate, global_step=epoch)
                elif self.env_name.startswith('Ant'):
                    success_rate1 = self._eval_Ant()
                    print('[{}] epoch is: {}, eval success rate is: {:.3f}: '.format(datetime.now(),
                                                                                   epoch, success_rate1))
                    writer.add_scalar('reward_eval', success_rate1, global_step=epoch)
                elif self.env_name.startswith('Point')  or self.env_name.startswith('MultiGoal'):
                    success_rate = self._eval_Point()
                    print('[{}] epoch is: {}, eval success rate is: {:.3f}'.format(datetime.now(), epoch, success_rate))
                    writer.add_scalar('reward_eval', success_rate, global_step=epoch)
                else:
                    if epoch%5==0:
                        success_rate = self._eval_agent_sawyer()
                        keys = list(info.keys())
                        j=0
                        for i, key in enumerate(keys):
                            if key=='xy-success':
                                j=i
                            res = success_rate[:,i::len(keys),i,0][:, -1]
                            writer.add_scalar('reward_eval/'+str(key), res.mean(), global_step=epoch)
                        print('[{}] epoch is: {}, eval success rate is: {:.3f}'.format(datetime.now(), epoch, res.mean()))

    # pre_process the inputs
    def _preproc_inputs(self, obs, g):
        obs_norm = self.o_norm.normalize(obs)
        g_norm = self.g_norm.normalize(g)
        # concatenate the stuffs
        inputs = np.concatenate([obs_norm, g_norm])
        inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0)
        if self.args.cuda:
            inputs = inputs.to(self.device)
        return inputs

    # this function will choose action for the agent and do the exploration
    def _select_actions(self, pi):
        action = pi.cpu().numpy().squeeze()
        # add the gaussian
        action += self.args.noise_eps * self.env_params['action_max'] * np.random.randn(*action.shape)
        action = np.clip(action, -self.env_params['action_max'], self.env_params['action_max'])
        # random actions...
        random_actions = np.random.uniform(low=-self.env_params['action_max'],
                                           high=self.env_params['action_max'], \
                                           size=self.env_params['action'])
        # choose if use the random actions
        action += np.random.binomial(1, self.args.random_eps, 1)[0] * (random_actions - action)
        return action

    # update the normalizer
    def _update_normalizer(self, episode_batch):
        mb_obs, mb_ag, mb_g, mb_actions = episode_batch
        mb_obs_next = mb_obs[:, 1:, :]
        mb_ag_next = mb_ag[:, 1:, :]
        # get the number of normalization transitions
        num_transitions = mb_actions.shape[1]
        # create the new buffer to store them
        buffer_temp = {'obs': mb_obs,
                       'ag': mb_ag,
                       'g': mb_g,
                       'actions': mb_actions,
                       'obs_next': mb_obs_next,
                       'ag_next': mb_ag_next,
                       }
        transitions, _,_,_,_,_= self.her_module.sample_her_transitions(buffer_temp, num_transitions)
        obs, g = transitions['obs'], transitions['g']
        # pre process the obs and g
        transitions['obs'], transitions['g'] = self._preproc_og(obs, g)
        # update
        self.o_norm.update(transitions['obs'])
        self.g_norm.update(transitions['g'])
        # recompute the stats
        self.o_norm.recompute_stats()
        self.g_norm.recompute_stats()

    def _preproc_og(self, o, g):
        o = np.clip(o, -self.args.clip_obs, self.args.clip_obs)
        g = np.clip(g, -self.args.clip_obs, self.args.clip_obs)
        return o, g

    # soft update
    def _soft_update_target_network(self, target, source):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data)

    def _compute_action_value(self,obs_tensor, goal_tensor):
        clip_return = 1 / (1 - self.args.gamma)
        gsp_norm_tensor = torch.cat((obs_tensor, goal_tensor), axis=1)
        actions_tensor = self.actor_target_network(gsp_norm_tensor)
        actionvalue = self.critic_target_network(gsp_norm_tensor, actions_tensor)
        actionvalue = torch.clamp(actionvalue, -clip_return, 0)
        return  actionvalue

    

    # update the network
    def _update_network(self, writer, epoch):
        # sample the episodes
        transitions,her_indexes,_,_,_,_ = self.buffer.sample(self.args.batch_size)

        o, o_next, g = transitions['obs'], transitions['obs_next'], transitions['g']
        transitions['obs'], transitions['g'] = self._preproc_og(o, g)
        transitions['obs_next'], transitions['g_next'] = self._preproc_og(o_next, g)
        
        randomindex = np.random.randint(0, self.args.batch_size, self.args.batch_size)
        g_shufflex= transitions['g'][randomindex]

        obs_norm = self.o_norm.normalize(transitions['obs'])
        g_norm = self.g_norm.normalize(transitions['g'])
        inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
        inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)

        obs_norm = self.o_norm.normalize(transitions['obs'])
        g_normx = self.g_norm.normalize(g_shufflex)
        inputs_normx = np.concatenate([obs_norm, g_normx], axis=1)
        inputs_norm_tensorx = torch.tensor(inputs_normx, dtype=torch.float32)


        actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32)
    
        real_q_value = self.critic_network(inputs_norm_tensor.to(self.device), actions_tensor.to(self.device))
        real_q_value2 = self.critic_network(inputs_norm_tensorx.to(self.device), actions_tensor.to(self.device))
        real_q_value_y1 = real_q_value.tile(1,self.args.batch_size)
        real_q_value_y2 = real_q_value2.tile(1,self.args.batch_size).transpose(1,0)
        
        real_q_value_diff = torch.abs(real_q_value_y1-real_q_value_y2).cpu().data.numpy()
        real_q_value_diff[real_q_value_diff==0]=1000
        if self.alg == 'CounterHERFar':
            index = np.argmax(real_q_value_diff, -1)
        elif self.alg == 'CounterHERShort':
            index = np.argmin(real_q_value_diff, -1)
        else:
            index = randomindex

        g_shuffle =  transitions['g'][index]
        g_norm_shuffle = g_norm[index]

        obs_norm = self.o_norm.normalize(transitions['obs'])
        g_norm = self.g_norm.normalize(transitions['g'])
        inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)

        obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
        g_next_norm = self.g_norm.normalize(transitions['g_next'])
        inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)

        obs_norm_tensor = torch.tensor(obs_norm, dtype=torch.float32)
        g_norm_tensor = torch.tensor(g_norm, dtype=torch.float32)
        g_norm_shuffle_tensor = torch.tensor(g_norm_shuffle, dtype=torch.float32)
        obs_next_norm_tensor = torch.tensor(obs_next_norm, dtype=torch.float32)
        inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
        inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32)
        
        r_tensor = torch.tensor(transitions['r'], dtype=torch.float32)
        r2 = self.env.compute_reward(transitions['ag_next'], g_shuffle, transitions['actions'])
        r_tensor2 = torch.tensor(np.expand_dims(r2, 1), dtype=torch.float32).cuda()
        
        if self.args.cuda:
            obs_norm_tensor = obs_norm_tensor.to(self.device)
            g_norm_shuffle_tensor = g_norm_shuffle_tensor.to(self.device)
            g_norm_tensor = g_norm_tensor.to(self.device)
            obs_next_norm_tensor = obs_next_norm_tensor.to(self.device)
            inputs_norm_tensor = inputs_norm_tensor.to(self.device)
            inputs_next_norm_tensor = inputs_next_norm_tensor.to(self.device)
            actions_tensor = actions_tensor.to(self.device)
            r_tensor = r_tensor.to(self.device)
        # calculate the target Q value function
        with torch.no_grad():
            # do the normalization
            actions_next = self.actor_target_network(inputs_next_norm_tensor)
            q_next_value = self.critic_target_network(inputs_next_norm_tensor, actions_next)
            
            q_next_value = q_next_value.detach()
            target_q_value = r_tensor + self.args.gamma * q_next_value
            target_q_value = target_q_value.detach()
            
            inputshuffle = torch.cat((obs_next_norm_tensor, g_norm_shuffle_tensor), 1)
            actions_next = self.actor_target_network(inputshuffle)
            q_next_value1 = self.critic_goaltarget_network(inputshuffle, actions_next)
            q_next_value1 = q_next_value1.detach()
            if self.env_name=='MultiGoal':
                target_q_value1 = r_tensor2 + self.args.gamma * q_next_value1
            else:
                target_q_value1 = r_tensor + self.args.gamma * q_next_value1
            target_q_value1 = target_q_value1.detach()

            # clip the q value
            clip_return = 1 / (1 - self.args.gamma)
            target_q_value = torch.clamp(target_q_value, -clip_return, 0)
            target_q_value1 = torch.clamp(target_q_value1, -clip_return, 0)

        weight = 1.0 - epoch / self.args.n_epochs
        # the q loss
        real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor)
        critic_loss = (target_q_value - real_q_value).pow(2).mean()
        real_q_value2 = self.criticgoal_network(inputs_norm_tensor, actions_tensor)
        
        if self.alg!='HER':
            mask = torch.greater(real_q_value2, real_q_value).float()
            if self.env_name=='MultiGoal':
                critic_loss =  critic_loss - float(self.args.weight1)*(real_q_value2-real_q_value).pow(2).mean()
            else:
                critic_loss =  critic_loss - weight*float(self.args.weight1)*(real_q_value2-real_q_value).pow(2).mean()
        
            input_tensor_shuffle = torch.cat((obs_norm_tensor, g_norm_shuffle_tensor), 1)
            real_q_value1 = self.criticgoal_network(input_tensor_shuffle,actions_tensor)
            critic_loss1 = (target_q_value1 - real_q_value1).pow(2).mean()

        # the actor loss
        actions_real = self.actor_network(inputs_norm_tensor)
        actor_loss = self.critic_network(inputs_norm_tensor, actions_real)

        
        if self.alg=='CounterHER'or self.alg == 'HER' or self.alg=='CounterHERShort'or self.alg == 'CounterHERv2':        
            
            x = torch.cat((obs_norm_tensor, g_norm_tensor), 1)
            actions_real_x = self.actor_network(x)
            actor_loss_ref2 = self.criticgoal_network(x, actions_real_x)
            
            reglos1 =  float(self.args.weight2)*(torch.minimum(actor_loss.detach(), actor_loss_ref2))
            reglos2 =  float(self.args.weight2)*(torch.maximum(actor_loss.detach(), actor_loss_ref2))
            if self.alg == 'CounterHERv2':# fourrooms
                actor_loss = -actor_loss.mean() - weight*reglos1.mean()  + weight*reglos2.mean()# mG
            
            elif self.alg=='HER' :
                actor_loss = -actor_loss.mean()
            else:
                actor_loss = -actor_loss.mean() + weight*reglos2.mean()
                
#         elif self.alg=='CounterHER'and if self.env_name=='Fourrooms':        
            
#             x = torch.cat((obs_norm_tensor, g_norm_shuffle_tensor), 1)
#             actions_real_x = self.actor_network(x)
#             actor_loss_ref2 = self.criticgoal_network(x, actions_real_x)
            
#             reglos1 =  float(self.args.weight2)*(torch.minimum(actor_loss.detach(), actor_loss_ref2))
#             reglos2 =  float(self.args.weight2)*(torch.maximum(actor_loss.detach(), actor_loss_ref2))

#             actor_loss = -actor_loss.mean() - weight*reglos1.mean()  + weight*reglos2.mean()# mG


        actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        
        
        # actions_realy = self.actorgoal_network(inputs_norm_tensor)
        # actor_loss2 = -self.criticgoal_network(inputs_norm_tensor, actions_realy).mean()
        # actor_loss2 += self.args.action_l2 * (actions_realy / self.env_params['action_max']).pow(2).mean()

        # start to update the network
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()
        

        # update the critic_network
        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()
        
        if self.alg!='HER':
           # update the critic_network
            self.criticgoal_optim.zero_grad()
            critic_loss1.backward()
            self.criticgoal_optim.step()
            
            # self.actorgoal_optim.zero_grad()
            # actor_loss2.backward()
            # self.actorgoal_optim.step()


    def _eval_agent(self):
        total_success_rate = []
        for _ in range(self.args.n_test_rollouts):
            per_success_rate = []
            observation = self.env.reset()
            obs = observation['observation']
            g = observation['desired_goal']
            for _ in range(self.env_params['max_timesteps']):
                with torch.no_grad():
                    input_tensor = self._preproc_inputs(obs, g)
                    pi = self.actor_network(input_tensor)
                    # convert the actions
                    actions = pi.detach().cpu().numpy().squeeze()
                observation_new, _, _, info = self.env.step(actions)
                obs = observation_new['observation']
                g = observation_new['desired_goal']
                per_success_rate.append(info['is_success'])
                # per_success_rate.append(info['state_success'])
            total_success_rate.append(per_success_rate)
        total_success_rate = np.array(total_success_rate)
        local_success_rate = np.mean(total_success_rate[:, -1])
        global_success_rate = MPI.COMM_WORLD.allreduce(local_success_rate, op=MPI.SUM)
        return global_success_rate / MPI.COMM_WORLD.Get_size()


    def _eval_Point(self):
        total_success_rate = []
        for _ in range(self.args.n_test_rollouts):
            per_success_rate = []
            observation = self.env.reset()
            obs = observation['observation']
            g = observation['desired_goal']
            for _ in range(self.env_params['max_timesteps']):
                with torch.no_grad():
                    input_tensor = self._preproc_inputs(obs, g)
                    pi = self.actor_network(input_tensor)
                    # convert the actions
                    actions = pi.detach().cpu().numpy().squeeze()
                observation_new, _, _, info = self.env.step(actions)
                obs = observation_new['observation']
                g = observation_new['desired_goal']
                is_success = np.linalg.norm(obs - g) < self.env.distance
                # is_success = np.sum(np.square(obs[:2] - g))< self.env.distance_threshold
                per_success_rate.append(is_success)
            total_success_rate.append(per_success_rate)
        total_success_rate = np.array(total_success_rate)
        local_success_rate = np.mean(np.max(total_success_rate[:, -3:], axis=-1))
        # local_success_rate = np.mean(total_success_rate[:, -5:])
        global_success_rate = MPI.COMM_WORLD.allreduce(local_success_rate, op=MPI.SUM)
        return global_success_rate / MPI.COMM_WORLD.Get_size()


    def _eval_agenttest(self):
        total_success_rate = []
        for _ in range(self.args.n_test_rollouts):
            per_success_rate = []
            observation = self.testenv.reset()
            obs = observation['observation']
            g = observation['desired_goal']
            for _ in range(self.env_params['max_timesteps2']):
                with torch.no_grad():
                    input_tensor = self._preproc_inputs(obs, g)
                    pi = self.actor_network(input_tensor)
                    # convert the actions
                    actions = pi.detach().cpu().numpy().squeeze()
                observation_new, _, _, info = self.testenv.step(actions)
                obs = observation_new['observation']
                g = observation_new['desired_goal']
                per_success_rate.append(info['is_success'])
            total_success_rate.append(per_success_rate)
        total_success_rate = np.array(total_success_rate)
        local_success_rate = np.mean(total_success_rate[:, -1:])
        global_success_rate = MPI.COMM_WORLD.allreduce(local_success_rate, op=MPI.SUM)
        return global_success_rate / MPI.COMM_WORLD.Get_size()


    def _eval_Ant(self):
        total_success_rate = []
        for _ in range(self.args.n_test_rollouts):
            per_success_rate = []
            observation = self.env.reset()
            obs = observation['observation']
            g = observation['desired_goal']
            for _ in range(self.env_params['max_timesteps']):
                with torch.no_grad():
                    input_tensor = self._preproc_inputs(obs, g)
                    pi = self.actor_network(input_tensor)
                    # convert the actions
                    actions = pi.detach().cpu().numpy().squeeze()
                observation_new, _, _, info = self.env.step(actions)
                obs = observation_new['observation']
                g = observation_new['desired_goal']
                ag = observation_new['achieved_goal']
                is_success = np.linalg.norm(ag - g) < self.env.distance_threshold
                per_success_rate.append(is_success)
            total_success_rate.append(per_success_rate)
        total_success_rate = np.array(total_success_rate)
        local_success_rate = np.mean(np.max(total_success_rate[:, -3:], axis=-1))
        # local_success_rate = np.mean(total_success_rate[:, -5:])
        global_success_rate = MPI.COMM_WORLD.allreduce(local_success_rate, op=MPI.SUM)
        return global_success_rate / MPI.COMM_WORLD.Get_size()


    def _eval_Point(self):
        total_success_rate = []
        for _ in range(self.args.n_test_rollouts):
            per_success_rate = []
            observation = self.env.reset()
            obs = observation['observation']
            g = observation['desired_goal']
            for _ in range(self.env_params['max_timesteps']):
                with torch.no_grad():
                    input_tensor = self._preproc_inputs(obs, g)
                    pi = self.actor_network(input_tensor)
                    # convert the actions
                    actions = pi.detach().cpu().numpy().squeeze()
                observation_new, _, _, info = self.env.step(actions)
                obs = observation_new['observation']
                g = observation_new['desired_goal']
                is_success = np.linalg.norm(obs - g) < self.env.distance
                per_success_rate.append(is_success)
            total_success_rate.append(per_success_rate)
        total_success_rate = np.array(total_success_rate)
        # local_success_rate = np.mean(np.max(total_success_rate[:, -1:], axis=-1))multiGoal
        local_success_rate = np.mean(np.max(total_success_rate[:, -3:], axis=-1))
        global_success_rate = MPI.COMM_WORLD.allreduce(local_success_rate, op=MPI.SUM)
        return global_success_rate / MPI.COMM_WORLD.Get_size()


    def _eval_agent_sawyer(self):
        total_success_rate = []
        for _ in range(self.args.n_test_rollouts):
            per_success_rate = []
            observation = self.env.reset()
            obs = observation['observation']
            g = observation['desired_goal']
            for _ in range(self.env_params['max_timesteps']):
                with torch.no_grad():
                    input_tensor = self._preproc_inputs(obs, g)
                    pi = self.actor_network(input_tensor)
                    # convert the actions
                    actions = pi.detach().cpu().numpy().squeeze()
                observation_new, _, _, info = self.env.step(actions)
                obs = observation_new['observation']
                g = observation_new['desired_goal']
                keys = list(info.keys())
                tmp = []
                for key in keys:
                    tmp.append([info[key]])
                    per_success_rate.append(tmp)
            total_success_rate.append(per_success_rate)
        total_success_rate = np.array(total_success_rate)

        return total_success_rate

