# -- coding: utf-8 --
from cmath import inf
import imp
import os
from time import sleep
# from cv2 import exp, log
import torch
import torch.nn.functional as F
from torch.optim import Adam
from wandb import agent
#from SAC.utils import set_seed
from utils import soft_update, hard_update
from model import GaussianPolicy, QNetwork, DeterministicPolicy,GaussianPolicy_continue,QNetwork_continue,LamadaNetwork,LamadaNetwork_hardconstraint,LamadaNetwork_hardconstraint_sa,QNetwork_continue_Multihead,GaussianPolicy_continue_Multihead
import numpy as np
from mpi4py import MPI
from mpi_utils.mpi_utils import sync_networks, sync_grads

import torch.nn as nn 

import torch.autograd as autograds



def logsumexp(lamda,log_pi,q_out):
    t = log_pi-q_out
    t = torch.clamp(t,min=-36.0,max=36.0)
    temp = lamda*torch.exp(t)#a*log_pi - q
    assert (lamda>=0).all(), print('lamda is wrong {}'.format(lamda))
    return torch.log(temp.sum(dim=1)+1e-9)


class SAC_continue(object):
    def __init__(self, num_inputs, action_space, args):
        self.args=args
        self.gamma = args.gamma
        self.tau = args.tau
        self.alpha = args.alpha

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if args.cuda else "cpu")

        self.critic = QNetwork_continue(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork_continue(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)

        # self.critic_h = QNetwork_continue(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device)
        # self.critic_h_optim = Adam(self.critic_h.parameters(), lr=args.lr)

        # self.critic_h_target = QNetwork_continue(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
        # hard_update(self.critic_h_target, self.critic_h)

        if self.policy_type == "Gaussian":
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning is True:
                self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
                self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
                self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

            self.policy = GaussianPolicy_continue(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

        else:
            self.alpha = 0
            self.automatic_entropy_tuning = False
            self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if evaluate is False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]

    def update_parameters(self, memory, batch_size, updates,inputTuple=False,index=0):
        # Sample a batch from memory
        if not inputTuple:
            state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
        else:
            state_batch, action_batch, reward_batch, next_state_batch, mask_batch =memory

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch[::,index:index+1]).to(self.device)#.unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device)#.unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)

        qf1, qf2 = self.critic(state_batch, action_batch)  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(qf1, next_q_value)  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(qf2, next_q_value)  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf_loss = qf1_loss + qf2_loss

        self.critic_optim.zero_grad()
        qf_loss.backward()
        self.critic_optim.step()



        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone() # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs


        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item()

    # Save model parameters
    def save_checkpoint(self, env_name, suffix="", ckpt_path=None):
        if not os.path.exists('checkpoints/'+suffix):
            os.makedirs('checkpoints/'+suffix)
        if ckpt_path is None:
            ckpt_path = "checkpoints/"+suffix+"/"+"sac_checkpoint_{}".format(env_name)#, suffix)
        # print('Saving models to {}'.format(ckpt_path))
        torch.save({'policy_state_dict': self.policy.state_dict(),
                    'critic_state_dict': self.critic.state_dict(),
                    'critic_target_state_dict': self.critic_target.state_dict(),
                    'critic_optimizer_state_dict': self.critic_optim.state_dict(),
                    'policy_optimizer_state_dict': self.policy_optim.state_dict(),
                    'alpha':self.alpha,
                    'alpha_log':self.log_alpha,
                    'alpha_opt':self.alpha_optim.state_dict()
                    }, ckpt_path)
        # print(self.alpha)

    # Load model parameters
    def load_checkpoint(self, ckpt_path, evaluate=False):
        # print('Loading models from {}'.format(ckpt_path))
        
        if (ckpt_path is not None) and (os.path.exists(ckpt_path)):
            # if  :
            checkpoint = torch.load(ckpt_path)
            self.policy.load_state_dict(checkpoint['policy_state_dict'])
            self.critic.load_state_dict(checkpoint['critic_state_dict'])
            self.critic_target.load_state_dict(checkpoint['critic_target_state_dict'])
            self.critic_optim.load_state_dict(checkpoint['critic_optimizer_state_dict'])
            self.policy_optim.load_state_dict(checkpoint['policy_optimizer_state_dict'])

            self.log_alpha = checkpoint['alpha_log']#self.alpha.log()
            self.alpha = checkpoint['alpha']
            self.alpha_optim = Adam([self.log_alpha], lr=self.args.lr)
            if evaluate:
                self.policy.eval()
                self.critic.eval()
                self.critic_target.eval()
            else:
                self.policy.train()
                self.critic.train()
                self.critic_target.train()
                

from pcgrad import PCGrad

class MRF2(object):
    def __init__(self, num_inputs, action_space, args):
        self.ensamble_num=args.ensamble_num
        self.eta = args.eta#-1
        if self.eta>0:
            print("EDAC loded")
        self.gamma = args.gamma
        self.tau = args.tau
        # self.alpha = args.alpha
        self.agentnum= args.agentnum

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if args.cuda else "cpu")


        # critic  output is batchsize*agentnum
        self.critic = QNetwork_continue_Multihead(num_inputs, action_space.shape[0], args.hidden_size,self.agentnum,ensamble_num=self.ensamble_num).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork_continue_Multihead(num_inputs, action_space.shape[0], args.hidden_size,self.agentnum,ensamble_num=self.ensamble_num).to(self.device)
        hard_update(self.critic_target, self.critic)

      
        self.lamadacalculator=LamadaNetwork_hardconstraint(self.agentnum,self.agentnum-1,args,args.hidden_size).to(self.device)
        self.lamada_optim = PCGrad(Adam(self.lamadacalculator.parameters(),lr=args.lr)) 
        #Adam(self.lamadacalculator.parameters(),lr=args.lr)
        
        #workers policy head num is  agnetnum-1
        self.workers_policy = GaussianPolicy_continue_Multihead(num_inputs,action_space.shape[0],args.hidden_size,self.agentnum-1,action_space).to(self.device)
        self.workers_policy_optim=Adam(self.workers_policy.parameters(),lr=args.lr)

        #Alice policy 
        #可能Alice需要自己进行学习因为他们的loss不同，不过这样会不会导致网络的资源的浪费？
        self.policy = GaussianPolicy_continue(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
        # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
        if self.automatic_entropy_tuning is True:
            self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
            self.log_alpha = torch.zeros(self.agentnum, requires_grad=True, device=self.device)
            self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

        self.alpha = self.log_alpha.exp()


    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if evaluate is False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]

    def select_action_workers(self, state,idx, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if evaluate is False:
            action, _, _ = self.workers_policy.sample(state)
        else:
            _, _, action = self.workers_policy.sample(state)
        return action.detach().cpu().numpy()[idx][0]

    def update_parameters(self, memory, batch_size, updates,inputTuple=False,index=0):
        # Sample a batch from memory
        if not inputTuple:
            state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
        else:
            state_batch, action_batch, reward_batch, next_state_batch, mask_batch =memory
       

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device)#  512 * n
        mask_batch = torch.FloatTensor(mask_batch).to(self.device)#.unsqueeze(1)
        if self.eta > 0:
            action_batch.requires_grad_(True)

        with torch.no_grad():
            # here has two part the workers q and Alice's q
            # first workers' q
            next_state_action,next_state_log_pi,_= self.workers_policy.sample(next_state_batch)#N*512*8 N*512*1  N*512*8
            next_state_action_Alice, next_state_log_pi_Alice, _ = self.policy.sample(next_state_batch)
            #Alice
            next_state_action=torch.cat([next_state_action,next_state_action_Alice.unsqueeze(dim=0)],dim=0)
            next_state_log_pi=torch.cat([next_state_log_pi,next_state_log_pi_Alice.unsqueeze(dim=0)],dim=0)#N+1*512*8 N+1*512*1 
            
            #2*4*h+1*512*1
            q_min = self.critic_target(next_state_batch.unsqueeze(0).repeat(next_state_action.shape[0],1,1), next_state_action).min(dim=0)[0].squeeze(-1).transpose(-1,0)
            mask = torch.eye(next_state_action.shape[0]).to(self.device)
            min_qf_next_target = (q_min*mask).sum(-1) -self.alpha* next_state_log_pi.squeeze(-1).transpose(-1,0)
            
            # print(((min_qf_next_target_t-min_qf_next_target)<1e-6).all())
            next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)  


        qs = self.critic(state_batch, action_batch)  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf_loss = F.mse_loss(qs.squeeze(dim=-1).transpose(-1,-2),torch.repeat_interleave(next_q_value.unsqueeze(dim=0),self.ensamble_num,dim=0))



        #EDAC here
        if self.eta>0:
           
            obs = state_batch.unsqueeze(0).repeat(self.ensamble_num,1,1).unsqueeze(0).repeat(self.agentnum,1,1,1)
            ac  =  action_batch.unsqueeze(0).repeat(self.ensamble_num,1,1).unsqueeze(0).repeat(self.agentnum,1,1,1)
            qs_p = self.critic(obs,ac)
            qs_p=(qs_p.squeeze(-1).transpose(0,-1).transpose(-1,2).transpose(-2,1)*(torch.eye(self.agentnum).to(self.device))).sum(-1).transpose(-1,-3)
            # qs_p.backward(torch.eye(self.ensamble_num))
            # qs_pred_gradss , = torch.autograd.grad(qs_p, ac, retain_graph=True, create_graph=True,grad_outputs=torch.eye(self.ensamble_num).unsqueeze(1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1,self.agentnum,1,self.agentnum,batch_size,1).to(self.device))
            qs_pred_grads , = torch.autograd.grad(qs_p, ac, retain_graph=True, create_graph=True,\
                grad_outputs=torch.eye(self.ensamble_num).unsqueeze(0).unsqueeze(0).repeat(batch_size,self.agentnum,1,1).to(self.device))

            qs_pred_grads = qs_pred_grads[:-1,::]
            # qs_pred_grads, = torch.autograd.grad(qs_preds_tile.sum(), actions_tile, retain_graph=True, create_graph=True)
            #2,4,2,256,3
            qs_pred_grads = qs_pred_grads / (torch.norm(qs_pred_grads, p=2, dim=-1).unsqueeze(-1) + 1e-10)
            #3,2,256,3
            qs_pred_grads = qs_pred_grads.transpose(1, 2)
            #256,4,2,3
            qs_pred_grads = torch.einsum('abik,abjk->abij', qs_pred_grads, qs_pred_grads)
            masks = torch.eye(self.ensamble_num, device=self.device).unsqueeze(dim=0).repeat(qs_pred_grads.size(1),1,1).unsqueeze(dim=0).repeat(qs_pred_grads.size(0),1, 1, 1)
            qs_pred_grads = (1 - masks) * qs_pred_grads
            grad_loss = torch.mean(torch.sum(qs_pred_grads, dim=(-2, -1))) / (self.ensamble_num - 1)
            qf_loss += self.eta * grad_loss

        
        self.critic_optim.zero_grad()
        qf_loss.backward()
        self.critic_optim.step()

        # update workers' policy
        pis,log_pis,_=self.workers_policy.sample(state_batch)#n*512*8 n*512*1

        q = self.critic(state_batch.unsqueeze(0).repeat(pis.shape[0],1,1),pis)

        min1 = q.min(dim = 0)[0].squeeze(-1).transpose(-1,0)[::,::,:-1] #q.min(dim = 0)[0].squeeze(-1).transpose(-1,-2).transpose(0,1)[::,:-1,::]   # h*n *512*1-> 512*n*n
        mask = torch.eye(min1.shape[-1]).unsqueeze(0).to(self.device)
        min_qf_pi = (min1*mask).sum(-1)#512*n

        log_pi = log_pis.squeeze(-1).transpose(1,0)
        workers_policy_loss=(self.alpha[:-1]*log_pi-min_qf_pi).mean()

        self.workers_policy_optim.zero_grad()
        workers_policy_loss.backward(retain_graph=True)
        self.workers_policy_optim.step()
  
        

        #here change
        with torch.no_grad():
            pi, log_pi, _ = self.policy.sample(state_batch)#get the Alice's action
            q_out=torch.min(self.critic(state_batch,pi),dim=0)[0].squeeze(dim=-1).transpose(1,0)# 512*n
            #p(a|s) in other policy
            log_probs=self.workers_policy.calculate_prob(state_batch,pi) #n*512*1
            # mean,std of other policys 
            means,logstds=self.workers_policy.forward(state_batch)
            mean_Alice,logstd_Alice = self.policy.forward(state_batch)

        weights_out=self.lamadacalculator(q_out)
        lamda_out = torch.abs(weights_out)
        flag = (weights_out>0).float()*2-1.0
        q_rstruct = 1*(flag)*q_out[::,:-1]*(torch.sum(lamda_out,dim=-1,keepdim=True)+1e-9)
        lamda_out = torch.div(lamda_out,torch.sum(lamda_out,dim=-1,keepdim=True)+1e-9) #512*n

        log_probs=log_probs.squeeze(dim=-1).transpose(1,0)#512*n #other policys' prob 

        #https://www.cnblogs.com/qizhou/p/13804283.html
        logstd_Alice = logstd_Alice.unsqueeze(0).repeat(self.agentnum-1,1,1)
        mean_Alice = mean_Alice.unsqueeze(0).repeat(self.agentnum-1,1,1)

        kl = torch.sum(logstds-logstd_Alice,dim=-1)+\
            0.5*torch.sum(torch.exp(2.0*(logstd_Alice-logstds)),dim=-1)+\
                0.5*torch.sum(torch.div(torch.pow(mean_Alice-means,2),(torch.exp(2.0*logstds)+1e-9)),dim=-1)-\
                    0.5*logstd_Alice.shape[-1]

        
        
        q_r = torch.clamp(torch.div(q_rstruct,self.alpha[-1]+1e-30),min=-1e36,max=1e36)#
    

        lamda_loss = logsumexp(lamda_out,log_probs,q_r).mean()+((kl.transpose(0,1)*lamda_out).sum(dim=1)).mean()\
            + torch.abs(weights_out).sum(-1).mean()
       
        assert torch.isinf(lamda_loss).sum()==0, print('lamda loss error:{}'.format(lamda_loss.item())) 
        assert torch.isnan(lamda_loss).sum()==0, print('lamda loss error nan:{}'.format(lamda_loss.item()))


        assert torch.isinf(lamda_loss).sum()==0, print('lamda2 loss error:{}'.format(lamda_loss.item())) 
        assert torch.isnan(lamda_loss).sum()==0, print('lamda2 loss error nan:{}'.format(lamda_loss.item())) 

        if True:

            action_batch.requires_grad_(True)

            obs = state_batch.unsqueeze(0).repeat(self.agentnum,1,1)# 4 256 n
            ac  =  action_batch.unsqueeze(0).repeat(self.agentnum,1,1)# 4*256 n
            qs_p = self.critic(obs,ac)#2 4 4 256 1 #,ac)
            
            #-> 256 4 4 
            qs_pred_grads, = torch.autograd.grad(qs_p.sum(dim=0).squeeze(0).squeeze(-1).transpose(0,-1), ac, retain_graph=True, create_graph=True,\
                grad_outputs=torch.eye(self.agentnum).unsqueeze(0).repeat(batch_size,1,1).to(self.device))
            # 4 256 n
            grads_workers = qs_pred_grads[:-1,::].transpose(0,1)# 256*3 *n
            grads_alice = qs_pred_grads[-1:,::].repeat(self.agentnum-1,1,1).transpose(0,1)# 256,3, n
            simliartys =  (grads_workers*grads_alice).sum(-1)#torch.einsum('bik,bjk->bij', qs_pred_grads, qs_pred_grads)
            ignore_flag = (simliartys<0).float()
            lamda_loss2 = torch.norm(ignore_flag*weights_out,p=2)

            if updates%1000==999:
                for i in range(self.agentnum-1):
                    print('updates:{}, Simlarity_{}: {}'.format(updates,i,ignore_flag.mean(0)[i]))
            import wandb
            if True:
                for i in range(self.agentnum-1):
                    wandb.log({'Simlarity_{}'.format(i):ignore_flag.mean(0)[i],'update':updates})
                wandb.log({'RegLoss': lamda_loss2,'update':updates})


        lamda_loss_all = [lamda_loss,lamda_loss2]

        # #pcGrad

        self.lamada_optim.zero_grad()
        self.lamada_optim.pc_backward(lamda_loss_all)
        # lamda_loss.backward(retain_graph=True)
        self.lamada_optim.step()

        #Alice's policy update
       
       
        #caculate the lamada
        pi, log_pi, _ = self.policy.sample(state_batch)
        with torch.no_grad():
            q_out=torch.min(self.critic(state_batch,pi),dim=0)[0].squeeze(dim=-1).transpose(1,0)# 512*n
            weights_out=self.lamadacalculator(q_out)#,state_batch,pi)
            lamda_out = torch.abs(weights_out)
            flag = (weights_out>0).float()*2-1.0
            q_rstruct = 1*(flag)*q_out[::,:-1]*(torch.sum(lamda_out,dim=-1,keepdim=True)+1e-9)
            lamda_out = torch.div(lamda_out,torch.sum(lamda_out,dim=-1,keepdim=True)+1e-9) #512*n
            log_probs=self.workers_policy.calculate_prob(state_batch,pi) #n*512*1
            means,logstds=self.workers_policy.forward(state_batch)

        mean_Alice,logstd_Alice = self.policy.forward(state_batch)


        logstd_Alice = logstd_Alice.unsqueeze(0).repeat(self.agentnum-1,1,1)
        mean_Alice = mean_Alice.unsqueeze(0).repeat(self.agentnum-1,1,1)

        kl = torch.sum(logstds-logstd_Alice,dim=-1)+\
            0.5*torch.sum(torch.exp(2.0*(logstd_Alice-logstds)),dim=-1)+\
                0.5*torch.sum(torch.div(torch.pow(mean_Alice-means,2),(torch.exp(2.0*logstds)+1e-9)),dim=-1)-\
                    0.5*logstd_Alice.shape[-1]

        policy_loss=((kl.transpose(0,1)*lamda_out).sum(dim=1)).mean()

        log_probs=log_probs.squeeze(dim=-1).transpose(1,0)#512*n #other policys' prob 
        
        # Jπ 
        q_r = torch.clamp(torch.div(q_rstruct,self.alpha[-1]+1e-30),min=-1e36,max=1e36)
        policy_loss+=logsumexp(lamda_out,log_probs,q_r).mean()#self.alpha*

        assert torch.isinf(policy_loss).sum()==0, print('policy loss error:{}'.format(policy_loss.item())) 
        assert torch.isnan(policy_loss).sum()==0, print('policy loss error_nan:{}'.format(policy_loss.item())) 
       

        self.policy_optim.zero_grad()
      
        policy_loss.backward()
  
        self.policy_optim.step()

        
        with torch.no_grad():
            _,workers_log_pi,_=self.workers_policy.sample(state_batch)# n*512*1
            _,Alice_log_pi,_=  self.policy.sample(state_batch)
            log_pis = torch.cat([workers_log_pi.squeeze(dim=-1).transpose(1,0),Alice_log_pi],dim=-1)
           

        alpha_loss = -(self.log_alpha * (log_pis + self.target_entropy).detach()).mean()

        self.alpha_optim.zero_grad()
        if  True:
            alpha_loss.backward()
        self.alpha_optim.step()

        self.alpha = self.log_alpha.exp()
        alpha_tlogs = self.alpha.clone() # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)
        if self.eta>0:
            return qf_loss.item(), grad_loss.item(),workers_policy_loss.item(), policy_loss.item(), alpha_loss.item(),alpha_tlogs.cpu().detach().numpy(), lamda_loss.item()
        else:
            return qf_loss.item(), qf_loss.item(),workers_policy_loss.item(), policy_loss.item(), alpha_loss.item(),alpha_tlogs.cpu().detach().numpy(), lamda_loss.item()

    # Save model parameters
    def save_checkpoint(self, env_name, suffix="", ckpt_path=None):
        if not os.path.exists('checkpoints/'+suffix):
            os.makedirs('checkpoints/'+suffix)
        if ckpt_path is None:
            ckpt_path = "checkpoints/"+suffix+"/"+"sac_checkpoint_{}_{}".format(env_name, suffix)
        print('Saving models to {}'.format(ckpt_path))
        torch.save({'policy_state_dict': self.policy.state_dict(),
                    'workers_policy_state_dict':self.workers_policy.state_dict(),
                    'critic_state_dict': self.critic.state_dict(),
                    'critic_target_state_dict': self.critic_target.state_dict(),
                    'critic_optimizer_state_dict': self.critic_optim.state_dict(),
                    'policy_optimizer_state_dict': self.policy_optim.state_dict(),
                    'workers_policy_optimizer_state_dict':self.workers_policy_optim.state_dict(),
                    'alpha':self.alpha,
                    'alpha_log':self.log_alpha,
                    'alpha_opt':self.alpha_optim.state_dict()
                    }, ckpt_path)
        print(self.alpha)

    # Load model parameters
    def load_checkpoint(self, ckpt_path, evaluate=False):
        print('Loading models from {}'.format(ckpt_path))
        if ckpt_path is not None:
            checkpoint = torch.load(ckpt_path)
            self.policy.load_state_dict(checkpoint['policy_state_dict'])
            self.critic.load_state_dict(checkpoint['critic_state_dict'])
            self.critic_target.load_state_dict(checkpoint['critic_target_state_dict'])
            self.critic_optim.load_state_dict(checkpoint['critic_optimizer_state_dict'])
            self.policy_optim.load_state_dict(checkpoint['policy_optimizer_state_dict'])
            self.workers_policy.load_state_dict(checkpoint['workers_policy_state_dict'])
            self.workers_policy_optim.load_state_dict(checkpoint['workers_policy_optimizer_state_dict'])
            self.log_alpha = checkpoint['alpha_log']#self.alpha.log()
            self.alpha = checkpoint['alpha']
            self.alpha_optim = Adam([self.log_alpha], lr=self.args.lr)
            if evaluate:
                self.policy.eval()
                self.critic.eval()
                self.critic_target.eval()
            else:
                self.policy.train()
                self.critic.train()
                self.critic_target.train()

