import torch
import os
from datetime import datetime
import numpy as np
from mpi4py import MPI
from mpi_utils.mpi_utils import sync_networks, sync_grads
from rl_modules.replay_buffer import replay_buffer
from rl_modules.models import actor, critic, actor_nogoal,critic_nogoal,VAE_NoGoal
from rl_modules.models2 import VAE3
from mpi_utils.normalizer import normalizer
from her_modules.her1 import her_sampler
import torch.nn.functional as F
"""
ddpg with HER (MPI-version)
"""
class ddpg_HERTwoSubgoal:
    def __init__(self, args, alg, env, env_name, env_params):
        self.args = args
        self.env = env
        self.alg = alg
        self.env_name = env_name
        self.env_params = env_params
        # create the network
        self.actor_network = actor(env_params)
        self.actor_network_nogoal = actor_nogoal(env_params)

        self.actor_target_network = actor(env_params)
        self.actor_target_network_nogoal = actor_nogoal(env_params)

        self.critic_network = critic(env_params)
        self.critic_target_network = critic(env_params)

        self.critic_network_nogoal = critic_nogoal(env_params)
        self.critic_target_network_nogoal = critic_nogoal(env_params)
            
        # load the weights into the target networks
        self.actor_target_network.load_state_dict(self.actor_network.state_dict())
        self.critic_target_network.load_state_dict(self.critic_network.state_dict())
        self.actor_target_network_nogoal.load_state_dict(self.actor_network_nogoal.state_dict())
        self.critic_target_network_nogoal.load_state_dict(self.critic_network_nogoal.state_dict())

        self.device = torch.device('cuda:%d'%self.args.gpuid)
        # if use gpu
        if self.args.cuda:
            self.actor_network.to(self.device)

            self.critic_network.to(self.device)
            self.critic_network_nogoal.to(self.device)
            self.actor_network_nogoal.to(self.device)
            
            self.actor_target_network.to(self.device)
            self.critic_target_network.to(self.device)

            self.actor_target_network_nogoal.to(self.device)
            self.critic_target_network_nogoal.to(self.device)           
            
        # create the optimizer
        self.actor_optim = torch.optim.Adam(list(self.actor_network.parameters())
            +list(self.actor_network_nogoal.parameters()), lr=self.args.lr_actor)

        self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr_critic)
          
        self.critic_optim_nogoal = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr_critic)

        # her sampler
        self.her_module = her_sampler(self.args.replay_strategy, self.args.replay_k,
                                      env_params,self.env_name, self.env.compute_reward)
        # create the replay buffer
        self.buffer = replay_buffer(self.env_params, self.args.buffer_size, self.her_module.sample_her_transitions)
        
        self.her_module2 = her_sampler(self.args.replay_strategy, self.args.replay_k,
                                      env_params,self.env_name, self.env.compute_reward)
        # create the replay buffer
        self.buffer2 = replay_buffer(self.env_params, self.args.buffer_size, self.her_module2.sample_her_transitions)
        

        # create the normalizer
        self.o_norm = normalizer(size=env_params['obs'], default_clip_range=self.args.clip_range)
        self.ag_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range)
        self.ag_next_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range)
        self.g_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range)
        self.dg_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range)
        # create the dict for store the model
        if MPI.COMM_WORLD.Get_rank() == 0:
            if not os.path.exists(self.args.save_dir):
                os.makedirs(self.args.save_dir)
            # path to save the model
            self.model_path = os.path.join(self.args.save_dir, self.args.env_name,
                                           alg +'_'+str(self.args.weight1)
                                           +'_'+str(self.args.weight2),str(self.args.seed))
            self.data_path = os.path.join(self.args.save_dir,
                                          self.args.env_name, alg, str(self.args.seed))
            if not os.path.exists(self.model_path):
                os.makedirs(self.model_path)
            if not os.path.exists(self.data_path):
                os.makedirs(self.data_path)

       

        self.total_step = 0



    def learn(self, writer, logpath, tag):
        """
        train the network

        """
        # start to collect samples
        for epoch in range(self.args.n_epochs):
            samples = []
            for _ in range(self.args.n_cycles):
                mb_obs, mb_ag, mb_g, mb_actions = [], [], [], []
                for numroll in range(self.args.num_rollouts_per_mpi):
                    # reset the rollouts
                    ep_obs, ep_ag, ep_g, ep_actions = [], [], [], []
                    # reset the environment
                    observation = self.env.reset()
                    obs = observation['observation']
                    ag = observation['achieved_goal']
                    g = observation['desired_goal']
                    # start to collect samples
                    for t in range(self.env_params['max_timesteps']):
                        with torch.no_grad():
                            input_tensor, obs_tensor, g_tensor = self._preproc_inputs1(obs, g)
                            if numroll%5!=0:
                                pi = self.actor_network(input_tensor)
                                action = self._select_actions(pi)
                            else:
                                pi = self.actor_network_nogoal(obs_tensor)
                                action = self._select_actions(pi)
                        # feed the actions into the environment
                        observation_new, _,_, info = self.env.step(action)
                        obs_new = observation_new['observation']
                        ag_new = observation_new['achieved_goal']
                        # append rollouts
                        ep_obs.append(obs.copy())
                        ep_ag.append(ag.copy())
                        ep_g.append(g.copy())
                        ep_actions.append(action.copy())
                        # re-assign the observation
                        obs = obs_new
                        ag = ag_new
                    ep_obs.append(obs.copy())
                    ep_ag.append(ag.copy())
                    mb_obs.append(ep_obs)
                    mb_ag.append(ep_ag)
                    mb_g.append(ep_g)
                    mb_actions.append(ep_actions)
                # convert them into arrays
                mb_obs = np.array(mb_obs)
                mb_ag = np.array(mb_ag)
                mb_g = np.array(mb_g)
                mb_actions = np.array(mb_actions)
                # store the episodes
                if epoch==0:
                    self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions])
                    self.buffer2.store_episode([mb_obs, mb_ag, mb_g, mb_actions])
                else:
                    if numroll%5!=0:
                        self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions])
                    else:
                        self.buffer2.store_episode([mb_obs, mb_ag, mb_g, mb_actions])
                self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions])
                samples.append(np.concatenate((mb_obs[:, :50, :], mb_ag[:, :50, :], mb_g[:, :50, :], mb_actions[:, :50, :]), 2))
                for _ in range(self.args.n_batches):
                    # train the network
                    trans = self._update_network(writer, epoch + 1)
                    # samples.append(trans)
                    self.total_step += 1
                # soft update
                self._soft_update_target_network(self.actor_target_network, self.actor_network)
                self._soft_update_target_network(self.critic_target_network, self.critic_network)
                self._soft_update_target_network(self.actor_target_network_nogoal, self.actor_network_nogoal)
                self._soft_update_target_network(self.critic_target_network_nogoal, self.critic_network_nogoal)
            # start to do the evaluation
            if epoch%5==0:
                torch.save([self.o_norm.mean, self.o_norm.std, self.g_norm.mean, self.g_norm.std,
                        self.actor_network.state_dict(), self.critic_network.state_dict()],
                       os.path.join(self.model_path, tag + '_' + str(epoch)+'.pt'))
            samples = np.array(samples)
            np.save(os.path.join(self.data_path, tag + '_' + str(epoch) + '.pt'), samples)


            success_rate = self._eval_agent()
            print('[{}] epoch is: {}, Fetch eval success rate is: {:.3f}'.format(datetime.now(), epoch, success_rate))
            writer.add_scalar('reward_eval', success_rate, global_step=epoch)


    # pre_process the inputs
    def _preproc_inputs(self, obs, g):
        obs_norm = self.o_norm.normalize(obs)
        g_norm = self.g_norm.normalize(g)
        # concatenate the stuffs
        inputs = np.concatenate([obs_norm, g_norm])
        inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0)
        obs_norm = torch.tensor(obs_norm, dtype=torch.float32).unsqueeze(0)
        g_norm = torch.tensor(g_norm, dtype=torch.float32).unsqueeze(0)
        if self.args.cuda:
            inputs = inputs.to(self.device)
            obs_norm = obs_norm.to(self.device)
            g_norm = g_norm.to(self.device)
        return inputs

    def _preproc_inputs1(self, obs, g):
        obs_norm = self.o_norm.normalize(obs)
        g_norm = self.g_norm.normalize(g)
        # concatenate the stuffs
        inputs = np.concatenate([obs_norm, g_norm])
        inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0)
        obs_norm = torch.tensor(obs_norm, dtype=torch.float32).unsqueeze(0)
        g_norm = torch.tensor(g_norm, dtype=torch.float32).unsqueeze(0)
        if self.args.cuda:
            inputs = inputs.to(self.device)
            obs_norm = obs_norm.to(self.device)
            g_norm = g_norm.to(self.device)
        return inputs,obs_norm,g_norm

    # this function will choose action for the agent and do the exploration
    def _select_actions(self, pi):
        action = pi.cpu().numpy().squeeze()
        # add the gaussian
        action += self.args.noise_eps * self.env_params['action_max'] * np.random.randn(*action.shape)
        action = np.clip(action, -self.env_params['action_max'], self.env_params['action_max'])
        # random actions...
        random_actions = np.random.uniform(low=-self.env_params['action_max'], high=self.env_params['action_max'], \
                                           size=self.env_params['action'])
        # choose if use the random actions
        action += np.random.binomial(1, self.args.random_eps, 1)[0] * (random_actions - action)
        return action

    # update the normalizer
    def _update_normalizer(self, episode_batch):
        mb_obs, mb_ag, mb_g, mb_actions = episode_batch
        mb_obs_next = mb_obs[:, 1:, :]
        mb_ag_next = mb_ag[:, 1:, :]
        # get the number of normalization transitions
        num_transitions = mb_actions.shape[1]
        # create the new buffer to store them
        buffer_temp = {'obs': mb_obs,
                       'ag': mb_ag,
                       'g': mb_g,
                       'actions': mb_actions,
                       'obs_next': mb_obs_next,
                       'ag_next': mb_ag_next,
                       }
        transitions,x,y,_,_,_ = self.her_module.sample_her_transitions(buffer_temp, num_transitions)
        obs, g = transitions['obs'], transitions['g']
        # pre process the obs and g
        transitions['obs'], transitions['g'] = self._preproc_og(obs, g)
        # update
        self.o_norm.update(transitions['obs'])
        self.g_norm.update(transitions['g'])
        # recompute the stats
        self.o_norm.recompute_stats()
        self.g_norm.recompute_stats()

    def _preproc_og(self, o, g):
        o = np.clip(o, -self.args.clip_obs, self.args.clip_obs)
        g = np.clip(g, -self.args.clip_obs, self.args.clip_obs)
        return o, g

    # soft update
    def _soft_update_target_network(self, target, source):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data)

    def _compute_action_value(self,obs_tensor, goal_tensor):
        clip_return = 1 / (1 - self.args.gamma)
        gsp_norm_tensor = torch.cat((obs_tensor, goal_tensor), axis=1)
        actions_tensor = self.actor_target_network(gsp_norm_tensor)
        actionvalue = self.critic_target_network(gsp_norm_tensor, actions_tensor)
        actionvalue = torch.clamp(actionvalue, -clip_return, 0)
        return  actionvalue

   

    def transitions_process(self, transitions):

        o, o_next, g = transitions['obs'], transitions['obs_next'], transitions['g']
        transitions['obs'], transitions['g'] = self._preproc_og(o, g)
        transitions['obs_next'], transitions['g_next'] = self._preproc_og(o_next, g)

        transitions['sg'] = np.clip(transitions['sg'], -self.args.clip_obs, self.args.clip_obs)
        transitions['fg'] = np.clip(transitions['fg'], -self.args.clip_obs, self.args.clip_obs)
        transitions['dg'] = np.clip(transitions['dg'], -self.args.clip_obs, self.args.clip_obs)
        transitions['ag'] = np.clip(transitions['ag'], -self.args.clip_obs, self.args.clip_obs)

        return transitions
    
    # update the network
    def _update_network(self, writer, epoch):

        # sample the episodes
        transitions, her_indexes, her_indexes2,_,_,_= self.buffer.sample(256)
        transitions2, her_indexes, her_indexes2,_,_,_= self.buffer2.sample(256)

        transitions = self.transitions_process(transitions)
        transitions2 = self.transitions_process(transitions2)

        obs_norm = self.o_norm.normalize(transitions['obs'])
        g_norm = self.g_norm.normalize(transitions['g'])
        dg_norm = self.g_norm.normalize(transitions['dg'])

        obs_norm2 = self.o_norm.normalize(transitions2['obs'])
        g_norm2 = self.g_norm.normalize(transitions2['g'])
        dg_norm2 = self.g_norm.normalize(transitions2['dg'])

        inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
        inputs_norm2 = np.concatenate([obs_norm2, g_norm2], axis=1)

        obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
        g_next_norm = self.g_norm.normalize(transitions['g_next'])
        inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)

        obs_next_norm2 = self.o_norm.normalize(transitions2['obs_next'])
        g_next_norm2 = self.g_norm.normalize(transitions2['g_next'])
        inputs_next_norm2 = np.concatenate([obs_next_norm2, g_next_norm2], axis=1)

        obs_norm_tensor = torch.tensor(obs_norm, dtype=torch.float32)
        obs_norm_tensor2 = torch.tensor(obs_norm2, dtype=torch.float32)
        obs_next_norm_tensor2 = torch.tensor(obs_next_norm2, dtype=torch.float32)
        
        inputs_norm_tensor1 = torch.tensor(inputs_norm, dtype=torch.float32)
        inputs_next_norm_tensor1 = torch.tensor(inputs_next_norm, dtype=torch.float32)
        
        inputs_norm_tensor2 = torch.tensor(inputs_norm2, dtype=torch.float32)
        inputs_next_norm_tensor2 = torch.tensor(inputs_next_norm2, dtype=torch.float32)
        
        actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32)
        actions_tensor2 = torch.tensor(transitions2['actions'], dtype=torch.float32)

        r_tensor = torch.tensor(transitions['r'], dtype=torch.float32)
        r_tensor2 = torch.tensor(transitions2['r'], dtype=torch.float32)
        
        distances = np.linalg.norm(transitions['g'][:, np.newaxis] - transitions2['g'], axis=2)
        if self.alg=='HERGoalFreeMin':
            max_distance_indices = np.argmin(distances, axis=1)  # Shape: (batch_size,)
            r_tensor2 = np.array([self.env.compute_reward(transitions['g'][i], transitions2['g'][idx], None) 
                                  for i, idx in enumerate(max_distance_indices)])
        elif self.alg=='HERGoalFreeMax':
            max_distance_indices = np.argmax(distances, axis=1)  # Shape: (batch_size,)
            r_tensor2 = np.array([self.env.compute_reward(transitions['g'][i], transitions2['g'][idx], None) 
                                  for i, idx in enumerate(max_distance_indices)])
        else:
            print('no selection')
        r_tensor2 = np.expand_dims(r_tensor2, 1)
        r_tensor2 = torch.tensor(r_tensor2, dtype=torch.float32)

        if self.args.cuda:

            obs_norm_tensor = obs_norm_tensor.to(self.device)
            obs_norm_tensor2 = obs_norm_tensor2.to(self.device)           
            obs_next_norm_tensor2 = obs_next_norm_tensor2.to(self.device)
            
            inputs_norm_tensor1 = inputs_norm_tensor1.to(self.device)
            inputs_next_norm_tensor1 = inputs_next_norm_tensor1.to(self.device)
            
            inputs_norm_tensor2 = inputs_norm_tensor2.to(self.device)
            inputs_next_norm_tensor2 = inputs_next_norm_tensor2.to(self.device)
            
            actions_tensor = actions_tensor.to(self.device)
            actions_tensor2 = actions_tensor2.to(self.device)
            r_tensor = r_tensor.to(self.device)
            r_tensor2 = r_tensor2.to(self.device)
        
        weight = (self.args.n_epochs-epoch)/self.args.n_epochs
        # calculate the target Q value function
        with torch.no_grad():
            # do the normalization
            # concatenate the stuffs
            
            actions_next2 = self.actor_target_network_nogoal(obs_next_norm_tensor2)
            q_next_value2 = self.critic_target_network_nogoal(obs_next_norm_tensor2, actions_next2)
            target_q_value2 = r_tensor2 + self.args.gamma * q_next_value2
            target_q_value2 = target_q_value2.detach()

            actions_next1 = self.actor_target_network(inputs_next_norm_tensor1)
            q_next_value1 = self.critic_target_network(inputs_next_norm_tensor1, actions_next1)
            target_q_value = r_tensor + self.args.gamma * q_next_value1
            target_q_value = target_q_value.detach()
            
            # clip the q value
            clip_return = 1 / (1 - self.args.gamma)
            target_q_value = torch.clamp(target_q_value, -clip_return, 0)
            target_q_value2 = torch.clamp(target_q_value2, -clip_return, 0)
        # the q loss
        real_q_value = self.critic_network(inputs_norm_tensor1, actions_tensor)
        critic_loss = (target_q_value - real_q_value).pow(2).mean()

        # real_q_value2 = self.critic_network_nogoal(obs_norm_tensor2, actions_tensor2)
        # critic_loss2 = (target_q_value2 - real_q_value2).pow(2).mean()

        # the actor loss
        actions_real = self.actor_network(inputs_norm_tensor1)
        actor_loss_goal1 = self.critic_network(inputs_norm_tensor1, actions_real)

        actions_real1 = self.actor_network_nogoal(obs_norm_tensor)
        actor_loss_goal2 = self.critic_network(inputs_norm_tensor1, actions_real1)

        # actions_real3 = self.actor_network_nogoal(obs_norm_tensor2)
        # actor_loss_goal3 = self.critic_network_nogoal(obs_norm_tensor2, actions_real3)
        
        if self.alg=='HERGoalFreeMin' or self.alg=='HERGoalFreeMax':
            actor_loss = -torch.minimum(actor_loss_goal1, actor_loss_goal2).mean() 
        else:
            print('No alg')

        actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()

        # actor_loss
        self.actor_optim.zero_grad()
        actor_loss.backward()
        sync_grads(self.actor_network)
        self.actor_optim.step()

        # update the critic_network
        self.critic_optim.zero_grad()
        critic_loss.backward()
        sync_grads(self.critic_network)
        self.critic_optim.step()
        
        # update the critic_network
        # self.critic_optim_nogoal.zero_grad()
        # critic_loss2.backward()
        # sync_grads(self.critic_network_nogoal)
        # self.critic_optim_nogoal.step()

    def _eval_agent(self):
        total_success_rate = []
        for _ in range(self.args.n_test_rollouts):
            per_success_rate = []
            observation = self.env.reset()
            obs = observation['observation']
            g = observation['desired_goal']
            for _ in range(self.env_params['max_timesteps']):
                with torch.no_grad():
                    input_tensor = self._preproc_inputs(obs, g)
                    pi = self.actor_network(input_tensor)
                    # convert the actions
                    actions = pi.detach().cpu().numpy().squeeze()
                observation_new,  _,_, info = self.env.step(actions)
                obs = observation_new['observation']
                g = observation_new['desired_goal']
                per_success_rate.append(info['is_success'])
            total_success_rate.append(per_success_rate)
        total_success_rate = np.array(total_success_rate)
        local_success_rate = np.mean(total_success_rate[:, -1])
        global_success_rate = MPI.COMM_WORLD.allreduce(local_success_rate, op=MPI.SUM)
        return global_success_rate / MPI.COMM_WORLD.Get_size()

    def _eval_Point(self):
        total_success_rate = []
        for _ in range(self.args.n_test_rollouts):
            per_success_rate = []
            observation = self.env.reset()
            obs = observation['observation']
            g = observation['desired_goal']
            for _ in range(self.env_params['max_timesteps']):
                with torch.no_grad():
                    input_tensor = self._preproc_inputs(obs, g)
                    pi = self.actor_network(input_tensor)
                    # convert the actions
                    actions = pi.detach().cpu().numpy().squeeze()
                observation_new, _, _,_, info = self.env.step(actions)
                obs = observation_new['observation']
                g = observation_new['desired_goal']
                is_success = np.linalg.norm(obs - g) < self.env.distance
                per_success_rate.append(is_success)
            total_success_rate.append(per_success_rate)
        total_success_rate = np.array(total_success_rate)
        local_success_rate = np.mean(np.max(total_success_rate[:, -5:], axis=-1))
        global_success_rate = MPI.COMM_WORLD.allreduce(local_success_rate, op=MPI.SUM)
        return global_success_rate / MPI.COMM_WORLD.Get_size()


    def _eval_agent_sawyer(self):
        total_success_rate = []
        for _ in range(self.args.n_test_rollouts):
            per_success_rate = []
            observation = self.env.reset()
            obs = observation['observation']
            g = observation['desired_goal']
            for _ in range(self.env_params['max_timesteps']):
                with torch.no_grad():
                    input_tensor = self._preproc_inputs(obs, g)
                    pi = self.actor_network(input_tensor)
                    # convert the actions
                    actions = pi.detach().cpu().numpy().squeeze()
                observation_new, _, _,_, info = self.env.step(actions)
                obs = observation_new['observation']
                g = observation_new['desired_goal']
                keys = list(info.keys())
                tmp = []
                for key in keys:
                    tmp.append([info[key]])
                    per_success_rate.append(tmp)
            total_success_rate.append(per_success_rate)
        total_success_rate = np.array(total_success_rate)

        return total_success_rate

