import time
from tqdm import tqdm
from copy import deepcopy
from collections import defaultdict
import numpy as np
from rdkit import Chem
from pathlib import Path

import torch
from torch import nn
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam, lr_scheduler
import torch.nn.functional as F

from torch_geometric.data import Batch
import dgl

from model.ac import GCNActorCritic
from utils_sac.utils import get_att_points, delete_multiple_element
from utils_ga.ga import reproduce
from utils_fgib.utils import get_load_model, get_sanitize_error_frags
from utils_fgib.data import get_graph

import pdb
from deprecated import deprecated
from timeout_decorator import timeout


class ReplayBuffer:
    def __init__(self, size, reward_dim):
        self.obs_buf = []                                   # o
        self.obs2_buf = []                                  # o2
        self.act_buf = np.zeros((size, 3), dtype=np.int32)  # ac
        self.rew_buf = np.zeros((size, reward_dim), dtype=np.float32)  \
            if reward_dim > 1 else np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)    # d
        
        self.ac_prob_buf = []
        self.log_ac_prob_buf = []
        
        self.ac_first_buf = []
        self.ac_second_buf = []
        self.ac_third_buf = []
        self.o_embeds_buf = []
        
        self.ptr, self.size, self.max_size = 0, 0, size
        self.done_location = []

    def store(self, obs, act, rew, next_obs, done, ac_prob, log_ac_prob,
              ac_first_prob, ac_second_hot, ac_third_prob, o_embeds):
        """ (obs, act, rew, next_obs) is (s_t, a_t, r_t, s_{t+1}) in SAC
        """
        if self.size == self.max_size:
            self.obs_buf.pop(0)
            self.obs2_buf.pop(0)
            
            self.ac_prob_buf.pop(0)
            self.log_ac_prob_buf.pop(0)
            
            self.ac_first_buf.pop(0)
            self.ac_second_buf.pop(0)
            self.ac_third_buf.pop(0)

            self.o_embeds_buf.pop(0)

        self.obs_buf.append(obs)
        self.obs2_buf.append(next_obs)
        
        # Detach tensors from computation graph to prevent memory leak
        self.ac_prob_buf.append(ac_prob.detach() if hasattr(ac_prob, 'detach') else ac_prob)
        self.log_ac_prob_buf.append(log_ac_prob.detach() if hasattr(log_ac_prob, 'detach') else log_ac_prob)
        
        self.ac_first_buf.append(ac_first_prob.detach() if hasattr(ac_first_prob, 'detach') else ac_first_prob)
        self.ac_second_buf.append(ac_second_hot.detach() if hasattr(ac_second_hot, 'detach') else ac_second_hot)
        self.ac_third_buf.append(ac_third_prob.detach() if hasattr(ac_third_prob, 'detach') else ac_third_prob)
        
        # Detach embeddings as well
        if isinstance(o_embeds, (tuple, list)):
            o_embeds_detached = []
            for embed in o_embeds:
                if hasattr(embed, 'detach'):
                    o_embeds_detached.append(embed.detach())
                else:
                    o_embeds_detached.append(embed)
            self.o_embeds_buf.append(tuple(o_embeds_detached))
        else:
            self.o_embeds_buf.append(o_embeds.detach() if hasattr(o_embeds, 'detach') else o_embeds)

        self.act_buf[self.ptr] = act
        # ==== MO modification ====
        try:
            self.rew_buf[self.ptr, :] = rew
        except:
            pdb.set_trace()
        # =========================
        
        self.done_buf[self.ptr] = done

        if done:
            self.done_location.append(self.ptr)
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
    
    def rew_store(self, rew):
        """
        Multi-objective reward with a reward tuple with dim [batch size, 3]

        why need re-storing the reward?
        Because the initial reward is an encouraging reward (growing score) when generate larger molecules, 
        rather than actual docking score. The rew here is actual RL reward during training.
        And these reward only update those done locations.
        Final self.rew_buf should consists of growing score and docking score
        """
        done_location_np = np.array(self.done_location)
        # ===== MO zero judge: select those row indexes with all zero  ===
        zero_rows = np.all(rew == 0, axis=1)
        zeros = np.where(zero_rows)[0]
        nonzeros = np.where(zero_rows != True)[0]
        # ================================================================
        
        zero_ptrs = done_location_np[zeros]
        done_location_np = done_location_np[nonzeros]
        rew = rew[nonzeros]

        if len(self.done_location) > 0:
            # =========== MO Modification: 2D array assign ======
            self.rew_buf[done_location_np, :] += rew
            # ===================================================
            
            self.done_location = []

        self.act_buf = np.delete(self.act_buf, zero_ptrs, axis=0)
        self.rew_buf = np.delete(self.rew_buf, zero_ptrs, axis=0)
        self.done_buf = np.delete(self.done_buf, zero_ptrs)
        delete_multiple_element(self.obs_buf, zero_ptrs.tolist())
        delete_multiple_element(self.obs2_buf, zero_ptrs.tolist())

        delete_multiple_element(self.ac_prob_buf, zero_ptrs.tolist())
        delete_multiple_element(self.log_ac_prob_buf, zero_ptrs.tolist())
        
        delete_multiple_element(self.ac_first_buf, zero_ptrs.tolist())
        delete_multiple_element(self.ac_second_buf, zero_ptrs.tolist())
        delete_multiple_element(self.ac_third_buf, zero_ptrs.tolist())

        delete_multiple_element(self.o_embeds_buf, zero_ptrs.tolist())

        self.size = min(self.size - len(zero_ptrs), self.max_size)
        self.ptr = (self.ptr - len(zero_ptrs)) % self.max_size
        
    def sample_batch(self, device, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        obs_batch = [self.obs_buf[idx] for idx in idxs]
        obs2_batch = [self.obs2_buf[idx] for idx in idxs]

        ac_prob_batch = [self.ac_prob_buf[idx] for idx in idxs]
        log_ac_prob_batch = [self.log_ac_prob_buf[idx] for idx in idxs]
        
        ac_first_batch = torch.stack([self.ac_first_buf[idx].to(device) for idx in idxs]).squeeze(1)
        ac_second_batch = torch.stack([self.ac_second_buf[idx].to(device) for idx in idxs]).squeeze(1)
        ac_third_batch = torch.stack([self.ac_third_buf[idx].to(device) for idx in idxs]).squeeze(1)
        o_g_emb_batch = torch.stack([self.o_embeds_buf[idx][2] for idx in idxs]).squeeze(1)

        act_batch = torch.as_tensor(self.act_buf[idxs], dtype=torch.float32).unsqueeze(-1).to(device)
        rew_batch = torch.as_tensor(self.rew_buf[idxs], dtype=torch.float32).to(device)
        done_batch = torch.as_tensor(self.done_buf[idxs], dtype=torch.float32).to(device)

        batch = dict(obs=obs_batch,
                     obs2=obs2_batch,
                     act=act_batch,
                     rew=rew_batch,
                     done=done_batch,
                     ac_prob=ac_prob_batch,
                     log_ac_prob=log_ac_prob_batch,
                     ac_first=ac_first_batch,
                     ac_second=ac_second_batch,
                     ac_third=ac_third_batch,
                     o_g_emb=o_g_emb_batch)
        return batch


def xavier_uniform_init(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)


def print_gpu_memory_usage(step=None):
    """Print current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3  # GB
        cached = torch.cuda.memory_reserved() / 1024**3  # GB
        max_allocated = torch.cuda.max_memory_allocated() / 1024**3  # GB
        
        step_str = f"Step {step}: " if step is not None else ""
        print(f"{step_str}GPU Memory - Allocated: {allocated:.2f}GB, "
              f"Cached: {cached:.2f}GB, Max: {max_allocated:.2f}GB")


class SAC:
    def __init__(self, args, vocab, env_fn,
                 replay_size=int(1e6), gamma=0.99, polyak=0.995, train_alpha=True):
        super().__init__()
        
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

        self.device = args.device
        self.num_mols = args.num_mols
        self.gamma = gamma
        self.polyak = polyak
        
        tm = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
        self.fname = f'results/{tm}_{args.target}_{args.seed}.csv'
        self.ckptname = f'ckpt/{args.target}/'
        print(f'\033[92m{self.fname}\033[0m')
        
        self.batch_size = args.batch_size
        self.start_steps = args.start_steps
        self.update_after = args.update_after
        self.update_every = args.update_every
        self.save_threshold = 0 # Counter to save model check point
        self.docking_every = int(args.update_every / 2)
        self.ga_steps = args.ga_steps
        self.train_alpha = train_alpha

        self.env = env_fn
        self.vocab = vocab

        self.obs_dim = args.emb_size * 2
        self.action_dims = [40, len(vocab['FRAG']), 40]
        
        self.target_entropy = args.target_entropy

        self.log_alpha = torch.tensor([np.log(args.init_alpha)], requires_grad=train_alpha) 

        self.ac = GCNActorCritic(self.env, args, vocab).to(args.device)
        self.ac_targ = deepcopy(self.ac).to(args.device).eval()

        # Freeze target networks with respect to optimizers (only update via polyak averaging)
        for p in self.ac_targ.parameters():
            p.requires_grad = False

        for q in self.ac.parameters():
            q.requires_grad = True

        self.replay_buffer = ReplayBuffer(size=replay_size, reward_dim=args.reward_dim)

        pi_lr = args.init_pi_lr
        q_lr = args.init_q_lr
        alpha_lr = args.init_alpha_lr
    
        self.pi_params = list(self.ac.pi.parameters()) 
        self.q_params = list(self.ac.q1.parameters()) + list(self.ac.q2.parameters()) + list(self.ac.embed.parameters())
        
        self.pi_optimizer = Adam(self.pi_params, lr=pi_lr, weight_decay=1e-4)
        self.q_optimizer = Adam(self.q_params, lr=q_lr, weight_decay=1e-4)
        self.alpha_optimizer = Adam([self.log_alpha], lr=alpha_lr, eps=1e-4)

        self.q_scheduler = lr_scheduler.ReduceLROnPlateau(self.q_optimizer, factor=0.1, patience=768) 
        self.pi_scheduler = lr_scheduler.ReduceLROnPlateau(self.pi_optimizer, factor=0.1, patience=768)        

        self.alpha_start = self.start_steps
        self.alpha_end = self.start_steps + 30000
        self.alpha_max = args.alpha_max
        self.alpha_min = args.alpha_min
        
        self.population_size = args.population_size
        self.mutation_rate = args.mutation_rate
        self.population = []
        self.population_score = []
        self.ga_smiles_list = []
        
        self.gib = get_load_model(args.gib_path, device=self.device)
        self.max_vocab_update = args.max_vocab_update
        self.max_vocab_size = args.max_vocab_size

        self.t = 0
        # ==== reward dim ====
        self.reward_dim = args.reward_dim
        self.pref_num = args.pref_num
        self.mean = args.mean
        self.mpo_task = args.target.split('_')[0]

        # === homotopy dynamic beta ====
        self.beta = args.beta
        self.beta_init = self.beta
        self.beta_uplim = args.beta_uplim
        self.beta_tau = args.beta_tau
        self.episode_num = args.episode_num
        self.beta_expbase = float(np.power(self.beta_tau * (self.beta_uplim - self.beta), 1.0 / self.episode_num))
        self.beta_delta = self.beta_expbase / self.beta_tau

        # ==== apply xavier uniform initialization ====
        self.ac.apply(xavier_uniform_init)

    def update_vocab(self, mol_list, scores_list):
        """ 
        Update the vocabulary The action space for dynamic since the vocabulary update
        """
        smiles_list = [Chem.MolToSmiles(m) for m in mol_list]
        batch = []
        for smiles in smiles_list:
            graph = get_graph(smiles)
            if graph is not None:
                batch.append(graph)
        if not batch:
            return
        
        batch = Batch.from_data_list(batch).to(self.device)
        with torch.no_grad():
            p = self.gib(batch, get_w=True)

        i, frag_w_dict, frag_prop_dict = 0, defaultdict(list), defaultdict(list)
        for j, frags in enumerate(batch.frags):
            for frag in frags:
                frag_w_dict[frag].append(p[i].item())
                frag_prop_dict[frag].append(scores_list[j])
                i += 1
        
        # Clean up batch and p to free GPU memory
        del batch
        del p
        
        error_frags = get_sanitize_error_frags(frag_w_dict)
        for frag in error_frags:
            del frag_w_dict[frag]
            del frag_prop_dict[frag]

        frag_num_dict = {}
        for frag in frag_w_dict:
            frag_num_dict[frag] = Chem.MolFromSmiles(frag).GetNumAtoms()
        
        # fragment score as defined in equation (6)
        scores = [(np.array(frag_prop_dict[k]) * np.array(frag_w_dict[k])).mean() / np.sqrt(frag_num_dict[k]) for k in frag_w_dict]
        frag_tuples = list(zip(frag_w_dict, scores))
        frag_tuples = sorted(frag_tuples, key=lambda x: x[1], reverse=True)[:self.max_vocab_update]
        frag_tuples = [(frag, score) for frag, score in frag_tuples
                       if frag not in self.vocab['FRAG']]
        
        self.vocab['FRAG_QUEUE'].extend(frag_tuples)
        self.vocab['FRAG_QUEUE'] = sorted(self.vocab['FRAG_QUEUE'],
                                          key=lambda x: x[1], reverse=True)[:self.max_vocab_size]
        self.vocab['FRAG'] = [frag for frag, score in self.vocab['FRAG_QUEUE']]
        self.vocab['FRAG_MOL'] = [Chem.MolFromSmiles(frag) for frag in self.vocab['FRAG']]
        self.vocab['FRAG_ATT'] = [get_att_points(mol) for mol in self.vocab['FRAG_MOL']]
        
        self.action_dims = [40, len(self.vocab['FRAG']), 40]
        self.env.update_vocab(self.vocab)
        self.ac.pi.update_vocab(self.vocab)
        
        # Force garbage collection and GPU cache cleanup
        import gc
        gc.collect()
        torch.cuda.empty_cache()

    def update_beta(self):
        """
        Update dynamic beta weight (call at the end of each episode)
        """
        old_beta = self.beta
        self.beta += self.beta_delta
        self.beta_delta = (self.beta - self.beta_init) * self.beta_expbase + self.beta_init - self.beta
        self.beta = min(self.beta, self.beta_uplim)
        self.beta = max(self.beta, 0.0)
        
        return self.beta

    def get_current_beta(self):
        """
        Get current beta value
        """
        return self.beta

    def compute_loss_q(self, data, shared_w_batch=None):
        ac_first, ac_second, ac_third = data['ac_first'], data['ac_second'], data['ac_third']
        self.ac.q1.train()
        self.ac.q2.train()
        
        o = data['obs']
        _, _, o_g_emb = self.ac.embed(o)
        
        if shared_w_batch is not None:
            w_batch = shared_w_batch
        else:
            w_batch = np.random.randn(self.pref_num, self.reward_dim)
            w_batch = np.abs(w_batch) / np.linalg.norm(w_batch, ord=1, axis=1, keepdims=True)
            w_batch = torch.from_numpy(w_batch.repeat(self.batch_size, axis=0)).to(self.device).float()

        o_g_emb_expanded = o_g_emb.repeat(self.pref_num, 1)
        ac_first_expanded = ac_first.repeat(self.pref_num, 1)
        ac_second_expanded = ac_second.repeat(self.pref_num, 1)
        ac_third_expanded = ac_third.repeat(self.pref_num, 1)
        
        q1 = self.ac.q1(o_g_emb_expanded, ac_first_expanded, ac_second_expanded, ac_third_expanded, w_batch).squeeze()
        q2 = self.ac.q2(o_g_emb_expanded, ac_first_expanded, ac_second_expanded, ac_third_expanded, w_batch).squeeze()

        o2 = data['obs2']
        r, d = data['rew'], data['done']
        
        with torch.no_grad():
            o2_g, o2_n_emb, o2_g_emb = self.ac.embed(o2)
            cands = self.ac.embed(self.ac.pi.cand)

            o2_g_list = dgl.unbatch(o2_g)
            o2_g_list_expanded = o2_g_list * self.pref_num
            o2_g_expanded = dgl.batch(o2_g_list_expanded)
            
            o2_g_emb_expanded = o2_g_emb.repeat(self.pref_num, 1)
            o2_n_emb_expanded = o2_n_emb.repeat(self.pref_num, 1)
            cands_expanded = cands 
            
            _, (a2_prob, log_a2_prob), (ac2_first, ac2_second, ac2_third) = \
                self.ac.pi(o2_g_emb_expanded, o2_n_emb_expanded, o2_g_expanded, cands_expanded, preference=w_batch)
            

            q1_pi_targ = self.ac_targ.q1(o2_g_emb_expanded, ac2_first, ac2_second, ac2_third, w_batch)
            q2_pi_targ = self.ac_targ.q2(o2_g_emb_expanded, ac2_first, ac2_second, ac2_third, w_batch)
            
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ) # shape: [batch_size * pref_num, reward_dim]
            
            w_ext = w_batch.unsqueeze(1) # [batch_size * pref_num, 1, reward_dim]
            q_ext = q_pi_targ.unsqueeze(2) # [batch_size * pref_num, reward_dim, 1]
            
            scalarized_q = torch.bmm(w_ext, q_ext).squeeze() # shape: [batch_size * pref_num]·
            
            actual_batch_size = o_g_emb.shape[0]
            scalarized_q_reshaped = scalarized_q.view(actual_batch_size, self.pref_num)
            
            best_indices_in_pref_dim = torch.argmax(scalarized_q_reshaped, dim=1) # shape: [batch_size]
            
            batch_indices = torch.arange(actual_batch_size).to(self.device)
            linear_indices = batch_indices * self.pref_num + best_indices_in_pref_dim
            envelope_targets = q_pi_targ[linear_indices] # shape: [batch_size, reward_dim]

            backup = r + self.gamma * (1 - d.unsqueeze(1).float()) * envelope_targets
            
            backup_expanded = backup.repeat(self.pref_num, 1) # shape: [batch_size * pref_num, reward_dim]

        # Clean up DGL graph structures to prevent memory accumulation
        del o2_g_list, o2_g_list_expanded, o2_g_expanded
        

        loss_q1_standard = ((q1 - backup_expanded) ** 2).mean()
        loss_q2_standard = ((q2 - backup_expanded) ** 2).mean()
        loss_standard = loss_q1_standard + loss_q2_standard
        
        w_ext = w_batch.unsqueeze(1)
        q1_ext = q1.unsqueeze(2)
        q2_ext = q2.unsqueeze(2)
        wq1 = torch.bmm(w_ext, q1_ext).squeeze()
        wq2 = torch.bmm(w_ext, q2_ext).squeeze()
        
        backup_ext = backup_expanded.unsqueeze(2)
        wtq = torch.bmm(w_ext, backup_ext).squeeze()
        
        loss_wq1 = ((wq1 - wtq) ** 2).mean()
        loss_wq2 = ((wq2 - wtq) ** 2).mean()
        loss_weighted = loss_wq1 + loss_wq2
        
        total_loss = self.beta * loss_weighted + (1 - self.beta) * loss_standard
        
        # Clean up intermediate tensors
        del backup_expanded, w_ext, q1_ext, q2_ext, backup_ext
        
        return total_loss, w_batch

    def compute_loss_pi(self, data, w_batch):
        with torch.no_grad():
            o_embeds = self.ac.embed(data['obs'])   
            o_g, o_n_emb, o_g_emb = o_embeds
            cands = self.ac.embed(self.ac.pi.cand)

        o_g_list = dgl.unbatch(o_g)
        o_g_list_expanded = o_g_list * self.pref_num
        o_g_expanded = dgl.batch(o_g_list_expanded)
        o_g_emb_expanded = o_g_emb.repeat(self.pref_num, 1)
        o_n_emb_expanded = o_n_emb.repeat(self.pref_num, 1)

        _, (ac_prob, log_ac_prob), (ac_first, ac_second, ac_third) = \
            self.ac.pi(o_g_emb_expanded, o_n_emb_expanded, o_g_expanded, cands, preference=w_batch)
        
        # Clean up DGL graph structures immediately after use
        del o_g_list, o_g_list_expanded, o_g_expanded
    
        
        q1_pi = self.ac.q1(o_g_emb_expanded, 
                            ac_first, 
                            ac_second, 
                            ac_third, 
                            w_batch)
        q2_pi = self.ac.q2(o_g_emb_expanded, 
                            ac_first, 
                            ac_second, 
                            ac_third, 
                            w_batch)
        q_pi = torch.min(q1_pi, q2_pi)
        
        w_ext = w_batch.unsqueeze(1)  # [batch_size * pref_num, 1, reward_dim]
        q_ext = q_pi.unsqueeze(2)     # [batch_size * pref_num, reward_dim, 1]
        wq_pi = torch.bmm(w_ext, q_ext).squeeze()  # [batch_size * pref_num]

        
        ac_prob_sp = torch.split(ac_prob, self.action_dims, dim=1)
        log_ac_prob_sp = torch.split(log_ac_prob, self.action_dims, dim=1)
        
        # Store batch size before cleaning up expanded tensors
        batch_size_expanded = ac_prob.shape[0]
        
        # Clean up expanded tensors to save memory
        del o_g_emb_expanded, o_n_emb_expanded
        
        loss_policy = torch.mean(-wq_pi) 

        alpha_tensor = self.log_alpha.exp().clamp(self.alpha_min, self.alpha_max)
        alpha = alpha_tensor.item()

        loss_entropy = 0
        loss_alpha = 0
        ac_prob_comb = torch.einsum('by, bz->byz', ac_prob_sp[1], ac_prob_sp[2]).reshape(batch_size_expanded, -1)
        ac_prob_comb = torch.einsum('bx, bz->bxz', ac_prob_sp[0], ac_prob_comb).reshape(batch_size_expanded, -1)
        
        log_ac_prob_comb = log_ac_prob_sp[0].reshape(batch_size_expanded, self.action_dims[0], 1, 1).repeat(
                                    1, 1, self.action_dims[1], self.action_dims[2]).reshape(batch_size_expanded, -1)\
                            + log_ac_prob_sp[1].reshape(batch_size_expanded, 1, self.action_dims[1], 1).repeat(
                                    1, self.action_dims[0], 1, self.action_dims[2]).reshape(batch_size_expanded, -1)\
                            + log_ac_prob_sp[2].reshape(batch_size_expanded, 1, 1, self.action_dims[2]).repeat(
                                    1, self.action_dims[0], self.action_dims[1], 1).reshape(batch_size_expanded, -1)
        
        loss_entropy = (alpha * ac_prob_comb * log_ac_prob_comb).sum(dim=1).mean()
        loss_alpha = -(self.log_alpha.to(self.device) * \
                        ((ac_prob_comb * log_ac_prob_comb).sum(dim=1) + self.target_entropy).detach()).mean()

        return loss_entropy, loss_policy, loss_alpha
    
    @timeout(900)
    def debug_update(self, data):
        # First run one gradient descent step for Q1 and Q2
        ave_pi_grads, ave_q_grads = [], []
        

        loss_q, w_batch = self.compute_loss_q(data)
        self.q_optimizer.zero_grad()
        loss_q.backward()
        clip_grad_norm_(self.q_params, 5)
        for q in list(self.q_params):
            ave_q_grads.append(q.grad.abs().mean().item())
        
        self.q_optimizer.step()
        self.q_scheduler.step(loss_q)

        # Freeze Q-networks so you don't waste computational effort 
        # computing gradients for them during the policy learning step.
        for q in self.q_params:
            q.requires_grad = False

        loss_entropy, loss_policy, loss_alpha = self.compute_loss_pi(data, w_batch)
        loss_pi = loss_entropy + loss_policy
        self.pi_optimizer.zero_grad()

        loss_pi.backward()
        clip_grad_norm_(self.pi_params, 5)
        for p in self.pi_params:
            ave_pi_grads.append(p.grad.abs().mean().item())
        
        self.pi_optimizer.step()
        self.pi_scheduler.step(loss_policy)
        
        if self.train_alpha:
            if self.alpha_start <= self.t < self.alpha_end:
                self.alpha_optimizer.zero_grad()
                loss_alpha.backward()
                self.alpha_optimizer.step()
        
        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        for p in self.q_params:
            p.requires_grad = True
        
        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(self.ac.parameters(), self.ac_targ.parameters()):
                p_targ.data.mul_(self.polyak)
                p_targ.data.add_((1 - self.polyak) * p.data)
        
        # Clean up memory after training step
        import gc
        del data, loss_q, loss_pi, loss_entropy, loss_policy, loss_alpha
        gc.collect()
        if self.t % 100 == 0:  # Periodic GPU cache cleanup
            torch.cuda.empty_cache()

    def update(self, data):
        try:
            self.debug_update(data)
        except TimeoutError:
            pdb.set_trace()

    def run(self):
        '''
        1. Initialization of environment (Get initial molecules benzene with attachment point)
        '''
        num_generated = 0
        pbar = tqdm(total=self.num_mols, desc="Generated Molecules")
        o = self.env.reset() # benzene with attachment point
        episode_num = 0
        
        while True:
            '''
            2. Action process of Reinforcement Learning
            If start_steps reached, use policy network to determin actions. Else randomly sample actions
            '''
            with torch.no_grad():
                cands = self.ac.embed(self.ac.pi.cand)
                o_embeds = self.ac.embed([o])
                o_g, o_n_emb, o_g_emb = o_embeds
                
                if self.t >= self.start_steps:
                    if np.random.rand() < 0.3:
                        w_explore = np.zeros((1, self.reward_dim))
                        w_explore[0, np.random.randint(self.reward_dim)] = 1.0
                    elif np.random.rand() < 0.5:
                        w_explore = np.ones((1, self.reward_dim)) / self.reward_dim
                    else: 
                        w_explore = np.random.dirichlet(np.ones(self.reward_dim), size=1)
                    w_explore = torch.from_numpy(w_explore).to(self.device).float()
                    
                    # Also sampling via self.ac.pi() after finishing training
                    ac, (ac_prob, log_ac_prob), (ac_first, ac_second, ac_third) = \
                    self.ac.pi(o_g_emb, o_n_emb, o_g, cands, preference=w_explore)
                else:
                    ac = self.env.sample_motif()[np.newaxis] # np.newaxis add at 0-th dimension by default
                    if np.random.rand() < 0.3:
                        w_explore = np.zeros((1, self.reward_dim))
                        w_explore[0, np.random.randint(self.reward_dim)] = 1.0
                    elif np.random.rand() < 0.5:
                        w_explore = np.ones((1, self.reward_dim)) / self.reward_dim
                    else:
                        w_explore = np.random.dirichlet(np.ones(self.reward_dim), size=1)
                    w_explore = torch.from_numpy(w_explore).to(self.device).float()
                    
                    (ac_prob, log_ac_prob), (ac_first, ac_second, ac_third) = \
                    self.ac.pi.sample(ac[0], o_g_emb, o_n_emb, o_g, cands, preference=w_explore)

            ''' 
            3. Interaction with environment
             o2: edited molecules
             r: only reward if d is False
             Warining: r here is not the docking score, just an encouraging reward when generate bigger molecules
             d: True for two circumstances (no attachment points, atom num > 30)
             info: dict with 'stop' when True for atom num > 30, False for no attachment points
            '''
            o2, r, d, info = self.env.step(ac[0])
            # ===== multi objective reward =====
            r = [r for _ in range(self.reward_dim)]
            # ==================================

            r_d = info['stop']

            '''
            4. Replay buffer storage
            '''
            # Only store observations where attachment point exists in o2
            if any(o2['att']):
                if isinstance(ac, np.ndarray):
                    self.replay_buffer.store(o, ac, r, o2, r_d,
                                             ac_prob, log_ac_prob, ac_first, ac_second, ac_third,
                                             o_embeds)
                else:
                    self.replay_buffer.store(o, ac.detach().cpu().numpy(), r, o2, r_d,
                                             ac_prob, log_ac_prob, ac_first, ac_second, ac_third,
                                             o_embeds)

            # Super critical, easy to overlook step: make sure to update most recent observation!
            o = o2

            # End of trajectory handling
            if get_att_points(self.env.mol) == []:  # Temporally force attachment calculation
                d = True
            if not any(o2['att']):
                d = True

            if d:
                o = self.env.reset()
                episode_num += 1
                self.update_beta()

                '''
                5. GA reproduce process to get offspring molecule
                ''' 
                if self.t >= self.start_steps and len(self.population) >= 2:
                    for _ in range(self.ga_steps):
                        offspring = reproduce(self.population, self.population_score, self.mutation_rate)
                        if offspring is not None:
                            self.ga_smiles_list.append(Chem.MolToSmiles(offspring))

            '''
            6. Scoring the generated molecules and update the buffer use true reward (task reward)
            '''
            if self.t > 1 and self.t % self.docking_every == 0 and self.env.smile_list != []:
                n_sac_smi = len(self.env.smile_list)
                self.env.smile_list += self.ga_smiles_list
                n_smi = len(self.env.smile_list)
                if n_smi > 0:
                    rews, ext_rew = self.env.reward_batch_res_interaction()
                    rews_array = np.array(rews).T
                    r_batch = rews_array[:n_sac_smi]

                    self.replay_buffer.rew_store(r_batch)
                    

                    with open(self.fname, 'a') as f:
                        for i in range(n_smi):
                            str = f'{self.env.smile_list[i]},'
                            for rew in rews: str += f'{rew[i]},'
                            str += (f'{ext_rew[i]}' + '\n')
                            f.write(str)
                    
                    mols = [Chem.MolFromSmiles(s) for s in self.env.smile_list]

                    # update vocab
                    if self.t >= self.start_steps:
                        self.update_vocab(mols[n_sac_smi:], ext_rew[n_sac_smi:])

                    # GA population handling
                    self.population.extend(mols)
                    self.population_score.extend(ext_rew)
                    population_tuples = list(zip(self.population, self.population_score))
                    population_tuples = sorted(population_tuples, key=lambda x: x[1], reverse=True)[:self.population_size]
                    self.population = [t[0] for t in population_tuples]
                    self.population_score = [t[1] for t in population_tuples]

                    num_generated += n_smi
                    self.save_threshold += n_smi
                    pbar.update(n_smi)
                    
                    if num_generated >= self.num_mols:
                        pbar.close()
                        break
                    
                    self.env.reset_batch() # Clean up self.smile_list
                    
                    self.ga_smiles_list = []
                    print("=== Reward Once ===")

            elif self.t > 1 and self.t % self.docking_every == 0 and self.env.smile_list == []:
                print("The SMILES list is EMPTY !!!") # Here is the problem, the smiles list is empty when we need reward

            '''
            7. Training using replay buffer with batched data
            '''
            if self.t >= self.update_after and self.t % self.update_every == 0:
                for j in range(self.update_every):
                    batch = self.replay_buffer.sample_batch(self.device, self.batch_size)
                    self.update(data=batch)
                    
                # Periodic memory cleanup during training
                if self.t % (self.update_every * 10) == 0:
                    print_gpu_memory_usage(self.t)
                    import gc
                    gc.collect()
                    torch.cuda.empty_cache()
                    print(f"Step {self.t}: GPU memory cleaned")
                    print_gpu_memory_usage()
            
            '''
            8. Save the policy and critic models
            '''
            if self.t >= self.update_after and self.save_threshold >= 500:
                save_path = Path(f"{self.ckptname}SMILES_{num_generated}.pt")
                save_path.parent.mkdir(parents=True, exist_ok=True)
                torch.save(self, save_path)
                self.save_threshold = 0
                print("Model Saved")
            
            self.t += 1
        print(f"Total episode number: {episode_num}")
    
    def generate(self, batch_size, preference=None):
        # Env reset
        o = self.env.reset()
        sampled_smiles = []
        
        if preference is None:
            preference = torch.ones(self.reward_dim).to(self.device) / self.reward_dim
        else:
            preference = torch.tensor(preference).to(self.device).float()
        
        preference = preference / preference.sum()

        self.ac.eval()

        # Start generate and store smiles in self.env.smile_list
        while True:
            with torch.no_grad():
                cands = self.ac.embed(self.ac.pi.cand)
                o_embeds = self.ac.embed([o])
                o_g, o_n_emb, o_g_emb = o_embeds

                # Sample the action sequence with preference
                preference_batch = preference.unsqueeze(0)  # [1, reward_dim]
                ac, _, _ = self.ac.pi(o_g_emb, o_n_emb, o_g, cands, preference=preference_batch)
            
            o2, r, d, info = self.env.step(ac[0])

            o = o2
            if get_att_points(self.env.mol) == []:
                d = True
            if not any(o2['att']):
                d = True

            if d:
                o = self.env.reset()

            if len(self.env.smile_list) == batch_size:
                sampled_smiles = self.env.smile_list.copy()
                # rews, ext_rew = self.env.reward_batch_guaca_mpo(name=self.mpo_task, scale=True)
                rews, ext_rew = self.env.reward_batch_res_interaction_unscaled()
                self.env.reset_batch()
                break

        return sampled_smiles, rews, ext_rew



    def preference_sensitivity(self, num_states=64, preference1=None, preference2=None, rand_steps=2, seed=None):
        """
        Measure actor sensitivity to preferences by comparing action distributions
        under two different preferences on a batch of states.

        Returns a dict with KL(p1||p2), KL(p2||p1), symmetric KL, and L2 metrics (means/stds).
        """
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)

        self.ac.eval()

        # Build two normalized preferences on simplex
        if preference1 is None:
            w1 = np.random.rand(self.reward_dim).astype(np.float32)
        else:
            w1 = np.asarray(preference1, dtype=np.float32)
        if preference2 is None:
            w2 = np.random.rand(self.reward_dim).astype(np.float32)
        else:
            w2 = np.asarray(preference2, dtype=np.float32)

        w1 = w1 / (w1.sum() + 1e-8)
        w2 = w2 / (w2.sum() + 1e-8)

        obs_list = []
        # Collect a batch of diverse states via short random walks from reset
        for i in range(num_states):
            o = self.env.reset()
            for _ in range(max(0, rand_steps)):
                ac = self.env.sample_motif()[np.newaxis]
                o2, r, d, info = self.env.step(ac[0])
                o = o2
                if not any(o2['att']):
                    break
            obs_list.append(o)

        with torch.no_grad():
            o_g, o_n_emb, o_g_emb = self.ac.embed(obs_list)
            cands = self.ac.embed(self.ac.pi.cand)

            bsz = o_g_emb.shape[0]
            w1_t = torch.from_numpy(w1).to(self.device).float().unsqueeze(0).repeat(bsz, 1)
            w2_t = torch.from_numpy(w2).to(self.device).float().unsqueeze(0).repeat(bsz, 1)

            # Get actor distributions under the two preferences
            _, (ac_prob1, log_ac_prob1), _ = self.ac.pi(o_g_emb, o_n_emb, o_g, cands, preference=w1_t)
            _, (ac_prob2, log_ac_prob2), _ = self.ac.pi(o_g_emb, o_n_emb, o_g, cands, preference=w2_t)

            # Split into factorized components to compute exact KL of product distribution
            p1 = torch.split(ac_prob1, self.action_dims, dim=1)
            q1 = torch.split(ac_prob2, self.action_dims, dim=1)
            lp1 = torch.split(log_ac_prob1, self.action_dims, dim=1)
            lq1 = torch.split(log_ac_prob2, self.action_dims, dim=1)

            # KL(p||q) = sum_i p_i * (log p_i - log q_i)
            kl_pq_parts = []
            kl_qp_parts = []
            for i in range(3):
                kl_pq_parts.append((p1[i] * (lp1[i] - lq1[i])).sum(dim=1))
                kl_qp_parts.append((q1[i] * (lq1[i] - lp1[i])).sum(dim=1))
            kl_pq = sum(kl_pq_parts)  # shape [bsz]
            kl_qp = sum(kl_qp_parts)
            sym_kl = 0.5 * (kl_pq + kl_qp)

            # L2 distance between concatenated probabilities
            l2 = torch.norm(ac_prob1 - ac_prob2, dim=1)

        metrics = {
            'kl_p1_p2_mean': kl_pq.mean().item(),
            'kl_p1_p2_std': kl_pq.std(unbiased=False).item(),
            'kl_p2_p1_mean': kl_qp.mean().item(),
            'kl_p2_p1_std': kl_qp.std(unbiased=False).item(),
            'sym_kl_mean': sym_kl.mean().item(),
            'sym_kl_std': sym_kl.std(unbiased=False).item(),
            'l2_mean': l2.mean().item(),
            'l2_std': l2.std(unbiased=False).item(),
        }

        # Clean up any accumulated batch states in env
        if hasattr(self.env, 'reset_batch'):
            try:
                self.env.reset_batch()
            except Exception:
                pass

        return metrics



