import torch
import torch.nn as nn
import os
from datetime import datetime
import numpy as np
import random
import pickle
import csv
from discrete_action_robots_modules.replay_buffer import replay_buffer
from discrete_action_robots_modules.models import ForwardMap, BackwardMap
from discrete_action_robots_modules.her import her_sampler
from discrete_action_robots_modules.normalizer import normalizer
from discrete_action_robots_modules.robots import goal_distance
from discrete_action_robots_modules.mdp_utils import extract_policy
from torch.distributions.cauchy import Cauchy
from discrete_action_robots_modules.models import mlp, weight_init
import wandb
import time
import torch.nn.functional as F
import typing as tp
import math

"""
FB agent with HER (MPI-version)

"""

class SamplingSeedActor(torch.nn.Module):
    def __init__(self, action_dim, z_dim, batch_size):
        super().__init__()
        self.z_dim = z_dim
        self.action_dim = action_dim
        self.powers = torch.tensor([2**i for i in range(self.z_dim)][::-1]).to('cuda').repeat(batch_size,1)
        self.max_seed = 2**z_dim+20000
        self.seed_to_action = []
        
        for i in range(self.max_seed):
            torch.random.manual_seed(i)
            action = torch.randint(0, self.action_dim, (1,)).unsqueeze(0).numpy()
            self.seed_to_action.append(action)
        self.seed_to_action = np.array(self.seed_to_action)
        self.seed_to_action = torch.tensor(self.seed_to_action).to('cuda')
    
    def forward(self, obs_hash, z):
        # import ipdb;ipdb.set_trace()
        actions = []
        seed_long = (z*self.powers).sum(1)
        # print("Time to compute z seed: ", time.time()-z_seed_time)
        final_seed = (seed_long+obs_hash.reshape(-1)) % self.max_seed
        # print("Time to compute final seed: ", time.time()-final_seed_computation_time)
        # import ipdb;ipdb.set_trace()
        actions = self.seed_to_action[final_seed.long()]
        # print("Time to compute actions: ", time.time()-actions_computation_time)
        return torch.tensor(actions.reshape(-1,1)).to('cuda')
    
class PSM(nn.Module):
    def __init__(self, obs_dim, goal_dim, d_dim, action_dim, feature_dim, hidden_dim) -> None:
        super().__init__()
        self.obs_dim = obs_dim
        self.goal_dim = goal_dim
        self.d_dim = d_dim
        self.action_dim = action_dim
        self.mlp_phi = mlp(self.obs_dim + self.goal_dim, hidden_dim, "ntanh",
                            hidden_dim, "irelu",
                            hidden_dim, "irelu")
        
        self.mlp_b = mlp(self.obs_dim + self.goal_dim, hidden_dim, "ntanh",
                            hidden_dim, "irelu",
                            hidden_dim, "irelu")
        feature_dim = hidden_dim

        seq_phi = [feature_dim, hidden_dim, "irelu", self.d_dim * self.action_dim]
        seq_b = [feature_dim, hidden_dim, "irelu", self.action_dim]

        self.phi_fc = mlp(*seq_phi)
        self.b_fc = mlp(*seq_b)

        self.apply(weight_init)

    def forward(self, obs, goal):
        phi = self.phi_fc(self.mlp_phi(torch.cat([obs, goal], dim=-1)))
        b = self.b_fc(self.mlp_b(torch.cat([obs, goal], dim=-1)))
        phi = phi.reshape(-1, self.action_dim, self.d_dim)
        b = b.reshape(-1, self.action_dim)
        return phi, b

class PSMAgent:
    def __init__(self, args, env, env_params, buffer_path='./'):
        self.args = args
        self.env = env
        self.env_params = env_params
        self.cauchy = Cauchy(torch.tensor([0.0]), torch.tensor([0.5]))

        self.device = 'cuda' if args.cuda else 'cpu'
        # create the network
        
        self.goal_dim = env_params['goal']
        self.obs_dim = env_params['obs']
        self.action_dim = env_params['action']

        self.psm = PSM(self.obs_dim, self.goal_dim, self.args.embed_dim, self.action_dim, 256, 256).to(self.device)
        # self.psm_2 = PSM(self.obs_dim, self.goal_dim, self.args.embed_dim, self.action_dim, 256, 256).to(self.device)

        self.psm_target = PSM(self.obs_dim, self.goal_dim, self.args.embed_dim, self.action_dim, 256, 256).to(self.device)
        # self.psm_target_2 = PSM(self.obs_dim, self.goal_dim, self.args.embed_dim, self.action_dim, 256, 256).to(self.device)

        self.w = mlp(self.args.sampling_z_dim, 256, "irelu", 
                     256, "irelu",
                     256, "irelu", self.args.embed_dim,"L2").to(self.device)
        
        self.w.apply(weight_init)
        
        self.w_target = mlp(self.args.sampling_z_dim, 256, "irelu", 
                     256, "irelu",
                     256, "irelu", self.args.embed_dim,"L2").to(self.device)
        
        self.w_inf = torch.zeros((self.args.embed_dim), requires_grad=True, device=self.device)

        # self.sampling_actor = SamplingActor(self.obs_dim, cfg.z_dim, self.action_dim, cfg.feature_dim, cfg.hidden_dim).to(cfg.device)
        self.sampling_actor = SamplingSeedActor(self.action_dim, self.args.sampling_z_dim, self.args.batch_size * self.args.batch_size).to(self.device)

        # self.o_norm = normalizer(size=env_params['obs'], default_clip_range=self.args.clip_range)
        # self.g_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range)
        
        # load the weights into the target networks
        self.psm_target.load_state_dict(self.psm.state_dict())
        # self.psm_target_2.load_state_dict(self.psm_2.state_dict())
        self.w_target.load_state_dict(self.w.state_dict())

        # optimizer
        self.opt = torch.optim.Adam([{'params': self.psm.parameters()},  # type: ignore
                                        {'params': self.w.parameters(), 'lr': self.args.lr}],
                                       lr=self.args.lr)
        
        self.inf_opt = torch.optim.Adam([self.w_inf], lr=self.args.lr)

        # self.actor_opt = torch.optim.Adam([{'params': self.sampling_actor.parameters()}], lr=cfg.lr)

        self.train()
        self.psm.train()
        self.w.train()

        # her sampler
        self.her_module = her_sampler(self.args.replay_strategy, self.args.replay_k, self.env.compute_reward)
        # create the replay buffer
        self.buffer = replay_buffer(self.env_params, self.args.buffer_size, self.her_module.sample_her_transitions)
        self.save_dir = f'{args.save_dir}/PSM-test/seed-{args.seed}/{datetime.now().strftime("%Y%m%d-%H%M%S")}'
        if args.save_dir is not None:
            # create the dict for store the model
            if not os.path.exists(self.save_dir):
                os.makedirs(self.save_dir)

            print(' ' * 26 + 'Options')
            for k, v in vars(self.args).items():
                print(' ' * 26 + k + ': ' + str(v))

            with open(self.save_dir + "/arguments.pkl", 'wb') as f:
                pickle.dump(self.args, f)

            with open('{}/score_monitor.csv'.format(self.save_dir), "wt") as monitor_file:
                monitor = csv.writer(monitor_file)
                monitor.writerow(['epoch', 'eval', 'avg dist'])

        self.buffer_path = buffer_path
        self.buffer.load(os.path.join(self.buffer_path, 'fetch_reach_rnd.pkl'))
        self.update_normalizer()

    def train(self, training: bool = True) -> None:
        self.training = training
        for net in [self.psm, self.w]:
            net.train(training)

    def int_to_binary_array(self, int_vector, num_bits=None):
        if num_bits is None:
            num_bits = int_vector.max().bit_length()
        
        binary_array = ((int_vector[:, None] & (1 << np.arange(num_bits))) > 0).astype(int)
        return binary_array
    
    def sample_z(self, size, device: str = "cpu"):
        z_np = np.random.randint(0, 2**self.args.sampling_z_dim, (size,))
        binary_array = self.int_to_binary_array(z_np, self.args.sampling_z_dim)
        return torch.FloatTensor(binary_array).to(device)
    
    def update_psm(
        self,
        obs: torch.Tensor,
        action: torch.Tensor,
        discount: torch.Tensor,
        next_obs: torch.Tensor,
        next_obs_hash: torch.Tensor,
        next_goal: torch.Tensor,
        z: torch.Tensor,
    ) -> tp.Dict[str, float]:
        metrics: tp.Dict[str, float] = {}

        #==========================================================
        
        # mesh_making_time= time.time()

        idx = torch.arange(obs.shape[0]).to(obs.device)
        mesh = torch.stack(torch.meshgrid(idx, idx, indexing='xy')).T.reshape(-1, 2)
        
        
        # print(mesh[:, 0].shape)
        # print(obs[mesh[:, 0]].shape)
        # print(obs.shape)
        # print(mesh.shape)
        action = action.unsqueeze(1)
        m_obs = obs[mesh[:, 0]]
        # print(m_obs.shape)
        m_next_obs = next_obs[mesh[:, 0]]
        m_next_obs_hash = next_obs_hash[mesh[:, 0]]
        # print(action.shape)
        m_action = action[mesh[:, 0]]
        # print(m_action.shape)
        m_next_goal = next_goal[mesh[:, 1]]
        with torch.no_grad():
            # compute greedy action
            target_phi, target_b = self.psm_target(m_next_obs, m_next_goal)
            # target_phi_2, target_b_2 = self.psm_target_2(m_next_obs, m_next_goal)
            # pi = self.sampling_actor(m_next_obs, z)
            target_w= self.w_target(z)
            next_actions = self.sampling_actor(m_next_obs_hash, z)

            target_phi = target_phi[torch.arange(target_phi.shape[0]), next_actions.squeeze(1)]
            target_b = target_b[torch.arange(target_b.shape[0]), next_actions.squeeze(1)]

            # target_phi_2 = target_phi_2[torch.arange(target_phi_2.shape[0]), next_actions.squeeze(1)]
            # target_b_2 = target_b_2[torch.arange(target_b_2.shape[0]), next_actions.squeeze(1)]
            
            # print(target_phi.shape, target_w.shape, target_b.shape)
            target_M = torch.einsum("sd, sd -> s", target_phi, target_w) + target_b
            # target_M_2 = torch.einsum("sd, sd -> s", target_phi_2, target_w) + target_b_2

        # compute PSM loss
        phi, b = self.psm(m_obs, m_next_goal)
        # phi_2, b_2 = self.psm_2(m_obs, m_next_goal)
        # print(m_action.shape)
        phi = phi[torch.arange(phi.shape[0]), m_action.squeeze(1)]
        # print(phi.shape)
        b = b[torch.arange(b.shape[0]), m_action.squeeze(1)]

        # phi_2 = phi_2[torch.arange(phi_2.shape[0]), m_action.squeeze(1)]
        # b_2 = b_2[torch.arange(b_2.shape[0]), m_action.squeeze(1)]
        # print(phi.shape, self.w(z).shape, b.shape, target_M.shape)

        M = torch.einsum("sd, sd -> s", phi, self.w(z)) + b
        # M_2 = torch.einsum("sd, sd -> s", phi_2, w) + b_2

        M = M.reshape(obs.shape[0], obs.shape[0])
        # M_2 = M_2.reshape(obs.shape[0], obs.shape[0])

        target_M = target_M.reshape(obs.shape[0], obs.shape[0])
        # target_M_2 = target_M_2.reshape(obs.shape[0], obs.shape[0])

        I = torch.eye(*M.size(), device=M.device)
        off_diag = ~I.bool()
        psm_offdiag = 0.5 * (M - discount * target_M)[off_diag].pow(2).mean()
        # psm_offdiag += 0.5 * (M_2 - discount * torch.min(target_M, target_M))[off_diag].pow(2).mean()
        psm_diag = -(1-discount) * M.diag().mean() 
        psm_loss = psm_offdiag + psm_diag

        # sampling_action = self.sampling_actor(next_obs_hash, z)

        # print(f'Step: {self.training_iters} | PSM Loss: {psm_loss.item()} | PSM Diag: {psm_diag.item()} | PSM Offdiag: {psm_offdiag.item()}')
        # print(f'Phi: {phi.mean()} | B: {b.mean()} | M: {M.mean()} | W: {self.w(z).mean()}')
        # print('-------------------------------------------------')

        # Cov = torch.matmul(phi.T, phi)
        # I = torch.eye(*Cov.size(), device=Cov.device)
        # off_diag = ~I.bool()
        # orth_loss_diag = - 2 * Cov.diag().mean()
        # orth_loss_offdiag = Cov[off_diag].pow(2).mean()
        # orth_loss = orth_loss_offdiag + orth_loss_diag
        # psm_loss += self.cfg.ortho_coef * orth_loss


        # if self.cfg.use_tb or self.cfg.use_wandb or self.cfg.use_hiplog:

        # optimize PSM
        self.opt.zero_grad(set_to_none=True)
        # self.actor_opt.zero_grad(set_to_none=True)
        psm_loss.backward()
        self.opt.step()
        # print("Time to optimize PSM: ", time.time()-start_t)
        # self.actor_opt.step()

        # print('Step: ', step, ' | PSM Loss: ', psm_loss.item(), ' | Div Loss: ', div_loss.item(), ' | Diversity: ', diversity.item(), ' | Orth Loss: ', orth_loss.item())
        metrics['train/psm_loss'] = psm_loss.item()
        metrics['train/psm_diag'] = psm_diag.item()
        metrics['train/psm_offdiag'] = psm_offdiag.item()
        metrics['train/phi_norm'] = phi.mean()
        metrics['train/b_norm'] = b.mean()
        metrics['train/M_norm'] = M.mean()
        metrics['train/w'] = self.w(z).mean()
        metrics['train/frame'] = self.training_iters


        if isinstance(self.opt, torch.optim.Adam):
            metrics["train/opt_lr"] = self.opt.param_groups[0]["lr"]
        return metrics
    
    def update(self) -> tp.Dict[str, float]:
        metrics: tp.Dict[str, float] = {}

        start_t = time.time()
        transitions = self.buffer.sample(self.args.batch_size)

        obs = transitions['obs']
        obs_next = transitions['obs_next']
        g = transitions['g']
        # actions = transitions['actions']
        ag = transitions['ag']
        # next_obs_hash = transitions['next_obs_hash']
        # other_ag = transitions['g']
        ag_next = transitions['ag_next']

        # obs = self.o_norm.normalize(obs)
        # g = self.g_norm.normalize(g)
        # obs_next = self.o_norm.normalize(obs_next)
        # ag = self.g_norm.normalize(ag)
        # ag_next = self.g_norm.normalize(ag_next)

        obs = (obs - self.obs_mean) / (self.obs_std + 1e-6)
        g = (g - self.goal_mean) / (self.goal_std + 1e-6)
        obs_next = (obs_next - self.obs_mean) / (self.obs_std + 1e-6)
        ag = (ag - self.goal_mean) / (self.goal_std + 1e-6)
        ag_next = (ag_next - self.goal_mean) / (self.goal_std + 1e-6)

        # transfer them into the tensor
        obs_tensor = torch.tensor(obs, dtype=torch.float32).to(self.device)
        g_tensor = torch.tensor(g, dtype=torch.float32).to(self.device)
        obs_next_tensor = torch.tensor(obs_next, dtype=torch.float32).to(self.device)
        actions_tensor = torch.tensor(transitions['actions'], dtype=torch.long).to(self.device)
        ag_tensor = torch.tensor(ag, dtype=torch.float32).to(self.device)
        ag_next_tensor = torch.tensor(ag_next, dtype=torch.float32).to(self.device)
        # ag_other_tensor = torch.tensor(other_ag, dtype=torch.float32)
        # ag_other_tensor = torch.tensor(other_transitions['ag'], dtype=torch.float32)
        next_obs_hash = torch.tensor(transitions['next_obs_hash'], dtype=torch.long).to(self.device)

        # obs_tensor = self.o_norm.normalize(obs_tensor)
        # g_tensor = self.g_norm.normalize(g_tensor)
        # obs_next_tensor = self.o_norm.normalize(obs_next_tensor)
        # ag_tensor = self.g_norm.normalize(ag_tensor)
        # ag_next_tensor = self.g_norm.normalize(ag_next_tensor)
        
        z = self.sample_z(self.args.batch_size, device=self.device)
        # print(z.shape)
        # print("Z sampling took: ",time.time()-start_sample)
        # print("Z sampled")
        z = torch.repeat_interleave(z, self.args.batch_size, 0)
        # z = z.repeat_interleave(self.cfg.num_neg_samples, 1)
        # print(z.shape)
        if not z.shape[-1] == self.args.sampling_z_dim:
            raise RuntimeError("There's something wrong with the logic here")

        # rz = self.sample_z(self.cfg.batch_size ** 2, device=self.cfg.device)

        metrics.update(self.update_psm(obs=obs_tensor, action=actions_tensor, discount=self.args.gamma,
                                      next_obs=obs_next_tensor,next_obs_hash=next_obs_hash, next_goal=ag_next_tensor, 
                                      z=z))

        # print(f"Replay buffer sampling time: {replay_buffer_sampling_time}, Batch time: {batch_time}, Z sampling time: {z_sampling_time}, PSM update time: {psm_update_time}")
        self.soft_update_params(self.psm, self.psm_target,
                                 self.args.polyak)
        # self.soft_update_params(self.psm_2, self.psm_target_2,
        #                          self.args.polyak)
        self.soft_update_params(self.w, self.w_target,
                                 self.args.polyak)
        
        # Log metrics
        wandb.log(metrics)
        self.training_iters += 1
        return metrics

    def soft_update_params(self, net, target_net, tau) -> None:
        for param, target_param in zip(net.parameters(), target_net.parameters()):
            target_param.data.copy_(tau * param.data +
                                    (1 - tau) * target_param.data)
            
    def init_inference(self):
        '''
        Initialize the w_inf parameter
        Initialize the lagrange variables, optimizers etc
        '''

        # initialize the w_inf parameter
        self.w_inf = torch.randn((self.args.embed_dim), requires_grad=True, device=self.device)
        self.inf_opt = torch.optim.Adam([{'params': self.w_inf}], lr=self.args.lr)

        if self.args.use_dgd:
            self.lmult = mlp(self.obs_dim + self.goal_dim, 256, "irelu", 
                            256, "irelu", self.action_dim, "soft").to(self.device)
            
            self.lmult.apply(weight_init)
            
            self.lmult_opt = torch.optim.Adam([{'params': self.lmult.parameters()}], lr=self.args.lr)

    def infer_w(self, g):
        metrics: tp.Dict[str, float] = {}
        self.init_inference()
        # goal = self.g_norm.normalize(g)
        goal = (g - self.goal_mean) / (self.goal_std + 1e-6)
        goal = torch.tensor(goal).unsqueeze(0).float().to(self.device)
        # print(goal.shape)
        for step in range(self.args.n_infer_steps):
            transitions = self.buffer.sample(self.args.batch_size)
            obs = transitions['obs']
            ag = transitions['ag']
            # print(goal.shape)
            # goal = self.g_norm.normalize(goal)
            goal_rep = goal.repeat(obs.shape[0], 1)
            # print(goal.shape)
            # obs = self.o_norm.normalize(obs)
            obs = (obs - self.obs_mean) / (self.obs_std + 1e-6)
            obs = torch.tensor(obs).float().to(self.device)
            
            ag = (ag - self.goal_mean) / (self.goal_std + 1e-6)
            ag = torch.tensor(ag).float().to(self.device)

            # print(obs.shape, goal.shape)
            metrics.update(self._infer_step(obs, ag, goal_rep, step))
            
        self.w_inf = (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)

        
        return metrics
    
    def _infer_step(self, obs, ag, goal, step):
        perm = torch.randperm(obs.shape[0])
        perm_obs = ag[perm]

        metrics = {}

        with torch.no_grad():
            phi_g, b_g = self.psm(obs, goal)
            # phi_g_2, b_g_2 = self.psm_2(obs, goal)
            phi_perm, b_perm = self.psm(obs, perm_obs)
        # import ipdb;ipdb.set_trace()
        # obj = -(phi_g.T * self.w_inf).mean()
        # normed_w = (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)
        # Q1 = -torch.einsum("sad, d -> sa", phi_g_1, normed_w).mean()
        # Q2 = -torch.einsum("sad, d -> sa", phi_g_2, normed_w).mean()

        # obj = torch.min(Q1, Q2).mean()
        # obj = -torch.min(torch.einsum("sad, d -> sa", phi_g, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_g, torch.tensor(0.0)).mean()
        obj = -torch.einsum("sad, d -> sa", phi_g, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)).mean()
        
        if self.args.use_dgd:
            with torch.no_grad():
                l_mult = self.lmult(torch.cat([obs, perm_obs], dim=-1))
            constraints = - ((torch.einsum("sad, d -> sa", phi_perm, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_perm) * l_mult).mean()
            # c2 = - ((torch.einsum("sad, d -> sa", phi_g_2, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_g_2) * l_mult).mean()
            # constraints = c1 + c2
        else:
            constraints = - (torch.min(torch.einsum("sad, d -> sa", phi_perm, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_perm, torch.tensor(0.0)) * self.args.inf_coeff).mean()
            # constraints = constraints - (torch.min(torch.einsum("sad, d -> sa", phi_g_2, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_g_2, torch.tensor(0.0)) * self.args.inf_coeff).mean()
        
        self.inf_opt.zero_grad(set_to_none=True)
        loss = obj + constraints
        loss.backward()
        self.inf_opt.step()

        metrics['eval/obj'] = obj.item()
        metrics['eval/constraints'] = constraints.item()
        metrics['eval/lamb'] = l_mult.mean().item() if self.args.use_dgd else self.args.inf_coeff
        metrics['eval/frame'] = self.eval_iters
        self.eval_iters += 1
        # metrics['step'] = step

        if self.args.use_dgd:        
            constraints = ((torch.einsum("sad, d -> sa", phi_perm, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_perm) * self.lmult(torch.cat([obs, perm_obs], dim=-1))).mean()
            # c2 =  ((torch.einsum("sad, d -> sa", phi_g_2, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_g_2) * self.lmult(torch.cat([obs, goal], dim=-1))).mean()
            # constraints = c1 + c2

            self.lmult_opt.zero_grad(set_to_none=True)
            loss = constraints
            loss.backward()
            self.lmult_opt.step()
        
        # Log
        wandb.log(metrics)
        # print('Step: ', step, ' | Objective: ', metrics["obj"], ' | Constraints: ', metrics["constraints"])
        return metrics
    
    def q_function(self, obs, goal):
        # obs_ = self.o_norm.normalize(obs)
        # goal_ = self.g_norm.normalize(goal)
        obs_ = (obs - self.obs_mean) / (self.obs_std + 1e-6)
        goal_ = (goal - self.goal_mean) / (self.goal_std + 1e-6)
        obs = torch.tensor(obs_).float().to(self.device)
        goal = torch.tensor(goal_).float().to(self.device)
        with torch.no_grad():
            phi, b = self.psm(obs, goal)
            # phi_2, b_2 = self.psm_2(obs, goal)
            # q = torch.min(torch.einsum("sad, d -> sa", phi_1, self.w_inf) + b_1, torch.einsum("sad, d -> sa", phi_2, self.w_inf) + b_2)
            q = torch.einsum("sad, d -> sa", phi, self.w_inf) + b

        return q
    
    def act(self, obs, goal):
        q_function = self.q_function(obs, goal)

        # pi = torch.distributions.Categorical(logits=q_function)
        # action = pi.sample()
        action = torch.argmax(q_function, dim=1)

        return action
        
    def learn(self):
        """
        train the network

        """
        # start to collect samples
        # print('MPI SIZE: ', MPI.COMM_WORLD.Get_size())
        self.training_iters = 0
        self.eval_iters = 0
        self.inf_iters = 0
        for epoch in range(self.args.n_epochs):
            for _ in range(self.args.n_cycles):
                # self._update_network()
                self.update()
                
                # self.training_iters += 1
            # start to do the evaluation
            success_rate, avg_dist = self._eval_agent()
            print('[{}] epoch is: {}, eval: {:.3f}, avg_dist : {:.3f}'.format(datetime.now(), epoch, success_rate, avg_dist))
            with open('{}/score_monitor.csv'.format(self.save_dir), "a") as monitor_file:
                monitor = csv.writer(monitor_file)
                monitor.writerow([epoch, success_rate, avg_dist])

        # self.buffer.save(os.path.join(self.buffer_path, 'fetch_reach_buffer.pkl'))

    def sample_uniform_ball(self, n, eps=1e-10):
        gaussian_rdv = torch.FloatTensor(n, self.args.embed_dim).normal_(mean=0, std=1)
        gaussian_rdv /= torch.norm(gaussian_rdv, dim=-1, keepdim=True) + eps
        uniform_rdv = torch.FloatTensor(n, 1).uniform_()
        w = np.sqrt(self.args.embed_dim) * gaussian_rdv * uniform_rdv
        if self.args.cuda:
            w = w.cuda()
        return w

    def sample_cauchy_ball(self, n, eps=1e-10):
        gaussian_rdv = torch.FloatTensor(n, self.args.embed_dim).normal_(mean=0, std=1)
        gaussian_rdv /= torch.norm(gaussian_rdv, dim=-1, keepdim=True) + eps
        cauchy_rdv = self.cauchy.sample((n, ))
        w = np.sqrt(self.args.embed_dim) * gaussian_rdv * cauchy_rdv
        if self.args.cuda:
            w = w.cuda()
        return w

    # pre_process the inputs
    def _preproc_o(self, obs):
        # obs = self._clip(obs)
        # obs_norm = self.o_norm.normalize(obs)
        obs_norm = (obs - self.obs_mean) / (self.obs_std + 1e-6)
        obs_tensor = torch.tensor(obs_norm, dtype=torch.float32).unsqueeze(0)
        if self.args.cuda:
            obs_tensor = obs_tensor.cuda()
        return obs_tensor

    def _preproc_g(self, g):
        # g = self._clip(g)
        # g_norm = self.g_norm.normalize(g)
        g_norm = (g - self.goal_mean) / (self.goal_std + 1e-6)
        g_tensor = torch.tensor(g_norm, dtype=torch.float32).unsqueeze(0)
        if self.args.cuda:
            g_tensor = g_tensor.cuda()
        return g_tensor

    def update_normalizer(self):
        # print('Buffer size: ', self.buffer.size)
        # for i in range(0, self.buffer.size, self.args.batch_size):
        #     if i + self.args.batch_size > self.buffer.size:
        #         ep_obs = self.buffer.buffers['obs'][i:]
        #         ep_ag = self.buffer.buffers['ag'][i:]
        #         ep_g = self.buffer.buffers['g'][i:]
        #         ep_actions = self.buffer.buffers['actions'][i:]
        #     else:
        #         ep_obs = self.buffer.buffers['obs'][i:i+self.args.batch_size]
        #         ep_ag = self.buffer.buffers['ag'][i:i+self.args.batch_size]
        #         ep_g = self.buffer.buffers['g'][i:i+self.args.batch_size]
        #         ep_actions = self.buffer.buffers['actions'][i:i+self.args.batch_size]

        #     self._update_normalizer([ep_obs, ep_ag, ep_g, ep_actions])
        # print(self.buffer.buffers['obs'][:5, :5, :])
        self.obs_mean = self.buffer.buffers['obs'].reshape(-1, self.obs_dim).mean(0)
        self.obs_std = self.buffer.buffers['obs'].reshape(-1, self.obs_dim).std(0)
        self.goal_mean = self.buffer.buffers['g'].reshape(-1, self.goal_dim).mean(0)
        self.goal_std = self.buffer.buffers['g'].reshape(-1, self.goal_dim).std(0)

        print('Obs mean: ', self.obs_mean)
        print('Obs std: ', self.obs_std)

        print('Goal mean: ', self.goal_mean)
        print('Goal std: ', self.goal_std)
        
    # update the normalizer
    def _update_normalizer(self, episode_batch):
        mb_obs, mb_ag, mb_g, mb_actions = episode_batch
        mb_obs_next = mb_obs[:, 1:, :]
        mb_ag_next = mb_ag[:, 1:, :]
        # get the number of normalization transitions
        num_transitions = mb_actions.shape[1]
        # create the new buffer to store them
        buffer_temp = {'obs': mb_obs,
                       'ag': mb_ag,
                       'g': mb_g,
                       'actions': mb_actions,
                       'obs_next': mb_obs_next,
                       'ag_next': mb_ag_next,
                       }
        # print('Buffer_temp:  ', buffer_temp['obs'].shape)
        # print(num_transitions)
        transitions = self.her_module.sample_her_transitions(buffer_temp, num_transitions)
        obs, g = transitions['obs'], transitions['ag']  # replace g by ag
        # pre process the obs and g
        transitions['obs'], transitions['g'] = self._clip(obs), self._clip(g)
        # update
        self.o_norm.update(transitions['obs'])
        self.g_norm.update(transitions['g'])
        # recompute the stats
        self.o_norm.recompute_stats()
        self.g_norm.recompute_stats()


    # def act(self, obs, w, g):
    #     g = torch.tensor(g, dtype=torch.float32).unsqueeze(0).cuda()
    #     f = self.forward_network(obs, w)
    #     b = self.backward_network(g)
    #     q = torch.einsum('sda, sd -> sa', f, b)
    #     return q.max(1)[1]

    def _clip(self, o):
        o = np.clip(o, -self.args.clip_obs, self.args.clip_obs)
        return o

    # update the network
    # def _update_network(self):
    #     # sample the episodes
    #     transitions = self.buffer.sample(self.args.batch_size)
    #     other_transitions = self.buffer.sample(self.args.batch_size)
    #     # pre-process the observation and goal
    #     o, o_next, g, ag, next_obs_hash = transitions['obs'], transitions['obs_next'], transitions['g'], transitions['ag'], transitions['next_obs_hash']
    #     # self._update_normalizer([o, o_next, g, actions])
    #     transitions['obs'], transitions['g'] = self.o_norm.normalize(o)\
    #         , self.g_norm.normalize(g)
    #     transitions['obs_next'] = self.o_norm.normalize(o_next)
    #     transitions['ag'] = self.g_norm.normalize(ag)
    #     other_transitions['ag'] = self.g_norm.normalize(other_transitions['ag'])
    #     # other_ag = transitions['g']

    #     # transfer them into the tensor
    #     obs_tensor = torch.tensor(transitions['obs'], dtype=torch.float32)
    #     g_tensor = torch.tensor(transitions['g'], dtype=torch.float32)
    #     obs_next_tensor = torch.tensor(transitions['obs_next'], dtype=torch.float32)
    #     actions_tensor = torch.tensor(transitions['actions'], dtype=torch.long)
    #     ag_tensor = torch.tensor(transitions['ag'], dtype=torch.float32)
    #     # ag_other_tensor = torch.tensor(other_ag, dtype=torch.float32)
    #     ag_other_tensor = torch.tensor(other_transitions['ag'], dtype=torch.float32)
    #     next_obs_hash = torch.tensor(next_obs_hash, dtype=torch.long)
    #     if self.args.cuda:
    #         obs_tensor = obs_tensor.cuda()
    #         g_tensor = g_tensor.cuda()
    #         obs_next_tensor = obs_next_tensor.cuda()
    #         actions_tensor = actions_tensor.cuda()
    #         ag_tensor = ag_tensor.cuda()
    #         ag_other_tensor = ag_other_tensor.cuda()
    #         next_obs_hash = next_obs_hash.cuda()

    #     # if self.args.w_sampling == 'goal_oriented':
    #     #     with torch.no_grad():
    #     #         w = self.backward_network(g_tensor)
    #     #         w = w.detach()
    #     # elif self.args.w_sampling == 'uniform_ball':
    #     #     w = self.sample_uniform_ball(self.args.batch_size)
    #     # elif self.args.w_sampling == 'cauchy_ball':
    #     #     w = self.sample_cauchy_ball(self.args.batch_size)

    #     z = self.sample_z(self.args.batch_size, device='cuda')

    #     # calculate the target Q value function
    #     with torch.no_grad():
    #         # if self.args.soft_update:
    #         #     pi = self.get_policy(obs_next_tensor, w, policy_type='boltzmann', temp=self.args.temp,
    #         #                          target_network=True)
    #         #     f_next = torch.einsum('sda, sa -> sd', self.forward_target_network(obs_next_tensor, w), pi)
    #         # else:
    #         #     actions_next_tensor = self.act(obs_next_tensor, w, target_network=True)
    #         #     next_idxs = actions_next_tensor[:, None].repeat(1, self.args.embed_dim)[:, :, None]
    #         #     f_next = self.forward_target_network(obs_next_tensor, w).gather(-1, next_idxs).squeeze()  # batch x dim
    #         actions = self.sampling_actor(next_obs_hash, z).squeeze()
    #         # next_idxs = actions[:, None].repeat(1, self.args.embed_dim)[:, :, None]
    #         f_next = self.forward_target_network(obs_next_tensor, self.w(z))
    #         f_next = f_next[torch.arange(f_next.size(0)), :, actions]
    #         # print(f_next.shape)
    #         # f_next = 


    #         b_next = self.backward_target_network(ag_other_tensor)  # batch x dim
    #         M_next = torch.einsum('sd, td -> st', f_next, b_next)  # batch x batch
    #         M_next = M_next.detach()
    #         # # clip the q value
    #         # clip_return = 1 / (1 - self.args.gamma)
    #         # target_q_value = torch.clamp(target_q_value, -clip_return, 0)
    #     # the forward loss
    #     idxs = actions_tensor[:, None].repeat(1, self.args.embed_dim)[:, :, None]
    #     f = self.forward_network(obs_tensor, self.w(z)).gather(-1, idxs).squeeze()
    #     b = self.backward_network(ag_tensor)
    #     b_other = self.backward_network(ag_other_tensor)
    #     M_diag = torch.einsum('sd, sd -> s', f, b)  # batch
    #     M = torch.einsum('sd, td -> st', f, b_other)  # batch x batch
    #     fb_diag = - M_diag.mean()
    #     fb_offdiag = 0.5 * (M - self.args.gamma * M_next).pow(2).mean() 
    #     fb_loss = fb_diag + fb_offdiag
    #     # compute orthonormality's regularisation loss
    #     b_b_other = torch.einsum('sd, xd -> sx', b, b_other)  # batch x batch
    #     b_b_other_detach = torch.einsum('sd, xd -> sx', b, b_other.detach())  # batch x batch
    #     b_b_detach = torch.einsum('sd, sd -> s', b, b.detach())  # batch
    #     reg_loss = (b_b_detach * b_b_other.detach()).mean() - b_b_other_detach.mean()
    #     fb_loss += self.args.reg_coef * reg_loss

    #     # update the forward_network
    #     self.fb_optim.zero_grad()
    #     fb_loss.backward()
    #     self.fb_optim.step()

    #     wandb.log({'train/fb_diag': fb_diag.item(), 'train/fb_offdiag': fb_offdiag.mean().item(), 'train/reg_loss': reg_loss.item(), 'train/frame': self.training_iters})
    #     # print(f'diag_Loss: {fb_diag.item()}, off_diag_loss: {fb_offdiag.mean().item()}, reg_loss: {reg_loss.item()}')

    #     # the backward loss
    #     # f = self.forward_network(obs_norm_tensor, actions_tensor, w)
    #     # b = self.backward_network(ag_norm_tensor)
    #     # b_other = self.backward_network(g_other_norm_tensor)
    #     # z_diag = torch.einsum('sd, sd -> s', f, b)  # batch
    #     # z = torch.einsum('sd, td -> st', f, b_other)  # batch x batch
    #     # b_loss = 0.5 * (z - self.args.gamma * z_next).pow(2).mean() - z_diag.mean()
    #     # compute orthonormality's regularisation loss
    #     # b_b_other = torch.einsum('sd, xd -> sx', b, b_other)  # batch x batch
    #     # b_b_other_detach = torch.einsum('sd, xd -> sx', b, b_other.detach())  # batch x batch
    #     # b_b_detach = torch.einsum('sd, sd -> s', b, b.detach())  # batch
    #     # reg_loss = (b_b_detach * b_b_other.detach()).mean() - b_b_other_detach.mean()
    #     # b_loss += self.args.reg_coef * reg_loss
    #     #
    #     # # update the backward_network
    #     # self.backward_optim.zero_grad()
    #     # b_loss.backward()
    #     # sync_grads(self.backward_network)
    #     # self.backward_optim.step()

    #     # print('f_loss: {}, b_loss: {}'.format(f_loss.item(), b_loss.item()))

    # def inference(self, g):
    #     g_tensor = torch.tensor(g, dtype=torch.float32)
    #     if self.args.cuda:
    #         g_tensor = g_tensor.cuda()

    #     # Initialize w and w optim

    #     self.w_inf = torch.randn((self.args.embed_dim), requires_grad=True, device='cuda')
    #     self.inf_opt = torch.optim.Adam([{'params': self.w_inf}], lr=self.args.lr)


    #     for _ in range(self.args.n_infer_steps):
    #         transitions = self.buffer.sample(self.args.batch_size)
    #         o = transitions['obs']
    #         obs_tensor = torch.tensor(o, dtype=torch.float32).cuda()

    #         w_inf = torch.nn.functional.normalize(self.w_inf.reshape(1, -1))
    #         w_inf_rep = w_inf.repeat(self.args.batch_size, 1)
            
    #         f = self.forward_network(obs_tensor, w_inf_rep)
    #         b = self.backward_network(g_tensor.repeat(self.args.batch_size, 1))
    #         # print(f.shape, b.shape)
    #         w_loss = -torch.einsum('sda, sd -> sa', f, b).mean()
    #         self.inf_opt.zero_grad()
    #         w_loss.backward()
    #         # print(self.w_inf.grad)
    #         self.inf_opt.step()
    #         # print(f'w_loss: {w_loss.item()}')
    #         wandb.log({'eval/w_loss': w_loss.item(), 'eval/frame': self.eval_iters})
    #         # self.eval_iters += 1
        
    #     return self.w_inf
    # # do the evaluation
    def _eval_agent(self):
        total_success_rate = []
        total_dist = []
        for _ in range(self.args.n_test_rollouts):
            per_success_rate = []
            per_dist = []
            observation = self.env.reset()
            obs = observation['observation']
            g = observation['desired_goal']

            w_opt = self.infer_w(g)
            # w_opt = torch.nn.functional.normalize(w_opt.reshape(1, -1))

            # for _ in range(self.env_params['max_timesteps']):
            for _ in range(25):
                # with torch.no_grad():
                #     g_tensor = self._preproc_g(g)
                #     w = self.backward_network(g_tensor)
                # obs_tensor = self._preproc_o(obs)
                action = self.act(obs, g)
                observation_new, _, _, info = self.env.step(action)
                obs = observation_new['observation']
                new_g = observation_new['desired_goal']
                if np.sum(new_g - g) > 0:
                    print('Goal changed')
                    w_opt = self.infer_w(new_g)
                    # w_opt = torch.nn.functional.normalize(w_opt.reshape(1, -1))
                # if torch.all(new_g != g):
                #     g = new_g
                #     w_opt = self.inference(g)
                #     w_opt = torch.nn.functional.normalize(w_opt)
                dist = goal_distance(observation_new['achieved_goal'], observation_new['desired_goal'])
                # per_dist.append(dist)
                # per_success_rate.append(info['is_success'])
                per_dist = dist
                per_success_rate = info['is_success']
                if info['is_success'] > 0:
                    print('Success')
                    break
            total_success_rate.append(per_success_rate)
            total_dist.append(per_dist)
        total_success_rate = np.array(total_success_rate)
        avg_success_rate = np.mean(total_success_rate)
        total_dist = np.array(total_dist)
        avg_dist = np.mean(total_dist)
        return avg_success_rate, avg_dist

    # def _eval_gpi_agent(self, num_gpi=20):
    #     total_success_rate = []
    #     total_dist = []
    #     for _ in range(self.args.n_test_rollouts):
    #         per_success_rate = []
    #         per_dist = []
    #         observation = self.env.reset()
    #         obs = observation['observation']
    #         g = observation['desired_goal']
    #         if self.args.w_sampling == 'goal_oriented':
    #             transitions = self.buffer.sample(num_gpi)
    #             g_train = transitions['g']
    #             g_train_tensor = torch.tensor(g_train, dtype=torch.float32)
    #             if self.args.cuda:
    #                 g_train_tensor = g_train_tensor.cuda()
    #             w_train = self.backward_network(g_train_tensor)
    #         elif self.args.w_sampling == 'uniform_ball':
    #             w_train = self.sample_uniform_ball(num_gpi)
    #         elif self.args.w_sampling == 'cauchy_ball':
    #             w_train = self.sample_cauchy_ball(num_gpi)

    #         # for _ in range(self.env_params['max_timesteps']):
    #         for _ in range(25):
    #             with torch.no_grad():
    #                 g_tensor = self._preproc_g(g)
    #                 w = self.backward_network(g_tensor)
    #                 obs_tensor = self._preproc_o(obs)
    #                 action = self.act_gpi(obs_tensor, w_train, w).item()
    #             observation_new, _, _, info = self.env.step(action)
    #             obs = observation_new['observation']
    #             g = observation_new['desired_goal']
    #             dist = goal_distance(observation_new['achieved_goal'], observation_new['desired_goal'])
    #             # per_dist.append(dist)
    #             # per_success_rate.append(info['is_success'])
    #             per_dist = dist
    #             per_success_rate = info['is_success']
    #             if info['is_success'] > 0:
    #                 break
    #         total_success_rate.append(per_success_rate)
    #         total_dist.append(per_dist)
    #     total_success_rate = np.array(total_success_rate)
    #     avg_success_rate = np.mean(total_success_rate)
    #     total_dist = np.array(total_dist)
    #     avg_dist = np.mean(total_dist)
    #     return avg_success_rate, avg_dist
