# -- coding: utf-8 --
from cmath import inf


import os
from re import S
from time import sleep
from cv2 import exp, log
import torch
import torch.nn.functional as F
from torch.optim import Adam
from wandb import agent, wandb
#from SAC.utils import set_seed
from utils import soft_update, hard_update
from model import GaussianPolicy, QNetwork,QNetwork_discrete_Multihead,GaussianPolicy_discrete_Multihead,GaussianPolicy_continue,QNetwork_continue,LamadaNetwork,LamadaNetwork_hardconstraint,LamadaNetwork_hardconstraint_s,QNetwork_continue_Multihead,GaussianPolicy_continue_Multihead
import numpy as np
from mpi4py import MPI
from mpi_utils.mpi_utils import sync_networks, sync_grads

import torch.nn as nn

import torch.autograd as autograd
from pcgrad import PCGrad

def logsumexp(lamda,log_pi,q_out):
    t = log_pi-q_out
    t = torch.clamp(t,min=-36.0,max=36.0)
    temp = lamda*torch.exp(t)#a*log_pi - q
    assert (lamda>=0).all(), print('lamda is wrong {}'.format(lamda))
    return torch.log(temp.sum(dim=1)+1e-9)

class SAC(object):
    def __init__(self, num_inputs, action_dim, args,name='Alice'):
        self.name= name
        self.gamma = args.gamma
        self.tau = args.tau
        self.alpha = args.alpha
        self.action_dim = action_dim.n

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if args.cuda else "cpu")

        self.critic = QNetwork(num_inputs, self.action_dim, args.hidden_size).to(device=self.device)
        # sync_networks(self.critic)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork(num_inputs, self.action_dim, args.hidden_size).to(self.device)
        # hard_update(self.critic_target, self.critic)
     
        self.target_entropy = 0.18 * -np.log(1 / self.action_dim)
        #-torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
        self.log_alpha =torch.zeros(1, requires_grad=True, device=self.device)# torch.tensor(-2.3025,requires_grad=True,device=self.device)#
        self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

        self.policy = GaussianPolicy(num_inputs, self.action_dim, args.hidden_size).to(self.device)
        # sync_networks(self.policy)
        self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

 

    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0).view(1,-1)
        action_probs = self.policy.forward(state).detach().cpu().numpy().squeeze(0)
        if evaluate is False:
            action = np.random.choice(range(self.action_dim),p=action_probs)
        else:
            action = np.argmax(action_probs)
        return action

    def q_out(self,state):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        with torch.no_grad():
            qf1, qf2 = self.critic.forward(state)
            out = torch.min(qf1, qf2).cpu().numpy().squeeze(0)
        return out

    def update_parameters(self, memory, batch_size, updates,inputTuple=False):
        # Sample a batch from memory
        if not inputTuple:
            state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
        else:
            state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory
        # if self.name=='Alice':
        #     state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
        # else:
        #     transitions = memory.sample(batch_size)
            # state_batch,next_state_batch,goal_batch,reward_batch,mask_batch, action_batch = transitions['obs'],transitions['obs_next'],transitions['g'],transitions['r'],transitions['done'],transitions['actions']
            
            # state_batch=np.concatenate([state_batch,goal_batch],axis=1)
            # next_state_batch=np.concatenate([next_state_batch,goal_batch],axis=1)
            # #Sample a batch from memory
            #state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.LongTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device)#.unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device)#.unsqueeze(1)

        with torch.no_grad():
            action_probs, log_action_probs = self.policy.get_action_info(next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(next_state_batch)
            #v
            min_qf_next_target = (action_probs* (torch.min(qf1_next_target, qf2_next_target) - self.alpha * log_action_probs)).sum(dim=1).unsqueeze(-1)#Expectation
            
            
            # v_nextstatebatch = (torch.sum(q_min*prob_alice_action.repeat(self.agentnum,1,1),dim=-1).transpose(0,1)) -\
            #      self.alpha*(torch.sum(log_prob_all*prob_alice_action.repeat(self.agentnum,1,1),dim=-1).transpose(0,1))
            next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)

        qf1, qf2 = self.critic(state_batch)#.gather(1,action_batch.unsqueeze(-1)).squeeze(-1) 
       # print(action_batch)
        qf1,qf2 = qf1.gather(1,action_batch), qf2.gather(1,action_batch)# Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(qf1, next_q_value)  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(qf2, next_q_value)  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf_loss = qf1_loss + qf2_loss

        self.critic_optim.zero_grad()
        qf_loss.backward()
       
        self.critic_optim.step()

        pi, log_pi = self.policy.get_action_info(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch)
        min_qf_pi = torch.min(qf1_pi, qf2_pi).detach()

        policy_loss = (pi*((self.alpha * log_pi) - min_qf_pi)).sum(dim=1).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.policy_optim.zero_grad()
        policy_loss.backward()
      
        self.policy_optim.step()


        #pi, log_pi = self.policy.get_action_info(state_batch)
        if not True:
            alpha_loss = -((self.log_alpha)*((log_pi).sum(dim=1) + self.target_entropy).detach()).mean()#torch.exp
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
        alpha_loss = policy_loss

        self.alpha = self.log_alpha.exp()
        alpha_tlogs = self.alpha.clone() # For TensorboardX logs


        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item()

    # Save model parameters
    def save_checkpoint(self, env_name, suffix="", ckpt_path=None):
        if MPI.COMM_WORLD.Get_rank() ==0:
            if not os.path.exists('checkpoints/'+suffix+"/{}/".format(self.name)):
                os.makedirs('checkpoints/'+suffix+"/{}/".format(self.name))
            if ckpt_path is None:
                ckpt_path = "checkpoints/"+suffix+"/{}/".format(self.name)+"sac_checkpoint_{}_{}".format(env_name, suffix)
            print('Saving models to {}'.format(ckpt_path))
            torch.save({'policy_state_dict': self.policy.state_dict(),
                        'critic_state_dict': self.critic.state_dict(),
                        'critic_target_state_dict': self.critic_target.state_dict(),
                        'critic_optimizer_state_dict': self.critic_optim.state_dict(),
                        'policy_optimizer_state_dict': self.policy_optim.state_dict(),
                        'alpha':self.alpha,
                        'alpha_log':self.log_alpha,
                        'alpha_opt':self.alpha_optim.state_dict()
                        }, ckpt_path)
            print(self.alpha)
        else: pass
    # Load model parameters
    def load_checkpoint(self, ckpt_path, evaluate=False):
        print('Loading models from {}'.format(ckpt_path))
        if ckpt_path is not None:
            checkpoint = torch.load(ckpt_path)
            self.policy.load_state_dict(checkpoint['policy_state_dict'])
            self.critic.load_state_dict(checkpoint['critic_state_dict'])
            self.critic_target.load_state_dict(checkpoint['critic_target_state_dict'])
            self.critic_optim.load_state_dict(checkpoint['critic_optimizer_state_dict'])
            self.policy_optim.load_state_dict(checkpoint['policy_optimizer_state_dict'])
            #self.alpha=checkpoint['alpha']
            self.log_alpha = checkpoint['alpha_log']#self.alpha.log()
            self.alpha = checkpoint['alpha']
            #self.alpha_optim.load_state_dict(checkpoint['alpha_opt'])
            self.alpha_optim = Adam([self.log_alpha], lr=self.args.lr)
            print(self.alpha)

            if evaluate:
                self.policy.eval()
                self.critic.eval()
                self.critic_target.eval()
                #self.alpha.eval()
            else:
                self.policy.train()
                self.critic.train()
                self.critic_target.train()
                #self.alpha.train()
class MRF2_Fixedalpha_regV(object):
    def __init__(self, num_inputs, action_space, args):
        self.ensamble_num=args.ensamble_num
        self.eta = args.eta#-1
        if self.eta>0:
            print("EDAC loded")
        self.gamma = args.gamma
        self.tau = args.tau
        self.action_nums = action_space.n
        # self.alpha = args.alpha
        self.agentnum= args.agentnum

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if args.cuda else "cpu")


        # critic  output is  en * agent *batch * out
        self.critic = QNetwork_discrete_Multihead(num_inputs, self.action_nums, args.hidden_size,self.agentnum,ensamble_num=self.ensamble_num).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork_discrete_Multihead(num_inputs, self.action_nums, args.hidden_size,self.agentnum,ensamble_num=self.ensamble_num).to(self.device)
        hard_update(self.critic_target, self.critic)


        self.lamadacalculator=LamadaNetwork_hardconstraint(self.agentnum,self.agentnum-1,args,args.hidden_size).to(self.device)
        self.lamada_optim =  PCGrad(Adam(self.lamadacalculator.parameters(),lr=args.lr)) 
       #Adam(self.lamadacalculator.parameters(),lr=args.lr)
        
        #workers policy head num is  agnetnum -1 * batch * action_nums
        self.workers_policy = GaussianPolicy_discrete_Multihead(num_inputs,self.action_nums,args.hidden_size,self.agentnum-1).to(self.device)
        self.workers_policy_optim=Adam(self.workers_policy.parameters(),lr=args.lr)

        #Alice policy 
        self.policy = GaussianPolicy_discrete_Multihead(num_inputs,self.action_nums,args.hidden_size, 1).to(self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
        # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper


        if self.automatic_entropy_tuning is True:
            self.target_entropy = 0.5 * -np.log(1 / self.action_nums)#-torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
            # torch.zeros(self.agentnum, requires_grad=True, device=self.device)
            self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

        self.alpha = self.log_alpha.exp()


    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0).view(1,-1)
        if evaluate is False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0][0][0]

    def select_action_workers(self, state,idx, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0).view(1,-1)
        if evaluate is False:
            action, _, _ = self.workers_policy.sample(state)
        else:
            _, _, action = self.workers_policy.sample(state)
        return action.detach().cpu().numpy()[idx][0][0]


    def update_parameters(self, memory, batch_size, updates,inputTuple=False,index=0):
        # Sample a batch from memory
        if not inputTuple:
            state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
        else:
            state_batch, action_batch, reward_batch, next_state_batch, mask_batch =memory
       

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.LongTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device)#  512 * n
        mask_batch = torch.FloatTensor(mask_batch).to(self.device)#.unsqueeze(1)
        if self.eta > 0:
            action_batch.requires_grad_(True)

        
        with torch.no_grad():
            #r+V(s)
            prob_workers_actions,log_prob_workers_actions = self.workers_policy.actions_logprob(next_state_batch)
            prob_alice_action,log_prob_alice_action = self.policy.actions_logprob(state_batch)
            log_prob_all = torch.cat([log_prob_workers_actions,log_prob_alice_action],dim=0)# agentnum * batch * actionnum
            prob_all = torch.cat([prob_workers_actions,prob_alice_action],dim=0)

            q_min = self.critic_target(next_state_batch).min(dim=0)[0] #  agentnum* batch *actionnums
            
            v_nextstatebatch = (torch.sum(q_min*prob_alice_action.repeat(self.agentnum,1,1),dim=-1).transpose(0,1)) -\
                 self.alpha*(torch.sum(log_prob_all*prob_alice_action.repeat(self.agentnum,1,1),dim=-1).transpose(0,1))
        
            next_q_value = reward_batch+ mask_batch*self.gamma*v_nextstatebatch # batch*n

        qs = self.critic(state_batch).gather(-1,action_batch.unsqueeze(0).repeat(self.agentnum,1,1).unsqueeze(0).repeat(self.ensamble_num,1,1,1))#en*agentnum*batch*1#, action_batch)  # Two Q-functions to mitigate positive bias in the policy improvement step
        qs = qs.squeeze(-1).transpose(-2,-1)#en*batch*agentnum
        qf_loss = F.mse_loss(qs,next_q_value.unsqueeze(0).repeat(self.ensamble_num,1,1))
        
        #EDAC here 不知道是不是对discrete适用
        if self.eta>0:

            obs = state_batch.unsqueeze(0).repeat(self.ensamble_num,1,1).unsqueeze(0).repeat(self.agentnum,1,1,1)# 4 2 256 n
            ac  =  action_batch.unsqueeze(0).repeat(self.ensamble_num,1,1).unsqueeze(0).repeat(self.agentnum,1,1,1)# 4 2 256 1
            qs_p = self.critic(obs)#2 4 4 2 256 n #,ac)
            qs_p = qs_p.gather(-1,ac.unsqueeze(0).repeat(self.agentnum,1,1).unsqueeze(0).repeat(self.ensamble_num,1,1,1))#.dtype(int))#2 4 4 2 256 1
            qs_p=(qs_p.squeeze(-1).transpose(0,-1).transpose(-1,2).transpose(-2,1)*(torch.eye(self.agentnum).to(self.device))).sum(-1).transpose(-1,-3)
            # qs_p.backward(torch.eye(self.ensamble_num))
            # qs_pred_gradss , = torch.autograd.grad(qs_p, ac, retain_graph=True, create_graph=True,grad_outputs=torch.eye(self.ensamble_num).unsqueeze(1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1,self.agentnum,1,self.agentnum,batch_size,1).to(self.device))
            qs_pred_grads, = torch.autograd.grad(qs_p, ac, retain_graph=True, create_graph=True,\
                grad_outputs=torch.eye(self.ensamble_num).unsqueeze(0).unsqueeze(0).repeat(batch_size,self.agentnum,1,1).to(self.device))

            qs_pred_grads = qs_pred_grads[:-1,::]# 3 2 256 1
            # qs_pred_grads, = torch.autograd.grad(qs_preds_tile.sum(), actions_tile, retain_graph=True, create_graph=True)
            #2,4,2,256,3
            qs_pred_grads = qs_pred_grads / (torch.norm(qs_pred_grads, p=2, dim=-1).unsqueeze(-1) + 1e-10)
            #3,2,256,3
            qs_pred_grads = qs_pred_grads.transpose(1, 2)
            #256,4,2,3
            qs_pred_grads = torch.einsum('abik,abjk->abij', qs_pred_grads, qs_pred_grads)
            masks = torch.eye(self.ensamble_num, device=self.device).unsqueeze(dim=0).repeat(qs_pred_grads.size(1),1,1).unsqueeze(dim=0).repeat(qs_pred_grads.size(0),1, 1, 1)
            qs_pred_grads = (1 - masks) * qs_pred_grads
            grad_loss = torch.mean(torch.sum(qs_pred_grads, dim=(-2, -1))) / (self.ensamble_num - 1)
            qf_loss += self.eta * grad_loss

        self.critic_optim.zero_grad()
        qf_loss.backward()
        self.critic_optim.step()
       


        # update workers' policy --dicrete
    
        pis,log_pis=self.workers_policy.actions_logprob(state_batch)#n-1*512*ac_dim n-1*512*ac_dim
        q = self.critic(state_batch)[::,:-1,::]# 2*4 * 256* ac_dim#.unsqueeze(0).repeat(pis.shape[0],1,1),pis)
        qmin = q.min(dim=0)[0]#n-1*256*ac_dim
      
        workers_policy_loss = (((pis*log_pis).sum(-1).transpose(0,1))*self.alpha - (pis*qmin).sum(-1).transpose(0,1)).mean()
        self.workers_policy_optim.zero_grad()
        workers_policy_loss.backward(retain_graph=True)
        self.workers_policy_optim.step()

        

        #here change
        with torch.no_grad():
            pi, log_pi = self.policy.actions_logprob(state_batch)#get the Alice's prob of action
            workers_pi, log_workers_pi = self.workers_policy.actions_logprob(state_batch)# n-1 *batch*ac_num
            q_out = self.critic(state_batch).min(dim=0)[0]# agentnum *batch* ac_num

        q_out = q_out.transpose(0,-1)# ac_num *batch* agentnum
        
        weights_out=self.lamadacalculator(q_out.reshape(q_out.shape[0]*q_out.shape[1],-1)).view(*q_out.shape[:-1],-1)#ac_num *batch* agentnum
        lamda_out = torch.abs(weights_out)
        flag = (weights_out>0).float()*2-1.0#ac_num *batch* agentnum

        q_rstruct = 1*(flag)*q_out[::,::,:-1]*(torch.sum(lamda_out,dim=-1,keepdim=True)+1e-9)# ac_num*512*n-1
        lamda_out = torch.div(lamda_out,torch.sum(lamda_out,dim=-1,keepdim=True)+1e-9) # ac_num*512*n-1

        temp_pi = pi.repeat(self.agentnum-1,1,1)# n-1*batch*ac_num
        temp_log_pi = log_pi.repeat(self.agentnum-1,1,1)
        D_kl = (temp_pi*(temp_log_pi-log_workers_pi)).sum(-1).transpose(0,1)# *batch*n-1

        
        q_r = torch.clamp(torch.div(q_rstruct,self.alpha+1e-30),min=-1e36,max=1e36)#    
        t = torch.clamp(log_workers_pi.transpose(-1,0)-q_r,min=-36.0,max=36.0)#ac_num*batch*n-1
        t = lamda_out*torch.exp(t)#a*log_pi - q
        t = torch.log(t.sum(dim=-1)+1e-9).transpose(0,1)#batch*ac_num

        lamda_loss1= (pi.squeeze(0)* t).sum(-1).mean() + \
              1.0*(lamda_out.transpose(0,1).gather(dim=1, index=torch.max(pi,dim=-1,keepdim=True)[1].squeeze(0).unsqueeze(-1).repeat(1,1,self.agentnum-1)).squeeze(1) * D_kl).sum(-1).mean()
        
        #reg the direction of V

        if True:

            state_batch.requires_grad_(True)

            obs = state_batch.unsqueeze(0).repeat(self.agentnum,1,1)# 4 256 n
            ac  =  action_batch.unsqueeze(0).repeat(self.agentnum,1,1)# 4*256 1
            qs_p = self.critic(obs)#2 4 4 256 n #,ac)
            qs_p = qs_p.gather(-1,ac.unsqueeze(0).repeat(self.agentnum,1,1,1).unsqueeze(0).repeat(self.ensamble_num,1,1,1,1))#2 4 4 256 1
            #-> 256 4 4 
            qs_pred_grads, = torch.autograd.grad(qs_p.sum(dim=0).squeeze(0).squeeze(-1).transpose(0,-1), obs, retain_graph=True, create_graph=True,\
                grad_outputs=torch.eye(self.agentnum).unsqueeze(0).repeat(batch_size,1,1).to(self.device))
            # 4 256 n
            grads_workers = qs_pred_grads[:-1,::].transpose(0,1)# 256*3 *n
            grads_alice = qs_pred_grads[-1:,::].repeat(self.agentnum-1,1,1).transpose(0,1)# 256,3, n
            simliartys =  (grads_workers*grads_alice).sum(-1)#torch.einsum('bik,bjk->bij', qs_pred_grads, qs_pred_grads)
            ignore_flag = (simliartys<0).float()
            lamda_loss2 = torch.norm(ignore_flag*weights_out,p=2)

            
            import wandb
            if  True:
                for i in range(self.agentnum-1):
                    wandb.log({'Simlarity_{}'.format(i):ignore_flag.mean(0)[i],'update':updates})
                wandb.log({'RegLoss': lamda_loss2,'update':updates})

        # print(torch.norm(ignore_flag*weights_out,p=2))
        lamda_loss = [lamda_loss1,lamda_loss1]
        #lamda_loss = [lamda_loss1,lamda_loss1]

        #pcGrad
        self.lamada_optim.zero_grad()
        #lamda_loss.backward(retain_graph=True)
        self.lamada_optim.pc_backward(lamda_loss)
        #lamda_loss.pc_backward(retain_graph=True)
        self.lamada_optim.step()

        #Alice's policy update
        pi, log_pi = self.policy.actions_logprob(state_batch)#get the Alice's prob of action

        with torch.no_grad():
            workers_pi, log_workers_pi = self.workers_policy.actions_logprob(state_batch)# n-1 *batch*ac_num
            q_out = self.critic(state_batch).min(dim=0)[0]# agentnum *batch* ac_num

            q_out = q_out.transpose(0,-1)# ac_num *batch* agentnum
            weights_out=self.lamadacalculator(q_out.reshape(q_out.shape[0]*q_out.shape[1],-1)).view(*q_out.shape[:-1],-1)#)#ac_num *batch* agentnum
            lamda_out = torch.abs(weights_out)
            flag = (weights_out>0).float()*2-1.0#ac_num *batch* agentnum
            q_rstruct = 1*(flag)*q_out[::,::,:-1]*(torch.sum(lamda_out,dim=-1,keepdim=True)+1e-9)# ac_num*512*n-1
            lamda_out = torch.div(lamda_out,torch.sum(lamda_out,dim=-1,keepdim=True)+1e-9) # ac_num*512*n-1
            q_r = torch.clamp(torch.div(q_rstruct,self.alpha+1e-30),min=-1e36,max=1e36)#    
            t = torch.clamp(log_workers_pi.transpose(-1,0)-q_r,min=-36.0,max=36.0)#ac_num*batch*n-1
            t = lamda_out*torch.exp(t)#a*log_pi - q
            t = torch.log(t.sum(dim=-1)+1e-9).transpose(0,1)#batch*ac_num

        temp_pi = pi.repeat(self.agentnum-1,1,1)# n-1*batch*ac_num
        temp_log_pi = log_pi.repeat(self.agentnum-1,1,1)
        D_kl = (temp_pi*(temp_log_pi-log_workers_pi)).sum(-1).transpose(0,1)# *batch*n-1
        
        policy_loss= 1.0* (pi.squeeze(0)* t).sum(-1).mean() +   \
            (lamda_out.transpose(0,1).gather(dim=1, index=torch.max(pi,dim=-1,keepdim=True)[1].squeeze(0).unsqueeze(-1).repeat(1,1,self.agentnum-1)).squeeze(1) * D_kl).sum(-1).mean()
        
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()  
        # policy_loss = torch.tensor([0])
        # lamda_loss = policy_loss
        if not True:
            with torch.no_grad():
                workers_pi,workers_log_pi=self.workers_policy.actions_logprob(state_batch)# n*512*1
                Alice_pi,Alice_log_pi=  self.policy.actions_logprob(state_batch)
                log_pis = torch.cat([workers_log_pi,Alice_log_pi],dim=0)
                pis = torch.cat([workers_pi,Alice_pi],dim=0)#n*batch*ac_num

            alpha_loss = self.log_alpha*(((-1*Alice_pi*Alice_log_pi).sum(-1).transpose(0,1))-self.target_entropy).detach().mean()
            # (self.log_alpha[:-1]*((-1*Alice_pi.repeat(self.agentnum-1,1,1)*workers_log_pi).sum(-1).transpose(0,1)-self.target_entropy).detach()).mean()+\
            #     F.mse_loss(self.log_alpha[:-1].sum(-1),self.log_alpha[-1])
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
        # print(alpha_loss.item())
        alpha_loss = policy_loss
        self.alpha = self.log_alpha.exp()
        # self.alpha[-1]=self.alpha[:-1].sum(-1)
        alpha_tlogs = self.alpha.clone() # For TensorboardX logs
        # print(self.alpha)




        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)
        if self.eta>0:
            return qf_loss.item(), grad_loss.item(),workers_policy_loss.item(), policy_loss.item(), alpha_loss.item(),alpha_tlogs.cpu().detach().numpy(), lamda_loss1.item()
        else:
            return qf_loss.item(), qf_loss.item(),workers_policy_loss.item(), policy_loss.item(), alpha_loss.item(),alpha_tlogs.cpu().detach().numpy(), lamda_loss1.item()

    # Save model parameters
    def save_checkpoint(self, env_name, suffix="", ckpt_path=None):
        if not os.path.exists('checkpoints/'+suffix):
            os.makedirs('checkpoints/'+suffix)
        if ckpt_path is None:
            ckpt_path = "checkpoints/"+suffix+"/"+"sac_checkpoint_{}_{}".format(env_name, suffix)
        print('Saving models to {}'.format(ckpt_path))
        torch.save({'policy_state_dict': self.policy.state_dict(),
                    'workers_policy_state_dict':self.workers_policy.state_dict(),
                    'critic_state_dict': self.critic.state_dict(),
                    'critic_target_state_dict': self.critic_target.state_dict(),
                    'critic_optimizer_state_dict': self.critic_optim.state_dict(),
                    'policy_optimizer_state_dict': self.policy_optim.state_dict(),
                    'workers_policy_optimizer_state_dict':self.workers_policy_optim.state_dict(),
                    'alpha':self.alpha,
                    'alpha_log':self.log_alpha,
                    'alpha_opt':self.alpha_optim.state_dict()
                    }, ckpt_path)
        print(self.alpha)

    # Load model parameters
    def load_checkpoint(self, ckpt_path, evaluate=False):
        print('Loading models from {}'.format(ckpt_path))
        if ckpt_path is not None:
            checkpoint = torch.load(ckpt_path)
            self.policy.load_state_dict(checkpoint['policy_state_dict'])
            self.critic.load_state_dict(checkpoint['critic_state_dict'])
            self.critic_target.load_state_dict(checkpoint['critic_target_state_dict'])
            self.critic_optim.load_state_dict(checkpoint['critic_optimizer_state_dict'])
            self.policy_optim.load_state_dict(checkpoint['policy_optimizer_state_dict'])
            self.workers_policy.load_state_dict(checkpoint['workers_policy_state_dict'])
            self.workers_policy_optim.load_state_dict(checkpoint['workers_policy_optimizer_state_dict'])
            self.log_alpha = checkpoint['alpha_log']#self.alpha.log()
            self.alpha = checkpoint['alpha']
            self.alpha_optim = Adam([self.log_alpha], lr=self.args.lr)
            if evaluate:
                self.policy.eval()
                self.critic.eval()
                self.critic_target.eval()
            else:
                self.policy.train()
                self.critic.train()
                self.critic_target.train()

class MRF2_Fixedalpha_noregV(object):
    def __init__(self, num_inputs, action_space, args):
        self.ensamble_num=args.ensamble_num
        self.eta = args.eta#-1
        if self.eta>0:
            print("EDAC loded")
        self.gamma = args.gamma
        self.tau = args.tau
        self.action_nums = action_space.n
        # self.alpha = args.alpha
        self.agentnum= args.agentnum

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if args.cuda else "cpu")


        # critic  output is  en * agent *batch * out
        self.critic = QNetwork_discrete_Multihead(num_inputs, self.action_nums, args.hidden_size,self.agentnum,ensamble_num=self.ensamble_num).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork_discrete_Multihead(num_inputs, self.action_nums, args.hidden_size,self.agentnum,ensamble_num=self.ensamble_num).to(self.device)
        hard_update(self.critic_target, self.critic)


        self.lamadacalculator=LamadaNetwork_hardconstraint(self.agentnum,self.agentnum-1,args,args.hidden_size).to(self.device)
        self.lamada_optim = Adam(self.lamadacalculator.parameters(),lr=args.lr)# PCGrad(Adam(self.lamadacalculator.parameters(),lr=args.lr)) 
       #Adam(self.lamadacalculator.parameters(),lr=args.lr)
        
        #workers policy head num is  agnetnum -1 * batch * action_nums
        self.workers_policy = GaussianPolicy_discrete_Multihead(num_inputs,self.action_nums,args.hidden_size,self.agentnum-1).to(self.device)
        self.workers_policy_optim=Adam(self.workers_policy.parameters(),lr=args.lr)

        #Alice policy 
        self.policy = GaussianPolicy_discrete_Multihead(num_inputs,self.action_nums,args.hidden_size, 1).to(self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
        # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper


        if self.automatic_entropy_tuning is True:
            self.target_entropy = 0.5 * -np.log(1 / self.action_nums)#-torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
            # torch.zeros(self.agentnum, requires_grad=True, device=self.device)
            self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

        self.alpha = self.log_alpha.exp()


    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if evaluate is False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0][0][0]

    def select_action_workers(self, state,idx, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if evaluate is False:
            action, _, _ = self.workers_policy.sample(state)
        else:
            _, _, action = self.workers_policy.sample(state)
        return action.detach().cpu().numpy()[idx][0][0]


    def update_parameters(self, memory, batch_size, updates,inputTuple=False,index=0):
        # Sample a batch from memory
        if not inputTuple:
            state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
        else:
            state_batch, action_batch, reward_batch, next_state_batch, mask_batch =memory
       

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.LongTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device)#  512 * n
        mask_batch = torch.FloatTensor(mask_batch).to(self.device)#.unsqueeze(1)
        if self.eta > 0:
            action_batch.requires_grad_(True)

        
        with torch.no_grad():
            #r+V(s)
            prob_workers_actions,log_prob_workers_actions = self.workers_policy.actions_logprob(next_state_batch)
            prob_alice_action,log_prob_alice_action = self.policy.actions_logprob(state_batch)
            log_prob_all = torch.cat([log_prob_workers_actions,log_prob_alice_action],dim=0)# agentnum * batch * actionnum
            prob_all = torch.cat([prob_workers_actions,prob_alice_action],dim=0)

            q_min = self.critic_target(next_state_batch).min(dim=0)[0] #  agentnum* batch *actionnums
            
            v_nextstatebatch = (torch.sum(q_min*prob_alice_action.repeat(self.agentnum,1,1),dim=-1).transpose(0,1)) -\
                 self.alpha*(torch.sum(log_prob_all*prob_alice_action.repeat(self.agentnum,1,1),dim=-1).transpose(0,1))
        
            next_q_value = reward_batch+ mask_batch*self.gamma*v_nextstatebatch # batch*n

        qs = self.critic(state_batch).gather(-1,action_batch.unsqueeze(0).repeat(self.agentnum,1,1).unsqueeze(0).repeat(self.ensamble_num,1,1,1))#en*agentnum*batch*1#, action_batch)  # Two Q-functions to mitigate positive bias in the policy improvement step
        qs = qs.squeeze(-1).transpose(-2,-1)#en*batch*agentnum
        qf_loss = F.mse_loss(qs,next_q_value.unsqueeze(0).repeat(self.ensamble_num,1,1))
        
        #EDAC here 不知道是不是对discrete适用
        if self.eta>0:

            obs = state_batch.unsqueeze(0).repeat(self.ensamble_num,1,1).unsqueeze(0).repeat(self.agentnum,1,1,1)# 4 2 256 n
            ac  =  action_batch.unsqueeze(0).repeat(self.ensamble_num,1,1).unsqueeze(0).repeat(self.agentnum,1,1,1)# 4 2 256 1
            qs_p = self.critic(obs)#2 4 4 2 256 n #,ac)
            qs_p = qs_p.gather(-1,ac.unsqueeze(0).repeat(self.agentnum,1,1).unsqueeze(0).repeat(self.ensamble_num,1,1,1))#.dtype(int))#2 4 4 2 256 1
            qs_p=(qs_p.squeeze(-1).transpose(0,-1).transpose(-1,2).transpose(-2,1)*(torch.eye(self.agentnum).to(self.device))).sum(-1).transpose(-1,-3)
            # qs_p.backward(torch.eye(self.ensamble_num))
            # qs_pred_gradss , = torch.autograd.grad(qs_p, ac, retain_graph=True, create_graph=True,grad_outputs=torch.eye(self.ensamble_num).unsqueeze(1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1,self.agentnum,1,self.agentnum,batch_size,1).to(self.device))
            qs_pred_grads, = torch.autograd.grad(qs_p, ac, retain_graph=True, create_graph=True,\
                grad_outputs=torch.eye(self.ensamble_num).unsqueeze(0).unsqueeze(0).repeat(batch_size,self.agentnum,1,1).to(self.device))

            qs_pred_grads = qs_pred_grads[:-1,::]# 3 2 256 1
            # qs_pred_grads, = torch.autograd.grad(qs_preds_tile.sum(), actions_tile, retain_graph=True, create_graph=True)
            #2,4,2,256,3
            qs_pred_grads = qs_pred_grads / (torch.norm(qs_pred_grads, p=2, dim=-1).unsqueeze(-1) + 1e-10)
            #3,2,256,3
            qs_pred_grads = qs_pred_grads.transpose(1, 2)
            #256,4,2,3
            qs_pred_grads = torch.einsum('abik,abjk->abij', qs_pred_grads, qs_pred_grads)
            masks = torch.eye(self.ensamble_num, device=self.device).unsqueeze(dim=0).repeat(qs_pred_grads.size(1),1,1).unsqueeze(dim=0).repeat(qs_pred_grads.size(0),1, 1, 1)
            qs_pred_grads = (1 - masks) * qs_pred_grads
            grad_loss = torch.mean(torch.sum(qs_pred_grads, dim=(-2, -1))) / (self.ensamble_num - 1)
            qf_loss += self.eta * grad_loss

        self.critic_optim.zero_grad()
        qf_loss.backward()
        self.critic_optim.step()
       


        # update workers' policy --dicrete
    
        pis,log_pis=self.workers_policy.actions_logprob(state_batch)#n-1*512*ac_dim n-1*512*ac_dim
        q = self.critic(state_batch)[::,:-1,::]# 2*4 * 256* ac_dim#.unsqueeze(0).repeat(pis.shape[0],1,1),pis)
        qmin = q.min(dim=0)[0]#n-1*256*ac_dim
      
        workers_policy_loss = (((pis*log_pis).sum(-1).transpose(0,1))*self.alpha - (pis*qmin).sum(-1).transpose(0,1)).mean()
        self.workers_policy_optim.zero_grad()
        workers_policy_loss.backward(retain_graph=True)
        self.workers_policy_optim.step()

        

        #here change
        with torch.no_grad():
            pi, log_pi = self.policy.actions_logprob(state_batch)#get the Alice's prob of action
            workers_pi, log_workers_pi = self.workers_policy.actions_logprob(state_batch)# n-1 *batch*ac_num
            q_out = self.critic(state_batch).min(dim=0)[0]# agentnum *batch* ac_num

        q_out = q_out.transpose(0,-1)# ac_num *batch* agentnum
        
        weights_out=self.lamadacalculator(q_out.reshape(q_out.shape[0]*q_out.shape[1],-1)).view(*q_out.shape[:-1],-1)#ac_num *batch* agentnum
        lamda_out = torch.abs(weights_out)
        flag = (weights_out>0).float()*2-1.0#ac_num *batch* agentnum

        q_rstruct = 1*(flag)*q_out[::,::,:-1]*(torch.sum(lamda_out,dim=-1,keepdim=True)+1e-9)# ac_num*512*n-1
        lamda_out = torch.div(lamda_out,torch.sum(lamda_out,dim=-1,keepdim=True)+1e-9) # ac_num*512*n-1

        temp_pi = pi.repeat(self.agentnum-1,1,1)# n-1*batch*ac_num
        temp_log_pi = log_pi.repeat(self.agentnum-1,1,1)
        D_kl = (temp_pi*(temp_log_pi-log_workers_pi)).sum(-1).transpose(0,1)# *batch*n-1

        
        q_r = torch.clamp(torch.div(q_rstruct,self.alpha+1e-30),min=-1e36,max=1e36)#    
        t = torch.clamp(log_workers_pi.transpose(-1,0)-q_r,min=-36.0,max=36.0)#ac_num*batch*n-1
        t = lamda_out*torch.exp(t)#a*log_pi - q
        t = torch.log(t.sum(dim=-1)+1e-9).transpose(0,1)#batch*ac_num

        lamda_loss1= (pi.squeeze(0)* t).sum(-1).mean() + \
              1.0*(lamda_out.transpose(0,1).gather(dim=1, index=torch.max(pi,dim=-1,keepdim=True)[1].squeeze(0).unsqueeze(-1).repeat(1,1,self.agentnum-1)).squeeze(1) * D_kl).sum(-1).mean()
        
        #reg the direction of V

        if True:

            state_batch.requires_grad_(True)

            obs = state_batch.unsqueeze(0).repeat(self.agentnum,1,1)# 4 256 n
            ac  =  action_batch.unsqueeze(0).repeat(self.agentnum,1,1)# 4*256 1
            qs_p = self.critic(obs)#2 4 4 256 n #,ac)
            qs_p = qs_p.gather(-1,ac.unsqueeze(0).repeat(self.agentnum,1,1,1).unsqueeze(0).repeat(self.ensamble_num,1,1,1,1))#2 4 4 256 1
            #-> 256 4 4 
            qs_pred_grads, = torch.autograd.grad(qs_p.sum(dim=0).squeeze(0).squeeze(-1).transpose(0,-1), obs, retain_graph=True, create_graph=True,\
                grad_outputs=torch.eye(self.agentnum).unsqueeze(0).repeat(batch_size,1,1).to(self.device))
            # 4 256 n
            grads_workers = qs_pred_grads[:-1,::].transpose(0,1)# 256*3 *n
            grads_alice = qs_pred_grads[-1:,::].repeat(self.agentnum-1,1,1).transpose(0,1)# 256,3, n
            simliartys =  (grads_workers*grads_alice).sum(-1)#torch.einsum('bik,bjk->bij', qs_pred_grads, qs_pred_grads)
            ignore_flag = (simliartys<0).float()
            lamda_loss2 = torch.norm(ignore_flag*weights_out,p=2)

            import wandb
            for i in range(self.agentnum-1):
                wandb.log({'Simlarity_{}'.format(i):ignore_flag.mean(0)[i],'update':updates})
            wandb.log({'RegLoss': torch.norm(ignore_flag*weights_out,p=2),'update':updates})
            # lamda_loss = [lamda_loss1,lamda_loss2]

        # print(torch.norm(ignore_flag*weights_out,p=2))
        #

        #pcGrad
        self.lamada_optim.zero_grad()
        lamda_loss1.backward(retain_graph=True)
        # self.lamada_optim.pc_backward(lamda_loss)
        #lamda_loss.pc_backward(retain_graph=True)
        self.lamada_optim.step()

        #Alice's policy update
        pi, log_pi = self.policy.actions_logprob(state_batch)#get the Alice's prob of action

        with torch.no_grad():
            workers_pi, log_workers_pi = self.workers_policy.actions_logprob(state_batch)# n-1 *batch*ac_num
            q_out = self.critic(state_batch).min(dim=0)[0]# agentnum *batch* ac_num

            q_out = q_out.transpose(0,-1)# ac_num *batch* agentnum
            weights_out=self.lamadacalculator(q_out.reshape(q_out.shape[0]*q_out.shape[1],-1)).view(*q_out.shape[:-1],-1)#)#ac_num *batch* agentnum
            lamda_out = torch.abs(weights_out)
            flag = (weights_out>0).float()*2-1.0#ac_num *batch* agentnum
            q_rstruct = 1*(flag)*q_out[::,::,:-1]*(torch.sum(lamda_out,dim=-1,keepdim=True)+1e-9)# ac_num*512*n-1
            lamda_out = torch.div(lamda_out,torch.sum(lamda_out,dim=-1,keepdim=True)+1e-9) # ac_num*512*n-1
            q_r = torch.clamp(torch.div(q_rstruct,self.alpha+1e-30),min=-1e36,max=1e36)#    
            t = torch.clamp(log_workers_pi.transpose(-1,0)-q_r,min=-36.0,max=36.0)#ac_num*batch*n-1
            t = lamda_out*torch.exp(t)#a*log_pi - q
            t = torch.log(t.sum(dim=-1)+1e-9).transpose(0,1)#batch*ac_num

        temp_pi = pi.repeat(self.agentnum-1,1,1)# n-1*batch*ac_num
        temp_log_pi = log_pi.repeat(self.agentnum-1,1,1)
        D_kl = (temp_pi*(temp_log_pi-log_workers_pi)).sum(-1).transpose(0,1)# *batch*n-1
        
        policy_loss= 1.0* (pi.squeeze(0)* t).sum(-1).mean() +   \
            (lamda_out.transpose(0,1).gather(dim=1, index=torch.max(pi,dim=-1,keepdim=True)[1].squeeze(0).unsqueeze(-1).repeat(1,1,self.agentnum-1)).squeeze(1) * D_kl).sum(-1).mean()
        
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()  
        # policy_loss = torch.tensor([0])
        # lamda_loss = policy_loss
        if not True:
            with torch.no_grad():
                workers_pi,workers_log_pi=self.workers_policy.actions_logprob(state_batch)# n*512*1
                Alice_pi,Alice_log_pi=  self.policy.actions_logprob(state_batch)
                log_pis = torch.cat([workers_log_pi,Alice_log_pi],dim=0)
                pis = torch.cat([workers_pi,Alice_pi],dim=0)#n*batch*ac_num

            alpha_loss = self.log_alpha*(((-1*Alice_pi*Alice_log_pi).sum(-1).transpose(0,1))-self.target_entropy).detach().mean()
            # (self.log_alpha[:-1]*((-1*Alice_pi.repeat(self.agentnum-1,1,1)*workers_log_pi).sum(-1).transpose(0,1)-self.target_entropy).detach()).mean()+\
            #     F.mse_loss(self.log_alpha[:-1].sum(-1),self.log_alpha[-1])
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
        # print(alpha_loss.item())
        alpha_loss = policy_loss
        self.alpha = self.log_alpha.exp()
        # self.alpha[-1]=self.alpha[:-1].sum(-1)
        alpha_tlogs = self.alpha.clone() # For TensorboardX logs
        # print(self.alpha)




        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)
        if self.eta>0:
            return qf_loss.item(), grad_loss.item(),workers_policy_loss.item(), policy_loss.item(), alpha_loss.item(),alpha_tlogs.cpu().detach().numpy(), lamda_loss1.item()
        else:
            return qf_loss.item(), qf_loss.item(),workers_policy_loss.item(), policy_loss.item(), alpha_loss.item(),alpha_tlogs.cpu().detach().numpy(), lamda_loss1.item()

    # Save model parameters
    def save_checkpoint(self, env_name, suffix="", ckpt_path=None):
        if not os.path.exists('checkpoints/'+suffix):
            os.makedirs('checkpoints/'+suffix)
        if ckpt_path is None:
            ckpt_path = "checkpoints/"+suffix+"/"+"sac_checkpoint_{}_{}".format(env_name, suffix)
        print('Saving models to {}'.format(ckpt_path))
        torch.save({'policy_state_dict': self.policy.state_dict(),
                    'workers_policy_state_dict':self.workers_policy.state_dict(),
                    'critic_state_dict': self.critic.state_dict(),
                    'critic_target_state_dict': self.critic_target.state_dict(),
                    'critic_optimizer_state_dict': self.critic_optim.state_dict(),
                    'policy_optimizer_state_dict': self.policy_optim.state_dict(),
                    'workers_policy_optimizer_state_dict':self.workers_policy_optim.state_dict(),
                    'alpha':self.alpha,
                    'alpha_log':self.log_alpha,
                    'alpha_opt':self.alpha_optim.state_dict()
                    }, ckpt_path)
        print(self.alpha)

    # Load model parameters
    def load_checkpoint(self, ckpt_path, evaluate=False):
        print('Loading models from {}'.format(ckpt_path))
        if ckpt_path is not None:
            checkpoint = torch.load(ckpt_path)
            self.policy.load_state_dict(checkpoint['policy_state_dict'])
            self.critic.load_state_dict(checkpoint['critic_state_dict'])
            self.critic_target.load_state_dict(checkpoint['critic_target_state_dict'])
            self.critic_optim.load_state_dict(checkpoint['critic_optimizer_state_dict'])
            self.policy_optim.load_state_dict(checkpoint['policy_optimizer_state_dict'])
            self.workers_policy.load_state_dict(checkpoint['workers_policy_state_dict'])
            self.workers_policy_optim.load_state_dict(checkpoint['workers_policy_optimizer_state_dict'])
            self.log_alpha = checkpoint['alpha_log']#self.alpha.log()
            self.alpha = checkpoint['alpha']
            self.alpha_optim = Adam([self.log_alpha], lr=self.args.lr)
            if evaluate:
                self.policy.eval()
                self.critic.eval()
                self.critic_target.eval()
            else:
                self.policy.train()
                self.critic.train()
                self.critic_target.train()

