import os
import torch
import torch.nn.functional as F
from torch.optim import Adam
from utils import soft_update, hard_update
import numpy as np
from sysmodel_kfc import MLP_Koopman, InverseDynamicsModel, SigmaModel
import random


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==========================================================================================
# NEW: Policy Model Definition for Behavior Cloning
# ==========================================================================================
class PolicyModel(torch.nn.Module):
    def __init__(self, env, hidden_dim=256):
        super(PolicyModel, self).__init__()
        self.activation = torch.nn.Tanh()
        
        # 获取环境的状态和动作维度
        if hasattr(env.observation_space, 'shape'):
            self.state_dim = env.observation_space.shape[0]
        else:
            self.state_dim = env.observation_space.n
            
        if hasattr(env.action_space, 'shape'):
            self.action_dim = env.action_space.shape[0]
            self.is_continuous = True
        else:
            self.action_dim = env.action_space.n
            self.is_continuous = False
        
        # 策略网络结构
        self.policy_net = nn.Sequential(
            torch.nn.Linear(self.state_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, self.action_dim)
        )
        
        # 如果是连续动作空间，添加tanh激活
        if self.is_continuous:
            self.policy_net.add_module('tanh', torch.nn.Tanh())
    
    def forward(self, state):
        return self.policy_net(state)
    
    def get_action(self, state):
        """获取动作，用于评估"""
        with torch.no_grad():
            if len(state.shape) == 1:
                state = state.unsqueeze(0)
            action = self.forward(state)
            if not self.is_continuous:
                action = torch.softmax(action, dim=-1)
        return action



class KFC_QL(object):
    def __init__(self,env,  args,device=device ):
        
        num_inputs = env.observation_space.shape[0]
        action_space = env.action_space
        
        self.gamma = args.gamma
        self.gamma2 = args.gamma2
        self.tau = args.tau
        self.alpha = args.alpha
        self.latent_dim = args.latent_dim
        
        self.noise_sigma = 3e-3 #paper S4RL
        self.noise_sigma2 = 6e-2 #6e-3
        self.shift_sigma = args.shift_sigma
        self.shift_sigma2 = self.shift_sigma*self.noise_sigma 
        
        self.num_random = args.num_random
        self.symmetry_type = args.symmetry_type
        self.policy_eval_start = args.policy_eval_start
        self.temp = args.temp
        self.min_q_weight = args.min_q_weight

        self.policy_type = args.policy_type
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = device
        
        #KFC  ---------
        
        self.policy_forwards = args.policy_forwards
        self.koopman_augmentation = args.koopman_augmentation
        self.koopman_probability = args.koopman_probability
        
        self.sysmodel = MLP_Koopman(env, latent_dim = self.latent_dim,hidden_dim = args.sys_hidden_size,device=device).to(device=self.device)
        
        self.variable_lr = 3e-4
        self.sysmodel_optimizer = Adam(self.sysmodel.parameters(), lr=self.variable_lr )
        
        self.inverse_model = InverseDynamicsModel(env, self.latent_dim, hidden_dim = args.sys_hidden_size).to(self.device)
        
        self.inverse_model_optimizer = Adam(self.inverse_model.parameters(), lr=self.variable_lr) 

        self.sigma_model = SigmaModel(self.latent_dim).to(self.device)
        # You can use a different learning rate if you wish
        self.sigma_model_optimizer = Adam(self.sigma_model.parameters(), lr=args.sigma_lr) 
        self.sigma_tau = args.sigma_tau

        self.policy_model = PolicyModel(env, hidden_dim=args.sys_hidden_size).to(self.device)
        self.policy_optimizer = Adam(self.policy_model.parameters(), lr=args.policy_lr)
        #------------

        self.obs_upper_bound = float(env.observation_space.high[0]) #state space upper bound
        self.obs_lower_bound = float(env.observation_space.low[0])  #state space lower bound
        self.reward_lower_bound,self.reward_upper_bound=0,0
        
    
    def noise(self,state):
        #state = torch.FloatTensor(state).to(self.device) 
  
        noise =  torch.normal(0,self.noise_sigma2, size=state.shape).to(self.device)
        state_noise = state + noise
        state_shift = state_noise
            
        return state_shift.clamp(self.obs_lower_bound ,self.obs_upper_bound)
 

    def eval_model( self, dataloader):
        running_loss = 0.0
        count = 0
        for i, batch in enumerate(dataloader, 0):
            state = torch.FloatTensor(batch['observations']).to(self.device)
            action = torch.FloatTensor(batch['actions']).to(self.device)
            next_state = torch.FloatTensor(batch['next_observations']).to(self.device)

            #predict the next state
            predict_next_state = self.sysmodel(state, action)

            #define the loss; constraint on model
            sysmodel_loss = F.smooth_l1_loss(predict_next_state, next_state)

            running_loss += sysmodel_loss.item()
            count +=1

        epoch_avg_loss = running_loss/count
        return epoch_avg_loss


    def eval_model_VAE(self,dataloader):
        running_loss = 0.0
        count = 0
        for i, batch in enumerate(dataloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            state = torch.FloatTensor(batch['observations']).to(self.device)
            action = torch.FloatTensor(batch['actions']).to(self.device)
            next_state = torch.FloatTensor(batch['next_observations']).to(self.device)

            ##Train VAE --> autoencoder
            states = torch.cat([state,next_state],dim=0)
            states = self.noise(states)
            predict_same_state = self.sysmodel(states, action,choice=False)
            sysmodel_loss_VAE = F.smooth_l1_loss(predict_same_state, states)


            running_loss +=   sysmodel_loss_VAE.item()
            count +=1

        epoch_avg_loss = running_loss/count
        return epoch_avg_loss
         
    
    def train_sysmodel(self,batch,epoch):
        state_batch = torch.FloatTensor(batch['observations']).to(self.device)
        action_batch = torch.FloatTensor(batch['actions']).to(self.device)
        next_state_batch = torch.FloatTensor(batch['next_observations']).to(self.device)
        
        #predict the next state
        predict_next_state_batch = self.sysmodel(state_batch, action_batch)
        
        #define the loss; constraint on model
        sysmodel_loss = F.smooth_l1_loss(predict_next_state_batch, next_state_batch)  
        
       
        #Train VAE --> autoencoder
        states_batch = torch.cat([state_batch,next_state_batch],dim=0)
        states_batch = self.noise(states_batch)
        predict_same_state_batch = self.sysmodel(states_batch, action_batch,choice=False)
        
        sysmodel_loss_VAE = F.smooth_l1_loss(predict_same_state_batch, states_batch)
        sysmodel_loss += 10*sysmodel_loss_VAE
        
        #Model compute gradient, back-prop, perform step
        
        
        self.sysmodel_optimizer.zero_grad()
        sysmodel_loss.backward()
        self.sysmodel_optimizer.step()    
        
        
        running_loss = sysmodel_loss.item() - 10*sysmodel_loss_VAE.item() #1
        running_loss_VAE = sysmodel_loss_VAE.item()
        return running_loss , running_loss_VAE      

    def train_inverse_model(self, batch):
        """
        Trains the inverse dynamics model.
        The goal is to predict the action `a_t` given two consecutive latent states `(z_t, z_{t+1})`.
        """
        # 1. Unpack data from the batch
        state_batch = torch.FloatTensor(batch['observations']).to(self.device)
        action_batch = torch.FloatTensor(batch['actions']).to(self.device)
        next_state_batch = torch.FloatTensor(batch['next_observations']).to(self.device)
        
        # 2. Encode states into the latent space using the pre-trained sysmodel encoder.
        # We use torch.no_grad() because we don't want to update the sysmodel here.
        with torch.no_grad():
            # The sysmodel's forward pass with choice=False should be the encoder
            # NOTE: Please verify that `self.sysmodel(state, action, choice=False)` correctly calls the encoder.
            # If not, you might need a dedicated `self.sysmodel.encode(state)` method.
            # Assuming `self.sysmodel(state, action, choice=False)` is the VAE/Encoder part:
            latent_z = self.sysmodel.Encoder_obs(state_batch) 
            next_latent_z = self.sysmodel.Encoder_obs(next_state_batch)

        # 3. Predict the action using the inverse model
        predicted_action = self.inverse_model(latent_z, next_latent_z)
        
        # 4. Calculate the loss (Mean Squared Error between predicted and actual actions)
        inverse_loss = F.mse_loss(predicted_action, action_batch)
        
        # 5. Perform the optimization step
        self.inverse_model_optimizer.zero_grad()
        inverse_loss.backward()
        self.inverse_model_optimizer.step()
        
        return inverse_loss.item()

    def eval_inverse_model(self, dataloader):
        """
        Evaluates the inverse dynamics model on a test/validation dataloader.
        Calculates the loss without performing any gradient updates.
        """
        self.inverse_model.eval()  # Set the model to evaluation mode
        total_loss = 0.0
        count = 0
        
        with torch.no_grad():  # Crucial: disable gradient calculation
            for batch in dataloader:
                # 1. Unpack data
                state_batch = torch.FloatTensor(batch['observations']).to(self.device)
                action_batch = torch.FloatTensor(batch['actions']).to(self.device)
                next_state_batch = torch.FloatTensor(batch['next_observations']).to(self.device)
                
                # 2. Encode states (same as in training)
                # !! 再次确认这里的编码器调用方式是正确的 !!
                latent_z = self.sysmodel.Encoder_obs(state_batch) 
                next_latent_z = self.sysmodel.Encoder_obs(next_state_batch)
                
                # 3. Predict action
                predicted_action = self.inverse_model(latent_z, next_latent_z)
                
                # 4. Calculate loss
                loss = F.mse_loss(predicted_action, action_batch)
                
                total_loss += loss.item()
                count += 1
                
        self.inverse_model.train()  # Set the model back to training mode
        
        avg_loss = total_loss / count
        return avg_loss

# In KFC_prime.py, inside the class KFC_QL(object)

    # ... (eval_inverse_model function) ...

# =====================================================================
# NEW: Training and Evaluation functions for the Sigma Model
# =====================================================================
    def train_sigma_model(self, batch):
        """
        Trains the Sigma model based on the custom loss function.
        loss = w * (K * sigma(z_t) - sigma(z_{t+1}))
        w = exp(tau * ||z_{t+1} - Kz_t||^2)
        """
        self.sigma_model.train()

        # 1. Unpack data
        state_batch = torch.FloatTensor(batch['observations']).to(self.device)
        # action_batch = torch.FloatTensor(batch['actions']).to(self.device)
        next_state_batch = torch.FloatTensor(batch['next_observations']).to(self.device)

        # 2. Get latent states and Koopman operator K (without gradients for sysmodel)
        with torch.no_grad():
            # !! VERIFY ENCODER CALL !!
            z_t = self.sysmodel.Encoder_obs(state_batch)
            z_t1 = self.sysmodel.Encoder_obs(next_state_batch)
            
            # !! VERIFY K ACCESS !! This is a critical assumption.
            # Replace with the correct way to access your Koopman matrix if this is wrong.
            K = self.sysmodel.layerK.weight.data
            
            # Predict next latent state using Koopman operator
            Kz_t = z_t @ K.T

        # 3. Calculate the weight 'w'
        # The loss 'w' depends on the Koopman model's prediction error.
        # We use squared L2 norm for the error distance.
        error = z_t1 - Kz_t
        # Weight `w` will have shape (batch_size, 1)
        w = torch.exp(self.sigma_tau * torch.sum(error**2, dim=1, keepdim=True)).detach()

        # 4. Apply the Sigma model
        sigma_zt = self.sigma_model(z_t)
        sigma_zt1 = self.sigma_model(z_t1)

        # 5. Calculate the core part of the loss
        K_sigma_zt = sigma_zt @ K.T
        sigma_loss_vec = K_sigma_zt - sigma_zt1
        
        # 6. Apply the weight and compute final loss
        # The weighted error has shape (batch_size, latent_dim)
        weighted_loss_vec = w * sigma_loss_vec
        # We take the mean squared error of this weighted vector
        loss = F.mse_loss(weighted_loss_vec, torch.zeros_like(weighted_loss_vec))
        
        # 7. Optimization step
        self.sigma_model_optimizer.zero_grad()
        loss.backward()
        self.sigma_model_optimizer.step()
        
        return loss.item()

    def eval_sigma_model(self, dataloader):
        """
        Evaluates the Sigma model on a test/validation dataloader.
        """
        self.sigma_model.eval()
        total_loss = 0.0
        count = 0
        
        with torch.no_grad():
            for batch in dataloader:
                state_batch = torch.FloatTensor(batch['observations']).to(self.device)
                action_batch = torch.FloatTensor(batch['actions']).to(self.device)
                next_state_batch = torch.FloatTensor(batch['next_observations']).to(self.device)

                z_t = self.sysmodel.Encoder_obs(state_batch)
                z_t1 = self.sysmodel.Encoder_obs(next_state_batch)
                K = self.sysmodel.layerK.weight.data
                Kz_t = z_t @ K.T
                
                error = z_t1 - Kz_t
                w = torch.exp(self.sigma_tau * torch.sum(error**2, dim=1, keepdim=True))

                sigma_zt = self.sigma_model(z_t)
                sigma_zt1 = self.sigma_model(z_t1)
                
                K_sigma_zt = sigma_zt @ K.T
                sigma_loss_vec = K_sigma_zt - sigma_zt1
                
                weighted_loss_vec = w * sigma_loss_vec
                loss = F.mse_loss(weighted_loss_vec, torch.zeros_like(weighted_loss_vec))

                total_loss += loss.item()
                count += 1
                
        avg_loss = total_loss / count
        return avg_loss

    # =========================================================================================
    # NEW: Data Augmentation using trained models
    # =========================================================================================
    def augment_data(self, batch):
        """
        使用训练好的模型对数据进行扩充：
        1. 编码状态到潜在空间
        2. 应用sigma变换
        3. 解码回状态空间
        4. 使用逆动力学模型预测对应动作
        """
        augmented_data = {'observations': [], 'actions': [], 'next_observations': []}
        
        # 设置所有模型为评估模式
        self.sysmodel.eval()
        self.sigma_model.eval()
        self.inverse_model.eval()
        
        with torch.no_grad():
            # 1. 获取原始数据
            state_batch = torch.FloatTensor(batch['observations']).to(self.device)
            action_batch = torch.FloatTensor(batch['actions']).to(self.device)
            next_state_batch = torch.FloatTensor(batch['next_observations']).to(self.device)
            
            # 2. 编码到潜在空间
            z_t = self.sysmodel.Encoder_obs(state_batch)  # 编码当前状态
            z_next = self.sysmodel.Encoder_obs(next_state_batch)  # 编码下一状态
            
            # 3. 应用sigma变换
            sigma_z_t = self.sigma_model(z_t)
            sigma_z_next = self.sigma_model(z_next)
            
            # 4. 解码回状态空间
            # 注意：这里需要根据您的sysmodel结构调整解码方式
            # 假设sysmodel有一个decoder部分，您需要根据实际情况修改
            augmented_states = self.sysmodel.Decoder_obs(sigma_z_t)  # 需要根据实际sysmodel结构调整
            augmented_next_states = self.sysmodel.Decoder_obs(sigma_z_next)  # 需要根据实际sysmodel结构调整
            
            # 5. 使用逆动力学模型预测对应的动作
            augmented_actions = self.inverse_model(augmented_states, augmented_next_states)
            
            # 6. 转换为numpy数组并添加到扩充数据中
            augmented_data['observations'] = augmented_states.cpu().numpy()
            augmented_data['actions'] = augmented_actions.cpu().numpy()
            augmented_data['next_observations'] = augmented_next_states.cpu().numpy()
        

        
        return augmented_data

    # =========================================================================================
    # NEW: Behavior Cloning Training
    # =========================================================================================
    def train_policy_bc(self, batch):
        """
        使用行为克隆损失训练策略模型
        BC Loss = MSE(policy(state), action) for continuous actions
                = CrossEntropy(policy(state), action) for discrete actions
        """
        self.policy_model.train()
        
        # 1. 准备数据
        state_batch = torch.FloatTensor(batch['observations']).to(self.device)
        action_batch = torch.FloatTensor(batch['actions']).to(self.device)
        
        # 2. 策略网络预测
        predicted_actions = self.policy_model(state_batch)
        
        # 3. 计算BC损失
        if self.policy_model.is_continuous:
            # 连续动作空间：使用MSE损失
            bc_loss = F.mse_loss(predicted_actions, action_batch)
        else:
            # 离散动作空间：使用交叉熵损失
            action_targets = action_batch.long().squeeze(-1)
            bc_loss = F.cross_entropy(predicted_actions, action_targets)
        
        # 4. 优化步骤
        self.policy_optimizer.zero_grad()
        bc_loss.backward()
        self.policy_optimizer.step()
        
        return bc_loss.item()

    # =========================================================================================
    # NEW: Policy Evaluation
    # =========================================================================================
    def eval_policy_bc(self, dataloader):
        """
        评估策略模型的BC损失
        """
        self.policy_model.eval()
        total_loss = 0.0
        count = 0
        
        with torch.no_grad():
            for batch in dataloader:
                state_batch = torch.FloatTensor(batch['observations']).to(self.device)
                action_batch = torch.FloatTensor(batch['actions']).to(self.device)
                
                predicted_actions = self.policy_model(state_batch)
                
                if self.policy_model.is_continuous:
                    loss = F.mse_loss(predicted_actions, action_batch)
                else:
                    action_targets = action_batch.long().squeeze(-1)
                    loss = F.cross_entropy(predicted_actions, action_targets)
                
                total_loss += loss.item()
                count += 1
        
        self.policy_model.train()
        avg_loss = total_loss / count
        return avg_loss

    # =========================================================================================
    # NEW: Combined training with augmented data
    # =========================================================================================
    def train_policy_with_augmentation(self, original_batch):
        """
        结合原始数据和扩充数据训练策略模型
        """
        # 1. 在原始数据上训练
        original_loss = self.train_policy_bc(original_batch)
        
        # 2. 生成扩充数据
        augmented_batch = self.augment_data(original_batch)
        
        # 3. 在扩充数据上训练
        augmented_loss = self.train_policy_bc(augmented_batch)
        
        # 返回总损失
        total_loss = (original_loss + augmented_loss) / 2.0
        return total_loss, original_loss, augmented_loss


    def state_noise(self,state,next_state,batch):
        #state = torch.FloatTensor(state).to(self.device)
        rand = random.uniform(0, 1)
        if self.koopman_augmentation and rand <= self.koopman_probability: #generate dynamical symmetry shift of state 
            
            
            if self.symmetry_type == "Sylvester":
                sym_gen = torch.FloatTensor(np.array(batch['symmetries'])).to(self.device)
                symmetry_scaling  = torch.normal(0,self.shift_sigma,(state.shape[0],1)).to(self.device)
                state_shift = self.sysmodel.Symmetry_Encoder_Decoder(state,sym_gen,symmetry_scaling)
                next_state_shift = self.sysmodel.Symmetry_Encoder_Decoder(next_state,sym_gen,symmetry_scaling)
                
            elif self.symmetry_type == "Eigenspace":
                sym_gen = torch.tensor(np.array(batch['symmetries']))
                symmetry_scaling  = np.random.normal(0,self.shift_sigma2,(state.shape[0],sym_gen.shape[-1]))
                symmetry_scaling  = torch.FloatTensor(np.apply_along_axis(np.diag,1,symmetry_scaling))
                symmetry_scaling  = torch.complex(symmetry_scaling,torch.FloatTensor(torch.zeros(symmetry_scaling.shape)))

                state_shift = self.sysmodel.Symmetry_Encoder_Decoder_Eigenspace(state,sym_gen,symmetry_scaling)
                next_state_shift = self.sysmodel.Symmetry_Encoder_Decoder_Eigenspace(next_state,sym_gen,symmetry_scaling)
            
            
        else:
            noise =  torch.normal(0,self.noise_sigma, size=state.shape).to(self.device)
            state_shift = state + noise
            
            noise2 =  torch.normal(0,self.noise_sigma, size=state.shape).to(self.device)
            next_state_shift = next_state + noise2
            
        return state_shift.clamp(self.obs_lower_bound ,self.obs_upper_bound), next_state_shift.clamp(self.obs_lower_bound ,self.obs_upper_bound)
    

            
        
    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if evaluate is False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]
    
    
    
    def _get_tensor_values(self, obs, actions, network=None):
        action_shape = actions.shape[0]
        obs_shape = obs.shape[0]
        num_repeat = int (action_shape / obs_shape)
        obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(obs.shape[0] * num_repeat, obs.shape[1])
        preds1 , preds2 = network(obs_temp, actions)
        preds1 = preds1.view(obs.shape[0], num_repeat, 1)
        preds2 = preds2.view(obs.shape[0], num_repeat, 1)
        return preds1, preds2

    def _get_policy_actions(self, obs, num_actions, network=None):
        obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
        new_obs_actions,new_obs_log_pi,_ = network.sample(obs_temp)

        return new_obs_actions, new_obs_log_pi.view(obs.shape[0], num_actions, 1)


    def update_parameters(self, batch, updates):
        # Sample a batch from memory
        state_batch = torch.FloatTensor(batch['observations']).to(self.device)
        action_batch = torch.FloatTensor(batch['actions']).to(self.device)
        next_state_batch = torch.FloatTensor(batch['next_observations']).to(self.device)
        reward_batch =  torch.FloatTensor(batch['rewards']).to(self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(batch['terminals'].numpy()).to(self.device).unsqueeze(1)
        
        
        
        with torch.no_grad():
            #Symmetry add noise or koopman symetry based augmentation
            state_shift_batch, next_state_shift_batch = self.state_noise(state_batch,next_state_batch,batch)
            #SAC 
            next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(next_state_shift_batch, next_state_action)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
            
        
        
            
            
        qf1_pred, qf2_pred = self.critic(state_shift_batch, action_batch)  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(qf1_pred, next_q_value)  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(qf2_pred, next_q_value)  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        
        ## add CQL
        random_actions_tensor = torch.FloatTensor(qf2_pred.shape[0] * self.num_random, action_batch.shape[-1]).uniform_(-1, 1).to(self.device)
        curr_actions_tensor, curr_log_pis = self._get_policy_actions(state_batch, num_actions=self.num_random, network=self.policy)
        new_curr_actions_tensor, new_log_pis = self._get_policy_actions(next_state_batch, num_actions=self.num_random, network=self.policy)
        q1_rand , q2_rand = self._get_tensor_values(state_batch, random_actions_tensor, network=self.critic)
        q1_curr_actions , q2_curr_actions = self._get_tensor_values(state_batch, curr_actions_tensor, network=self.critic)
        q1_next_actions , q2_next_actions = self._get_tensor_values(state_batch, new_curr_actions_tensor, network=self.critic)
       
        cat_q1 = torch.cat(
            [q1_rand, qf1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1
        )
        cat_q2 = torch.cat(
            [q2_rand, qf2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1
        )
        std_q1 = torch.std(cat_q1, dim=1)
        std_q2 = torch.std(cat_q2, dim=1)

   
        # importance sammpled version
        random_density = np.log(0.5 ** curr_actions_tensor.shape[-1])
        cat_q1 = torch.cat(
                [q1_rand - random_density, q1_next_actions - new_log_pis.detach(), q1_curr_actions - curr_log_pis.detach()], 1
            )
        cat_q2 = torch.cat(
                [q2_rand - random_density, q2_next_actions - new_log_pis.detach(), q2_curr_actions - curr_log_pis.detach()], 1
            )
            
        min_qf1_loss = torch.logsumexp(cat_q1 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp
        min_qf2_loss = torch.logsumexp(cat_q2 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp
                    
        """Subtract the log likelihood of data"""
        min_qf1_loss = min_qf1_loss - qf1_pred.mean() * self.min_q_weight
        min_qf2_loss = min_qf2_loss - qf2_pred.mean() * self.min_q_weight
        
        if self.with_lagrange:
            alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0)
            min_qf1_loss = alpha_prime * (min_qf1_loss - self.target_action_gap)
            min_qf2_loss = alpha_prime * (min_qf2_loss - self.target_action_gap)

            self.alpha_prime_optimizer.zero_grad()
            alpha_prime_loss = (-min_qf1_loss - min_qf2_loss)*0.5 
            alpha_prime_loss.backward(retain_graph=True)
            self.alpha_prime_optimizer.step()
            

        qf1_loss = qf1_loss + min_qf1_loss
        qf2_loss = qf2_loss + min_qf2_loss

        qf_loss = qf1_loss + qf2_loss

        """
        Update critic
        """
        self.critic_optim.zero_grad()
        qf_loss.backward()
        self.critic_optim.step()


        
        #Policy Loss
        pi, log_pi, _ = self.policy.sample(state_batch)
        
         #entropy tuning for alpha
        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self.alpha = self.log_alpha.exp()
            #alpha_tlogs = self.alpha.clone() # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            #alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs


        #Policy Loss...
        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        if updates < self.policy_eval_start:
            """
            For the initial few epochs, try doing behaivoral cloning, if needed
            conventionally, there's not much difference in performance with having 20k 
            gradient steps here, or not having it
            """
            policy_log_prob = self.policy.log_prob(state_batch, action_batch)
            policy_loss = (self.alpha * log_pi - policy_log_prob).mean()
        #Forward looking Q-function policy update
        elif updates > 12*self.policy_eval_start and self.policy_forwards > 0:
            state_batch_Fwd = next_state_batch
            pi_Fwd, log_pi_Fwd,_ = self.policy.sample(state_batch_Fwd)
            for _ in range(0,self.policy_forwards):
            
                next_state_batch_Fwd = self.sysmodel(state_batch_Fwd, pi_Fwd)
            
                pi_Fwd, log_pi_Fwd, _ = self.policy.sample(next_state_batch_Fwd)
            
                qf1_pi_Fwd, qf2_pi_Fwd = self.critic(next_state_batch_Fwd, pi_Fwd)
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                policy_loss += self.gamma2*(-min_qf_pi.mean()) 
            
                state_batch_Fwd = next_state_batch_Fwd
        """
        Update policy
        """
        
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        
        
       
        #soft updates 
        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item()

    # Save model parameters

    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optim.state_dict(), filename + "_critic_optimizer")

        torch.save(self.policy.state_dict(), filename + "_actor")
        torch.save(self.policy_optim.state_dict(), filename + "_actor_optimizer")

    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic.pth"))
        self.critic_optim.load_state_dict(torch.load(filename + "_critic_optimizer"))
        self.critic_target = copy.deepcopy(self.critic)

        self.policy.load_state_dict(torch.load(filename + "_actor.pth"))
        self.policy_optim.load_state_dict(torch.load(filename + "_actor_optimizer"))
