import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import utils
import hydra
import time
from typing import Union
import warnings
import matplotlib.pyplot as plt


class D2CAgent(object):
    def __init__(self, obs_shape, action_shape, action_range, device,
                 encoder_cfg, # encoder_target_cfg, encoder_target_tau, encoder_update_frequency, 
                 critic_cfg, critic_target_cfg, 
                 actor_cfg, 
                 discount,
                 init_temperature, lr, actor_update_frequency,
                 critic_target_tau, critic_target_update_frequency,
                 batch_size,
                 num_seed_steps,                 
                 env_name = None,                 
                 
                 consider_done_true_in_critic = False,   
                 normalize_rl_obs = False,
                 randomwalk_method = 'rand_action',
                 goal_dim = None,                  
                 adam_eps = 1e-8, optim='adam',
                 rl_reward_type = 'sparse',
                 sparse_reward_type = 'negative',
                 use_d2c = False, d2c_cfg = None, d2c_kwargs = None, d2c_reward_type = None,
                 
                 grad_norm_clipping = 0., grad_value_clipping = 0., q_clip=False,

                 d2c_feature_dim = None,
                 alpha_auto = True,
                 ):
        self.action_range = action_range
        self.device = device
        self.discount = discount
        self.actor_update_frequency = actor_update_frequency
        self.critic_target_tau = critic_target_tau
        self.critic_target_update_frequency = critic_target_update_frequency
        
        self.alpha_auto = alpha_auto
        self.batch_size = batch_size
        self.goal_dim = goal_dim
        
        self.num_seed_steps = num_seed_steps
        self.lr = lr
        self.rl_reward_type = rl_reward_type
        self.sparse_reward_type = sparse_reward_type
        self.encoder = encoder_cfg.to(self.device) 
        
        
        self.d2c_feature_dim = d2c_feature_dim
    
        actor_cfg.repr_dim = self.encoder.repr_dim
        critic_cfg.repr_dim = self.encoder.repr_dim        
        
        self.actor = actor_cfg.to(self.device) 
        self.critic = critic_cfg.to(self.device) 

        self.critic_target = critic_target_cfg.to(self.device) 
        self.critic_target.load_state_dict(self.critic.state_dict())
        
        self.log_alpha = torch.from_numpy(np.array(np.log(init_temperature))).float().to(self.device)
        self.log_alpha.requires_grad = True    
        self.alpha_lr = 1e-5
            
        self.grad_value_clipping = grad_value_clipping
        self.grad_norm_clipping = grad_norm_clipping
        self.q_clip = q_clip
        
          
        self.env_name = env_name        

        self.consider_done_true_in_critic = consider_done_true_in_critic
        self.adam_eps = adam_eps
        self.optim = optim
        
        
        self.initial_state = None
        self.final_goal_states = None

        self.normalize_rl_obs = normalize_rl_obs
        self.randomwalk_method = randomwalk_method
        
        self.use_d2c = use_d2c
        self.d2c_dict = None
        if use_d2c:            
            self.d2c_kwargs = d2c_kwargs
            self.d2c_reward_type = d2c_reward_type
            self.d2c_gcrl = d2c_kwargs['goal_condition']
            self.d2c_goal_candidate_type = d2c_kwargs['goal_candidate_type']
            self.d2c_n_goal_candidates = d2c_kwargs['n_goal_candidates']
            self.d2c_n_noise_augment_per_goal = d2c_kwargs['n_noise_augment_per_goal']
            self.d2c_noise_scale = d2c_kwargs['noise_scale']
            self.d2c_mode = d2c_kwargs['mode']
            self.d2c_reduction = d2c_kwargs['reduction']
            self.d2c_aux_weight = d2c_kwargs['aux_weight']
            self.d2c_normalize = d2c_kwargs['normalize']            
            self.d2c_batch_size = d2c_kwargs['batch_size']
            self.d2c_train_every_k = d2c_kwargs['train_every_k']
            self.d2c_num_init_update = d2c_kwargs['num_init_update']
            self.d2c_temperature = d2c_kwargs['temperature']
            self.d2c_num_update = d2c_kwargs['num_update']
            self.d2c_lr = d2c_kwargs['lr']
            self.d2c_use_randomwalk_buffer = d2c_kwargs['use_randomwalk_buffer']
            self.d2c = d2c_cfg.to(self.device) 
            from d2c.d2c import DiversifyLoss
            self.d2c_loss_fn = DiversifyLoss(heads=self.d2c.heads, mode=self.d2c_mode, reduction=self.d2c_reduction)
        
                    
        # set target entropy to -|A|
        self.target_entropy = -action_shape[0] # default
        
        # optimizers
        self.init_optimizers(lr)

        self.train()
        
        self.critic_target.train() 
        
    def init_optimizers(self, lr):
        if self.optim=='adam':
            self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr, eps=self.adam_eps)
            self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr, eps=self.adam_eps)
            
            if self.alpha_auto:
                self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=lr, eps=self.adam_eps) #eps=1e-02
            if self.use_d2c:
                self.d2c_optimizer = torch.optim.Adam(self.d2c.parameters(), lr=self.d2c_lr, eps=self.adam_eps) #eps=1e-02
            

    def train(self, training=True):
        self.training = training
        self.actor.train(training)
        self.critic.train(training)        
        self.encoder.train(training)
        
        if self.use_d2c:
            self.d2c.train(training)
        
    @property
    def alpha(self):
        return self.log_alpha.exp()
    

    def act(self, obs, spec, sample=False):
        if self.normalize_rl_obs:
            obs = self.normalize_obs(obs, self.env_name)


        obs = torch.from_numpy(obs).float().to(self.device)
        obs = obs.unsqueeze(0)
        obs = self.encoder.encode(obs)
        
        dist = self.actor(obs)
        action = dist.sample() if sample else dist.mean
        action = action.clamp(*self.action_range)
        assert action.ndim == 2 and action.shape[0] == 1
        return utils.to_np(action[0])        


    def sample_negatives(self, replay_buffer, goal_env, size): # from replay buffer        
        obs, _, _, _, _, _ = replay_buffer.sample_without_relabeling(size, self.discount, sample_only_state = False)
        batch = goal_env.convert_obs_to_dict(obs.detach().cpu().numpy())['achieved_goal']
        
        negatives = batch
        labels = np.zeros(len(negatives))
        
        return negatives.astype(np.float32), labels
        
    def sample_positives(self, size): # from final goal
        
        final_goal = self.final_goal_states.copy()
        
        rand_positive_ind = np.random.randint(0, final_goal.shape[0], size=size)
        
        batch = final_goal[rand_positive_ind]
        
        positives = batch

        return positives.astype(np.float32), np.ones(len(positives))
    
    def get_d2c_target_data(self, uniform_goal_sampler, size):
        # get uniform samples from entire state space
        return uniform_goal_sampler.sample(num_sample=size, sample_feasible=False)
        
    def sample_d2c_batch(self, size, uniform_goal_sampler, replay_buffer=None, goal_env=None, randomwalk_buffer=None):    
        if self.d2c_gcrl:             
            if randomwalk_buffer is not None and len(randomwalk_buffer)!=0:
                num_randomwalk_data = int(self.d2c_n_goal_candidates*self.d2c_n_noise_augment_per_goal/10)
                num_original_data = self.d2c_n_goal_candidates*self.d2c_n_noise_augment_per_goal - num_randomwalk_data
                data_1, _ = self.sample_negatives(replay_buffer, goal_env, num_original_data)
                data_2, _ = self.sample_negatives(randomwalk_buffer, goal_env, num_randomwalk_data)
                buffer_data = np.concatenate([data_1, data_2], axis =0)
                
            else:
                buffer_data, _ = self.sample_negatives(replay_buffer, goal_env, self.d2c_n_goal_candidates*self.d2c_n_noise_augment_per_goal)
            
            if self.d2c_goal_candidate_type=='buffer':
                if randomwalk_buffer is not None:
                    raise NotImplementedError
                                
                obs, _, _, _, _, _ = replay_buffer.sample_without_relabeling(int(self.d2c_n_goal_candidates/2), self.discount, sample_only_state = False)

                goal_candidates = goal_env.convert_obs_to_dict(obs.detach().cpu().numpy())['achieved_goal'] # [can, dim] , [c, f*c, h, w] if image               
                
                obs2 = self.final_goal_states.copy() # num_curriculum*num_target
                
                while obs2.shape[0] < goal_candidates.shape[0]:
                    obs2 = np.tile(obs2, (2, 1))
                
                indices = np.random.randint(0, obs2.shape[0], size=int(self.d2c_n_goal_candidates/2))
                obs2 = obs2[indices]

                goal_candidates = np.concatenate([goal_candidates, obs2], axis=0)
    
            elif self.d2c_goal_candidate_type=='uniform':                
                goal_candidates = uniform_goal_sampler.sample(num_sample=self.d2c_n_goal_candidates, sample_feasible=False) # [can, dim]
                # already feature level uniform sample -> do not need to encode
                
            else:
                raise NotImplementedError

            noise = np.random.uniform(-self.d2c_noise_scale, self.d2c_noise_scale, size=(1, self.d2c_n_noise_augment_per_goal, self.d2c_feature_dim)) # [1, num_noise, dim]
            
            tiled_goal_candidates = np.tile(goal_candidates[:, None, :], (1, self.d2c_n_noise_augment_per_goal,1)) # [can, dim] -> [can, num_noise, dim]]
            augmented_goal_candidates = goal_candidates[:, None, :] + noise # [can, 1, dim] + [1, num_noise, dim]  = [can, num_noise, dim]
            
            # [label 0 inputs, conditioned goal]
            x_0_s = np.concatenate([buffer_data, tiled_goal_candidates.reshape(-1, self.d2c_feature_dim)], axis=-1) # [can*num_noise, dim*2]
            
            # [label 1 inputs, conditioned goal]
            x_1_s = np.concatenate([augmented_goal_candidates.reshape(-1, self.d2c_feature_dim), tiled_goal_candidates.reshape(-1, self.d2c_feature_dim)], axis=-1) # [can*num_noise, dim*2]
            
            y_0_s = np.zeros([x_0_s.shape[0],1])
            y_1_s = np.ones([x_1_s.shape[0],1])
            
            train_data_x = np.concatenate([x_0_s, x_1_s], 0)
            train_data_y = np.concatenate([y_0_s, y_1_s], 0)

            x_0_t = uniform_goal_sampler.sample(num_sample=self.d2c_n_goal_candidates*self.d2c_n_noise_augment_per_goal, sample_feasible=False) # [can, dim]
            x_0_t = np.concatenate([x_0_t, tiled_goal_candidates.reshape(-1, self.d2c_feature_dim)], axis =-1)
            y_0_t = np.zeros([x_0_t.shape[0],1])

            test_data_x = np.concatenate([x_0_t, x_0_s, x_1_s], 0)
            test_data_y = np.concatenate([y_0_t, y_0_s, y_1_s], 0)
        else:
            if randomwalk_buffer is not None:
                raise  NotImplementedError
            negatives = self.sample_negatives(replay_buffer, goal_env, size)
            positives = self.sample_positives(size)
            underspecified_target_data = self.get_d2c_target_data(uniform_goal_sampler, size)
            
            train_data_x = np.concatenate([negatives[0], positives[0]], axis=0)
            train_data_y = np.concatenate([negatives[1], positives[1]], axis=0)
            if len(train_data_y.shape)==1:
                train_data_y = train_data_y[:, None] # [bs]->[bs,1]
                
            test_data_x = np.concatenate([underspecified_target_data, negatives[0], positives[0]], axis=0)
        
        return train_data_x, train_data_y, test_data_x
    
    def get_prob_by_d2c(self, observations):
        inputs = observations
        if self.d2c_normalize:
            inputs = self.normalize_obs(inputs, self.env_name)
        
        if type(inputs)==np.ndarray:
            inputs = torch.from_numpy(inputs).float().to(self.device)
            
        with torch.no_grad():
            preds = self.d2c(inputs).sigmoid().detach().cpu().numpy() # [bs, heads]
        
        pseudo_prob_list = []
        for t in range(self.d2c.heads):
            pseudo_prob_list.append(preds[:,t])
        prob = np.stack(pseudo_prob_list, axis=1).mean(1)

        return prob

        
    def update_d2c(self, uniform_goal_sampler, replay_buffer=None, goal_env=None, randomwalk_buffer=None):
        train_data_x, train_data_y, test_data_x = self.sample_d2c_batch(self.d2c_batch_size, uniform_goal_sampler, replay_buffer=replay_buffer, goal_env=goal_env, randomwalk_buffer=randomwalk_buffer)
        train_data_x = torch.from_numpy(train_data_x).float().to(self.device)
        train_data_y = torch.from_numpy(train_data_y).float().to(self.device)
        test_data_x = torch.from_numpy(test_data_x).float().to(self.device)
        
        if self.d2c_normalize:
            train_data_x = self.normalize_obs(train_data_x, env_name=self.env_name)
            test_data_x = self.normalize_obs(test_data_x, env_name=self.env_name)

        

        logits = self.d2c(train_data_x)
        logits_chunked = torch.chunk(logits, self.d2c.heads, dim=-1)
        losses = [F.binary_cross_entropy_with_logits(logit, train_data_y) for logit in logits_chunked]
        xent = sum(losses)

        target_logits = self.d2c(test_data_x)
        repulsion_loss = self.d2c_loss_fn(target_logits)

        d2c_loss = xent + self.d2c_aux_weight * repulsion_loss
        
        self.d2c_optimizer.zero_grad()
        d2c_loss.backward()
        self.d2c_optimizer.step()

        d2c_dict = {'d2c_loss' : d2c_loss, 
                        'xent_loss' : xent, 
                        'repulsion_loss' : repulsion_loss,
                        }
        
        return d2c_dict


    def normalize_obs(self, obs, env_name): 
        # normalize to [-1,1]
        if obs is None:
            return None
        if type(obs)==np.ndarray:
            obs = obs.copy()    
        elif type(obs)==torch.Tensor:
            obs = copy.deepcopy(obs)
        else:
            raise NotImplementedError
        
        
        if env_name in ['AntMazeComplex2Way-v0', 'Point2WaySpiralMaze-v0', 'Point4WayComplexVer2Maze-v0','Point4WayFarmlandMaze-v0']:
            center_g, scale_g = None, None
            if env_name in ["Point2WaySpiralMaze-v0"]:                
                center, scale = np.array([0, 0]), np.array([14, 18])
            elif env_name in ['Point4WayComplexVer2Maze-v0','Point4WayFarmlandMaze-v0']:                
                center, scale = np.array([0, 0]), np.array([18, 18])
            elif env_name in ['AntMazeComplex2Way-v0']:                
                center, scale = np.array([0, 0]), np.array([6, 10])

            if obs.shape[-1]==self.goal_dim:
                pass
            elif obs.shape[-1]==self.goal_dim*2:
                center = np.tile(center, 2)
                scale = np.tile(scale, 2)
            else: # full obs
                center_g = np.tile(center, 2)
                scale_g = np.tile(scale, 2)

            if torch.is_tensor(obs):
                center = torch.from_numpy(center).to(self.device)
                scale = torch.from_numpy(scale).to(self.device)
                if center_g is not None or scale_g is not None:
                    center_g = torch.from_numpy(center_g).to(self.device)
                    scale_g = torch.from_numpy(scale_g).to(self.device)

            if obs.shape[-1]==self.goal_dim or obs.shape[-1]==self.goal_dim*2:
                obs = (obs-center)/scale            
            else: # full obs concatnated with ag, dg
                obs[..., :2] = (obs[..., :2]-center)/scale
                obs[..., -4:] = (obs[..., -4:]-center_g)/scale_g
            
            return obs
 
        elif env_name in ['sawyer_peg_push','sawyer_peg_pick_and_place']:
            if self.use_d2c:
                # normalization maybe not needed                
                assert not self.normalize_rl_obs, 'should not normalize rl obs as as scale is 0.05 for normlaize_d2c_obs'
                center, scale = 0, 0.05
            
            return obs
        
        else:
            raise NotImplementedError

        
        
    
    
    def sample_randomwalk_goals(self, obs, ag, episode, env, replay_buffer, num_candidate = 5, random_noise = 2.5, uncertainty_mode = 'f', dg = None):
        noise = np.random.uniform(low=-random_noise, high=random_noise, size=(num_candidate, env.goal_dim))

        if self.env_name in ['sawyer_peg_pick_and_place']:
            pass
        elif self.env_name in ['sawyer_peg_push']:
            noise[2] = 0
            
        candidate_goal = np.tile(ag, (num_candidate,1)) + noise
        
        if uncertainty_mode == 'd2c' and self.use_d2c:
            if self.d2c_gcrl:
                assert dg is not None
                classification_probabilities = self.get_prob_by_d2c(np.concatenate([candidate_goal, np.tile(dg, (candidate_goal.shape[0],1))], axis=-1))
            else:
                classification_probabilities = self.get_prob_by_d2c(candidate_goal)
            
            satisfied = False
            epsilon = 0
            iter = 0
            while not satisfied:                
                epsilon = 0.02*iter
                lb = 0.4-epsilon                
                if lb <0:
                    warnings.warn(f'd2c uncertainty threshold is out of range!!, lb : {lb} prob {classification_probabilities}')
                
                uncertain_indices = np.where(((classification_probabilities>=0.4-epsilon))==1)[0]
                if uncertain_indices.shape[0]==0:
                    # epsilon +=0.02 # due to machine zero
                    iter+=1
                else:
                    satisfied = True
        
            prob = F.softmax(torch.from_numpy(classification_probabilities[uncertain_indices]/self.d2c_temperature).float().to(self.device), dim = 0)
            dist = torch.distributions.Categorical(probs=prob)
            idxs = dist.sample((1,)).detach().cpu().numpy()
            

            obs = candidate_goal[uncertain_indices[idxs]]
            
        
        else:
            raise NotImplementedError
        
        return np.squeeze(obs)


    def update_critic(self, obs, action, reward, next_obs, discount, done, step):
        
        with torch.no_grad():
            dist = self.actor(next_obs)
            next_action = dist.rsample()            
            target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
            
            log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
            target_V = torch.min(target_Q1,
                                target_Q2) - self.alpha.detach() * log_prob
    
            # target_Q = reward + (discount * target_V)
            if self.consider_done_true_in_critic:
                target_Q = reward + (discount * target_V)*(1-done)
            else:
                target_Q = reward + (discount * target_V)

        
        if self.q_clip:
            if self.rl_reward_type in ['d2c', 'sparse']:
                if self.d2c_reward_type=='positive':
                    target_Q = torch.clamp(target_Q, 0, 1/(1-self.discount))
                elif self.d2c_reward_type=='negative':
                    target_Q = torch.clamp(target_Q, -1/(1-self.discount), 0)
            else:
                raise NotImplementedError

        # get current Q estimates
        Q1, Q2 = self.critic(obs, action)
      
        critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)

        # optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        
        c_norm = None

        if self.grad_norm_clipping > 0.:
            c_norm = torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.grad_norm_clipping)
        if self.grad_value_clipping > 0.:            
            torch.nn.utils.clip_grad_value_(self.critic.parameters(), self.grad_value_clipping)

        self.critic_optimizer.step()
        

        return Q1, Q2, critic_loss, c_norm
    
    
    def update_actor_and_alpha(self, obs, step, goal_env = None, replay_buffer = None):
        # already normalize obs if normalize is true        
        dist = self.actor(obs)
        action = dist.rsample()
        D_KL = None
        start = time.time()
        
        actor_Q1, actor_Q2 = self.critic(obs, action)
        
        actor_Q = torch.min(actor_Q1, actor_Q2)
        
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean()


        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()

        a_norm = None

        if self.grad_norm_clipping > 0.:
            a_norm = torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.grad_norm_clipping)            
        if self.grad_value_clipping > 0.:            
            torch.nn.utils.clip_grad_value_(self.actor.parameters(), self.grad_value_clipping)

        self.actor_optimizer.step()


        if self.alpha_auto:
            self.log_alpha_optimizer.zero_grad()
            alpha_loss = (self.alpha *
                        (-log_prob - self.target_entropy).detach()).mean()
            alpha_loss.backward()
            self.log_alpha_optimizer.step()
        else:
            # log_prob = None
            alpha_loss = None

        return actor_loss, alpha_loss, log_prob, a_norm, D_KL
    
    
    def update(self, replay_buffer, randomwalk_buffer, step, goal_env = None, uniform_goal_sampler = None):
        

        if step == self.num_seed_steps // 2:
            if self.use_d2c:
                for _ in range(self.d2c_num_init_update):
                    if self.d2c_use_randomwalk_buffer:
                        self.d2c_dict = self.update_d2c(uniform_goal_sampler, replay_buffer=replay_buffer, goal_env=goal_env, randomwalk_buffer=randomwalk_buffer)
                    else:
                        self.d2c_dict = self.update_d2c(uniform_goal_sampler, replay_buffer=replay_buffer, goal_env=goal_env)


        if step < self.num_seed_steps:
            return
        
        
        if self.use_d2c and step % self.d2c_train_every_k == 0:
            for _ in range(self.d2c_num_update):
                if self.d2c_use_randomwalk_buffer:
                    self.d2c_dict = self.update_d2c(uniform_goal_sampler, replay_buffer=replay_buffer, goal_env=goal_env, randomwalk_buffer=randomwalk_buffer)
                else:
                    self.d2c_dict = self.update_d2c(uniform_goal_sampler, replay_buffer=replay_buffer, goal_env=goal_env)
        
        
        if randomwalk_buffer is None or self.randomwalk_method == 'rand_action': 
            obs, action, extr_reward, next_obs, discount, dones = replay_buffer.sample(self.batch_size, self.discount)            
        else:
            obs, action, extr_reward, next_obs, discount, dones = utils.sample_mixed_buffer(replay_buffer, randomwalk_buffer, self.batch_size, self.discount)
        
    
        if self.rl_reward_type=='sparse':
            reward = extr_reward
        elif self.rl_reward_type=='d2c':
            # NOTE : Currently, relabeled obs, next_obs are used!!!!
            
            obs_dict, next_obs_dict = map(goal_env.convert_obs_to_dict, (obs, next_obs)) 
            
            
            prob = self.get_prob_by_d2c(torch.cat([next_obs_dict['achieved_goal'], next_obs_dict['desired_goal']], dim = -1))
            reward = torch.from_numpy(prob[:, None]).float().to(self.device) # [bs, 1]
            if self.d2c_reward_type=='negative': # [-1, 0]
                reward = reward - 1.0
            elif self.d2c_reward_type=='positive': # [0, 1]
                pass
            else:
                raise NotImplementedError
        
        
        # From here, RL related 
        if self.normalize_rl_obs:
            obs = self.normalize_obs(obs, self.env_name)
            next_obs = self.normalize_obs(next_obs, self.env_name)
        
        
        # decouple representation
        with torch.no_grad():
            obs = self.encoder.encode(obs)
            next_obs = self.encoder.encode(next_obs)

        Q1, Q2, critic_loss, c_norm = self.update_critic(obs, action, reward, next_obs, discount, dones, step)
        

        if step % self.actor_update_frequency == 0:
            self.actor_loss, self.alpha_loss, self.actor_log_prob, self.a_norm, self.D_KL= self.update_actor_and_alpha(obs, step, goal_env=goal_env, replay_buffer=replay_buffer)
            
        if step % self.critic_target_update_frequency == 0:            
            utils.soft_update_params(self.critic, self.critic_target,
                                    self.critic_target_tau)
            

        # logging
        logging_dict = dict(q1=Q1.detach().cpu().numpy().mean(),
                            q2=Q2.detach().cpu().numpy().mean(),
                            critic_loss=critic_loss.detach().cpu().numpy(),
                            actor_loss = self.actor_loss.detach().cpu().numpy(),                            
                            batch_reward_mean = reward.detach().cpu().numpy().mean(),                            
                            )
        
        
        logging_dict.update(dict(
                                bacth_actor_log_prob = self.actor_log_prob.detach().cpu().numpy().mean(),
                                alpha = self.alpha.detach().cpu().numpy(),
                                entropy_diff = (-self.actor_log_prob-self.target_entropy).detach().cpu().numpy().mean(),
                                ))
        if self.alpha_auto:
            logging_dict.update(dict(alpha_loss = self.alpha_loss.detach().cpu().numpy()))

        if c_norm is not None:
            logging_dict.update(dict(critic_grad_norm=c_norm.detach().cpu().numpy().mean()))
        if self.a_norm is not None:
            logging_dict.update(dict(actor_grad_norm=self.a_norm.detach().cpu().numpy().mean()))
        
        
        if self.use_d2c:
            if self.d2c_dict is not None:
                for key, val in self.d2c_dict.items():
                    logging_dict.update({'d2c_'+key : val})

        return logging_dict

        
        
        