import csv
import json
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
from contextlib import nullcontext
import numpy as np
import math
import random
import os
from pathlib import Path
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional, Callable, Any
import pytorch_lightning as pl
from collections import defaultdict
from multiprocessing import Pool, cpu_count, TimeoutError as MPTimeoutError
import torch.multiprocessing as mp
from functools import partial
from rdkit import Chem
from omegaconf import ListConfig

import utils
from flow_matching import flow_matching_utils
from grpo_rewards import create_reward_function, MolecularValidityReward, resolve_target_task
from graph_discrete_flow_model import GraphDiscreteFlowModel
from eval_gdpo_docking import (
    gdpo_eval_smiles,
    gdpo_get_sim_threshold,
    gdpo_load_train_fps,
)
from src.grpo_core import GRPOCore
from src.trajectory_data import TrajectoryData
try:
    import swanlab
except ImportError:
    swanlab = None

Graph = Tuple[torch.Tensor, torch.Tensor]



def _set_single_thread_env(num_threads: int = 1, *, force: bool = False) -> None:
    num_threads_str = str(int(num_threads))
    for key in (
        "OMP_NUM_THREADS",
        "MKL_NUM_THREADS",
        "OPENBLAS_NUM_THREADS",
        "NUMEXPR_NUM_THREADS",
        "VECLIB_MAXIMUM_THREADS",
        "BLIS_NUM_THREADS",
    ):
        if force:
            os.environ[key] = num_threads_str
        else:
            os.environ.setdefault(key, num_threads_str)
    try:
        torch.set_num_threads(int(num_threads))
        torch.set_num_interop_threads(int(num_threads))
    except Exception:
        pass


_WORKER_REWARD_FUNC = None

def _reward_worker_initializer(num_threads: int = 1, reward_type: str = None, reward_kwargs: dict = None) -> None:
    import warnings
    warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*")
    warnings.filterwarnings("ignore", message=".*Boto3 will no longer support Python 3.9.*")
    warnings.filterwarnings("ignore", category=FutureWarning)
    import sys
    _set_single_thread_env(num_threads=num_threads, force=True)
    sys.stdout = sys.stderr
    
    global _WORKER_REWARD_FUNC
    if reward_type:
        import torch
        from grpo_rewards import create_reward_function
        device = torch.device('cpu')
        try:
            _WORKER_REWARD_FUNC = create_reward_function(reward_type, device=device, **(reward_kwargs or {}))
        except Exception as e:
            print(f"❌ [Worker {os.getpid()}] Failed to initialize persistent reward function: {e}", file=sys.stderr, flush=True)

def _compute_batch_rewards_worker(batch_graphs, reward_type: str, device_str: str, reward_kwargs: Optional[Dict] = None):
    import torch
    global _WORKER_REWARD_FUNC
    
    device = torch.device(device_str)
    
    if _WORKER_REWARD_FUNC is not None:
        reward_func = _WORKER_REWARD_FUNC
    else:
        from grpo_rewards import create_reward_function
        reward_func = create_reward_function(reward_type, device=device, **(reward_kwargs or {}))

    processed_graphs = []
    for atom_data, edge_data in batch_graphs:
        if torch.is_tensor(atom_data):
            atom_tensor = atom_data.detach().to(device)
        else:
            atom_tensor = torch.as_tensor(atom_data, device=device)

        if torch.is_tensor(edge_data):
            edge_tensor = edge_data.detach().to(device)
        else:
            edge_tensor = torch.as_tensor(edge_data, device=device)

        processed_graphs.append((atom_tensor, edge_tensor))

    with torch.no_grad():
        rewards = reward_func(processed_graphs)

    if isinstance(rewards, torch.Tensor):
        return rewards.cpu().tolist()
    return rewards


class GRPOTrainer:
    
    def __init__(
        self,
        model: nn.Module,
        reward_function: Callable,
        cfg: Dict,
        model_kwargs: dict,
    ):
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        _set_single_thread_env()
        
        self.model = model
        self.reward_function = reward_function
        self.reward_kwargs = self._prepare_reward_kwargs(reward_function, cfg)
        self.cfg = cfg
        self.model_kwargs = model_kwargs
        grpo_config = cfg.grpo
        
        try:
            mp.set_sharing_strategy("file_system")
        except RuntimeError:
            pass
        
        self.grpo_core = GRPOCore(cfg)
        
        self.num_reward_workers = grpo_config.get('num_reward_workers', min(cpu_count(), 8))
        self.reward_worker_threads = int(grpo_config.get("reward_worker_threads", 1))
        
        reward_type = getattr(reward_function, "name", None) or grpo_config.get("reward_type", "default")
        
        reward_type_lower = reward_type.lower() if isinstance(reward_type, str) else "default"
        self.reward_pool = None
        if self.num_reward_workers > 0 and reward_type_lower != "disabled_reward":
            self.reward_pool = mp.get_context('spawn').Pool(
                processes=self.num_reward_workers,
                initializer=_reward_worker_initializer,
                initargs=(self.reward_worker_threads, reward_type, self.reward_kwargs),
            )
        self.eval_timeout_seconds = grpo_config.get('eval_timeout_seconds', 600)
        
        self.num_batches_per_epoch = grpo_config.get('num_batches_per_epoch', 1)
        self.train_batch_size = grpo_config.get('train_batch_size', 32)
        self.group_size = grpo_config.get('group_size', 8)
        self.concurrent_sampling_groups = max(1, grpo_config.get('concurrent_sampling_groups', 1))
        self.num_inner_epochs = grpo_config.get('num_inner_epochs', 1)
        self.gradient_accumulation_steps = grpo_config.get('gradient_accumulation_steps', 1)
        self.sample_group_num = grpo_config.get('sample_group_num', 1000)
        self._next_group_id = 0
        self.eval_interval = max(1, grpo_config.get('eval_interval', 5))
        
        self.learning_rate = grpo_config.learning_rate
        self.adam_beta1 = grpo_config.get('adam_beta1', 0.9)
        self.adam_beta2 = grpo_config.get('adam_beta2', 0.999)
        self.adam_weight_decay = grpo_config.get('adam_weight_decay', 1e-4)
        self.adam_epsilon = grpo_config.get('adam_epsilon', 1e-8)
        
        
        self.ref_model_update_freq = grpo_config.get('ref_model_update_freq', 200)
        self.beta = grpo_config.kl_penalty
        
        default_sample_steps = getattr(cfg.sample, 'sample_steps', 100)
        grpo_forward_steps = grpo_config.get('forward_steps', None)
        self.sample_steps = (
            grpo_forward_steps
            if grpo_forward_steps is not None
            else default_sample_steps
        )
        self.target_node_count = grpo_config.get('target_node_count', None)
        self.node_count_min = grpo_config.get('node_count_min', None)
        self.node_count_max = grpo_config.get('node_count_max', 256)

        self._refine_seed_smiles = self._normalize_smiles_list(
            self._get_cfg_value(cfg, "grpo.refine_seed_smiles")
        )
        self._refine_seed_graphs = self._build_refine_seed_graphs(self._refine_seed_smiles)
        self._refine_seed_smiles_set = set(self._refine_seed_smiles)
        self._lead_eval_seed_cache: Dict[str, List[str]] = {}
        self._refine_topk = int(grpo_config.get("refine_topk", 0) or 0)
        self._refine_topk_graphs: List[Tuple[float, Graph]] = []


        self.enable_dynamic_node_dist = bool(grpo_config.get("enable_dynamic_node_dist", False))
        self.dynamic_node_dist_alpha = float(grpo_config.get("dynamic_node_dist_alpha", 0.05))
        self.dynamic_node_dist_reward_threshold = float(grpo_config.get("dynamic_node_dist_reward_threshold", 0.001))
        self._pending_node_count_prob_update = None
        self.lr_decay_threshold = grpo_config.get('lr_decay_threshold', None)
        if self.lr_decay_threshold is not None:
            self.lr_decay_threshold = float(self.lr_decay_threshold)
        self.lr_decay_window = int(grpo_config.get('lr_decay_window', 3) or 3)
        self.lr_decay_factor = float(grpo_config.get('lr_decay_factor', 0.5))
        self.lr_decay_min = grpo_config.get('lr_decay_min', None)
        if self.lr_decay_min is not None:
            self.lr_decay_min = float(self.lr_decay_min)
        self._lr_decay_applied = False
        self._lr_decay_history = []
        self.use_grpo_step_probs_for_sampling = grpo_config.get('use_grpo_step_probs_for_sampling', True)
        
        self.global_step = 0
        self.epoch = 0
        
        self.global_p0_buffer = [] 
        self.p0_buffer_size = 1000
        self._pending_p0_update = None

        self._initialize_training_components()
        
        
        if self.target_node_count is not None:
            print(f"   [anonymized]: {self.target_node_count} ([anonymized])")
        else:
            print(f"   [anonymized]: [anonymized] (min={self.node_count_min}, max={self.node_count_max})")

    def _prepare_reward_kwargs(self, reward_function: Callable, cfg) -> Dict[str, Any]:
        reward_type = getattr(reward_function, "name", None) or getattr(cfg.grpo, "reward_type", "default")
        reward_type = reward_type.lower() if isinstance(reward_type, str) else "default"
        reward_kwargs: Dict[str, Any] = {}

        dataset_info = getattr(self.model, "dataset_info", None)
        if dataset_info is not None:
            atom_decoder = getattr(dataset_info, "atom_decoder", None)
            if atom_decoder is not None:
                try:
                    reward_kwargs["atom_decoder"] = list(atom_decoder)
                except Exception:
                    reward_kwargs["atom_decoder"] = atom_decoder

        if reward_type in ("molecular_validity", "guacamol_reward", "gracamol_reward", "gracamol", "validity_connectivity", "valid_connectivity"):
            grpo_cfg = getattr(cfg, "grpo", {})
            try:
                dist_coef = grpo_cfg.get("dist_coef", None)
            except AttributeError:
                dist_coef = getattr(grpo_cfg, "dist_coef", None)
            if dist_coef is None:
                try:
                    dist_coef = grpo_cfg.get("reward_dist_coef", None)
                except AttributeError:
                    dist_coef = getattr(grpo_cfg, "reward_dist_coef", None)
            if dist_coef is not None:
                reward_kwargs["dist_coef"] = float(dist_coef)
            edge_dist_factor = grpo_cfg.get("edge_dist_factor", None) if hasattr(grpo_cfg, "get") else getattr(grpo_cfg, "edge_dist_factor", None)
            if edge_dist_factor is not None:
                reward_kwargs["edge_dist_factor"] = float(edge_dist_factor)

            dataset_info = getattr(self.model, "dataset_info", None)
            if dataset_info is not None:
                node_dist = getattr(dataset_info, "node_types", None)
                edge_dist = getattr(dataset_info, "edge_types", None)
                atom_decoder = getattr(dataset_info, "atom_decoder", None)

                serialized_node = self._serialize_distribution_for_worker(node_dist)
                serialized_edge = self._serialize_distribution_for_worker(edge_dist)
                if serialized_node is not None:
                    reward_kwargs["target_node_dist"] = serialized_node
                if serialized_edge is not None:
                    reward_kwargs["target_edge_dist"] = serialized_edge

        if reward_type in ("target_mpo", "guacamol_mpo", "target_goal"):
            target_task = resolve_target_task(cfg)
            if target_task:
                reward_kwargs["target_task"] = target_task
        if reward_type in ("tdc_oracle", "tdc_pmo", "pmo"):
            grpo_cfg = getattr(cfg, "grpo", {})
            getter = grpo_cfg.get if hasattr(grpo_cfg, "get") else lambda k, d=None: getattr(grpo_cfg, k, d)

            tdc_oracles = getter("tdc_oracles", None)
            tdc_oracle = getter("tdc_oracle", None)
            if tdc_oracles is not None:
                try:
                    reward_kwargs["tdc_oracles"] = list(tdc_oracles)
                except Exception:
                    reward_kwargs["tdc_oracles"] = tdc_oracles
            elif tdc_oracle is not None:
                reward_kwargs["tdc_oracle"] = tdc_oracle

            tdc_aggregation = getter("tdc_aggregation", None)
            if tdc_aggregation is not None:
                reward_kwargs["tdc_aggregation"] = tdc_aggregation
            tdc_weights = getter("tdc_weights", None)
            if tdc_weights is not None:
                try:
                    reward_kwargs["tdc_weights"] = [float(x) for x in list(tdc_weights)]
                except Exception:
                    reward_kwargs["tdc_weights"] = tdc_weights
            tdc_minimize = getter("tdc_minimize", None)
            if tdc_minimize is not None:
                reward_kwargs["tdc_minimize"] = bool(tdc_minimize)
            tdc_invalid_score = getter("tdc_invalid_score", None)
            if tdc_invalid_score is not None:
                reward_kwargs["tdc_invalid_score"] = float(tdc_invalid_score)
            tdc_clip_min = getter("tdc_clip_min", None)
            if tdc_clip_min is not None:
                reward_kwargs["tdc_clip_min"] = float(tdc_clip_min)
            tdc_clip_max = getter("tdc_clip_max", None)
            if tdc_clip_max is not None:
                reward_kwargs["tdc_clip_max"] = float(tdc_clip_max)

            tdc_home = getter("tdc_home", None)
            try:
                repo_root = Path(__file__).resolve().parents[1]
            except Exception:
                repo_root = None

            if tdc_home is None and repo_root is not None and (repo_root / "oracle").is_dir():
                oracle_names: List[str] = []
                if tdc_oracles is not None:
                    try:
                        oracle_names = [str(x) for x in list(tdc_oracles)]
                    except Exception:
                        oracle_names = [str(tdc_oracles)]
                elif tdc_oracle is not None:
                    oracle_names = [str(tdc_oracle)]

                oracle_dir = repo_root / "oracle"
                has_local_pkl = False
                for name in oracle_names:
                    if (oracle_dir / f"{name}.pkl").is_file() or (oracle_dir / f"{name}_current.pkl").is_file():
                        has_local_pkl = True
                        break
                if has_local_pkl:
                    tdc_home = str(repo_root)

            if tdc_home is not None:
                try:
                    tdc_path = Path(os.path.expanduser(str(tdc_home)))
                    if not tdc_path.is_absolute() and repo_root is not None:
                        tdc_path = repo_root / tdc_path
                    if tdc_path.is_dir() and tdc_path.name == "oracle":
                        tdc_path = tdc_path.parent
                    reward_kwargs["tdc_home"] = str(tdc_path.resolve())
                except Exception:
                    reward_kwargs["tdc_home"] = str(tdc_home)

            dataset_info = getattr(self.model, "dataset_info", None)
            if dataset_info is not None:
                pass

        if reward_type in ("gdpo_docking", "gdpo"):
            try:
                grpo_cfg = cfg.grpo
            except Exception:
                grpo_cfg = None
            if grpo_cfg is not None:
                try:
                    target_name = grpo_cfg.get("target_name", None)
                except Exception:
                    target_name = getattr(grpo_cfg, "target_name", None)
                if target_name is not None:
                    reward_kwargs["target_name"] = target_name
                dataset_name = getattr(getattr(cfg, "dataset", None), "name", "") or ""
                datadir = getattr(getattr(cfg, "dataset", None), "datadir", None)
                remove_h = getattr(getattr(cfg, "dataset", None), "remove_h", None)
                if dataset_name:
                    reward_kwargs["dataset_name"] = dataset_name
                if datadir is not None:
                    reward_kwargs["datadir"] = datadir
                if remove_h is not None:
                    reward_kwargs["remove_h"] = bool(remove_h)
                sim_override = None
                try:
                    sim_override = grpo_cfg.get("gdpo_sim_threshold", None)
                except Exception:
                    sim_override = getattr(grpo_cfg, "gdpo_sim_threshold", None)
                if sim_override is None:
                    try:
                        sim_override = grpo_cfg.get("gdpo_eval_sim_threshold", None)
                    except Exception:
                        sim_override = getattr(grpo_cfg, "gdpo_eval_sim_threshold", None)
                reward_kwargs["sim_threshold"] = gdpo_get_sim_threshold(
                    dataset_name,
                    override=sim_override,
                )
                try:
                    sa_threshold = grpo_cfg.get("gdpo_sa_threshold", None)
                except Exception:
                    sa_threshold = getattr(grpo_cfg, "gdpo_sa_threshold", None)
                if sa_threshold is not None:
                    reward_kwargs["sa_threshold"] = float(sa_threshold)
                try:
                    dock_exhaustiveness = grpo_cfg.get("gdpo_dock_exhaustiveness", None)
                except Exception:
                    dock_exhaustiveness = getattr(grpo_cfg, "gdpo_dock_exhaustiveness", None)
                if dock_exhaustiveness is not None:
                    reward_kwargs["dock_exhaustiveness"] = int(dock_exhaustiveness)
                try:
                    dock_num_modes = grpo_cfg.get("gdpo_dock_num_modes", None)
                except Exception:
                    dock_num_modes = getattr(grpo_cfg, "gdpo_dock_num_modes", None)
                if dock_num_modes is not None:
                    reward_kwargs["dock_num_modes"] = int(dock_num_modes)
                try:
                    dock_timeout = grpo_cfg.get("gdpo_dock_timeout", None)
                except Exception:
                    dock_timeout = getattr(grpo_cfg, "gdpo_dock_timeout", None)
                if dock_timeout is not None:
                    reward_kwargs["dock_timeout"] = int(dock_timeout)

        if reward_type in ("planar_graph", "planar", "sbm", "sbm_graph", "tree", "tree_graph"):
            if hasattr(reward_function, "state_dict_for_workers"):
                try:
                    reward_kwargs.update(reward_function.state_dict_for_workers())
                except Exception as e:
                    print(f"⚠️ Failed to serialize distribution-matching reward stats for workers: {e}")

        return reward_kwargs

    @staticmethod
    def _serialize_distribution_for_worker(dist) -> Optional[List[float]]:
        if dist is None:
            return None
        if torch.is_tensor(dist):
            arr = dist.detach().cpu().float()
            total = float(arr.sum())
            if total > 0:
                arr = arr / total
            return arr.tolist()
        if isinstance(dist, np.ndarray):
            arr = dist.astype(float)
            total = float(arr.sum())
            if total > 0:
                arr = arr / total
            return arr.tolist().copy()
        if isinstance(dist, (list, tuple)):
            arr = np.array(dist, dtype=float)
            total = float(arr.sum())
            if total > 0:
                arr = arr / total
            return arr.tolist()
        if isinstance(dist, dict):
            if not dist:
                return None
            max_idx = max(int(k) for k in dist.keys())
            values = [0.0] * (max_idx + 1)
            for k, v in dist.items():
                idx = int(k)
                if idx >= len(values):
                    values.extend([0.0] * (idx - len(values) + 1))
                values[idx] = float(v)
            total = sum(values)
            if total > 0:
                values = [v / total for v in values]
            return values
        return None

    def _initialize_training_components(self):
        self.core_model = self.model
    
        for param in self.core_model.parameters():
            param.requires_grad = True
        self.reference_model = None
        self._create_reference_model()


    def _debug_tdc_zero_reward_batch(self, graph_list: List, *, max_samples: int = 32) -> None:
        try:
            from rdkit import Chem
            from src.analysis.rdkit_functions import build_molecule
        except Exception as e:
            print(f"⚠️ [TDC Debug] [anonymized] RDKit/build_molecule: {e}")
            return

        dataset_info = getattr(self.model, "dataset_info", None)
        atom_decoder = getattr(dataset_info, "atom_decoder", None) if dataset_info is not None else None
        if not atom_decoder:
            print("⚠️ [TDC Debug] [anonymized] dataset_info.atom_decoder，[anonymized]->SMILES [anonymized]")
            return

        n_total = min(int(max_samples), len(graph_list))
        if n_total <= 0:
            return

        built_ok = 0
        sanitize_ok = 0
        smiles_ok = 0
        err_counter: Dict[str, int] = defaultdict(int)
        example_smiles: List[str] = []

        for at, et in graph_list[:n_total]:
            try:
                at_t = at.detach().cpu() if torch.is_tensor(at) else torch.as_tensor(at)
                et_t = et.detach().cpu() if torch.is_tensor(et) else torch.as_tensor(et)
                if at_t.dim() == 2:
                    at_t = at_t.argmax(dim=-1)
                if et_t.dim() == 3:
                    et_t = et_t.argmax(dim=-1)
                at_t = at_t.to(dtype=torch.long)
                et_t = et_t.to(dtype=torch.long)
            except Exception as e:
                err_counter[f"tensor_cast:{type(e).__name__}"] += 1
                continue

            try:
                mol = build_molecule(at_t, et_t, atom_decoder)
                built_ok += 1
            except Exception as e:
                err_counter[f"build_molecule:{type(e).__name__}"] += 1
                continue

            try:
                Chem.SanitizeMol(mol)
                sanitize_ok += 1
            except Exception as e:
                err_counter[f"sanitize:{type(e).__name__}"] += 1
                continue

            try:
                smi = Chem.MolToSmiles(mol)
                if smi:
                    smiles_ok += 1
                    if len(example_smiles) < 5:
                        example_smiles.append(smi)
            except Exception as e:
                err_counter[f"smiles:{type(e).__name__}"] += 1
                continue

        common_err = None
        if err_counter:
            common_err = max(err_counter.items(), key=lambda kv: kv[1])[0]

        print(
            "⚠️ [TDC Debug] [anonymized] 0；[anonymized]->SMILES [anonymized]："
            f"n={n_total}, built={built_ok}, sanitize={sanitize_ok}, smiles={smiles_ok}"
            + (f", top_err={common_err}" if common_err else "")
        )
        if example_smiles:
            print(f"   [anonymized]([anonymized]{len(example_smiles)}[anonymized]): {example_smiles}")

        try:
            if example_smiles:
                from tdc import Oracle

                oracle_name = None
                try:
                    oracle_name = self.cfg.grpo.get("tdc_oracle", None)
                except Exception:
                    oracle_name = getattr(self.cfg.grpo, "tdc_oracle", None)
                if oracle_name:
                    oracle = Oracle(name=str(oracle_name))
                    raw = oracle(example_smiles[: min(5, len(example_smiles))])
                    arr = np.asarray(raw)
                    if arr.ndim == 0:
                        vals = [float(arr)]
                    else:
                        vals = [float(x) for x in arr.reshape(-1).tolist()]
                    print(f"   PyTDC oracle('{oracle_name}') [anonymized]: {vals}")
        except Exception as e:
            print(f"   ⚠️ PyTDC oracle [anonymized]([anonymized]): {e}")
            try:
                import traceback
                traceback.print_exc()
            except Exception:
                pass

    def _create_reference_model(self):
        if self.beta == 0:
            return
        
        device = next(self.model.parameters()).device
        
        self.reference_model = GraphDiscreteFlowModel(cfg=self.cfg, **self.model_kwargs).to(device)
        
        loaded_from_pretrained = False
        ref_ckpt_path = self.cfg.grpo.get('pretrained_checkpoint')
        if ref_ckpt_path:
            ref_ckpt_path = os.path.expanduser(ref_ckpt_path)
            if os.path.exists(ref_ckpt_path):
                loaded_from_pretrained = self._load_reference_model_from_checkpoint(ref_ckpt_path)
            else:
                print(f"⚠️ [anonymized]checkpoint[anonymized]: {ref_ckpt_path}")
        
        if not loaded_from_pretrained:
            with torch.no_grad():
                state_dict = self.core_model.state_dict()
                self.reference_model.load_state_dict(state_dict)

        try:
            sync_p0 = bool(self.cfg.grpo.get("reference_sync_p0", True))
        except Exception:
            sync_p0 = True

        if sync_p0:
            try:
                if (
                    hasattr(self.core_model, "p0_node_dist")
                    and hasattr(self.core_model, "p0_edge_dist")
                    and hasattr(self.reference_model, "update_limit_dist")
                ):
                    self.reference_model.update_limit_dist(
                        self.core_model.p0_node_dist,
                        self.core_model.p0_edge_dist,
                    )
                if (
                    hasattr(self.core_model, "node_count_prob")
                    and hasattr(self.reference_model, "update_node_count_dist")
                ):
                    self.reference_model.update_node_count_dist(self.core_model.node_count_prob)
            except Exception as e:
                print(f"⚠️ [anonymized](p0/node_count)[anonymized]: {e}")
        
        for param in self.reference_model.parameters():
            param.requires_grad = False
        
        self.reference_model.eval()
        
        first_param = next(self.core_model.parameters())
    
    def _load_reference_model_from_checkpoint(self, ckpt_path: str) -> bool:
        try:
            checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False)
            state_dict = checkpoint.get('state_dict', checkpoint)
            if not isinstance(state_dict, dict):
                print(f"⚠️ [anonymized]checkpoint[anonymized]（[anonymized]dict state_dict）: {type(state_dict)}")
                return False

            model_state = self.reference_model.state_dict()
            model_keys = set(model_state.keys())

            def _score_strip_prefix(prefix: str) -> int:
                if not prefix:
                    return sum(1 for k in state_dict.keys() if k in model_keys)
                n = 0
                for k in state_dict.keys():
                    if isinstance(k, str) and k.startswith(prefix):
                        if k[len(prefix):] in model_keys:
                            n += 1
                return n

            prefixes = [
                "",
                "model.",
                "module.",
                "module.model.",
                "model.module.",
                "module.model.model.",
                "model.model.",
            ]
            best_prefix = max(prefixes, key=_score_strip_prefix)
            best_match = _score_strip_prefix(best_prefix)

            if best_prefix and best_match > 0:
                remapped = {}
                for k, v in state_dict.items():
                    new_k = k[len(best_prefix):] if isinstance(k, str) and k.startswith(best_prefix) else k
                    if new_k not in remapped:
                        remapped[new_k] = v
                state_dict = remapped

            filtered_state = {}
            resized_keys = []
            dropped_keys = []

            def _resize_1d_distribution(src_tensor, target_tensor):
                device = target_tensor.device
                dtype = target_tensor.dtype
                tgt_len = target_tensor.numel()
                out = torch.zeros(tgt_len, device=device, dtype=dtype)
                copy_len = min(src_tensor.numel(), tgt_len)
                out[:copy_len] = src_tensor.reshape(-1)[:copy_len].to(device=device, dtype=dtype)
                total = out.sum()
                if total > 0:
                    out = out / total
                return out

            for k, v in state_dict.items():
                if k in model_state and hasattr(model_state[k], "shape") and hasattr(v, "shape"):
                    if model_state[k].shape == v.shape:
                        filtered_state[k] = v
                    else:
                        if (
                            len(model_state[k].shape) == 1
                            and len(v.shape) == 1
                            and ("sampling_metrics" in k or k in ("p0_node_dist", "p0_edge_dist", "node_count_prob"))
                        ):
                            resized = _resize_1d_distribution(v, model_state[k])
                            filtered_state[k] = resized
                            resized_keys.append(k)
                        else:
                            dropped_keys.append(k)
                else:
                    filtered_state[k] = v

            if resized_keys:
                print(f"   🔧 [anonymized]: {resized_keys}")
            if dropped_keys:
                print(f"   ⚠️ [anonymized]（shape[anonymized]）: {dropped_keys}")

            incompatible = self.reference_model.load_state_dict(filtered_state, strict=False)
            matched_keys = sum(1 for k in filtered_state.keys() if k in model_state)
            ratio = (matched_keys / max(1, len(model_state)))
            if ratio < 0.3:
                print(
                    "⚠️ [anonymized]checkpoint[anonymized]，[anonymized]（KL[anonymized]）。\n"
                    f"   ckpt={ckpt_path}\n"
                    f"   matched={matched_keys}/{len(model_state)} ({ratio:.1%})\n"
                    "   [anonymized] reference=[anonymized]（[anonymized]KL[anonymized]）。"
                )
                return False

            print(f"✅ [anonymized]checkpoint[anonymized]: {ckpt_path} (matched={matched_keys}/{len(model_state)}; {ratio:.1%})")
            try:
                if hasattr(self.reference_model, "p0_node_dist") and hasattr(self.reference_model, "p0_edge_dist"):
                    self.reference_model.update_limit_dist(self.reference_model.p0_node_dist, self.reference_model.p0_edge_dist)
            except Exception as e:
                print(f"   ⚠️ [anonymized] p0 buffers [anonymized]: {e}")
            try:
                if hasattr(self.reference_model, "node_count_prob"):
                    self.reference_model.update_node_count_dist(self.reference_model.node_count_prob)
            except Exception as e:
                print(f"   ⚠️ [anonymized] node_count_prob [anonymized]: {e}")
            if getattr(incompatible, 'missing_keys', None):
                mk = list(incompatible.missing_keys)
                print(f"   ⚠️ [anonymized]: {len(mk)} (showing up to 20): {mk[:20]}")
            if getattr(incompatible, 'unexpected_keys', None):
                uk = list(incompatible.unexpected_keys)
                print(f"   ⚠️ [anonymized]: {len(uk)} (showing up to 20): {uk[:20]}")
            return True
        except Exception as e:
            print(f"⚠️ [anonymized]checkpoint[anonymized] ({ckpt_path}): {e}")
            return False
    
    def _update_reference_model(self):
        if self.reference_model is None or self.beta == 0:
            return
        
        if self.global_step > 0 and self.global_step % self.ref_model_update_freq == 0:
            tau = self.cfg.grpo.get('ref_model_update_tau', 0.01)
            with torch.no_grad():
                for online_param, target_param in zip(
                    self.core_model.parameters(),
                    self.reference_model.parameters()
                ):
                    target_param.data.copy_(
                        tau * online_param.data + (1.0 - tau) * target_param.data
                    )
    
    def run_epoch(self, optimizer=None):
        epoch_start_time = time.time()
        self.sampling_phase()
        training_batch = self.grpo_core.prepare_training_batch()
        if training_batch is None:
            print("⚠️ [anonymized]，[anonymized]")
            self.epoch += 1
            return
        torch.cuda.empty_cache()
        if self.grpo_core.stat_tracker:
            avg_group_size, num_configs = self.grpo_core.stat_tracker.get_statistics_summary()

        self.training_phase(training_batch, optimizer)
        torch.cuda.empty_cache()
        
        self.grpo_core.clear_sample_buffer()

        if hasattr(self, '_pending_p0_update') and self._pending_p0_update is not None:
             updated_node_dist, updated_edge_dist = self._pending_p0_update
             self.core_model.update_limit_dist(updated_node_dist, updated_edge_dist)
             print(f"✅ Applied Pending p0 Update for Next Epoch.")
             self._pending_p0_update = None

        if hasattr(self, "_pending_node_count_prob_update") and self._pending_node_count_prob_update is not None:
            try:
                self.core_model.update_node_count_dist(self._pending_node_count_prob_update)
                print("✅ Applied Pending node_dist update for Next Epoch.")
            except Exception as e:
                print(f"⚠️ Failed to apply pending node_dist update: {e}")
            self._pending_node_count_prob_update = None
        
        
        
        epoch_time = time.time() - epoch_start_time
        self.epoch += 1
        
        self._maybe_decay_lr(optimizer, training_batch)
        self._maybe_run_gdpo_eval()
    
    def sampling_phase(self):
        self.core_model.eval()
        try:
            if getattr(self.cfg.grpo, "debug_use_model_sampler", False):
                print("🔍 [GRPO Debug] [anonymized] GraphDiscreteFlowModel.sample_batch [anonymized]（[anonymized]GRPO[anonymized]）")
                device = next(self.core_model.parameters()).device
                debug_batch_size = self.group_size * self.concurrent_sampling_groups
                keep_chain = 0
                number_chain_steps = min(self.sample_steps, getattr(self.core_model, "sample_T", self.sample_steps))
                samples, labels = self.core_model.sample_batch(
                    batch_id=0,
                    batch_size=debug_batch_size,
                    keep_chain=keep_chain,
                    number_chain_steps=number_chain_steps,
                    save_final=0,
                    num_nodes=None,
                    save_visualization=False,
                )
                try:
                    rewards = self.reward_function(samples)
                    if isinstance(rewards, torch.Tensor):
                        rewards_tensor = rewards
                    else:
                        rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
                    if rewards_tensor.numel() > 0:
                        mean_r = rewards_tensor.mean().item()
                        std_r = rewards_tensor.std().item()
                        min_r = rewards_tensor.min().item()
                        max_r = rewards_tensor.max().item()
                        print(
                            f"[GRPO Debug] Molecular reward stats: mean={mean_r:.4f}, "
                            f"std={std_r:.4f}, min={min_r:.4f}, max={max_r:.4f}"
                        )
                except Exception as e:
                    print(f"[GRPO Debug] reward evaluation failed: {e}")
                return
            total_rewards = []
            debug_all_rewards = []                                                         
            debug_all_node_counts = []                                             
            debug_molecules = []                                                         
            
            p0_candidate_buffer = []
            node_count_candidate_buffer = []

            from collections import Counter
            epoch_node_counts = Counter()
            epoch_edge_counts = Counter()
            sample_start_time = time.time()
            device = next(self.core_model.parameters()).device
            
            total_groups = self.sample_group_num
            groups_collected = 0
            
            while groups_collected < total_groups:
                active_groups = min(self.concurrent_sampling_groups, total_groups - groups_collected)
                group_ids = []
                for _ in range(active_groups):
                    group_id = f"group_{self._next_group_id}"
                    self._next_group_id += 1
                    group_ids.append(group_id)
                
                batch_size = active_groups * self.group_size
                enable_visualization = bool(self.cfg.grpo.get("enable_visualization", True))
                should_log_vis = enable_visualization and (self.epoch % 5 == 0) and (groups_collected == 0)
                
                graphs, node_mask, current_log_prob, ref_log_prob, trajectory_states, trajectory_preds, trajectory_probs =\
                    self.sample_graphs_with_trajectory_tracking(
                        batch_size=batch_size,
                        seed=int(time.time() * 1000) % (2**32) + groups_collected,
                        total_inference_steps=self.sample_steps,
                        force_same_start=True,
                        group_size_for_same_start=self.group_size,
                        return_probs=should_log_vis
                    )
                
                node_mask_cpu = node_mask.detach().cpu()
                graph_list = self._convert_placeholder_to_graph_list_cpu(graphs, node_mask, as_tensor=True)

                if enable_visualization and len(debug_molecules) < self.group_size * 4:
                    for mol in graph_list:
                        debug_molecules.append(mol)
                        if len(debug_molecules) >= self.group_size * 4:
                            break
                
                for at, et in graph_list:
                    if torch.is_tensor(at):
                        epoch_node_counts.update(at.tolist())
                    if torch.is_tensor(et):
                        epoch_edge_counts.update(et.reshape(-1).tolist())
                
                chunk_rewards = self._compute_rewards_multiprocess_sync(
                    graph_list,
                    timeout=1800,
                    context="sampling",
                )
                if chunk_rewards.numel() == 0:
                    print("⚠️ [anonymized]/[anonymized]，[anonymized]")
                    groups_collected += active_groups
                    continue
                if chunk_rewards.device != device:
                    chunk_rewards = chunk_rewards.to(device, non_blocking=True)

                try:
                    reward_name = getattr(self.reward_function, "name", "") or ""
                    reward_name = str(reward_name).lower()
                except Exception:
                    reward_name = ""
                if reward_name in ("tdc_oracle", "tdc_pmo", "pmo"):
                    try:
                        all_zero = float(chunk_rewards.abs().max().item()) <= 0.0
                    except Exception:
                        all_zero = False
                    if all_zero and getattr(self, "_tdc_zero_reward_debug_epoch", None) != self.epoch:
                        self._tdc_zero_reward_debug_epoch = int(self.epoch)
                        self._debug_tdc_zero_reward_batch(graph_list, max_samples=32)

                rewards_np = chunk_rewards.detach().cpu().numpy()
                node_counts_batch = node_mask_cpu.sum(dim=1).detach().cpu().tolist()

                if self._refine_topk > 0:
                    for g, r in zip(graph_list, rewards_np):
                        if float(r) <= 0.0:
                            continue
                        self._refine_topk_graphs.append((float(r), g))
                    if self._refine_topk_graphs:
                        self._refine_topk_graphs = sorted(
                            self._refine_topk_graphs, key=lambda x: x[0], reverse=True
                        )[: self._refine_topk]

                if getattr(self.cfg.grpo, 'enable_dynamic_p0', False):
                    for (at, et), r in zip(graph_list, rewards_np):
                        if float(r) <= 0.001:
                            continue
                        if torch.is_tensor(at):
                            at_indices = at.argmax(-1) if at.dim() == 2 else at
                            at_list = at_indices.flatten().detach().cpu().tolist()
                        else:
                            at_list = list(at) if isinstance(at, (list, tuple)) else [int(at)]

                        if torch.is_tensor(et):
                            et_indices = et.argmax(-1) if et.dim() == 3 else et
                            et_list = et_indices.reshape(-1).detach().cpu().tolist()
                        else:
                            if isinstance(et, (list, tuple)) and et and isinstance(et[0], (list, tuple)):
                                et_list = [int(x) for row in et for x in row]
                            else:
                                et_list = list(et) if isinstance(et, (list, tuple)) else [int(et)]

                        p0_candidate_buffer.append((float(r), at_list, et_list))

                        epoch_node_counts.update(at_list)
                        epoch_edge_counts.update(et_list)

                if self.enable_dynamic_node_dist:
                    thr = float(self.dynamic_node_dist_reward_threshold)
                    for r, n in zip(rewards_np, node_counts_batch):
                        if float(r) > thr:
                            node_count_candidate_buffer.append((float(r), int(n)))
                
                debug_all_rewards.append(chunk_rewards.detach().cpu())
                debug_all_node_counts.append(torch.as_tensor(node_counts_batch, dtype=torch.long))

            
                dense_rewards_tensor = None
                use_dense = self.cfg.grpo.get("use_dense_reward", True)
                if trajectory_preds and use_dense:
                    steps_per_traj = [len(t) for t in trajectory_preds]
                    if not steps_per_traj:
                        T_len = 0
                    else:
                        T_len = steps_per_traj[0]
                    
                    if T_len > 0:
                        start_step = int(T_len * 0.0)
                        
                        valid_indices = (chunk_rewards > 0).nonzero().squeeze(-1).tolist()
                        
                        dense_rewards_tensor = torch.zeros((batch_size, T_len), device=device, dtype=torch.float32)
                        
                        if len(valid_indices) > 0:
                            valid_pred_graphs = []
                            valid_indices_flat_map = []                                               
                            
                            for idx in valid_indices:
                                traj_graphs = trajectory_preds[idx]
                                for t_idx in range(start_step, T_len):
                                    if t_idx < len(traj_graphs):
                                        valid_pred_graphs.append(traj_graphs[t_idx])
                                        valid_indices_flat_map.append((idx, t_idx))
                            
                            if valid_pred_graphs:
                                flat_valid_rewards = self._compute_rewards_multiprocess_sync(
                                    valid_pred_graphs,
                                    timeout=1800,
                                    context="dense_sampling",
                                )
                                
                                if flat_valid_rewards.numel() > 0:
                                    if flat_valid_rewards.device != device:
                                        flat_valid_rewards = flat_valid_rewards.to(device, non_blocking=True)
                                    
                                    
                                    map_tensor = torch.tensor(valid_indices_flat_map, device=device, dtype=torch.long)
                                    batch_idxs = map_tensor[:, 0]
                                    t_idxs = map_tensor[:, 1]
                                    
                                    count = min(flat_valid_rewards.numel(), map_tensor.shape[0])
                                    dense_rewards_tensor[batch_idxs[:count], t_idxs[:count]] = flat_valid_rewards[:count]
                
                if should_log_vis and trajectory_probs is not None:
                     try:
                        self._log_detailed_visualization(
                            trajectory_states=trajectory_states,
                            trajectory_preds=trajectory_preds,
                            trajectory_probs=trajectory_probs,
                            dense_rewards=dense_rewards_tensor,
                            final_rewards=chunk_rewards,
                            log_dir=f"visualization_outputs/epoch_{self.epoch}_group_{group_ids[0]}",
                            batch_indices=[0, 1, 2, 3]                           
                        )
                     except Exception as e:
                         print(f"⚠️ Visualization logging failed: {e}")

                log_prob_cpu = current_log_prob.detach().cpu()
                
                vectorized_traj = self.grpo_core._vectorize_trajectories(
                    trajectory_states,
                    node_mask_cpu
                )
                train_max_steps = None
                try:
                    train_max_steps = self.cfg.grpo.get("train_max_steps")
                except Exception:
                    train_max_steps = None
                if train_max_steps is None:
                    train_window_steps = int(self.sample_steps)
                else:
                    train_window_steps = max(1, min(int(train_max_steps), int(self.sample_steps)))
                train_start_step = int(self.sample_steps) - int(train_window_steps)
                vectorized_traj["trajectory_t_start"] = torch.full(
                    (batch_size,),
                    train_start_step,
                    dtype=torch.long,
                )
                vectorized_traj["trajectory_total_inference_steps"] = torch.full(
                    (batch_size,),
                    int(self.sample_steps),
                    dtype=torch.long,
                )
                for key in vectorized_traj:
                    vectorized_traj[key] = vectorized_traj[key].cpu()
                
                new_batches = []
                for local_idx, group_id in enumerate(group_ids):
                    start = local_idx * self.group_size
                    end = start + self.group_size
                    
                    batch_rewards = chunk_rewards[start:end]
                    group_log_probs = log_prob_cpu[start:end]
                    group_node_masks = node_mask_cpu[start:end]
                    group_vectorized = {
                        key: value[start:end]
                        for key, value in vectorized_traj.items()
                    }
                    graph_configs = [group_id for _ in range(self.group_size)]
                    
                    batch_dense = None
                    if dense_rewards_tensor is not None:
                        batch_dense = dense_rewards_tensor[start:end].cpu()

                    new_batches.append({
                        "rewards": batch_rewards,
                        "old_log_probs": group_log_probs,
                        "graph_configs": graph_configs,
                        "node_masks": group_node_masks,
                        "trajectory_tensors": group_vectorized,
                        "dense_rewards": batch_dense, 
                    })
                
                for batch in new_batches:
                    batch_rewards = batch["rewards"]
                    if batch_rewards.numel() == 0:
                        continue

                    all_neg_one = torch.all(batch_rewards == -1.0)
                    all_pos_one = torch.all(batch_rewards == 1.0)
                    all_zero = torch.all(
                        torch.isclose(
                            batch_rewards,
                            torch.zeros_like(batch_rewards),
                            atol=1e-8,
                            rtol=0.0,
                        )
                    )
                    if all_neg_one or all_pos_one or all_zero:
                        try:
                            group_id = batch["graph_configs"][0]
                        except Exception:
                            group_id = "unknown"
                        print(
                            f"   ⚠️ [anonymized] {group_id}: [anonymized] {batch_rewards[0].item():.2f}，[anonymized]GRPO[anonymized]"
                        )
                        continue

                    self.grpo_core.collect_trajectory_samples(
                        trajectories=None,
                        rewards=batch_rewards,
                        old_log_probs=batch["old_log_probs"],
                        graph_configs=batch["graph_configs"],
                        node_masks=batch["node_masks"],
                        trajectory_tensors=batch["trajectory_tensors"],
                        dense_rewards=batch.get("dense_rewards"),
                    )

                    if getattr(self, '_last_reward_log_epoch', -1) < self.epoch:
                        dense_R = batch.get("dense_rewards")
                        final_R = batch_rewards
                        if dense_R is not None:
                            try:
                                import os
                                import csv
                                log_dir = "reward_logs"
                                os.makedirs(log_dir, exist_ok=True)
                                
                                B, T_steps = dense_R.shape
                                final_R_cpu = final_R.cpu()
                                final_expanded = final_R_cpu.unsqueeze(1).expand(-1, T_steps)
                                diff_R = final_expanded - dense_R
                                
                                dense_np = dense_R.detach().cpu().numpy()
                                final_np = final_R.detach().cpu().numpy()
                                diff_np = diff_R.detach().cpu().numpy()
                                
                                save_path = f"{log_dir}/rewards_epoch_{self.epoch}.csv"
                                with open(save_path, "w", newline="") as f:
                                    writer = csv.writer(f)
                                    writer.writerow(["Epoch", "GlobalStep", "TrajID", "Step", "DenseReward", "FinalReward", "DiffReward"])
                                    
                                    for b in range(B):
                                        f_val = final_np[b]
                                        for t in range(T_steps):
                                            d_val = dense_np[b, t]
                                            diff_val = diff_np[b, t]
                                            writer.writerow([self.epoch, self.global_step, b, t, f"{d_val:.4f}", f"{f_val:.4f}", f"{diff_val:.4f}"])
                                
                                print(f"   📈 Logged reward trajectories to {save_path}")
                                self._last_reward_log_epoch = self.epoch
                            except Exception as e:
                                print(f"   ⚠️ Failed to log rewards: {e}")
                    total_rewards.append(batch_rewards)
                
                groups_collected += active_groups
                progress = groups_collected / total_groups
                print(
                    f"   ⏳ [anonymized]: {groups_collected}/{total_groups} [anonymized] ({progress:.1%})"
                )
            
            sample_time = time.time() - sample_start_time
            if total_rewards:
                all_rewards_tensor = torch.cat(total_rewards)
                print(f"\n📊 [anonymized]:")
                print(f"   [anonymized]: {len(all_rewards_tensor)}")
                mean_r = all_rewards_tensor.mean().item()
                std_r = all_rewards_tensor.std().item()
                min_r = all_rewards_tensor.min().item()
                max_r = all_rewards_tensor.max().item()
                gt0 = int((all_rewards_tensor > 0).sum().item())
                gt1e3 = int((all_rewards_tensor > 1e-3).sum().item())
                print(f"   [anonymized]: {mean_r:.6f} ± {std_r:.6f}")
                print(f"   [anonymized]/[anonymized]: {min_r:.6f} / {max_r:.6f}")
                print(f"   >0: {gt0} ({gt0/len(all_rewards_tensor):.2%}), >1e-3: {gt1e3} ({gt1e3/len(all_rewards_tensor):.2%})")
                print(f"   [anonymized]: {sample_time:.2f}[anonymized]")

            if debug_all_rewards:
                debug_all = torch.cat(debug_all_rewards)
                mean_r = debug_all.mean().item()
                std_r = debug_all.std().item()
                min_r = debug_all.min().item()
                max_r = debug_all.max().item()
                print(
                    f"[GRPO Debug] Molecular reward stats (all sampled graphs): "
                    f"mean={mean_r:.4f}, std={std_r:.4f}, min={min_r:.4f}, max={max_r:.4f}"
                )

            if p0_candidate_buffer:
                p0_candidate_buffer.sort(key=lambda x: x[0], reverse=True)
                
                top_k = max(1, int(len(p0_candidate_buffer) * 0.1))
                top_samples = p0_candidate_buffer[:top_k]
                
                for _, at, et in top_samples:
                    if isinstance(at, torch.Tensor):
                        at = at.detach().cpu().flatten().tolist()
                    if isinstance(et, torch.Tensor):
                        et = et.detach().cpu().flatten().tolist()
                    if isinstance(at, (list, tuple)):
                        epoch_node_counts.update([int(x) for x in at])
                    if isinstance(et, (list, tuple)):
                        epoch_edge_counts.update([int(x) for x in et])
                         
                avg_top_r = sum(s[0] for s in top_samples) / len(top_samples)
                print(f"🔄 Dynamic p0: Selected top {top_k}/{len(p0_candidate_buffer)} samples (Avg R: {avg_top_r:.4f})")
            
            def print_dist(name, counter):
                total = sum(counter.values())
                if total == 0:
                    print(f"📊 [Epoch {self.epoch}] {name} Dist: (Empty)")
                    return
                sorted_keys = sorted(counter.keys())
                dist_str = ", ".join([f"{k}: {v/total:.2%}" for k, v in counter.items() if k in sorted_keys])
                print(f"📊 [Epoch {self.epoch}] {name} Dist (N={total}): {dist_str}")
            
            print_dist("Node", epoch_node_counts)
            print_dist("Edge", epoch_edge_counts)

            if getattr(self.cfg.grpo, 'enable_dynamic_p0', False) and p0_candidate_buffer:
                self.global_p0_buffer.extend(p0_candidate_buffer)
                
                self.global_p0_buffer.sort(key=lambda x: x[0], reverse=True)
                
                if len(self.global_p0_buffer) > self.p0_buffer_size:
                    self.global_p0_buffer = self.global_p0_buffer[:self.p0_buffer_size]
                    
                print(f"🔄 [Dynamic p0] Global Buffer Size: {len(self.global_p0_buffer)}. "
                      f"Best Reward: {self.global_p0_buffer[0][0]:.4f}, "
                      f"Worst in Buffer: {self.global_p0_buffer[-1][0]:.4f}")
                
                global_node_counts = Counter()
                global_edge_counts = Counter()
                
                for _, at, et in self.global_p0_buffer:
                    if isinstance(at, torch.Tensor): at = at.tolist()
                    if isinstance(et, torch.Tensor): et = et.tolist()
                    
                    if isinstance(at, list):
                        for a in at: global_node_counts[a] += 1
                    else:
                         global_node_counts[at] += 1
                         
                    if isinstance(et, list):
                        for e in et: global_edge_counts[e] += 1
                    else:
                        global_edge_counts[et] += 1

                print(f"🔄 Check Dynamic p0 Update... (Valid atoms in Global Buffer: {sum(global_node_counts.values())})")
                
                curr_node_dist = self.core_model.limit_dist.X.to(device)
                curr_edge_dist = self.core_model.limit_dist.E.to(device)
                
                dx = curr_node_dist.shape[0]
                de = curr_edge_dist.shape[0]
                
                new_node_dist = torch.zeros(dx, device=device)
                for k, v in global_node_counts.items():
                     if k < dx: new_node_dist[k] = v
                    
                new_edge_dist = torch.zeros(de, device=device)
                for k, v in global_edge_counts.items():
                     if k < de: new_edge_dist[k] = v
                    
                if new_node_dist.sum() > 0:
                     new_node_dist = new_node_dist / new_node_dist.sum()
                else: 
                     new_node_dist = curr_node_dist

                if new_edge_dist.sum() > 0:
                     new_edge_dist = new_edge_dist / new_edge_dist.sum()
                else:
                     new_edge_dist = curr_edge_dist
                
                alpha = self.cfg.grpo.get('dynamic_p0_alpha', 0.05)
                
                updated_node_dist = (1 - alpha) * curr_node_dist + alpha * new_node_dist
                updated_edge_dist = (1 - alpha) * curr_edge_dist + alpha * new_edge_dist
                
                self._pending_p0_update = (updated_node_dist, updated_edge_dist)
                
                safe_curr = curr_node_dist + 1e-9
                safe_updated = updated_node_dist + 1e-9
                kl_node = (safe_updated * (safe_updated.log() - safe_curr.log())).sum().item()
                print(f"🔄 Pending Global p0 Update (alpha={alpha}). KL(new||old): {kl_node:.6f}")

            if self.enable_dynamic_node_dist and node_count_candidate_buffer:
                if self.target_node_count is not None:
                    print("⚠️ [Dynamic node-dist] target_node_count [anonymized]，[anonymized] node_dist [anonymized]。")
                elif not hasattr(self.core_model, "node_count_prob"):
                    print("⚠️ [Dynamic node-dist] core_model [anonymized] node_count_prob buffer，[anonymized] node_dist [anonymized]。")
                else:
                    rewards_new = torch.tensor([r for r, _ in node_count_candidate_buffer], dtype=torch.float32)
                    nodes_new = torch.tensor([n for _, n in node_count_candidate_buffer], dtype=torch.long)

                    buf_rewards = self.core_model.node_count_buffer_rewards.detach().cpu()
                    buf_nodes = self.core_model.node_count_buffer_nodes.detach().cpu()
                    filled = int(self.core_model.node_count_buffer_filled.detach().cpu().item())
                    filled = max(0, min(filled, int(buf_rewards.numel())))

                    if filled > 0:
                        rewards_all = torch.cat([buf_rewards[:filled], rewards_new], dim=0)
                        nodes_all = torch.cat([buf_nodes[:filled], nodes_new], dim=0)
                    else:
                        rewards_all = rewards_new
                        nodes_all = nodes_new

                    k = min(int(buf_rewards.numel()), int(rewards_all.numel()))
                    if k <= 0:
                        print("⚠️ [Dynamic node-dist] empty buffer after merge; skip update.")
                        k = 0
                    if k == 0:
                        pass
                    else:
                        top_idx = torch.topk(rewards_all, k=k, largest=True).indices
                        top_rewards = rewards_all[top_idx]
                        top_nodes = nodes_all[top_idx]

                        self.core_model.node_count_buffer_rewards.fill_(-1e9)
                        self.core_model.node_count_buffer_nodes.zero_()
                        self.core_model.node_count_buffer_rewards[:k].copy_(top_rewards.to(self.core_model.node_count_buffer_rewards.device))
                        self.core_model.node_count_buffer_nodes[:k].copy_(top_nodes.to(self.core_model.node_count_buffer_nodes.device))
                        self.core_model.node_count_buffer_filled.fill_(int(k))

                        curr_prob = self.core_model.node_count_prob.detach().cpu().to(dtype=torch.float32)
                        new_prob = torch.zeros_like(curr_prob)

                        min_n = int(self.node_count_min) if self.node_count_min is not None else None
                        max_n = int(self.node_count_max) if self.node_count_max is not None else None
                        for n in top_nodes.tolist():
                            n_int = int(n)
                            if n_int < 0 or n_int >= int(new_prob.numel()):
                                continue
                            if min_n is not None and n_int < min_n:
                                continue
                            if max_n is not None and n_int > max_n:
                                continue
                            new_prob[n_int] += 1.0

                        if float(new_prob.sum().item()) <= 0.0:
                            print("⚠️ [Dynamic node-dist] Top-K node_count histogram is empty after bounds; skip update.")
                        else:
                            new_prob = new_prob / new_prob.sum()
                            alpha = max(0.0, min(1.0, float(self.dynamic_node_dist_alpha)))
                            updated = (1.0 - alpha) * curr_prob + alpha * new_prob
                            if min_n is not None:
                                updated[:min_n] = 0.0
                            if max_n is not None and max_n + 1 < int(updated.numel()):
                                updated[max_n + 1 :] = 0.0
                            if float(updated.sum().item()) > 0.0:
                                updated = updated / updated.sum()
                                self._pending_node_count_prob_update = updated
                                print(
                                    f"🔄 Pending node_dist update (alpha={alpha}): "
                                    f"support=[{min_n},{max_n}], filled={k}, "
                                    f"best_reward={float(top_rewards.max().item()):.4f}"
                                )



            if debug_molecules and bool(self.cfg.grpo.get("enable_visualization", True)):
                try:
                    viz = getattr(self.core_model, "visualization_tools", None)
                    if viz is not None:
                        import os
                        result_path = os.path.join(
                            os.getcwd(),
                            f"graphs/grpo_sampling_debug/epoch{self.epoch}",
                        )
                        num_to_vis = 5
                        print(
                            f"[GRPO Debug] Visualizing {num_to_vis} sampled graphs to {result_path}"
                        )
                        viz.visualize(result_path, debug_molecules, num_to_vis)
                except Exception as e:
                    print(f"[GRPO Debug] visualization failed: {e}")
        finally:
            self.core_model.eval()

    def training_phase(self, training_batch, optimizer):
        if training_batch is None:
            print("⚠️ training_phase[anonymized]batch，[anonymized]")
            return

        if isinstance(training_batch, TrajectoryData) and training_batch.is_empty():
            print("⚠️ training_phase[anonymized]TrajectoryData，[anonymized]")
            return
        self.core_model.eval()
        train_start_time = time.time()
        device = next(self.core_model.parameters()).device
        gradient_accumulation_steps = max(1, self.gradient_accumulation_steps)
        shuffle_batches = self.cfg.grpo.get('shuffle_training_batches', False)
        total_samples = (
            len(training_batch)
            if isinstance(training_batch, TrajectoryData)
            else training_batch["old_log_probs"].shape[0]
        )

        for inner_epoch in range(self.num_inner_epochs):
            print(f"\n   🔄 Inner Epoch {inner_epoch + 1}/{self.num_inner_epochs}")
            num_mini_batches = max(1, math.ceil(total_samples / self.train_batch_size))
            epoch_losses = defaultdict(list)
            optimizer.zero_grad(set_to_none=True)
            accumulation_counter = 0
            inner_epoch_start = time.time()
            train_loop_start = time.time()

            if isinstance(training_batch, TrajectoryData):
                mini_batch_iter = self._iter_cpu_minibatches(
                    training_batch, shuffle=shuffle_batches
                )
            else:
                mini_batch_iter = self._iter_dict_minibatches(training_batch)

            def _train_step(cpu_batch, mini_idx, num_mini_batches):
                nonlocal accumulation_counter
                if isinstance(cpu_batch, TrajectoryData):
                    batch_on_device = cpu_batch.to(device, non_blocking=True)
                else:
                    batch_on_device = {
                        k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v
                        for k, v in cpu_batch.items()
                    }

                use_autocast = device.type == "cuda"
                amp_dtype = torch.bfloat16
                autocast_ctx = autocast(enabled=use_autocast, dtype=amp_dtype) if use_autocast else nullcontext()
                with autocast_ctx:
                    loss_dict = self.grpo_core.compute_losses(
                        self.core_model,
                        batch_on_device,
                        self.reference_model,
                        max_steps=self.cfg.grpo.get("train_max_steps"),
                    )

                scaled_loss = loss_dict["total_loss"] / gradient_accumulation_steps
                scaled_loss.backward()
                accumulation_counter += 1

                for key, value in loss_dict.items():
                    if isinstance(value, torch.Tensor):
                        epoch_losses[key].append(value.detach())

                should_update = (
                    accumulation_counter >= gradient_accumulation_steps
                    or mini_idx == num_mini_batches - 1
                )

                if should_update and optimizer is not None:
                    max_grad_norm = self.cfg.grpo.get('max_grad_norm', 1.0)
                    grad_norm_before_clip = torch.nn.utils.clip_grad_norm_(
                        self.core_model.parameters(),
                        max_grad_norm
                    ).item()

                    optimizer.step()
                    optimizer.zero_grad(set_to_none=True)

                    grad_norm_after_clip = min(grad_norm_before_clip, max_grad_norm)
                    accumulation_counter = 0
                    self.global_step += 1

                    if swanlab is not None and swanlab.run is not None:
                        self._log_training_metrics_to_swanlab(
                            epoch_losses=epoch_losses,
                            grad_norm_before_clip=grad_norm_before_clip,
                            grad_norm_after_clip=grad_norm_after_clip,
                            loss_dict=loss_dict,
                            training_batch=training_batch,
                            optimizer=optimizer
                        )

                    if self.global_step % 10 == 0:
                        recent_losses = {
                            k: torch.stack(v[-10:]).mean().item() if v else 0
                            for k, v in epoch_losses.items()
                        }
                        ratio_mean = recent_losses.get("ratio_mean", 0)
                        clipfrac = recent_losses.get("clipfrac", 0)
                        print(
                            f"      Step {self.global_step}: Loss={recent_losses.get('total_loss', 0):.4f}, "
                            f"KL={recent_losses.get('kl_loss', 0):.4f}, "
                            f"Entropy={recent_losses.get('policy_entropy', 0):.4f}, "
                            f"Ratio={ratio_mean:.4f}, Clip={clipfrac:.4f}, "
                            f"Grad={grad_norm_before_clip:.4f}→{grad_norm_after_clip:.4f}"
                        )
                    
                    self._maybe_run_evaluation()

            for mini_idx, cpu_batch in enumerate(mini_batch_iter):
                _train_step(cpu_batch, mini_idx, num_mini_batches)

            train_loop_time = time.time() - train_loop_start
            avg_losses = {
                k: torch.stack(v).mean().item() for k, v in epoch_losses.items() if v
            }
            inner_epoch_time = time.time() - inner_epoch_start

            print(f"   ✅ [anonymized]，[anonymized]: {train_loop_time:.2f}[anonymized]")
            print(f"\n   ✅ Inner Epoch {inner_epoch + 1} [anonymized]:")
            print(f"      Loss={avg_losses.get('total_loss', 0):.4f}")
            print(f"      Policy Loss={avg_losses.get('policy_loss', 0):.4f}")
            print(f"      Policy Entropy={avg_losses.get('policy_entropy', 0):.4f}")
            print(f"      KL={avg_losses.get('kl_loss', 0):.4f}")
            print(f"      Clipfrac={avg_losses.get('clipfrac', 0):.4f}")
            print(f"      [anonymized]: {inner_epoch_time:.2f}[anonymized]")
            print(f"      └─ [anonymized]: {train_loop_time:.2f}[anonymized]")

        train_time = time.time() - train_start_time
        print(f"\n✅ [anonymized]，[anonymized]: {train_time:.2f}[anonymized]")

    def _iter_cpu_minibatches(self, batch: TrajectoryData, shuffle: bool = False):
        dataset_size = len(batch)
        if dataset_size == 0:
            return

        indices = torch.arange(dataset_size)
        if shuffle:
            perm = torch.randperm(dataset_size)
            indices = indices[perm]

        for start in range(0, dataset_size, self.train_batch_size):
            chunk = indices[start:start + self.train_batch_size]
            if chunk.numel() == 0:
                continue
            yield batch[chunk]

    def _iter_dict_minibatches(self, training_batch: Dict[str, torch.Tensor]):
        total_samples = training_batch["old_log_probs"].shape[0]
        for i in range(0, total_samples, self.train_batch_size):
            end = min(i + self.train_batch_size, total_samples)
            yield {
                k: v[i:end] if isinstance(v, torch.Tensor) else v
                for k, v in training_batch.items()
            }

    @staticmethod
    def _normalize_smiles_list(val: Any) -> List[str]:
        if val is None:
            return []
        if isinstance(val, (list, tuple, ListConfig)):
            out = [str(x).strip() for x in val if str(x).strip()]
            return out
        s = str(val).strip()
        if not s:
            return []
        if "," in s:
            return [item.strip() for item in s.split(",") if item.strip()]
        return [s]

    @staticmethod
    def _get_cfg_value(cfg: Dict, key: str) -> Optional[Any]:
        cur: Any = cfg
        for part in str(key).split("."):
            if part == "":
                continue
            try:
                cur = cur.get(part)
            except Exception:
                cur = getattr(cur, part, None)
            if cur is None:
                return None
        return cur

    def _smiles_to_graph(self, smiles: str) -> Optional["Graph"]:
        try:
            mol = Chem.MolFromSmiles(smiles, sanitize=False)
            if mol is None:
                return None
            mol = Chem.RemoveHs(mol)
            try:
                Chem.Kekulize(mol, clearAromaticFlags=True)
            except Exception:
                return None
            dataset_info = getattr(self.model, "dataset_info", None)
            atom_encoder = getattr(dataset_info, "atom_encoder", None) if dataset_info is not None else None
            if not isinstance(atom_encoder, dict):
                return None
            x_idx: List[int] = []
            for atom in mol.GetAtoms():
                sym = atom.GetSymbol()
                if sym not in atom_encoder:
                    return None
                x_idx.append(int(atom_encoder[sym]))
            n = len(x_idx)
            if n <= 0:
                return None
            e_idx = torch.zeros((n, n), dtype=torch.long)
            bt = Chem.rdchem.BondType
            for bond in mol.GetBonds():
                a = int(bond.GetBeginAtomIdx())
                b = int(bond.GetEndAtomIdx())
                t = bond.GetBondType()
                if t == bt.SINGLE:
                    v = 1
                elif t == bt.DOUBLE:
                    v = 2
                elif t == bt.TRIPLE:
                    v = 3
                elif t == bt.AROMATIC and int(self.model.input_dims.get("E", 0) or 0) > 4:
                    v = 4
                else:
                    return None
                e_idx[a, b] = v
                e_idx[b, a] = v
            return torch.tensor(x_idx, dtype=torch.long), e_idx
        except Exception:
            return None

    def _build_refine_seed_graphs(self, smiles_list: List[str]) -> List[Tuple[str, "Graph"]]:
        graphs: List[Tuple[str, "Graph"]] = []
        for smi in smiles_list:
            g = self._smiles_to_graph(smi)
            if g is None:
                print(f"⚠️ [GRPO] refine seed SMILES cannot be converted to graph: {smi}", flush=True)
                continue
            graphs.append((smi, g))
        return graphs

    def _graph_to_smiles(self, graph: Graph) -> Optional[str]:
        try:
            from analysis.rdkit_functions import build_molecule, mol2smiles
        except Exception:
            return None

        dataset_info = getattr(self.model, "dataset_info", None)
        atom_decoder = getattr(dataset_info, "atom_decoder", None) if dataset_info is not None else None
        if not atom_decoder:
            return None

        X, E = graph
        if torch.is_tensor(X) and X.dim() > 1:
            X = torch.argmax(X, dim=-1)
        if torch.is_tensor(E) and E.dim() > 2:
            E = torch.argmax(E, dim=-1)
        if not torch.is_tensor(X):
            X = torch.tensor(X, dtype=torch.long)
        if not torch.is_tensor(E):
            E = torch.tensor(E, dtype=torch.long)

        try:
            mol = build_molecule(X, E, atom_decoder)
        except Exception:
            return None
        if mol is None:
            return None
        try:
            smi = mol2smiles(mol)
        except Exception:
            smi = None
        if not smi:
            return None
        try:
            mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
            largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
            smi = mol2smiles(largest_mol)
        except Exception:
            return None
        return smi or None

    def _compute_batch_loss_parallel(self, batch: Dict) -> Dict[str, torch.Tensor]:
        return self.grpo_core.compute_losses(
            model=self.core_model,
            batch_data=batch,
            reference_model=self.reference_model,
            max_steps=None,
        )
    
    def sample_graphs_with_trajectory_tracking(
        self,
        batch_size: int,
        seed: Optional[int] = None,
        total_inference_steps: int = 50,
        force_same_start: bool = True,
        group_size_for_same_start: Optional[int] = None,
        return_probs: bool = False,
    ) -> Tuple:
        import random
        from flow_matching import flow_matching_utils
        
        try:
            if seed is not None:
                torch.manual_seed(seed)
                np.random.seed(seed % (2**31))
                random.seed(seed)
            
            device = next(self.core_model.parameters()).device
            use_sampling_autocast = True
            try:
                use_sampling_autocast = (
                    device.type == "cuda"
                    and bool(self.cfg.grpo.get("sampling_autocast", False))
                )
            except Exception:
                use_sampling_autocast = False
            sampling_autocast_ctx = (
                autocast(enabled=True, dtype=torch.bfloat16)
                if use_sampling_autocast
                else nullcontext()
            )
            
            sampling_start_step = 0
            refine_noise_fraction = 0.0
            if force_same_start and group_size_for_same_start is not None:
                if batch_size % group_size_for_same_start != 0:
                    raise ValueError("batch_size must be divisible by group_size_for_same_start when force_same_start=True")
                num_groups = batch_size // group_size_for_same_start
                base_entries: List[Tuple[str, "Graph"]] = []
                if self._refine_seed_graphs:
                    base_entries.extend([(f"seed:{smi}", g) for smi, g in self._refine_seed_graphs])
                if self._refine_topk_graphs:
                    for idx, (score, g) in enumerate(self._refine_topk_graphs, start=1):
                        base_entries.append((f"topk#{idx}:score={score:.4f}", g))

                if base_entries:
                    try:
                        refine_noise_fraction = float(
                            self.cfg.grpo.get("noise_fraction_early", 0.0) or 0.0
                        )
                    except Exception:
                        refine_noise_fraction = 0.0
                    refine_noise_fraction = max(0.0, min(1.0, refine_noise_fraction))
                    if refine_noise_fraction > 0.0:
                        steps = int(total_inference_steps)
                        sampling_start_step = int(
                            torch.round(torch.tensor((1.0 - refine_noise_fraction) * (steps + 1))).item()
                        )
                        sampling_start_step = max(0, min(steps - 1, sampling_start_step))

                group_nodes = torch.zeros((num_groups,), device=device, dtype=torch.long)
                for g in range(num_groups):
                    if base_entries:
                        _, base_g = base_entries[g % len(base_entries)]
                        x_idx, _ = base_g
                        group_nodes[g] = int(x_idx.shape[0])
                    elif self.target_node_count is not None:
                        group_nodes[g] = int(self.target_node_count)
                    else:
                        node_count = self.core_model.node_dist.sample_n(1, device=device).long()[0]
                        if self.node_count_min is not None:
                            node_count = torch.clamp(node_count, min=int(self.node_count_min))
                        if self.node_count_max is not None:
                            node_count = torch.clamp(node_count, max=int(self.node_count_max))
                        group_nodes[g] = int(node_count)

                n_max = int(group_nodes.max().item())
                arange = torch.arange(n_max, device=device).unsqueeze(0).expand(num_groups, -1)
                base_masks = arange < group_nodes.unsqueeze(1)
                node_mask = base_masks.repeat_interleave(group_size_for_same_start, dim=0)

                X_blocks = []
                E_blocks = []
                y_blocks = [] if self.core_model.conditional else None
                noise_dist = self.core_model.noise_dist.get_limit_dist()
                x_classes = int(noise_dist.X.size(-1))
                e_classes = int(noise_dist.E.size(-1))
                for g in range(num_groups):
                    single_mask = base_masks[g:g+1]
                    if base_entries:
                        _, base_g = base_entries[g % len(base_entries)]
                        x_raw, e_raw = base_g
                        x_idx = x_raw.argmax(dim=-1) if x_raw.dim() == 2 else x_raw.long()
                        e_idx = e_raw.argmax(dim=-1) if e_raw.dim() == 3 else e_raw.long()
                        n_nodes = int(x_idx.shape[0])
                        pad_len = n_max - n_nodes
                        if pad_len > 0:
                            x_idx = torch.cat([x_idx, torch.full((pad_len,), -1, dtype=torch.long)], dim=0)
                            e_pad = torch.full((n_max, n_max), -1, dtype=torch.long)
                            e_pad[:n_nodes, :n_nodes] = e_idx
                            e_idx = e_pad
                        x_idx = x_idx.unsqueeze(0).to(device)
                        e_idx = e_idx.unsqueeze(0).to(device)
                        node_mask_1 = single_mask.to(device)
                        edge_mask_1 = node_mask_1.unsqueeze(1) * node_mask_1.unsqueeze(2)
                        X_single = F.one_hot(x_idx.clamp(min=0), num_classes=x_classes).float()
                        E_single = F.one_hot(e_idx.clamp(min=0), num_classes=e_classes).float()
                        X_single = X_single * node_mask_1.unsqueeze(-1)
                        E_single = E_single * edge_mask_1.unsqueeze(-1)
                        if self.core_model.conditional:
                            y_single = torch.zeros(1, 1, device=device)
                        else:
                            y_single = torch.zeros(1, 0, device=device)
                        if refine_noise_fraction > 0.0:
                            steps = int(total_inference_steps)
                            t_raw = torch.full(
                                (1, 1),
                                float(sampling_start_step) / float(steps + 1),
                                device=device,
                                dtype=X_single.dtype,
                            )
                            t_noisy = self.core_model.time_distorter.sample_ft(
                                t_raw, self.cfg.sample.time_distortion
                            )
                            noisy = self.core_model.apply_noise(
                                X_single, E_single, y_single, single_mask, t=t_noisy
                            )
                            X_single = noisy["X_t"]
                            E_single = noisy["E_t"]
                            y_single = noisy["y_t"]
                    else:
                        z_single = flow_matching_utils.sample_discrete_feature_noise(
                            limit_dist=noise_dist,
                            node_mask=single_mask
                        )
                        X_single = z_single.X
                        E_single = z_single.E
                        if self.core_model.conditional:
                            y_single = torch.zeros(1, 1, device=device)
                        else:
                            y_single = torch.zeros(1, 0, device=device)

                    X_blocks.append(X_single.repeat(group_size_for_same_start, 1, 1))
                    E_blocks.append(E_single.repeat(group_size_for_same_start, 1, 1, 1))
                    if self.core_model.conditional:
                        y_blocks.append(y_single.repeat(group_size_for_same_start, 1))

                X = torch.cat(X_blocks, dim=0)
                E = torch.cat(E_blocks, dim=0)
                if self.core_model.conditional:
                    y = torch.cat(y_blocks, dim=0)
                else:
                    y = torch.zeros(batch_size, 0, device=device)


            else:
                if self.target_node_count is not None:
                    n_nodes = torch.full(
                        (batch_size,),
                        int(self.target_node_count),
                        device=device,
                        dtype=torch.long,
                    )
                else:
                    n_nodes = self.core_model.node_dist.sample_n(
                        batch_size, device=device
                    ).long()
                    if self.node_count_min is not None:
                        n_nodes = torch.clamp(n_nodes, min=int(self.node_count_min))
                    if self.node_count_max is not None:
                        n_nodes = torch.clamp(n_nodes, max=int(self.node_count_max))
                n_max = int(n_nodes.max().item())
                arange = torch.arange(n_max, device=device).unsqueeze(0).expand(batch_size, -1)
                node_mask = arange < n_nodes.unsqueeze(1)
                z_T = flow_matching_utils.sample_discrete_feature_noise(
                    limit_dist=self.core_model.noise_dist.get_limit_dist(),
                    node_mask=node_mask
                )
                if self.core_model.conditional:
                    z_T.y = torch.zeros(batch_size, 1).to(device)
                X, E, y = z_T.X, z_T.E, z_T.y
                if not self.core_model.conditional and y is None:
                    y = torch.zeros(batch_size, 0, device=device)

            trajectory_states = [[] for _ in range(batch_size)]
            trajectory_log_probs = []
            
            use_dense_reward = False
            try:
                use_dense_reward = bool(self.cfg.grpo.get("use_dense_reward", False))
            except Exception:
                use_dense_reward = False

            vis_num_samples = 1
            try:
                vis_num_samples = int(self.cfg.grpo.get("vis_num_samples", 1))
            except Exception:
                vis_num_samples = 1
            vis_num_samples = max(1, min(int(batch_size), int(vis_num_samples)))

            trajectory_preds = None
            if use_dense_reward:
                trajectory_preds = [[] for _ in range(batch_size)]
            elif return_probs:
                trajectory_preds = [[] for _ in range(vis_num_samples)]

            trajectory_probs = [[] for _ in range(vis_num_samples)] if return_probs else None
            
            train_max_steps = None
            try:
                train_max_steps = self.cfg.grpo.get("train_max_steps")
            except Exception:
                train_max_steps = None

            if train_max_steps is None:
                train_window_steps = int(total_inference_steps) - int(sampling_start_step)
            else:
                train_window_steps = max(
                    1,
                    min(
                        int(train_max_steps),
                        int(total_inference_steps) - int(sampling_start_step),
                    ),
                )

            train_start_step = int(total_inference_steps) - int(train_window_steps)
            if train_start_step < sampling_start_step:
                train_start_step = int(sampling_start_step)
            
            for t_int in range(int(sampling_start_step), total_inference_steps):
                t_array = t_int * torch.ones((batch_size, 1)).type_as(X)
                t_norm = t_array / (total_inference_steps + 1)
                s_array = t_array + 1
                s_norm = s_array / (total_inference_steps + 1)
                
                t_norm = self.core_model.time_distorter.sample_ft(
                    t_norm, self.cfg.sample.time_distortion
                )
                s_norm = self.core_model.time_distorter.sample_ft(
                    s_norm, self.cfg.sample.time_distortion
                )
                
                if t_int >= train_start_step:
                    for i in range(batch_size):
                        sample_state = utils.PlaceHolder(
                            X=X[i:i+1].clone(),
                            E=E[i:i+1].clone(),
                            y=y[i:i+1] if y is not None else None,
                        )
                        trajectory_states[i].append(sample_state)

                with torch.inference_mode(), sampling_autocast_ctx:
                    noisy_data = {
                        "X_t": X, "E_t": E, "y_t": y,
                        "t": t_norm, "node_mask": node_mask
                    }

                    extra_data = self.core_model.compute_extra_data(noisy_data)
                    pred = self.core_model.forward(noisy_data, extra_data, node_mask)

                    sampling_temperature = self.cfg.grpo.get('sampling_temperature', 1.0)

                    if abs(sampling_temperature - 1.0) > 1e-5:
                        pred_X = F.softmax(pred.X / sampling_temperature, dim=-1)
                        pred_E = F.softmax(pred.E / sampling_temperature, dim=-1)
                    else:
                        pred_X = F.softmax(pred.X, dim=-1)
                        pred_E = F.softmax(pred.E, dim=-1)

                    if return_probs and t_int >= train_start_step:
                        for i in range(vis_num_samples):
                            n_nodes_i = int(node_mask[i].sum().item())
                            if n_nodes_i > 0:
                                trajectory_probs[i].append(
                                    (
                                        pred_X[i, :n_nodes_i].detach().cpu(),
                                        pred_E[i, :n_nodes_i, :n_nodes_i].detach().cpu(),
                                    )
                                )
                            else:
                                trajectory_probs[i].append(None)

                    if trajectory_preds is not None and t_int >= train_start_step:
                        sampled_indices_X = torch.distributions.Categorical(probs=pred_X).sample()
                        pred_X_discrete = sampled_indices_X

                        sampled_indices_E = torch.distributions.Categorical(probs=pred_E).sample()
                        pred_E_discrete = sampled_indices_E                 

                        pred_X_discrete = pred_X_discrete * node_mask

                        edge_mask = node_mask.unsqueeze(1) * node_mask.unsqueeze(2)
                        diag_indices = torch.arange(node_mask.size(1), device=node_mask.device)
                        edge_mask[:, diag_indices, diag_indices] = 0

                        pred_E_discrete = torch.triu(pred_E_discrete, diagonal=1)
                        pred_E_discrete = pred_E_discrete + pred_E_discrete.transpose(1, 2)
                        pred_E_discrete = pred_E_discrete * edge_mask

                        if use_dense_reward:
                            target_samples = batch_size
                        else:
                            target_samples = vis_num_samples

                        for i in range(target_samples):
                            n_nodes_i = int(node_mask[i].sum().item())
                            if n_nodes_i <= 0:
                                trajectory_preds[i].append(
                                    (
                                        torch.empty(0, dtype=torch.long),
                                        torch.empty(0, 0, dtype=torch.long),
                                    )
                                )
                                continue
                            trajectory_preds[i].append(
                                (
                                    pred_X_discrete[i, :n_nodes_i].detach().cpu(),
                                    pred_E_discrete[i, :n_nodes_i, :n_nodes_i].detach().cpu(),
                                )
                            )

                    dt = (s_norm - t_norm)[:, 0]
                    rate_designer = self.core_model.get_rate_matrix_designer() if hasattr(self.core_model, "get_rate_matrix_designer") else self.core_model.rate_matrix_designer
                    R_t_X, R_t_E = rate_designer.compute_graph_rate_matrix(
                        t_norm, node_mask, (X, E), (pred_X, pred_E)
                    )

                    limit_x = self.core_model.limit_dist.X
                    limit_e = self.core_model.limit_dist.E
                    if self.use_grpo_step_probs_for_sampling:
                        prob_X, prob_E = self.core_model.compute_step_probs_grpo(
                            R_t_X, R_t_E, X, E, dt, limit_x, limit_e
                        )
                    else:
                        prob_X, prob_E = self.core_model.compute_step_probs(
                            R_t_X, R_t_E, X, E, dt, limit_x, limit_e
                        )

                    sampled_s = flow_matching_utils.sample_discrete_features(
                        prob_X, prob_E, node_mask=node_mask
                    )

                    X_next = F.one_hot(sampled_s.X, num_classes=X.size(-1)).float()
                    E_next = F.one_hot(sampled_s.E, num_classes=E.size(-1)).float()
                    X_next = X_next * node_mask.unsqueeze(-1)
                    edge_mask = node_mask.unsqueeze(1) * node_mask.unsqueeze(2)
                    E_next = E_next * edge_mask.unsqueeze(-1)
                    
                    X_indices = sampled_s.X                   
                    E_indices = sampled_s.E                      
                    
                    X_log_probs = torch.log(prob_X.clamp(min=1e-8))
                    X_step_log_prob = torch.gather(X_log_probs, dim=-1, 
                                                    index=X_indices.unsqueeze(-1)).squeeze(-1)
                    X_masked = (X_step_log_prob * node_mask).sum(dim=-1)
                    
                    E_log_probs = torch.log(prob_E.clamp(min=1e-8))
                    E_step_log_prob = torch.gather(E_log_probs, dim=-1,
                                                    index=E_indices.unsqueeze(-1)).squeeze(-1)
                    edge_mask = node_mask.unsqueeze(1) * node_mask.unsqueeze(2)
                    diag_indices = torch.arange(node_mask.size(1), device=node_mask.device)
                    edge_mask[:, diag_indices, diag_indices] = 0
                    E_masked = (E_step_log_prob * edge_mask).sum(dim=[-2, -1]) * 0.5
                    
                    step_log_prob = X_masked + E_masked                 
                    if t_int >= train_start_step:
                        trajectory_log_probs.append(step_log_prob)
                    
                    
                    X = X_next
                    E = E_next
                    
            
            for i in range(batch_size):
                sample_state = utils.PlaceHolder(
                    X=X[i:i+1].clone(),
                    E=E[i:i+1].clone(),
                    y=y[i:i+1] if y is not None else None,
                )
                trajectory_states[i].append(sample_state)

            if len(trajectory_states) > 0 and len(trajectory_states[0]) > 0:
                actual_traj_len = len(trajectory_states[0])
                expected_traj_len = int(train_window_steps) + 1
                if actual_traj_len != expected_traj_len:
                    print(f"⚠️ [anonymized]：[anonymized]。[anonymized]: {actual_traj_len}, [anonymized]: {expected_traj_len}")
            
            X, E, y = self.core_model.noise_dist.ignore_virtual_classes(X, E, y)
            clean_graphs = utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True)
            
            if trajectory_log_probs:
                current_log_probs = torch.stack(trajectory_log_probs, dim=1)                           
            else:
                current_log_probs = torch.zeros(batch_size, 0, device=device)
            
            ref_log_probs = None

            return clean_graphs, node_mask, current_log_probs, ref_log_probs, trajectory_states, trajectory_preds, trajectory_probs
            
        finally:
            pass

    @torch.no_grad()
    def refine_candidate_via_denoising(
        self,
        init_X: torch.Tensor,
        init_E: torch.Tensor,
        num_variations: int,
        noise_fraction: "Union[float, torch.Tensor]",
        total_inference_steps: Optional[int] = None,
        seed: Optional[int] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        import random
        from flow_matching import flow_matching_utils

        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed % (2**31))
            random.seed(seed)

        device = next(self.core_model.parameters()).device
        steps = int(total_inference_steps if total_inference_steps is not None else self.sample_steps)
        if steps <= 0:
            raise ValueError(f"total_inference_steps must be > 0, got {steps}")
        if num_variations <= 0:
            raise ValueError(f"num_variations must be > 0, got {num_variations}")

        if torch.is_tensor(noise_fraction):
            noise_frac = noise_fraction.detach().to(device=device, dtype=torch.float32)
            if noise_frac.dim() == 2 and noise_frac.shape[1] == 1:
                noise_frac = noise_frac[:, 0]
            if noise_frac.dim() != 1 or noise_frac.shape[0] != num_variations:
                raise ValueError(
                    f"noise_fraction tensor must have shape ({num_variations},) or ({num_variations},1), "
                    f"got {tuple(noise_frac.shape)}"
                )
            noise_frac = noise_frac.clamp(0.0, 1.0)
        else:
            nf = max(0.0, min(1.0, float(noise_fraction)))
            noise_frac = torch.full((num_variations,), nf, device=device, dtype=torch.float32)

        x = init_X.to(device)
        e = init_E.to(device)

        if x.dim() == 1:
            node_mask_1 = (x >= 0)
            x_idx_1 = x.long()
        elif x.dim() == 2:
            node_mask_1 = (x.sum(dim=-1) > 0)
            x_idx_1 = x.argmax(dim=-1).long()
        else:
            raise ValueError(f"Unsupported init_X shape: {tuple(x.shape)}")

        if e.dim() == 2:
            e_idx_1 = e.long()
        elif e.dim() == 3:
            e_idx_1 = e.argmax(dim=-1).long()
        else:
            raise ValueError(f"Unsupported init_E shape: {tuple(e.shape)}")

        input_dims = getattr(self.core_model, "input_dims", None)
        if not isinstance(input_dims, dict) or "X" not in input_dims or "E" not in input_dims:
            raise RuntimeError("core_model.input_dims is missing; cannot one-hot encode candidate graph")

        x_idx = x_idx_1.unsqueeze(0).expand(num_variations, -1).contiguous()
        e_idx = e_idx_1.unsqueeze(0).expand(num_variations, -1, -1).contiguous()
        node_mask = node_mask_1.unsqueeze(0).expand(num_variations, -1).contiguous().bool()
        x_idx = x_idx.masked_fill(~node_mask, -1)
        edge_mask_idx = node_mask.unsqueeze(1) * node_mask.unsqueeze(2)
        e_idx = e_idx.masked_fill(~edge_mask_idx, -1)

        X = F.one_hot(x_idx.clamp(min=0), num_classes=int(input_dims["X"])).float()
        E = F.one_hot(e_idx.clamp(min=0), num_classes=int(input_dims["E"])).float()

        X = X * node_mask.unsqueeze(-1)
        E = E * edge_mask_idx.unsqueeze(-1)

        if self.core_model.conditional:
            y = torch.zeros(num_variations, 1, device=device)
        else:
            y = torch.zeros(num_variations, 0, device=device)

        if float(noise_frac.max().item()) <= 0.0:
            return x_idx, e_idx

        t_start = 1.0 - noise_frac        
        start_steps = torch.round(t_start * float(steps + 1)).to(dtype=torch.long)        
        start_steps = start_steps.clamp(0, steps - 1)
        min_start_step = int(start_steps.min().item())

        t_raw = (start_steps.to(dtype=torch.float32) / float(steps + 1)).view(num_variations, 1)
        t_noisy = self.core_model.time_distorter.sample_ft(t_raw, self.cfg.sample.time_distortion)

        noisy_data = self.core_model.apply_noise(X, E, y, node_mask, t=t_noisy)
        X = noisy_data["X_t"]
        E = noisy_data["E_t"]
        y = noisy_data["y_t"]

        sampling_temperature = float(self.cfg.grpo.get("sampling_temperature", 1.0))

        for t_int in range(min_start_step, steps):
            t_array = torch.full((num_variations, 1), float(t_int), device=device, dtype=X.dtype)
            t_norm = t_array / (steps + 1)
            s_array = t_array + 1
            s_norm = s_array / (steps + 1)

            t_norm = self.core_model.time_distorter.sample_ft(t_norm, self.cfg.sample.time_distortion)
            s_norm = self.core_model.time_distorter.sample_ft(s_norm, self.cfg.sample.time_distortion)

            noisy_data = {"X_t": X, "E_t": E, "y_t": y, "t": t_norm, "node_mask": node_mask}
            extra_data = self.core_model.compute_extra_data(noisy_data)
            pred = self.core_model.forward(noisy_data, extra_data, node_mask)

            if abs(sampling_temperature - 1.0) > 1e-5:
                pred_X = F.softmax(pred.X / sampling_temperature, dim=-1)
                pred_E = F.softmax(pred.E / sampling_temperature, dim=-1)
            else:
                pred_X = F.softmax(pred.X, dim=-1)
                pred_E = F.softmax(pred.E, dim=-1)

            dt = (s_norm - t_norm)[:, 0]
            rate_designer = (
                self.core_model.get_rate_matrix_designer()
                if hasattr(self.core_model, "get_rate_matrix_designer")
                else self.core_model.rate_matrix_designer
            )
            R_t_X, R_t_E = rate_designer.compute_graph_rate_matrix(
                t_norm, node_mask, (X, E), (pred_X, pred_E)
            )

            limit_x = self.core_model.limit_dist.X
            limit_e = self.core_model.limit_dist.E
            if self.use_grpo_step_probs_for_sampling:
                prob_X, prob_E = self.core_model.compute_step_probs_grpo(
                    R_t_X, R_t_E, X, E, dt, limit_x, limit_e
                )
            else:
                prob_X, prob_E = self.core_model.compute_step_probs(
                    R_t_X, R_t_E, X, E, dt, limit_x, limit_e
                )

            sampled_s = flow_matching_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask)

            X_next = F.one_hot(sampled_s.X, num_classes=X.size(-1)).float()
            E_next = F.one_hot(sampled_s.E, num_classes=E.size(-1)).float()

            X_next = X_next * node_mask.unsqueeze(-1)
            edge_mask = node_mask.unsqueeze(1) * node_mask.unsqueeze(2)
            E_next = E_next * edge_mask.unsqueeze(-1)

            active = (t_int >= start_steps).view(num_variations, 1, 1)
            active_E = active.view(num_variations, 1, 1, 1)
            X = torch.where(active, X_next, X)
            E = torch.where(active_E, E_next, E)

        X, E, _ = self.core_model.noise_dist.ignore_virtual_classes(X, E, y)
        clean_graphs = utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True)
        return clean_graphs.X, clean_graphs.E
    
    def _compute_rewards_multiprocess_sync(
        self,
        graph_list: List,
        timeout: Optional[float] = None,
        context: str = "reward",
    ) -> torch.Tensor:
        num_graphs = len(graph_list)
        if num_graphs == 0:
            return torch.tensor([], dtype=torch.float32)

        if self.reward_pool is None or self.num_reward_workers <= 0:
            with torch.no_grad():
                rewards = self.reward_function(graph_list)
                if not isinstance(rewards, torch.Tensor):
                    rewards = torch.tensor(rewards, dtype=torch.float32)
                return rewards
        
        reward_func_type = getattr(self.reward_function, 'name', 'default')
        reward_func_type_lower = str(reward_func_type).lower()
        device_str = 'cpu'                               
        base_reward_kwargs = getattr(self, "reward_kwargs", None) or {}
        reward_kwargs = dict(base_reward_kwargs)

        py_graphs = []
        is_optimized = False
        
        if graph_list:
            first_item = graph_list[0]
            first_X = first_item[0] if isinstance(first_item, (tuple, list)) else first_item
            
            try:
                if torch.is_tensor(first_X):
                    all_X = [item[0] for item in graph_list]
                    all_E = [item[1] for item in graph_list]
                    
                    batch_X = torch.stack(all_X).cpu()
                    batch_E = torch.stack(all_E).cpu()
                    
                    py_graphs = list(zip(batch_X, batch_E))
                    is_optimized = True
            except Exception as batch_err:
                is_optimized = False
        
        if not is_optimized:
            for item in graph_list:
                if isinstance(item, (tuple, list)) and len(item) == 2:
                    atom_types, edge_types = item
                else:
                    atom_types, edge_types = item

                if torch.is_tensor(atom_types):
                    atom_arr = atom_types.detach().cpu()
                else:
                    atom_arr = torch.tensor(atom_types)                         

                if torch.is_tensor(edge_types):
                    edge_arr = edge_types.detach().cpu()
                else:
                    edge_arr = torch.tensor(edge_types)

                py_graphs.append((atom_arr, edge_arr))

        if reward_func_type_lower in ("molecular_validity", "guacamol_reward", "gracamol_reward", "gracamol"):
            try:
                target_node_dist = reward_kwargs.get("target_node_dist")
                target_edge_dist = reward_kwargs.get("target_edge_dist")
                if target_node_dist is None and hasattr(self.reward_function, "target_node_dist"):
                    target_node_dist = getattr(self.reward_function, "target_node_dist")
                if target_edge_dist is None and hasattr(self.reward_function, "target_edge_dist"):
                    target_edge_dist = getattr(self.reward_function, "target_edge_dist")

                scale_factor = reward_kwargs.get("scale_factor") or reward_kwargs.get("dist_scale_factor")
                clip_range = reward_kwargs.get("clip_range") or reward_kwargs.get("dist_clip_range")
                if scale_factor is None and hasattr(self.reward_function, "scale_factor"):
                    scale_factor = getattr(self.reward_function, "scale_factor")
                if clip_range is None and hasattr(self.reward_function, "clip_range"):
                    clip_range = getattr(self.reward_function, "clip_range")

                node_weights, edge_weights = MolecularValidityReward.compute_distribution_weights(
                    py_graphs,
                    target_node_dist=target_node_dist,
                    target_edge_dist=target_edge_dist,
                    scale_factor=scale_factor,
                    clip_range=clip_range,
                )
                reward_kwargs["precomputed_node_weights"] = node_weights
                reward_kwargs["precomputed_edge_weights"] = edge_weights

                if target_node_dist is not None and "target_node_dist" not in reward_kwargs:
                    reward_kwargs["target_node_dist"] = target_node_dist
                if target_edge_dist is not None and "target_edge_dist" not in reward_kwargs:
                    reward_kwargs["target_edge_dist"] = target_edge_dist
                if scale_factor is not None:
                    reward_kwargs["scale_factor"] = float(scale_factor)
                if clip_range is not None:
                    reward_kwargs["clip_range"] = float(clip_range)
            except Exception as weight_err:
                print(f"⚠️ [anonymized]，[anonymized]: {weight_err}")
        
        unique_smiles_to_indices = defaultdict(list)
        graph_to_smiles = []
        
        dataset_info = getattr(self.model, "dataset_info", None)
        atom_decoder = getattr(dataset_info, "atom_decoder", None) if dataset_info is not None else None
        
        from src.analysis.rdkit_functions import build_molecule
        
        valid_graphs_for_docking = []
        original_idx_to_unique_idx = {}
        
        for i, (at, et) in enumerate(py_graphs):
            try:
                at_types = torch.argmax(at, dim=-1) if at.dim() == 2 else at
                et_types = torch.argmax(et, dim=-1) if et.dim() == 3 else et
                mol = build_molecule(at_types, et_types, atom_decoder)
                if mol:
                    smi = Chem.MolToSmiles(mol)
                    if smi:
                        if smi not in unique_smiles_to_indices:
                            unique_smiles_to_indices[smi].append(i)
                            valid_graphs_for_docking.append((at, et))
                            original_idx_to_unique_idx[i] = len(valid_graphs_for_docking) - 1
                        else:
                            first_idx = unique_smiles_to_indices[smi][0]
                            original_idx_to_unique_idx[i] = original_idx_to_unique_idx[first_idx]
                        continue
            except Exception:
                pass
            valid_graphs_for_docking.append((at, et))
            original_idx_to_unique_idx[i] = len(valid_graphs_for_docking) - 1

        num_unique = len(valid_graphs_for_docking)
        if num_unique < num_graphs:
            print(f"✅ [GRPO] SMILES Deduplication: {num_graphs} -> {num_unique} unique molecules")

        batch_size = max(1, num_unique // (self.num_reward_workers * 4))
        batch_size = min(batch_size, 2000)
        
        batches = []
        for i in range(0, num_unique, batch_size):
            batch = valid_graphs_for_docking[i:min(i + batch_size, num_unique)]
            batches.append((batch, reward_func_type, device_str, reward_kwargs))

        try:
            if timeout is not None and timeout > 0:
                async_result = self.reward_pool.starmap_async(_compute_batch_rewards_worker, batches, chunksize=1)
                try:
                    results = async_result.get(timeout=timeout)
                except MPTimeoutError:
                    print(f"⚠️ {context} [anonymized] {timeout} [anonymized]，[anonymized]。")
                    self.reward_pool.terminate()
                    self.reward_pool.join()
                    self.reward_pool = mp.get_context('spawn').Pool(
                        processes=self.num_reward_workers,
                        initializer=_reward_worker_initializer,
                        initargs=(self.reward_worker_threads, reward_func_type, reward_kwargs),
                    )
                    return torch.tensor([], dtype=torch.float32)
            else:
                results = self.reward_pool.starmap(_compute_batch_rewards_worker, batches, chunksize=1)
            
            unique_rewards = []
            for batch_rewards in results:
                unique_rewards.extend(batch_rewards)
            
            all_rewards = []
            for i in range(num_graphs):
                u_idx = original_idx_to_unique_idx[i]
                all_rewards.append(unique_rewards[u_idx])
            
            return torch.tensor(all_rewards, dtype=torch.float32)
            
        except Exception as e:
            print(f"⚠️ [anonymized]（{context}）: {e}")
            print("   [anonymized]...")
            with torch.no_grad():
                rewards = self.reward_function(graph_list)
                if not isinstance(rewards, torch.Tensor):
                    rewards = torch.tensor(rewards, dtype=torch.float32)
                return rewards
    
    def _convert_placeholder_to_graph_list_cpu(self, graphs: utils.PlaceHolder, node_mask: torch.Tensor, as_tensor: bool = False) -> List:
        graph_list = []
        X, E = graphs.X, graphs.E
        
        X_cpu = X.cpu()
        E_cpu = E.cpu()
        node_mask_cpu = node_mask.cpu()
        
        for i in range(X.size(0)):
            n_nodes = node_mask_cpu[i].sum().item()
            atom_tensor = X_cpu[i, :n_nodes].contiguous()
            edge_tensor = E_cpu[i, :n_nodes, :n_nodes].contiguous()
            
            if as_tensor:
                graph_list.append((atom_tensor, edge_tensor))
            else:
                atom_types = torch.argmax(atom_tensor, dim=-1)
                edge_types = torch.argmax(edge_tensor, dim=-1)
                graph_list.append([
                    atom_types.to(torch.int64).tolist(),
                    edge_types.to(torch.int64).tolist(),
                ])
        
        return graph_list

    def _maybe_run_gdpo_eval(self):
        every = int(
            self._get_cfg_value(self.cfg, "grpo.gdpo_eval_every_n_epochs")
            or self._get_cfg_value(self.cfg, "grpo.lead_eval_every_n_epochs")
            or 0
        )
        if every <= 0:
            return
        if self.epoch <= 0 or self.epoch % every != 0:
            return

        reward_type = str(self._get_cfg_value(self.cfg, "grpo.reward_type") or "").lower()
        if reward_type not in ("gdpo_docking", "gdpo"):
            return

        target_name = (
            self._get_cfg_value(self.cfg, "grpo.target_name")
            or self._get_cfg_value(self.cfg, "grpo.target_task")
            or ""
        )
        if not target_name:
            print("⚠️ [GDPO Eval] Missing grpo.target_name; skipping docking eval.")
            return

        dataset_name = str(self._get_cfg_value(self.cfg, "dataset.name") or "")
        sim_threshold = gdpo_get_sim_threshold(
            dataset_name,
            override=self._get_cfg_value(self.cfg, "grpo.gdpo_eval_sim_threshold"),
        )
        target_samples = 2048
        eval_exhaustiveness = self._get_cfg_value(self.cfg, "grpo.gdpo_dock_exhaustiveness") or 8
        eval_num_modes = self._get_cfg_value(self.cfg, "grpo.gdpo_dock_num_modes")
        eval_timeout = self._get_cfg_value(self.cfg, "grpo.gdpo_dock_timeout")
        eval_workers = int(self._get_cfg_value(self.cfg, "grpo.num_reward_workers") or 1)
        eval_workers = max(1, eval_workers)
        eval_cpu_per_worker =  1
        out_dir = self._get_cfg_value(self.cfg, "grpo.gdpo_eval_out_dir") or "gdpo_eval_results"
        out_dir = os.path.abspath(os.path.expanduser(str(out_dir)))
        os.makedirs(out_dir, exist_ok=True)

        py_state = random.getstate()
        np_state = np.random.get_state()
        torch_state = torch.get_rng_state()
        cuda_states = None
        if torch.cuda.is_available():
            try:
                cuda_states = torch.cuda.get_rng_state_all()
            except Exception:
                cuda_states = None

        was_training = self.core_model.training
        self.core_model.eval()

        try:
            print(f"📊 [GDPO Eval] Sampling {target_samples} graphs for docking evaluation...")
            graphs, node_mask, _, _, _, _, _ = self.sample_graphs_with_trajectory_tracking(
                batch_size=target_samples,
                seed=int(time.time() * 1000) % (2**32) + self.global_step,
                total_inference_steps=self.sample_steps,
                force_same_start=False,
            )
            eval_graphs: List[Graph] = self._convert_placeholder_to_graph_list_cpu(
                graphs, node_mask, as_tensor=True
            )

            smiles_list = [self._graph_to_smiles(g) for g in eval_graphs]
            valid_smiles = [s for s in smiles_list if s]
            valid_r = len(valid_smiles) / (len(smiles_list) + 1e-8) if smiles_list else 0.0
            uniq_r = len(set(valid_smiles)) / (len(valid_smiles) + 1e-8) if valid_smiles else 0.0

            repo_root = Path(__file__).resolve().parents[1]
            train_fps = gdpo_load_train_fps(
                dataset_name=dataset_name,
                datadir=str(self._get_cfg_value(self.cfg, "dataset.datadir") or ""),
                remove_h=bool(self._get_cfg_value(self.cfg, "dataset.remove_h")),
                repo_root=repo_root,
            )
            result = gdpo_eval_smiles(
                target_name=str(target_name),
                smiles=valid_smiles,
                train_fps=train_fps,
                sim_threshold=float(sim_threshold),
                repo_root=repo_root,
                dock_exhaustiveness=eval_exhaustiveness,
                dock_num_modes=int(eval_num_modes) if eval_num_modes is not None else None,
                dock_num_workers=eval_workers,
                dock_cpu_per_worker=eval_cpu_per_worker,
                dock_timeout=int(eval_timeout) if eval_timeout is not None else None,
            )

            top_ds_mean, top_ds_std = result.get("top_ds", (float("nan"), float("nan")))
            log_entry = {
                "epoch": int(self.epoch),
                "global_step": int(self.global_step),
                "dataset": str(self._get_cfg_value(self.cfg, "dataset.name") or ""),
                "target_prop": str(target_name),
                "VALID": round(100 * valid_r, 4),
                "UNIQ": round(100 * uniq_r, 4),
                "novelty": result.get("novelty", 0.0),
                "top_ds": [top_ds_mean, top_ds_std],
                "avgscore": result.get("avgscore", 0.0),
                "hit": result.get("hit", 0.0),
                "avgds": round(result.get("avgds", 0.0), 4),
                "avgqed": round(result.get("avgqed", 0.0), 4),
                "avgsa": round(result.get("avgsa", 0.0), 4),
                "sim_threshold": float(sim_threshold),
                "samples": int(target_samples),
            }

            log_suffix = "moses" if "moses" in dataset_name.lower() else "zinc"
            log_path = os.path.join(out_dir, f"evaluation_dict{log_suffix}.log")
            with open(log_path, "a", encoding="utf-8") as f:
                f.write(json.dumps(log_entry) + "\n")

            print(
                f"[GDPO Eval] VALID={log_entry['VALID']} UNIQ={log_entry['UNIQ']} "
                f"Novelty={log_entry['novelty']:.4f} Top-DS={top_ds_mean:.4f}±{top_ds_std:.4f} "
                f"Hit={log_entry['hit']:.4f} AvgDS={log_entry['avgds']:.4f}"
            )

            if swanlab is not None and swanlab.run is not None:
                swanlab.log(
                    {
                        "gdpo_eval/valid_percent": log_entry["VALID"],
                        "gdpo_eval/uniq_percent": log_entry["UNIQ"],
                        "gdpo_eval/novelty": log_entry["novelty"],
                        "gdpo_eval/top_ds_mean": top_ds_mean,
                        "gdpo_eval/top_ds_std": top_ds_std,
                        "gdpo_eval/hit": log_entry["hit"],
                        "gdpo_eval/avgds": log_entry["avgds"],
                        "gdpo_eval/avgqed": log_entry["avgqed"],
                        "gdpo_eval/avgsa": log_entry["avgsa"],
                    },
                    step=self.global_step,
                )
        finally:
            try:
                random.setstate(py_state)
                np.random.set_state(np_state)
                torch.set_rng_state(torch_state)
                if cuda_states is not None:
                    torch.cuda.set_rng_state_all(cuda_states)
            except Exception:
                pass
            self.core_model.train(was_training)

    def _maybe_run_evaluation(self):
        if self.eval_interval <= 0:
            return
        if self.global_step <= 0 or self.global_step % self.eval_interval != 0:
            return
        eval_rewards = self._run_evaluation_rollout()
        if eval_rewards.numel() == 0:
            return
        reward_mean = eval_rewards.mean().item()
        reward_std = eval_rewards.std().item()
        reward_min = eval_rewards.min().item()
        reward_max = eval_rewards.max().item()
        print(
            f"\n🎯 Eval @ step {self.global_step}: mean={reward_mean:.4f}, std={reward_std:.4f}, "
            f"min={reward_min:.4f}, max={reward_max:.4f}"
        )
        valid_mask = eval_rewards > 0.01
        num_valid = valid_mask.sum().item()
        valid_rate = num_valid / eval_rewards.numel() if eval_rewards.numel() > 0 else 0.0
        
        if num_valid > 0:
            avg_valid_reward = eval_rewards[valid_mask].mean().item()
        else:
            avg_valid_reward = 0.0

        if swanlab is not None and swanlab.run is not None:
            swanlab.log({
                'eval/reward_mean': reward_mean,
                'eval/reward_std': reward_std,
                'eval/reward_min': reward_min,
                'eval/reward_max': reward_max,
                'eval/valid_rate': valid_rate,                      
                'eval/avg_valid_reward': avg_valid_reward             
            }, step=self.global_step)
        print(f"  [Stats] Mean Reward: {reward_mean:.4f} | Max Reward: {reward_max:.4f} | Valid Rate: {valid_rate:.2%} | Avg Valid Reward: {avg_valid_reward:.4f}")
        
    def _run_evaluation_rollout(self) -> torch.Tensor:
        eval_graphs = []
        self.core_model.eval()
        try:
            target_samples = 2048
            batch_size = 128                                   
            num_batches = (target_samples + batch_size - 1) // batch_size
            
            print(f"📊 Sampling {target_samples} graphs for evaluation in {num_batches} batches...")
            
            for i in range(num_batches):
                current_batch_size = min(batch_size, target_samples - len(eval_graphs))
                if current_batch_size <= 0:
                    break
                    
                graphs, node_mask, _, _, _, _, _ = self.sample_graphs_with_trajectory_tracking(
                    batch_size=current_batch_size,
                    seed=int(time.time() * 1000) % (2**32) + self.global_step + i,
                    total_inference_steps=self.sample_steps,
                    force_same_start=False,
                )

                batch_graphs = self._convert_placeholder_to_graph_list_cpu(graphs, node_mask, as_tensor=True)
                eval_graphs.extend(batch_graphs)
                print(f"   - Batch {i+1}/{num_batches}: Collected {len(eval_graphs)}/{target_samples}")
        finally:
            self.core_model.eval()
            
        if not eval_graphs:
            return torch.tensor([], dtype=torch.float32)

        print(f"\n📊 Starting Evaluation on {len(eval_graphs)} samples...")

        eval_rewards = self._compute_rewards_multiprocess_sync(
            eval_graphs,
            timeout=1800,
            context="eval",
        )
        
        return eval_rewards
    
    @staticmethod
    def _compute_single_reward_worker(graph_data, reward_type: str, device_str: str):
        import torch
        from grpo_rewards import create_reward_function
        
        device = torch.device(device_str)
        reward_func = create_reward_function(reward_type, device=device)
        
        rewards = reward_func([graph_data])
        return rewards[0].item()
    
    @staticmethod
    def _convert_placeholder_to_graph_list(graphs: utils.PlaceHolder, node_mask: torch.Tensor) -> List:
        graph_list = []
        X, E = graphs.X, graphs.E
        
        device = X.device
        
        for i in range(X.size(0)):
            n_nodes = node_mask[i].sum().item()
            
            if X.dim() == 3:
                atom_types = torch.argmax(X[i, :n_nodes], dim=-1)
            else:
                atom_types = X[i, :n_nodes]
            
            if E.dim() == 4:
                edge_types = torch.argmax(E[i, :n_nodes, :n_nodes], dim=-1)
            else:
                edge_types = E[i, :n_nodes, :n_nodes]
            
            graph_list.append([atom_types, edge_types])
        
        return graph_list
    
    def state_dict(self) -> Dict:
        state = {
            'global_step': self.global_step,
            'epoch': self.epoch,
        }
        if self.grpo_core.stat_tracker:
            state['stat_tracker_stats'] = self.grpo_core.stat_tracker.stats
        
        if hasattr(self, 'global_p0_buffer') and self.global_p0_buffer:
             state['global_p0_buffer'] = self.global_p0_buffer
             
        return state
    
    def load_state_dict(self, state_dict: Dict):
        self.global_step = state_dict.get('global_step', self.global_step)
        self.epoch = state_dict.get('epoch', self.epoch)
        if 'stat_tracker_stats' in state_dict and self.grpo_core.stat_tracker:
            self.grpo_core.stat_tracker.stats = state_dict['stat_tracker_stats']
            
        if 'global_p0_buffer' in state_dict:
            self.global_p0_buffer = state_dict['global_p0_buffer']
            print(f"📥 Restored Global p0 Buffer with {len(self.global_p0_buffer)} items.")
    
    def _log_training_metrics_to_swanlab(
        self,
        epoch_losses: Dict,
        grad_norm_before_clip: float,
        grad_norm_after_clip: float,
        loss_dict: Dict,
        training_batch,
        optimizer
    ):
        log_metrics = {}

        if 'policy_entropy' in loss_dict:
            policy_entropy = loss_dict['policy_entropy']
            if isinstance(policy_entropy, torch.Tensor):
                policy_entropy_value = policy_entropy.detach().item()
            else:
                policy_entropy_value = float(policy_entropy)
            log_metrics['train/policy_entropy'] = policy_entropy_value
        
        log_metrics['train/grad_norm_before_clip'] = grad_norm_before_clip
        log_metrics['train/grad_norm_after_clip'] = grad_norm_after_clip
        log_metrics['train/grad_clip_ratio'] = grad_norm_after_clip / (grad_norm_before_clip + 1e-8)
        
        batch_view = training_batch.as_dict() if isinstance(training_batch, TrajectoryData) else training_batch

        if 'rewards' in batch_view:
            rewards = batch_view['rewards']
            if isinstance(rewards, torch.Tensor):
                log_metrics['train/avg_reward'] = rewards.mean().item()
                log_metrics['train/std_reward'] = rewards.std().item()
                log_metrics['train/min_reward'] = rewards.min().item()
                log_metrics['train/max_reward'] = rewards.max().item()
        
        if 'total_loss' in loss_dict:
            total = loss_dict['total_loss']
            log_metrics['train/loss_total'] = total.item() if isinstance(total, torch.Tensor) else float(total)
        if 'policy_loss' in loss_dict:
            pol = loss_dict['policy_loss']
            log_metrics['train/loss_policy'] = pol.item() if isinstance(pol, torch.Tensor) else float(pol)
        if 'policy_entropy' in loss_dict and getattr(self.grpo_core, 'entropy_coef', 0.0) != 0:
            ent = loss_dict['policy_entropy']
            ent_val = ent.item() if isinstance(ent, torch.Tensor) else float(ent)
            log_metrics['train/loss_entropy'] = -self.grpo_core.entropy_coef * ent_val
        if 'kl_loss' in loss_dict and getattr(self.grpo_core, 'beta', 0.0) != 0:
            kl = loss_dict['kl_loss']
            kl_val = kl.item() if isinstance(kl, torch.Tensor) else float(kl)
            log_metrics['train/loss_kl'] = self.grpo_core.beta * kl_val
        if getattr(self.grpo_core, 'gdcr_coef', 0.0) != 0:
            if 'gdcr/mean_match' in epoch_losses and epoch_losses['gdcr/mean_match']:
                recent_mm = epoch_losses['gdcr/mean_match'][-min(10, len(epoch_losses['gdcr/mean_match'])):]
                if isinstance(recent_mm[0], torch.Tensor):
                    mm_raw = torch.stack(recent_mm).mean().item()
                else:
                    mm_raw = float(np.mean(recent_mm))
                log_metrics['train/loss_gdcr_mean'] = self.grpo_core.gdcr_coef * mm_raw
            elif 'gdcr/mean_match' in loss_dict:
                mm = loss_dict['gdcr/mean_match']
                mm_val = mm.item() if isinstance(mm, torch.Tensor) else float(mm)
                log_metrics['train/loss_gdcr_mean'] = self.grpo_core.gdcr_coef * mm_val
        if getattr(self.grpo_core, 'diversity_coef', 0.0) != 0:
            if 'gdcr/diversity' in epoch_losses and epoch_losses['gdcr/diversity']:
                recent_div = epoch_losses['gdcr/diversity'][-min(10, len(epoch_losses['gdcr/diversity'])):]
                if isinstance(recent_div[0], torch.Tensor):
                    div_raw = torch.stack(recent_div).mean().item()
                else:
                    div_raw = float(np.mean(recent_div))
                log_metrics['train/loss_gdcr_div'] = self.grpo_core.diversity_coef * div_raw
            elif 'gdcr/diversity' in loss_dict:
                div = loss_dict['gdcr/diversity']
                div_val = div.item() if isinstance(div, torch.Tensor) else float(div)
                log_metrics['train/loss_gdcr_div'] = self.grpo_core.diversity_coef * div_val
        if 'ratio_mean' in loss_dict:
            log_metrics['train/avg_ratio'] = loss_dict['ratio_mean'].item() if isinstance(loss_dict['ratio_mean'], torch.Tensor) else loss_dict['ratio_mean']
        if 'ratio_std' in loss_dict:
            log_metrics['train/std_ratio'] = loss_dict['ratio_std'].item() if isinstance(loss_dict['ratio_std'], torch.Tensor) else loss_dict['ratio_std']
        
        if optimizer is not None:
            log_metrics['train/learning_rate'] = optimizer.param_groups[0]['lr']
        
        log_metrics['train/global_step'] = self.global_step
        log_metrics['train/epoch'] = self.epoch
        
        swanlab.log(log_metrics, step=self.global_step)
    
    def _maybe_decay_lr(self, optimizer, training_batch):
        if optimizer is None or self.lr_decay_threshold is None:
            return

        batch_view = training_batch.as_dict() if isinstance(training_batch, TrajectoryData) else training_batch
        if 'rewards' not in batch_view:
            return

        epoch_rewards = batch_view['rewards']
        if isinstance(epoch_rewards, torch.Tensor):
            epoch_mean = epoch_rewards.mean().item()
        else:
            epoch_mean = float(np.mean(epoch_rewards))

        self._lr_decay_history.append(epoch_mean)
        if len(self._lr_decay_history) > self.lr_decay_window * 2:
            self._lr_decay_history = self._lr_decay_history[-self.lr_decay_window * 2:]

        if self._lr_decay_applied:
            return

        if len(self._lr_decay_history) >= self.lr_decay_window:
            recent_means = self._lr_decay_history[-self.lr_decay_window:]
            all_means_high = all(m >= self.lr_decay_threshold for m in recent_means)
            if all_means_high:
                old_lr = optimizer.param_groups[0]['lr']
                new_lr = old_lr * self.lr_decay_factor
                if self.lr_decay_min is not None:
                    new_lr = max(new_lr, self.lr_decay_min)
                for group in optimizer.param_groups:
                    group['lr'] = new_lr
                self._lr_decay_applied = True
                print(
                    f"\n📉 [anonymized]: mean_reward>= {self.lr_decay_threshold:.3f} (window={self.lr_decay_window})\n"
                    f"   lr: {old_lr:.6g} → {new_lr:.6g}"
                )
                if swanlab is not None and swanlab.run is not None:
                    swanlab.log({
                        'lr_decay/trigger': 1,
                        'lr_decay/old_lr': old_lr,
                        'lr_decay/new_lr': new_lr,
                        'lr_decay/threshold': self.lr_decay_threshold,
                    }, step=self.global_step)
    
    def __del__(self):
        reward_pool = getattr(self, "reward_pool", None)
        if reward_pool is not None:
            reward_pool.close()
            reward_pool.join()
    def _log_detailed_visualization(
        self,
        trajectory_states,
        trajectory_preds,
        trajectory_probs,
        dense_rewards,
        final_rewards,
        log_dir,
        batch_indices
    ):
        import csv
        from src.analysis.visualization import MolecularVisualization
        
        dataset_info = getattr(self.model, "dataset_info", None)
        if dataset_info is None:
            print("⚠️ [anonymized] dataset_info，[anonymized]")
            return
            
        vis_tool = MolecularVisualization(remove_h=True, dataset_infos=dataset_info)
        
        os.makedirs(log_dir, exist_ok=True)
        
        if torch.is_tensor(final_rewards):
            final_rewards = final_rewards.cpu().tolist()
            
        def _extract_state_graph(state):
            if state is None:
                return None, None
            X_state = state.X
            E_state = state.E
            if torch.is_tensor(X_state):
                X_state = X_state.detach()
            if torch.is_tensor(E_state):
                E_state = E_state.detach()
            if torch.is_tensor(X_state) and X_state.dim() == 3:
                X_state = X_state.squeeze(0)
            if torch.is_tensor(E_state) and E_state.dim() == 4:
                E_state = E_state.squeeze(0)

            if torch.is_tensor(X_state) and X_state.dim() == 2:
                node_idx = X_state.argmax(dim=-1)
                node_mask = X_state.sum(dim=-1) > 0
            else:
                node_idx = X_state
                node_mask = node_idx >= 0 if torch.is_tensor(node_idx) else None

            if torch.is_tensor(E_state) and E_state.dim() == 3:
                edge_idx = E_state.argmax(dim=-1)
            else:
                edge_idx = E_state

            if torch.is_tensor(node_mask):
                n_nodes = int(node_mask.sum().item())
            else:
                n_nodes = len(node_idx) if node_idx is not None else 0

            if n_nodes <= 0:
                return (
                    torch.empty(0, dtype=torch.long),
                    torch.empty(0, 0, dtype=torch.long),
                )

            return (
                node_idx[:n_nodes].to(torch.long).cpu(),
                edge_idx[:n_nodes, :n_nodes].to(torch.long).cpu(),
            )

        def _save_graph_image(graph_dir, nodes, adj, filename):
            mols_to_plot = [(nodes, adj,)]
            vis_tool.visualize(graph_dir, mols_to_plot, 1, log=None)

            src_img = os.path.join(graph_dir, "molecule_0.png")
            dst_img = os.path.join(graph_dir, filename)

            if os.path.exists(src_img):
                if os.path.exists(dst_img):
                    os.remove(dst_img)
                os.rename(src_img, dst_img)

        for b_idx in batch_indices:
            if b_idx >= len(trajectory_preds):
                continue
                
            traj_pred = trajectory_preds[b_idx]
            traj_states = None
            if trajectory_states is not None and b_idx < len(trajectory_states):
                traj_states = trajectory_states[b_idx]
            
            traj_dense_rewards = None
            if dense_rewards is not None:
                 traj_dense_rewards = dense_rewards[b_idx].cpu().tolist()
                 
            final_r = final_rewards[b_idx] if b_idx < len(final_rewards) else 0.0
            
            graph_dir = os.path.join(log_dir, f"graph_{b_idx}")
            os.makedirs(graph_dir, exist_ok=True)
            
            csv_path = os.path.join(graph_dir, "rewards.csv")
            csv_rows = []
            
            num_steps = len(traj_pred)
            
            for t in range(num_steps):
                if traj_states is not None and t < len(traj_states):
                    try:
                        zt_nodes, zt_adj = _extract_state_graph(traj_states[t])
                        _save_graph_image(graph_dir, zt_nodes, zt_adj, f"step_{t}_zt.png")
                    except Exception as e:
                        print(f"   ⚠️ Step {t} zt visualization failed: {e}")

                z1 = traj_pred[t]

                try:
                    nodes = z1[0].cpu() if torch.is_tensor(z1[0]) else z1[0]
                    adj = z1[1].cpu() if torch.is_tensor(z1[1]) else z1[1]

                    _save_graph_image(graph_dir, nodes, adj, f"step_{t}.png")

                except Exception as e:
                    print(f"   ⚠️ Step {t} z1 visualization failed: {e}")
                
                reward_val = 0.0
                if traj_dense_rewards and t < len(traj_dense_rewards):
                    reward_val = traj_dense_rewards[t]
                
                csv_rows.append([t, reward_val, final_r])

            try:
                with open(csv_path, "w", newline="") as f:
                    writer = csv.writer(f)
                    writer.writerow(["Step", "DenseReward", "FinalReward"])
                    writer.writerows(csv_rows)
            except Exception as e:
                 print(f"   ⚠️ CSV writing failed: {e}")
        
        print(f"✅ Detailed visualization saved to {log_dir}")
        
        return graph_list
    
    def state_dict(self) -> Dict:
        state = {
            'global_step': self.global_step,
            'epoch': self.epoch,
        }
        if self.grpo_core.stat_tracker:
            state['stat_tracker_stats'] = self.grpo_core.stat_tracker.stats
        
        if hasattr(self, 'global_p0_buffer') and self.global_p0_buffer:
             state['global_p0_buffer'] = self.global_p0_buffer
             
        return state
    
    def load_state_dict(self, state_dict: Dict):
        self.global_step = state_dict.get('global_step', self.global_step)
        self.epoch = state_dict.get('epoch', self.epoch)
        if 'stat_tracker_stats' in state_dict and self.grpo_core.stat_tracker:
            self.grpo_core.stat_tracker.stats = state_dict['stat_tracker_stats']
            
        if 'global_p0_buffer' in state_dict:
            self.global_p0_buffer = state_dict['global_p0_buffer']
            print(f"📥 Restored Global p0 Buffer with {len(self.global_p0_buffer)} items.")
    
    def _log_training_metrics_to_swanlab(
        self,
        epoch_losses: Dict,
        grad_norm_before_clip: float,
        grad_norm_after_clip: float,
        loss_dict: Dict,
        training_batch,
        optimizer
    ):
        log_metrics = {}

        if 'policy_entropy' in loss_dict:
            policy_entropy = loss_dict['policy_entropy']
            if isinstance(policy_entropy, torch.Tensor):
                policy_entropy_value = policy_entropy.detach().item()
            else:
                policy_entropy_value = float(policy_entropy)
            log_metrics['train/policy_entropy'] = policy_entropy_value
        
        log_metrics['train/grad_norm_before_clip'] = grad_norm_before_clip
        log_metrics['train/grad_norm_after_clip'] = grad_norm_after_clip
        log_metrics['train/grad_clip_ratio'] = grad_norm_after_clip / (grad_norm_before_clip + 1e-8)
        
        batch_view = training_batch.as_dict() if isinstance(training_batch, TrajectoryData) else training_batch

        if 'rewards' in batch_view:
            rewards = batch_view['rewards']
            if isinstance(rewards, torch.Tensor):
                log_metrics['train/avg_reward'] = rewards.mean().item()
                log_metrics['train/std_reward'] = rewards.std().item()
                log_metrics['train/min_reward'] = rewards.min().item()
                log_metrics['train/max_reward'] = rewards.max().item()
        
        if 'total_loss' in loss_dict:
            total = loss_dict['total_loss']
            log_metrics['train/loss_total'] = total.item() if isinstance(total, torch.Tensor) else float(total)
        if 'policy_loss' in loss_dict:
            pol = loss_dict['policy_loss']
            log_metrics['train/loss_policy'] = pol.item() if isinstance(pol, torch.Tensor) else float(pol)
        if 'policy_entropy' in loss_dict and getattr(self.grpo_core, 'entropy_coef', 0.0) != 0:
            ent = loss_dict['policy_entropy']
            ent_val = ent.item() if isinstance(ent, torch.Tensor) else float(ent)
            log_metrics['train/loss_entropy'] = -self.grpo_core.entropy_coef * ent_val
        if 'kl_loss' in loss_dict and getattr(self.grpo_core, 'beta', 0.0) != 0:
            kl = loss_dict['kl_loss']
            kl_val = kl.item() if isinstance(kl, torch.Tensor) else float(kl)
            log_metrics['train/loss_kl'] = self.grpo_core.beta * kl_val
        if getattr(self.grpo_core, 'gdcr_coef', 0.0) != 0:
            if 'gdcr/mean_match' in epoch_losses and epoch_losses['gdcr/mean_match']:
                recent_mm = epoch_losses['gdcr/mean_match'][-min(10, len(epoch_losses['gdcr/mean_match'])):]
                if isinstance(recent_mm[0], torch.Tensor):
                    mm_raw = torch.stack(recent_mm).mean().item()
                else:
                    mm_raw = float(np.mean(recent_mm))
                log_metrics['train/loss_gdcr_mean'] = self.grpo_core.gdcr_coef * mm_raw
            elif 'gdcr/mean_match' in loss_dict:
                mm = loss_dict['gdcr/mean_match']
                mm_val = mm.item() if isinstance(mm, torch.Tensor) else float(mm)
                log_metrics['train/loss_gdcr_mean'] = self.grpo_core.gdcr_coef * mm_val
        if getattr(self.grpo_core, 'diversity_coef', 0.0) != 0:
            if 'gdcr/diversity' in epoch_losses and epoch_losses['gdcr/diversity']:
                recent_div = epoch_losses['gdcr/diversity'][-min(10, len(epoch_losses['gdcr/diversity'])):]
                if isinstance(recent_div[0], torch.Tensor):
                    div_raw = torch.stack(recent_div).mean().item()
                else:
                    div_raw = float(np.mean(recent_div))
                log_metrics['train/loss_gdcr_div'] = self.grpo_core.diversity_coef * div_raw
            elif 'gdcr/diversity' in loss_dict:
                div = loss_dict['gdcr/diversity']
                div_val = div.item() if isinstance(div, torch.Tensor) else float(div)
                log_metrics['train/loss_gdcr_div'] = self.grpo_core.diversity_coef * div_val
        if 'ratio_mean' in loss_dict:
            log_metrics['train/avg_ratio'] = loss_dict['ratio_mean'].item() if isinstance(loss_dict['ratio_mean'], torch.Tensor) else loss_dict['ratio_mean']
        if 'ratio_std' in loss_dict:
            log_metrics['train/std_ratio'] = loss_dict['ratio_std'].item() if isinstance(loss_dict['ratio_std'], torch.Tensor) else loss_dict['ratio_std']
        
        if optimizer is not None:
            log_metrics['train/learning_rate'] = optimizer.param_groups[0]['lr']
        
        log_metrics['train/global_step'] = self.global_step
        log_metrics['train/epoch'] = self.epoch
        
        swanlab.log(log_metrics, step=self.global_step)
    
    def _maybe_decay_lr(self, optimizer, training_batch):
        if optimizer is None or self.lr_decay_threshold is None:
            return

        batch_view = training_batch.as_dict() if isinstance(training_batch, TrajectoryData) else training_batch
        if 'rewards' not in batch_view:
            return

        epoch_rewards = batch_view['rewards']
        if isinstance(epoch_rewards, torch.Tensor):
            epoch_mean = epoch_rewards.mean().item()
        else:
            epoch_mean = float(np.mean(epoch_rewards))

        self._lr_decay_history.append(epoch_mean)
        if len(self._lr_decay_history) > self.lr_decay_window * 2:
            self._lr_decay_history = self._lr_decay_history[-self.lr_decay_window * 2:]

        if self._lr_decay_applied:
            return

        if len(self._lr_decay_history) >= self.lr_decay_window:
            recent_means = self._lr_decay_history[-self.lr_decay_window:]
            all_means_high = all(m >= self.lr_decay_threshold for m in recent_means)
            if all_means_high:
                old_lr = optimizer.param_groups[0]['lr']
                new_lr = old_lr * self.lr_decay_factor
                if self.lr_decay_min is not None:
                    new_lr = max(new_lr, self.lr_decay_min)
                for group in optimizer.param_groups:
                    group['lr'] = new_lr
                self._lr_decay_applied = True
                print(
                    f"\n📉 [anonymized]: mean_reward>= {self.lr_decay_threshold:.3f} (window={self.lr_decay_window})\n"
                    f"   lr: {old_lr:.6g} → {new_lr:.6g}"
                )
                if swanlab is not None and swanlab.run is not None:
                    swanlab.log({
                        'lr_decay/trigger': 1,
                        'lr_decay/old_lr': old_lr,
                        'lr_decay/new_lr': new_lr,
                        'lr_decay/threshold': self.lr_decay_threshold,
                    }, step=self.global_step)
    
    def __del__(self):
        reward_pool = getattr(self, "reward_pool", None)
        if reward_pool is not None:
            reward_pool.close()
            reward_pool.join()
    def _log_detailed_visualization(
        self,
        trajectory_states,
        trajectory_preds,
        trajectory_probs,
        dense_rewards,
        final_rewards,
        log_dir,
        batch_indices
    ):
        import csv
        from src.analysis.visualization import MolecularVisualization
        
        dataset_info = getattr(self.model, "dataset_info", None)
        if dataset_info is None:
            print("⚠️ [anonymized] dataset_info，[anonymized]")
            return
            
        vis_tool = MolecularVisualization(remove_h=True, dataset_infos=dataset_info)
        
        os.makedirs(log_dir, exist_ok=True)
        
        if torch.is_tensor(final_rewards):
            final_rewards = final_rewards.cpu().tolist()
            
        def _extract_state_graph(state):
            if state is None:
                return None, None
            X_state = state.X
            E_state = state.E
            if torch.is_tensor(X_state):
                X_state = X_state.detach()
            if torch.is_tensor(E_state):
                E_state = E_state.detach()
            if torch.is_tensor(X_state) and X_state.dim() == 3:
                X_state = X_state.squeeze(0)
            if torch.is_tensor(E_state) and E_state.dim() == 4:
                E_state = E_state.squeeze(0)

            if torch.is_tensor(X_state) and X_state.dim() == 2:
                node_idx = X_state.argmax(dim=-1)
                node_mask = X_state.sum(dim=-1) > 0
            else:
                node_idx = X_state
                node_mask = node_idx >= 0 if torch.is_tensor(node_idx) else None

            if torch.is_tensor(E_state) and E_state.dim() == 3:
                edge_idx = E_state.argmax(dim=-1)
            else:
                edge_idx = E_state

            if torch.is_tensor(node_mask):
                n_nodes = int(node_mask.sum().item())
            else:
                n_nodes = len(node_idx) if node_idx is not None else 0

            if n_nodes <= 0:
                return (
                    torch.empty(0, dtype=torch.long),
                    torch.empty(0, 0, dtype=torch.long),
                )

            return (
                node_idx[:n_nodes].to(torch.long).cpu(),
                edge_idx[:n_nodes, :n_nodes].to(torch.long).cpu(),
            )

        def _save_graph_image(graph_dir, nodes, adj, filename):
            mols_to_plot = [(nodes, adj,)]
            vis_tool.visualize(graph_dir, mols_to_plot, 1, log=None)

            src_img = os.path.join(graph_dir, "molecule_0.png")
            dst_img = os.path.join(graph_dir, filename)

            if os.path.exists(src_img):
                if os.path.exists(dst_img):
                    os.remove(dst_img)
                os.rename(src_img, dst_img)

        for b_idx in batch_indices:
            if b_idx >= len(trajectory_preds):
                continue
                
            traj_pred = trajectory_preds[b_idx]
            traj_states = None
            if trajectory_states is not None and b_idx < len(trajectory_states):
                traj_states = trajectory_states[b_idx]
            
            traj_dense_rewards = None
            if dense_rewards is not None:
                 traj_dense_rewards = dense_rewards[b_idx].cpu().tolist()
                 
            final_r = final_rewards[b_idx] if b_idx < len(final_rewards) else 0.0
            
            graph_dir = os.path.join(log_dir, f"graph_{b_idx}")
            os.makedirs(graph_dir, exist_ok=True)
            
            csv_path = os.path.join(graph_dir, "rewards.csv")
            csv_rows = []
            
            num_steps = len(traj_pred)
            
            for t in range(num_steps):
                if traj_states is not None and t < len(traj_states):
                    try:
                        zt_nodes, zt_adj = _extract_state_graph(traj_states[t])
                        _save_graph_image(graph_dir, zt_nodes, zt_adj, f"step_{t}_zt.png")
                    except Exception as e:
                        print(f"   ⚠️ Step {t} zt visualization failed: {e}")

                z1 = traj_pred[t]

                try:
                    nodes = z1[0].cpu() if torch.is_tensor(z1[0]) else z1[0]
                    adj = z1[1].cpu() if torch.is_tensor(z1[1]) else z1[1]

                    _save_graph_image(graph_dir, nodes, adj, f"step_{t}.png")

                except Exception as e:
                    print(f"   ⚠️ Step {t} z1 visualization failed: {e}")
                
                reward_val = 0.0
                if traj_dense_rewards and t < len(traj_dense_rewards):
                    reward_val = traj_dense_rewards[t]
                
                csv_rows.append([t, reward_val, final_r])

            try:
                with open(csv_path, "w", newline="") as f:
                    writer = csv.writer(f)
                    writer.writerow(["Step", "DenseReward", "FinalReward"])
                    writer.writerows(csv_rows)
            except Exception as e:
                 print(f"   ⚠️ CSV writing failed: {e}")
        
        print(f"✅ Detailed visualization saved to {log_dir}")
