import warnings
warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*")
warnings.filterwarnings("ignore", message=".*Boto3 will no longer support Python 3.9.*")
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

import torch
import numpy as np
from typing import List, Tuple, Callable, Optional, Dict, Union
import networkx as nx
from collections import Counter
import hashlib
import pickle
import time
import math

from omegaconf import OmegaConf
from rdkit import Chem, DataStructs
from rdkit import RDLogger
from rdkit.Chem import QED, rdMolDescriptors, AllChem
import sys
from rdkit.Chem import Descriptors
from rdkit.Chem import rdFMCS
from scipy.stats import gmean
import os
from pathlib import Path
import shutil
try:
    import sascorer                                       
except ImportError:
    sascorer = None
    conda_prefix = sys.prefix
    contrib_path = os.path.join(conda_prefix, "share", "RDKit", "Contrib", "SA_Score")
    candidate = os.path.join(contrib_path, "sascorer.py")
    if os.path.exists(candidate):
        if contrib_path not in sys.path:
            sys.path.append(contrib_path)
        try:
            import sascorer
        except ImportError as e:
            print(f"⚠️ [GRPO] [anonymized] sascorer [anonymized]: {e}")
            sascorer = None
from analysis.rdkit_functions import build_molecule, build_molecule_with_partial_charges
from analysis.lead_opt_oracle import LeadOptOracle
from eval_gdpo_docking import gdpo_get_sim_threshold, gdpo_load_train_fps

RDLogger.DisableLog("rdApp.*")
_SA_FALLBACK_WARNED = False


def resolve_target_task(cfg, default: str = "penalized_logp") -> str:
    if cfg is None:
        return default

    task = None
    try:
        task = OmegaConf.select(cfg, "grpo.target_task", default=None)
    except Exception:
        task = None

    if not task and isinstance(cfg, dict):
        section = cfg.get("grpo", {})
        if isinstance(section, dict):
            task = section.get("target_task")

    return task or default


class GaussianModifier:
    def __init__(self, mu: float, sigma: float):
        self.mu = mu
        self.sigma = sigma

    def __call__(self, x: float) -> float:
        return float(np.exp(-0.5 * np.power((x - self.mu) / self.sigma, 2)))






class BaseRewardFunction:
    
    def __init__(self, name: str = "base", device: Optional[torch.device] = None):
        self.name = name
        self._cache = {}
        self._cache_size = 1000
        self.device = device if device is not None else torch.device("cpu")
    
    def __call__(self, graphs: List[Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
        rewards = []
        for atom_types, edge_types in graphs:
            try:
                nx_graph = self._convert_tensor_to_networkx_graph(atom_types, edge_types)
                if nx_graph.number_of_nodes() > 0:
                    is_connected = nx.is_connected(nx_graph)
                    is_planar, _ = nx.check_planarity(nx_graph)
                    if is_connected and is_planar:
                        reward = 1.0
                    else:
                        reward = 0.1
                else:
                    reward = 0.0
                rewards.append(reward)
            except Exception as e:
                print(f"[anonymized]: {e}")
                rewards.append(0.0)
        
        return torch.tensor(rewards, dtype=torch.float32, device=self.device)
    
    def _convert_tensor_to_networkx_graph(self, atom_types: torch.Tensor, edge_types: torch.Tensor) -> nx.Graph:
        try:
            n_nodes = atom_types.size(0)
            
            if edge_types.dim() == 3:
                edge_decisions = torch.argmax(edge_types, dim=-1)                      
            elif edge_types.dim() == 2:
                edge_decisions = edge_types
            else:
                raise ValueError(f"Unsupported edge_types dimension: {edge_types.dim()}")
            
            A = edge_decisions.cpu().numpy()
            
            A = (A + A.T) > 0
            A = A.astype(int)
            
            np.fill_diagonal(A, 0)
            
            nx_graph = nx.from_numpy_array(A)
            
            return nx_graph
            
        except Exception as e:
            print(f"[anonymized]NetworkX[anonymized]: {e}")
            print(f"  atom_types shape: {atom_types.shape}")
            print(f"  edge_types shape: {edge_types.shape}")
            print(f"  edge_types dim: {edge_types.dim()}")
            return nx.Graph()
    
    def _compute_graph_hash_for_caching(self, atom_types: torch.Tensor, edge_types: torch.Tensor) -> str:
        try:
            atom_hash = hashlib.md5(atom_types.cpu().numpy().tobytes()).hexdigest()
            edge_hash = hashlib.md5(edge_types.cpu().numpy().tobytes()).hexdigest()
            return f"{atom_hash}_{edge_hash}"
        except:
            return str(hash(str(atom_types) + str(edge_types)))

class DefaultRewardFunction(BaseRewardFunction):
    
    def __init__(self, device: Optional[torch.device] = None):
        super().__init__("default", device=device)
    
    def __call__(self, graphs: List[Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
        rewards = []
        
        for atom_types, edge_types in graphs:
            n_nodes = atom_types.size(0)
            n_edges = (edge_types.sum(dim=-1) > 0).sum().item() // 2
            
            connectivity_reward = min(n_edges / max(1, n_nodes - 1), 1.0)
            
            unique_atoms = torch.unique(torch.argmax(atom_types, dim=-1)).size(0)
            diversity_reward = unique_atoms / max(1, n_nodes)
            
            total_reward = (connectivity_reward + diversity_reward) / 2.0
            rewards.append(total_reward)
        
        return torch.tensor(rewards, dtype=torch.float32, device=self.device)


class PlanarGraphReward(BaseRewardFunction):

    _WARNED_ORCA = False
    _FIXED_BINS = 100
    _FIXED_W_VALID = 1.0
    _FIXED_W_DEG = 0.2
    _FIXED_W_CLUS = 0.2
    _FIXED_W_ORB = 0.2
    _FIXED_DEG_SCALE = 5.0
    _FIXED_CLUS_SCALE = 5.0
    _FIXED_ORB_SCALE = 1.0
    _FIXED_USE_ORB = True

    def __init__(
        self,
        device: Optional[torch.device] = None,
        *,
        datamodule=None,
        ref_degree_dist: Optional[Union[np.ndarray, List[float]]] = None,
        ref_clustering_hist: Optional[Union[np.ndarray, List[float]]] = None,
        ref_orbit_mean: Optional[Union[np.ndarray, List[float]]] = None,
    ):
        super().__init__("planar_graph", device=device)
        self.bins = int(self._FIXED_BINS)
        self.w_valid = float(self._FIXED_W_VALID)
        self.w_deg = float(self._FIXED_W_DEG)
        self.w_clus = float(self._FIXED_W_CLUS)
        self.w_orb = float(self._FIXED_W_ORB)
        self.deg_scale = float(self._FIXED_DEG_SCALE)
        self.clus_scale = float(self._FIXED_CLUS_SCALE)
        self.orb_scale = float(self._FIXED_ORB_SCALE)
        self.use_orb = bool(self._FIXED_USE_ORB)
        self._cache_size = 2000

        need_orb = bool(self.use_orb and ref_orbit_mean is None)
        if ref_degree_dist is None or ref_clustering_hist is None or need_orb:
            if datamodule is not None:
                stats = self._load_reference_stats_from_ref_metrics(datamodule)
                if ref_degree_dist is None: ref_degree_dist = stats.get("ref_degree_dist")
                if ref_clustering_hist is None: ref_clustering_hist = stats.get("ref_clustering_hist")
                if ref_orbit_mean is None: ref_orbit_mean = stats.get("ref_orbit_mean")
                
                if ref_degree_dist is None or ref_clustering_hist is None:
                    print("⏳ [PlanarGraphReward] Reference stats missing from file. Computing from training set (InMemory)...")
                    stats_computed = self._compute_stats_from_datamodule(datamodule)
                    if ref_degree_dist is None: ref_degree_dist = stats_computed.get("ref_degree_dist")                     
                    if ref_clustering_hist is None: ref_clustering_hist = stats_computed.get("ref_clustering_hist")
                    if self.use_orb and ref_orbit_mean is None:
                        ref_orbit_mean = stats_computed.get("ref_orbit_mean")

        if ref_degree_dist is None:
            print("⚠️ [PlanarGraphReward] Missing 'ref_degree_dist'. Disabling degree reward.")
            self.w_deg = 0.0
            self.ref_degree_dist = np.zeros(1)        
        else:
            self.ref_degree_dist = self._safe_normalize(np.asarray(ref_degree_dist, dtype=np.float64))

        if ref_clustering_hist is None:
            print("⚠️ [PlanarGraphReward] Missing 'ref_clustering_hist'. Disabling clustering reward.")
            self.w_clus = 0.0
            self.ref_clustering_hist = np.zeros(1)        
        else:
            self.ref_clustering_hist = self._safe_normalize(np.asarray(ref_clustering_hist, dtype=np.float64))

        if self.use_orb:
            if ref_orbit_mean is None:
                print("⚠️ [PlanarGraphReward] Missing 'ref_orbit_mean'. Disabling orbit reward.")
                self.w_orb = 0.0
                self.use_orb = False
                self.ref_orbit_mean = None
            else:
                 self.ref_orbit_mean = np.asarray(ref_orbit_mean, dtype=np.float64)
        else:
            self.ref_orbit_mean = None

    def state_dict_for_workers(self) -> Dict[str, object]:
        use_orb = bool(self.use_orb and self.ref_orbit_mean is not None)
        return {
            "ref_degree_dist": self.ref_degree_dist.astype(np.float64).tolist(),
            "ref_clustering_hist": self.ref_clustering_hist.astype(np.float64).tolist(),
            "ref_orbit_mean": None if not use_orb else self.ref_orbit_mean.astype(np.float64).tolist(),
        }

    @staticmethod
    def _safe_normalize(vec: np.ndarray, eps: float = 1e-12) -> np.ndarray:
        vec = np.asarray(vec, dtype=np.float64)
        s = float(vec.sum())
        if not np.isfinite(s) or s <= 0:
            return np.zeros_like(vec, dtype=np.float64)
        out = vec / (s + eps)
        out = np.clip(out, 0.0, 1.0)
        out = out / max(eps, float(out.sum()))
        return out

    @staticmethod
    def _js_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-12) -> float:
        if p.shape != q.shape: return 1.0           
        p = np.asarray(p, dtype=np.float64) + eps
        q = np.asarray(q, dtype=np.float64) + eps
        p = p / float(p.sum())
        q = q / float(q.sum())
        m = 0.5 * (p + q)
        kl_pm = float(np.sum(p * (np.log(p) - np.log(m))))
        kl_qm = float(np.sum(q * (np.log(q) - np.log(m))))
        return 0.5 * (kl_pm + kl_qm)

    def _degree_hist(self, G: nx.Graph) -> np.ndarray:
        hist = np.asarray(nx.degree_histogram(G), dtype=np.float64)
        target_len = int(self.ref_degree_dist.shape[0])
        if hist.shape[0] < target_len:
            hist = np.pad(hist, (0, target_len - hist.shape[0]))
        elif hist.shape[0] > target_len:
            hist = hist[:target_len]
        return self._safe_normalize(hist)

    def _clustering_hist(self, G: nx.Graph) -> np.ndarray:
        coeffs = list(nx.clustering(G).values())
        hist, _ = np.histogram(coeffs, bins=self.bins, range=(0.0, 1.0), density=False)
        hist = np.asarray(hist, dtype=np.float64)
        target_len = int(self.ref_clustering_hist.shape[0])
        if hist.shape[0] < target_len:
            hist = np.pad(hist, (0, target_len - hist.shape[0]))
        elif hist.shape[0] > target_len:
            hist = hist[:target_len]
        return self._safe_normalize(hist)

    def _orbit_vec(self, G: nx.Graph) -> Optional[np.ndarray]:
        if not self.use_orb:
            return None
        try:
            from analysis.spectre_utils import orca as _orca
        except Exception:
            return None

        try:
            counts = _orca(G)
            counts = np.asarray(counts, dtype=np.float64)
            if counts.ndim != 2 or counts.shape[0] <= 0:
                return None
            vec = np.sum(counts, axis=0) / float(G.number_of_nodes())
            if self.ref_orbit_mean is not None:
                ref_len = int(self.ref_orbit_mean.shape[0])
                if vec.shape[0] < ref_len:
                    vec = np.pad(vec, (0, ref_len - vec.shape[0]))
                elif vec.shape[0] > ref_len:
                    vec = vec[:ref_len]
            return vec
        except Exception:
            return None

    @staticmethod
    def _graph_hash(edge_types: torch.Tensor) -> str:
        if edge_types.dim() == 3:
            edge_idx = edge_types.detach().to("cpu").argmax(dim=-1).to(torch.uint8).numpy()
        else:
            edge_idx = edge_types.detach().to("cpu").to(torch.uint8).numpy()
        return hashlib.md5(edge_idx.tobytes()).hexdigest()

    def compute_components(self, graphs: List[Tuple[torch.Tensor, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        valid_scores: List[float] = []
        deg_scores: List[float] = []
        clus_scores: List[float] = []
        orb_scores: List[float] = []
        train_scores: List[float] = []
        total_scores: List[float] = []

        for atom_types, edge_types in graphs:
            try:
                key = self._graph_hash(edge_types)
                cached = self._cache.get(key)
                if cached is not None:
                    v, d, c, o, t, r = cached
                    valid_scores.append(v)
                    deg_scores.append(d)
                    clus_scores.append(c)
                    orb_scores.append(o)
                    train_scores.append(t)
                    total_scores.append(r)
                    continue

                G = self._convert_tensor_to_networkx_graph(atom_types, edge_types)
                if G.number_of_nodes() <= 0 or G.number_of_edges() <= 0:
                    valid = 0.0
                else:
                    try:
                        valid = 1.0 if (nx.is_connected(G) and nx.check_planarity(G)[0]) else 0.0
                    except Exception:
                        valid = 0.0

                deg_sim = 0.0
                clus_sim = 0.0
                orb_sim = 0.0

                if valid > 0.0:
                    if self.w_deg > 0:
                        p_deg = self._degree_hist(G)
                        deg_sim = float(np.exp(-self.deg_scale * self._js_divergence(p_deg, self.ref_degree_dist)))

                    if self.w_clus > 0:
                        p_clus = self._clustering_hist(G)
                        clus_sim = float(np.exp(-self.clus_scale * self._js_divergence(p_clus, self.ref_clustering_hist)))

                    if self.use_orb and self.w_orb > 0:
                        vec = self._orbit_vec(G)
                        if vec is not None and self.ref_orbit_mean is not None:
                            dist = float(np.mean(np.abs(vec - self.ref_orbit_mean)))
                            orb_sim = float(np.exp(-self.orb_scale * dist))

                train_reward = float(self.w_deg * deg_sim + self.w_clus * clus_sim + self.w_orb * orb_sim)
                denom = float(self.w_valid + self.w_deg + self.w_clus + self.w_orb)
                if denom <= 0.0:
                    denom = 1.0
                total = float(valid * (self.w_valid + train_reward) / denom)

                valid_scores.append(valid)
                deg_scores.append(deg_sim)
                clus_scores.append(clus_sim)
                orb_scores.append(orb_sim)
                train_scores.append(train_reward)
                total_scores.append(total)

                if len(self._cache) >= self._cache_size:
                    try:
                        self._cache.pop(next(iter(self._cache)))
                    except Exception:
                        self._cache.clear()
                self._cache[key] = (valid, deg_sim, clus_sim, orb_sim, train_reward, total)

            except Exception:
                valid_scores.append(0.0)
                deg_scores.append(0.0)
                clus_scores.append(0.0)
                orb_scores.append(0.0)
                train_scores.append(0.0)
                total_scores.append(0.0)

        def _to(x: List[float]) -> torch.Tensor:
            return torch.tensor(x, dtype=torch.float32, device=self.device)

        return {
            "valid": _to(valid_scores),
            "deg": _to(deg_scores),
            "clus": _to(clus_scores),
            "orb": _to(orb_scores),
            "train": _to(train_scores),
            "total": _to(total_scores),
        }

    def __call__(self, graphs: List[Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
        return self.compute_components(graphs)["total"]

    @staticmethod
    def _as_1d_float_array(x) -> np.ndarray:
        arr = np.asarray(x, dtype=np.float64)
        if arr.ndim != 1:
            raise ValueError(f"expected 1D array, got shape={arr.shape}")
        return arr

    def _load_reference_stats_from_ref_metrics(self, datamodule) -> Dict[str, np.ndarray]:
        try:
            root = datamodule.train_dataloader().dataset.root
        except Exception:
             return {}

        ref_metrics_path = os.path.join(root, "ref_metrics.pkl")
        if hasattr(datamodule, "remove_h"):
            try:
                if bool(datamodule.remove_h):
                    ref_metrics_path = ref_metrics_path.replace(".pkl", "_no_h.pkl")
                else:
                    ref_metrics_path = ref_metrics_path.replace(".pkl", "_h.pkl")
            except Exception:
                pass

        if not os.path.exists(ref_metrics_path):
            return {}

        try:
            with open(ref_metrics_path, "rb") as f:
                payload = pickle.load(f)
        except Exception:
            return {}

        if not isinstance(payload, dict):
            return {}

        candidates: List[dict] = []
        if isinstance(payload.get("planar_reward_stats"), dict):
            candidates.append(payload["planar_reward_stats"])
        if isinstance(payload.get("train"), dict):
            candidates.append(payload["train"])
        candidates.append(payload)

        for cand in candidates:
            try:
                deg_raw = cand.get("ref_degree_dist", cand.get("degree_dist", cand.get("degree_hist")))
                clus_raw = cand.get("ref_clustering_hist", cand.get("clustering_hist"))
                orb_raw = cand.get("ref_orbit_mean", cand.get("orbit_mean"))

                if deg_raw is None or clus_raw is None:
                    continue

                if isinstance(deg_raw, (float, int)) or isinstance(clus_raw, (float, int)):
                     continue

                deg = self._as_1d_float_array(deg_raw)
                clus = self._as_1d_float_array(clus_raw)

                out: Dict[str, np.ndarray] = {
                    "ref_degree_dist": deg,
                    "ref_clustering_hist": clus,
                }
                if orb_raw is not None and not isinstance(orb_raw, (float, int)):
                     out["ref_orbit_mean"] = self._as_1d_float_array(orb_raw)

                return out
            except Exception:
                continue

        return {}

    def _compute_stats_from_datamodule(self, datamodule) -> Dict[str, np.ndarray]:
        try:
            from analysis.spectre_utils import degree_worker, clustering_worker, orca
        except ImportError:
             print("⚠️ [PlanarGraphReward] analysis.spectre_utils not found. Cannot compute stats.")
             return {}
        
        import networkx as nx
        import numpy as np

        def _to_nx_local(X_dense, E_dense):
            if E_dense.dim() == 3:
                adj = E_dense.argmax(dim=-1).float()
            else:
                adj = E_dense.float()
            adj_np = adj.cpu().numpy()
            G = nx.from_numpy_array(adj_np)
            G.remove_edges_from(nx.selfloop_edges(G))
            return G

        print("   [Compute] Accessing training data loader...")
        try:
             loader = datamodule.train_dataloader()
        except:
             return {}

        graphs_list = []
        from utils import to_dense
        
        for batch in loader:
             dense_data, node_mask = to_dense(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
             X, E = dense_data.X, dense_data.E
             B = X.size(0)
             for i in range(B):
                 mask_i = node_mask[i]
                 valid_nodes = mask_i.sum().item()
                 if E.dim() == 4:
                     E_sub = E[i, :valid_nodes, :valid_nodes, :]
                 else:
                     E_sub = E[i, :valid_nodes, :valid_nodes]
                 
                 G = _to_nx_local(None, E_sub)
                 if G.number_of_nodes() > 0:
                     graphs_list.append(G)

        print(f"   [Compute] Collected {len(graphs_list)} valid graphs. Calculating statistics...")
        
        out = {}
        deg_hists = [degree_worker(G) for G in graphs_list]
        max_len = max([len(h) for h in deg_hists] + [1])
        deg_sum = np.zeros(max_len)
        for h in deg_hists:
            deg_sum[:len(h)] += h
        if deg_sum.sum() > 0:
            out["ref_degree_dist"] = deg_sum / deg_sum.sum()
        else:
             out["ref_degree_dist"] = np.zeros(1)

        clus_hists = [clustering_worker((G, self.bins)) for G in graphs_list]
        if clus_hists:
            clus_sum = np.sum(clus_hists, axis=0)
            if clus_sum.sum() > 0:
                out["ref_clustering_hist"] = clus_sum / clus_sum.sum()
            else:
                 out["ref_clustering_hist"] = np.zeros(self.bins)
        else:
             out["ref_clustering_hist"] = np.zeros(self.bins)

        if self.use_orb:
            orb_vecs = []
            for G in graphs_list:
                try:
                    cnts = orca(G)
                    if G.number_of_nodes() > 0:
                         vec = cnts.sum(axis=0) / G.number_of_nodes()
                         orb_vecs.append(vec)
                except:
                     pass
            if orb_vecs:
                out["ref_orbit_mean"] = np.mean(orb_vecs, axis=0)

        print(f"   [Compute] Done. ref_degree_dist len={len(out.get('ref_degree_dist', []))}")
        return out


class SBMGraphReward(PlanarGraphReward):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.name = "sbm_graph"

    def compute_components(self, graphs: List[Tuple[torch.Tensor, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        valid_scores: List[float] = []
        deg_scores: List[float] = []
        clus_scores: List[float] = []
        orb_scores: List[float] = []
        train_scores: List[float] = []
        total_scores: List[float] = []

        for atom_types, edge_types in graphs:
            try:
                key = self._graph_hash(edge_types)
                cached = self._cache.get(key)
                if cached is not None:
                    v, d, c, o, t, r = cached
                    valid_scores.append(v)
                    deg_scores.append(d)
                    clus_scores.append(c)
                    orb_scores.append(o)
                    train_scores.append(t)
                    total_scores.append(r)
                    continue

                G = self._convert_tensor_to_networkx_graph(atom_types, edge_types)
                if G.number_of_nodes() <= 0 or G.number_of_edges() <= 0:
                    valid = 0.0
                    conn_score = 0.0
                else:
                    try:
                        ccs = list(nx.connected_components(G))
                        max_cc_len = max(len(c) for c in ccs)
                        conn_score = float(max_cc_len) / float(G.number_of_nodes())
                    except:
                        conn_score = 0.0
                    
                    valid = 1.0 if conn_score >= 0.999 else 0.0

                deg_sim = 0.0
                clus_sim = 0.0
                orb_sim = 0.0

                base_reward = conn_score

                if valid > 0.0:
                    if self.w_deg > 0:
                        p_deg = self._degree_hist(G)
                        deg_sim = float(np.exp(-self.deg_scale * self._js_divergence(p_deg, self.ref_degree_dist)))

                    if self.w_clus > 0:
                        p_clus = self._clustering_hist(G)
                        clus_sim = float(np.exp(-self.clus_scale * self._js_divergence(p_clus, self.ref_clustering_hist)))

                    if self.use_orb and self.w_orb > 0:
                        vec = self._orbit_vec(G)
                        if vec is not None and self.ref_orbit_mean is not None:
                            dist = float(np.mean(np.abs(vec - self.ref_orbit_mean)))
                            orb_sim = float(np.exp(-self.orb_scale * dist))

                train_reward = float(self.w_deg * deg_sim + self.w_clus * clus_sim + self.w_orb * orb_sim)
                
                total = float(base_reward + valid * train_reward)

                valid_scores.append(conn_score)                                                         
                deg_scores.append(deg_sim)
                clus_scores.append(clus_sim)
                orb_scores.append(orb_sim)
                train_scores.append(train_reward)
                total_scores.append(total)

                if len(self._cache) >= self._cache_size:
                    try: self._cache.pop(next(iter(self._cache)))
                    except: self._cache.clear()
                self._cache[key] = (conn_score, deg_sim, clus_sim, orb_sim, train_reward, total)

            except Exception:
                valid_scores.append(0.0)
                deg_scores.append(0.0)
                clus_scores.append(0.0)
                orb_scores.append(0.0)
                train_scores.append(0.0)
                total_scores.append(0.0)

        def _to(x: List[float]) -> torch.Tensor:
            return torch.tensor(x, dtype=torch.float32, device=self.device)

        return {
            "valid": _to(valid_scores),
            "deg": _to(deg_scores),
            "clus": _to(clus_scores),
            "orb": _to(orb_scores),
            "train": _to(train_scores),
            "total": _to(total_scores),
        }


class TreeGraphReward(PlanarGraphReward):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.name = "tree_graph"

    def compute_components(self, graphs: List[Tuple[torch.Tensor, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        valid_scores: List[float] = []
        deg_scores: List[float] = []
        clus_scores: List[float] = []
        orb_scores: List[float] = []
        train_scores: List[float] = []
        total_scores: List[float] = []

        for atom_types, edge_types in graphs:
            try:
                key = self._graph_hash(edge_types)
                cached = self._cache.get(key)
                if cached is not None:
                    v, d, c, o, t, r = cached
                    valid_scores.append(v)
                    deg_scores.append(d)
                    clus_scores.append(c)
                    orb_scores.append(o)
                    train_scores.append(t)
                    total_scores.append(r)
                    continue

                G = self._convert_tensor_to_networkx_graph(atom_types, edge_types)
                if G.number_of_nodes() <= 0:
                    conn_score = 0.0
                    is_tree = False
                    valid = 0.0
                else:
                    try:
                        ccs = list(nx.connected_components(G))
                        max_cc_len = max(len(c) for c in ccs)
                        conn_score = float(max_cc_len) / float(G.number_of_nodes())
                    except:
                        conn_score = 0.0
                    
                    try:
                        is_tree = nx.is_tree(G)
                    except:
                        is_tree = False
                    
                    valid = 1.0 if is_tree else 0.0

                deg_sim = 0.0
                clus_sim = 0.0
                orb_sim = 0.0

                base_reward = conn_score if is_tree else 0.0

                if valid > 0.0:
                    if self.w_deg > 0:
                        p_deg = self._degree_hist(G)
                        deg_sim = float(np.exp(-self.deg_scale * self._js_divergence(p_deg, self.ref_degree_dist)))

                    if self.w_clus > 0:
                        p_clus = self._clustering_hist(G)
                        clus_sim = float(np.exp(-self.clus_scale * self._js_divergence(p_clus, self.ref_clustering_hist)))

                    if self.use_orb and self.w_orb > 0:
                        vec = self._orbit_vec(G)
                        if vec is not None and self.ref_orbit_mean is not None:
                            dist = float(np.mean(np.abs(vec - self.ref_orbit_mean)))
                            orb_sim = float(np.exp(-self.orb_scale * dist))

                train_reward = float(self.w_deg * deg_sim + self.w_clus * clus_sim + self.w_orb * orb_sim)
                total = float(base_reward + valid * train_reward)

                valid_scores.append(float(valid))
                deg_scores.append(deg_sim)
                clus_scores.append(clus_sim)
                orb_scores.append(orb_sim)
                train_scores.append(train_reward)
                total_scores.append(total)

                if len(self._cache) >= self._cache_size:
                    try: self._cache.pop(next(iter(self._cache)))
                    except: self._cache.clear()
                self._cache[key] = (float(valid), deg_sim, clus_sim, orb_sim, train_reward, total)

            except Exception:
                valid_scores.append(0.0)
                deg_scores.append(0.0)
                clus_scores.append(0.0)
                orb_scores.append(0.0)
                train_scores.append(0.0)
                total_scores.append(0.0)

        def _to(x: List[float]) -> torch.Tensor:
            return torch.tensor(x, dtype=torch.float32, device=self.device)

        return {
            "valid": _to(valid_scores),
            "deg": _to(deg_scores),
            "clus": _to(clus_scores),
            "orb": _to(orb_scores),
            "train": _to(train_scores),
            "total": _to(total_scores),
        }
         

class MolecularValidityReward(BaseRewardFunction):

    _DEFAULT_ATOM_DECODER = ["C", "N", "O", "F", "B", "Br", "Cl", "I", "P", "S", "Se", "Si"]
    
    _DEFAULT_TARGET_NODE_DIST = {
        0: 0.74, 1: 0.11, 2: 0.11, 3: 0.014,
        4: 0.0, 5: 0.002, 6: 0.008, 7: 0.0, 8: 0.001, 9: 0.015, 10: 0.0, 11: 0.0
    }
    
    _DEFAULT_TARGET_EDGE_DIST = {
        0: 0.925, 1: 0.036, 2: 0.005, 3: 0.0002, 4: 0.033
    }

    def __init__(
        self,
        atom_decoder: Optional[List[str]] = None,
        device: Optional[torch.device] = None,
        target_node_dist: Optional[Dict[int, float]] = None,
        target_edge_dist: Optional[Dict[int, float]] = None,
        dist_coef: float = 0.0,
        scale_factor: float = 10.0,
        clip_range: float = 2.0,
        edge_dist_factor: float = 1.0,
        precomputed_node_weights: Optional[Dict[int, float]] = None,
        precomputed_edge_weights: Optional[Dict[int, float]] = None,
        conformer_weight: float = 0.5,
        conformer_num: int = 5,
        conformer_eref: float = 1.0,
        conformer_deref: float = 5.0,
        conformer_s1: float = 0.5,
        conformer_s2: float = 2.0,
    ):
        super().__init__("molecular_validity", device=device)

        if build_molecule is None:
            raise ImportError("analysis.rdkit_functions [anonymized]。")

        self.atom_decoder = atom_decoder or self._DEFAULT_ATOM_DECODER
        
        node_dist_in = target_node_dist if target_node_dist is not None else self._DEFAULT_TARGET_NODE_DIST
        edge_dist_in = target_edge_dist if target_edge_dist is not None else self._DEFAULT_TARGET_EDGE_DIST
        self.target_node_dist = self._to_distribution_dict(node_dist_in)
        self.target_edge_dist = self._to_distribution_dict(edge_dist_in)

        self.dist_coef = dist_coef
        self.scale_factor = scale_factor
        self.clip_range = clip_range
        self.edge_dist_factor = edge_dist_factor
        
        self.precomputed_node_weights = self._sanitize_weight_dict(precomputed_node_weights)
        self.precomputed_edge_weights = self._sanitize_weight_dict(precomputed_edge_weights)
        
        self.conformer_weight = conformer_weight
        self.conformer_num = conformer_num
        self.conformer_eref = conformer_eref
        self.conformer_deref = conformer_deref
        self.conformer_s1 = conformer_s1
        self.conformer_s2 = conformer_s2

    def __call__(self, graphs: List[Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
        batch_size = len(graphs)
        t_start = time.time()

        t0 = time.time()
        batch_indices, mols, invalid_mask = self._preprocess_batch(graphs)
        t_pre = time.time() - t0

        t0 = time.time()
        base_rewards, qed_scores, sa_scores = self._compute_chem_metrics(mols)
        for i, r in enumerate(base_rewards):
            if r < 0:
                invalid_mask[i] = True
        t_chem = time.time() - t0

        t0 = time.time()
        conf_scores = self._compute_batch_conformer_scores(mols, invalid_mask)
        t_conf = time.time() - t0

        t0 = time.time()
        dist_scores = self._compute_batch_distribution_scores(batch_indices, invalid_mask)
        t_dist = time.time() - t0

        t0 = time.time()
        final_rewards = self._aggregate_final_rewards(
            base_rewards, dist_scores, conf_scores, invalid_mask,
            debug_info=(qed_scores, sa_scores)
        )
        t_aggr = time.time() - t0

        total = time.time() - t_start
        eps = 1e-12
        if total > 0:
            def pct(v):
                return 100.0 * v / max(total, eps)

            print(
                f"[Profile] Batch={batch_size} | Total={total:.3f}s\n"
                f"   >> Preprocess: {t_pre:.3f}s ({pct(t_pre):.1f}%)\n"
                f"   >> Chem Metrics: {t_chem:.3f}s ({pct(t_chem):.1f}%)\n"
                f"   >> Conformer: {t_conf:.3f}s ({pct(t_conf):.1f}%)\n"
                f"   >> Dist Scores: {t_dist:.3f}s ({pct(t_dist):.1f}%)\n"
                f"   >> Aggregate: {t_aggr:.3f}s ({pct(t_aggr):.1f}%)"
            )

        return torch.tensor(final_rewards, dtype=torch.float32, device=self.device)


    def _preprocess_batch(self, graphs):
        batch_indices = []
        mols = []
        invalid_mask = []
        for atom_types, edge_types in graphs:
            idx_pair = self._extract_graph_indices(atom_types, edge_types)
            batch_indices.append(idx_pair)
            mol = self._graph_to_mol(atom_types, edge_types)
            mols.append(mol)
            invalid_mask.append(mol is None)
        return batch_indices, mols, invalid_mask

    def _compute_chem_metrics(self, mols: List[Optional["Chem.Mol"]]):
        base_rewards, qed_scores, sa_scores = [], [], []
        for mol in mols:
            r, qed, sa = self._compute_base_reward_single(mol)
            base_rewards.append(r)
            qed_scores.append(qed)
            sa_scores.append(sa)
        return base_rewards, qed_scores, sa_scores

    def _compute_batch_conformer_scores(self, mols, invalid_mask):
        scores = []
        if self.conformer_weight <= 0:
            return [(0.0, 0.0, 0.0)] * len(mols)

        t_start = time.time()
        timing_acc = {"add_h": 0.0, "embed": 0.0, "optimize": 0.0, "scoring": 0.0}
        valid_calls = 0

        for i, mol in enumerate(mols):
            if invalid_mask[i]:
                scores.append((0.0, 0.0, 0.0))
            else:
                score_tuple, timing = self._compute_conformer_stability(mol)
                scores.append(score_tuple)
                if timing is not None:
                    valid_calls += 1
                    for key in timing_acc:
                        timing_acc[key] += timing.get(key, 0.0)

        total = time.time() - t_start
        if total > 0 and valid_calls > 0:
            def pct(v):
                return 100.0 * v / max(total, 1e-12)

            print(
                f"[Conformer Profile] Batch={len(mols)} | Valid={valid_calls} | Total={total:.3f}s\n"
                f"   >> AddHs: {timing_acc['add_h']:.3f}s ({pct(timing_acc['add_h']):.1f}%)\n"
                f"   >> Embed: {timing_acc['embed']:.3f}s ({pct(timing_acc['embed']):.1f}%)\n"
                f"   >> Optimize: {timing_acc['optimize']:.3f}s ({pct(timing_acc['optimize']):.1f}%)\n"
                f"   >> Scoring: {timing_acc['scoring']:.3f}s ({pct(timing_acc['scoring']):.1f}%)"
            )

        return scores

    def _compute_batch_distribution_scores(self, batch_indices, invalid_mask):
        if self.precomputed_node_weights and self.precomputed_edge_weights:
            node_weights = self.precomputed_node_weights
            edge_weights = self.precomputed_edge_weights
        else:
            node_weights, edge_weights = self._calculate_dynamic_weights(batch_indices)

        dist_scores = []
        for i, (atom_indices, edge_indices_flat) in enumerate(batch_indices):
            if invalid_mask[i]:
                dist_scores.append((0.0, 0.0))
                continue

            n_score = 0.0
            n_nodes = len(atom_indices)
            if n_nodes > 0:
                n_score = sum(node_weights.get(int(idx), -self.clip_range) for idx in atom_indices) / n_nodes
            
            e_score = 0.0
            if len(edge_indices_flat) > 0:
                norm_edges = max(1, n_nodes)
                edge_sum = sum(edge_weights.get(int(idx), -self.clip_range) for idx in edge_indices_flat)
                e_score = edge_sum / norm_edges
            
            dist_scores.append((n_score, e_score))
        return dist_scores

    def _calculate_dynamic_weights(self, batch_indices):
        nc, ec = Counter(), Counter()
        tn, te = 0, 0
        for atom_indices, edge_indices_flat in batch_indices:
            nc.update(int(i) for i in atom_indices)
            ec.update(int(i) for i in edge_indices_flat)
            tn += len(atom_indices)
            te += len(edge_indices_flat)
        return self._compute_weights_from_counts(nc, ec, tn, te)

    def _aggregate_final_rewards(self, base_rewards, dist_scores, conf_scores, invalid_mask, debug_info):
        final_rewards = []
        qed_s, sa_s = debug_info
        
        for i in range(len(base_rewards)):
            base_r = base_rewards[i]
            
            if invalid_mask[i]:
                final_rewards.append(base_r)
                continue

            n_score, e_score = dist_scores[i]
            c_score, _, _ = conf_scores[i]

            dist_term = self.dist_coef * (n_score + self.edge_dist_factor * e_score)
            conf_term = self.conformer_weight * c_score
            
            total_r = base_r + dist_term + conf_term
            
            print(
                f"[Reward] idx={i}, Base={base_r:.2f}, Qed={qed_s[i]:.2f}, Sa={sa_s[i]:.2f}, "
                f"D_n={n_score:.2f}, D_e={e_score:.2f}, C={c_score:.2f}, Tot={total_r:.2f}"
            )
            final_rewards.append(total_r)
            
        return final_rewards


    def _graph_to_mol(self, atom_types, edge_types):
        if build_molecule is None: return None
        try:
            at = torch.as_tensor(atom_types).long().cpu()
            et = torch.as_tensor(edge_types).long().cpu()
            if at.dim() == 2: at = at.argmax(dim=-1)
            if et.dim() == 3: et = et.argmax(dim=-1)
            if at.numel() == 0: return None
            mol = build_molecule(at, et, self.atom_decoder)
            return mol if (mol and mol.GetNumAtoms() > 0) else None
        except: return None

    @staticmethod
    def _compute_base_reward_single(mol) -> Tuple[float, float, float]:
        MIN_R, MAX_R = -1.0, 1.0
        
        if mol is None: return MIN_R, 0.0, 0.0
        
        try:
            if len(Chem.GetMolFrags(mol)) > 1: return MIN_R, 0.0, 0.0
        except: return MIN_R, 0.0, 0.0

        has_hetero = any(a.GetAtomicNum() not in (6, 1) for a in mol.GetAtoms())
        if not has_hetero: return -0.2, 0.0, 0.0 

        try: Chem.SanitizeMol(mol)
        except: return -0.5, 0.0, 0.0

        try: qed = float(QED.qed(mol))
        except: qed = 0.0
        
        global _SA_FALLBACK_WARNED
        raw_sa = 10.0
        if sascorer:
            try: raw_sa = float(sascorer.calculateScore(mol))
            except: pass
        sa_norm = float(np.clip(1.0 - (raw_sa - 1.0) / 9.0, 0.0, 1.0))

        final_reward = 0.6 * qed + 0.4 * sa_norm
        return float(np.clip(final_reward, MIN_R, MAX_R)), qed, sa_norm

    def _compute_conformer_stability(self, mol) -> Tuple[Tuple[float, float, float], Optional[Dict[str, float]]]:
        if self.conformer_weight <= 0 or mol is None:
            return (0.0, 0.0, 0.0), None

        try:
            from rdkit.Chem import AllChem
        except ImportError:
            print("[Debug] AllChem not available")
            return (0.0, 0.0, 0.0), None

        timing = {"add_h": 0.0, "embed": 0.0, "optimize": 0.0, "scoring": 0.0}

        try:
            t0 = time.time()
            mol_h = Chem.AddHs(Chem.Mol(mol), addCoords=True)
            timing["add_h"] += time.time() - t0

            params = AllChem.ETKDGv3()
            params.randomSeed = 42
            params.useRandomCoords = True
            params.maxIterations = 100 
            params.useSmallRingTorsions = False
            
            t0 = time.time()
            cids = AllChem.EmbedMultipleConfs(mol_h, numConfs=self.conformer_num, params=params)
            timing["embed"] += time.time() - t0

            if not cids:
                return (0.0, 0.0, 0.0), timing

            t0 = time.time()
            try:
                if AllChem.MMFFHasAllMoleculeParams(mol_h):
                    res = AllChem.MMFFOptimizeMoleculeConfs(mol_h, numThreads=1, maxIters=500)
                else:
                    res = AllChem.UFFOptimizeMoleculeConfs(mol_h, numThreads=1, maxIters=500)
            except Exception:
                timing["optimize"] += time.time() - t0
                return (0.0, 0.0, 0.0), timing
            timing["optimize"] += time.time() - t0

            t0 = time.time()
            energies = [float(r[1]) for r in res if r[0] == 0]

            if not energies:
                return (0.0, 0.0, 0.0), timing

            E_min, E_max = min(energies), max(energies)
            n_heavy = mol_h.GetNumHeavyAtoms() or 1

            S_energy = math.exp(-self.conformer_s1 * max(0.0, (E_min/n_heavy) - self.conformer_eref))
            S_range = math.exp(-self.conformer_s2 * max(0.0, (E_max - E_min) - self.conformer_deref))
            timing["scoring"] += time.time() - t0
            
            return (0.7 * S_energy + 0.3 * S_range, S_energy, S_range), timing

        except Exception:
            return (0.0, 0.0, 0.0), timing
        
    @staticmethod
    def _extract_graph_indices(atom_types, edge_types):
        if torch.is_tensor(atom_types):
            if atom_types.dim() == 2: atom_types = atom_types.argmax(-1)
            a_idx = atom_types.detach().cpu().numpy()
        else: a_idx = np.array(atom_types)

        if torch.is_tensor(edge_types):
            if edge_types.dim() == 3: edge_types = edge_types.argmax(-1)
            e_idx = edge_types.detach().cpu().numpy()
        else: e_idx = np.array(edge_types)
        return a_idx, e_idx.flatten()

    @staticmethod
    def _to_distribution_dict(d):
        if d is None:
            return {}
        if torch.is_tensor(d):
            if d.numel() == 0:
                return {}
            d = d.detach().cpu().numpy()
        if isinstance(d, (list, tuple, np.ndarray)):
            if len(d) == 0:
                return {}
            d = {i: float(v) for i, v in enumerate(d)}
        if isinstance(d, dict):
            if len(d) == 0:
                return {}
        total = sum(d.values())
        return {int(k): v/total for k, v in d.items()} if total > 0 else {}

    @staticmethod
    def _sanitize_weight_dict(d):
        return {int(k): float(v) for k, v in d.items()} if d else None

    @staticmethod
    def compute_distribution_weights(
        graphs,
        target_node_dist=None,
        target_edge_dist=None,
        scale_factor=10.0,
        clip_range=2.0,
    ):
        from collections import Counter

        eps = 1e-6
        tnd = MolecularValidityReward._to_distribution_dict(
            target_node_dist if target_node_dist is not None else MolecularValidityReward._DEFAULT_TARGET_NODE_DIST
        )
        ted = MolecularValidityReward._to_distribution_dict(
            target_edge_dist if target_edge_dist is not None else MolecularValidityReward._DEFAULT_TARGET_EDGE_DIST
        )

        nc, ec = Counter(), Counter()
        tn = te = 0
        for atom_types, edge_types in graphs:
            at_idx, ed_idx = MolecularValidityReward._extract_graph_indices(atom_types, edge_types)
            nc.update(int(i) for i in at_idx)
            ec.update(int(i) for i in ed_idx)
            tn += len(at_idx)
            te += len(ed_idx)

        def calc_w(counts, total, target_dist):
            w = {}
            for idx in set(target_dist) | set(counts):
                p_tgt = target_dist.get(idx, 0.0)
                p_batch = counts.get(idx, 0) / max(1, total)
                val = np.log(p_tgt + eps) - np.log(p_batch + eps)
                w[idx] = float(np.clip(val * scale_factor, -clip_range, clip_range))
            return w

        return calc_w(nc, tn, tnd), calc_w(ec, te, ted)
    
    def _compute_weights_from_counts(self, nc, ec, tn, te):
        eps = 1e-6
        sf, cr = self.scale_factor, self.clip_range
        tnd, ted = self.target_node_dist, self.target_edge_dist
        
        def calc_w(counts, total, target_dist):
            w = {}
            for idx in set(target_dist) | set(counts):
                p_tgt = target_dist.get(idx, 0.0)
                p_batch = counts.get(idx, 0) / max(1, total)
                val = np.log(p_tgt + eps) - np.log(p_batch + eps)
                w[idx] = float(np.clip(val * sf, -cr, cr))
            return w
        return calc_w(nc, tn, tnd), calc_w(ec, te, ted)

class TargetMPOReward(BaseRewardFunction):
    _DEFAULT_ATOM_DECODER = ["C", "N", "O", "F", "B", "Br", "Cl", "I", "P", "S", "Se", "Si"]

    def __init__(self, target_task: str = "penalized_logp", atom_decoder: Optional[List[str]] = None, device: Optional[torch.device] = None):
        print(f"⚠️ [WARNING] TargetMPOReward is DEPRECATED. Please switch to specific reward classes like ValsartanSmartsReward.")
        super().__init__("target_mpo", device=device)
        self.atom_decoder = atom_decoder or self._DEFAULT_ATOM_DECODER
        self.target_task = target_task.lower()
        
        try:
            from guacamol import standard_benchmarks as _standard_benchmarks              
        except ImportError:
            raise ImportError("[anonymized] Guacamol [anonymized]。[anonymized] guacamol [anonymized] reward。")
        
        self.objective = None
        self._init_task()

    def _init_task(self):
        if self.target_task == "penalized_logp":
            self.score_metric = self._score_penalized_logp
            
        elif self.target_task == "aripiprazole_similarity":
            from guacamol.standard_benchmarks import aripiprazole_similarity
            benchmark = aripiprazole_similarity()
            self.objective = benchmark.objective
            self.score_metric = self._score_guacamol_objective
            
        elif self.target_task == "qed":
            from guacamol.standard_benchmarks import qed_benchmark
            benchmark = qed_benchmark()
            self.objective = benchmark.objective
            self.score_metric = self._score_guacamol_objective
            
        elif self.target_task == "osimertinib_mpo":
            from guacamol.standard_benchmarks import hard_osimertinib
            benchmark = hard_osimertinib()
            self.objective = benchmark.objective
            self.score_metric = self._score_guacamol_objective
            
        elif self.target_task == "fexofenadine_mpo":
            from guacamol.standard_benchmarks import hard_fexofenadine
            benchmark = hard_fexofenadine()
            self.objective = benchmark.objective
            self.score_metric = self._score_guacamol_objective

        elif self.target_task == "ranolazine_mpo":
            from guacamol.standard_benchmarks import ranolazine_mpo
            benchmark = ranolazine_mpo()
            self.objective = benchmark.objective
            self.score_metric = self._score_guacamol_objective

        elif self.target_task == "perindopril_mpo":
            from guacamol.standard_benchmarks import perindopril_rings
            benchmark = perindopril_rings()
            self.objective = benchmark.objective
            self.score_metric = self._score_guacamol_objective

        elif self.target_task == "amlodipine_mpo":
            from guacamol.standard_benchmarks import amlodipine_rings
            benchmark = amlodipine_rings()
            self.objective = benchmark.objective
            self.score_metric = self._score_guacamol_objective

        elif self.target_task == "sitagliptin_mpo":
            from guacamol.standard_benchmarks import sitagliptin_replacement
            benchmark = sitagliptin_replacement()
            self.objective = benchmark.objective
            self.score_metric = self._score_guacamol_objective

        elif self.target_task == "zaleplon_mpo":
            from guacamol.standard_benchmarks import zaleplon_with_other_formula
            benchmark = zaleplon_with_other_formula()
            self.objective = benchmark.objective
            self.score_metric = self._score_guacamol_objective
            
        else:
            raise ValueError(
                f"[anonymized] Target [anonymized]: {self.target_task}。[anonymized]: penalized_logp, aripiprazole_similarity, qed, "
                "osimertinib_mpo, fexofenadine_mpo, ranolazine_mpo, perindopril_mpo, amlodipine_mpo, sitagliptin_mpo, zaleplon_mpo"
            )

    def __call__(self, graphs: List[Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
        rewards = []
        for atom_types, edge_types in graphs:
            mol = self._graph_to_mol(atom_types, edge_types)
            
            if mol is None:
                rewards.append(0.0)
                continue
            
            is_valid = False
            try:
                Chem.SanitizeMol(mol)
                is_valid = True
            except:
                is_valid = False
                
            try:
                if is_valid:
                    frags = Chem.GetMolFrags(mol, asMols=True)
                    if len(frags) > 1:
                        largest_mol = max(frags, key=lambda m: m.GetNumAtoms())
                        try:
                             base_score = self.score_metric(largest_mol)
                        except:
                             base_score = 0.0
                        
                        rewards.append(float(base_score))
                    else:
                        score = self.score_metric(mol)
                        rewards.append(float(score))
                else:
                    rewards.append(0.0)
            except Exception:
                rewards.append(0.0)
                
        return torch.tensor(rewards, dtype=torch.float32, device=self.device)

    def _score_guacamol_objective(self, mol) -> float:
        smi = Chem.MolToSmiles(mol)
        if not smi:
            return 0.0
        return self.objective.score(smi)

    def _score_penalized_logp(self, mol) -> float:
        try:
            logp = rdMolDescriptors.CalcCrippenDescriptors(mol)[0]
        except:
            return -10.0          
        
        sa = 10.0
        if sascorer:
            try:
                sa = sascorer.calculateScore(mol)
            except:
                pass
        
        cycle_score = 0.0
        try:
             cycle_list = mol.GetRingInfo().AtomRings()
             for ring in cycle_list:
                 if len(ring) > 6:
                     cycle_score += 1.0
        except:
            pass
            
        return logp - sa - cycle_score

    def _graph_to_mol(self, atom_types, edge_types):
        if build_molecule is None: return None
        try:
            at = torch.as_tensor(atom_types).long().cpu()
            et = torch.as_tensor(edge_types).long().cpu()
            if at.dim() == 2: at = at.argmax(dim=-1)
            if et.dim() == 3: et = et.argmax(dim=-1)
            if at.numel() == 0: return None
            mol = build_molecule(at, et, self.atom_decoder)
            return mol if (mol and mol.GetNumAtoms() > 0) else None
        except: return None


class TDCOracleReward(BaseRewardFunction):

    def __init__(
        self,
        oracle_names: Union[str, List[str]],
        atom_decoder: Optional[List[str]] = None,
        aggregation: str = "mean",
        weights: Optional[List[float]] = None,
        minimize: bool = False,
        invalid_score: float = 0.0,
        clip_min: Optional[float] = None,
        clip_max: Optional[float] = None,
        tdc_home: Optional[str] = None,
        device: Optional[torch.device] = None,
    ):
        super().__init__("tdc_oracle", device=device)
        self.atom_decoder = atom_decoder
        self.oracle_names = [oracle_names] if isinstance(oracle_names, str) else list(oracle_names)
        if not self.oracle_names:
            raise ValueError("TDCOracleReward: oracle_names [anonymized]")

        self.aggregation = (aggregation or "mean").lower()
        self.weights = weights
        self.minimize = bool(minimize)
        self.invalid_score = float(invalid_score)
        self.clip_min = clip_min
        self.clip_max = clip_max
        self._tdc_home = tdc_home

        self._maybe_configure_tdc_home(self._tdc_home)

        try:
            from tdc import Oracle         
        except ImportError as e:
            raise ImportError(
                "[anonymized] TDC (PyTDC)。[anonymized]：`pip install PyTDC`，[anonymized] import `tdc`。"
            ) from e

        self._oracles = [Oracle(name=name) for name in self.oracle_names]

    @staticmethod
    def _project_root() -> Optional[Path]:
        try:
            return Path(__file__).resolve().parents[1]
        except Exception:
            return None

    @classmethod
    def _resolve_tdc_home(cls, tdc_home: Optional[str]) -> Optional[Path]:
        if tdc_home is None:
            tdc_home = os.environ.get("TDC_HOME") or None

        if tdc_home is not None:
            raw = Path(os.path.expanduser(str(tdc_home)))
            if raw.is_absolute():
                candidate = raw
            else:
                root = cls._project_root()
                candidate = (root / raw) if root is not None else raw

            if candidate.is_dir() and candidate.name == "oracle":
                return candidate.parent
            return candidate

        root = cls._project_root()
        if root is not None and (root / "oracle").is_dir():
            return root

        return None

    @staticmethod
    def _safe_link_or_copy(src: Path, dst: Path) -> None:
        if dst.exists():
            return
        try:
            dst.parent.mkdir(parents=True, exist_ok=True)
        except Exception:
            pass

        try:
            dst.symlink_to(src.name)
            return
        except FileExistsError:
            return
        except Exception:
            pass

        try:
            os.link(str(src), str(dst))
            return
        except FileExistsError:
            return
        except Exception:
            pass

        try:
            tmp = dst.with_suffix(dst.suffix + ".tmp")
            shutil.copy2(str(src), str(tmp))
            os.replace(str(tmp), str(dst))
        except FileExistsError:
            return
        except Exception:
            try:
                if tmp.exists():
                    tmp.unlink()
            except Exception:
                pass

    @classmethod
    def _is_html_corrupt(cls, p: Path) -> bool:
        if not p.is_file():
            return False
        try:
            with open(p, "rb") as f:
                head = f.read(16)
                return len(head) > 0 and head.startswith(b"<")
        except Exception:
            return False

    @classmethod
    def _ensure_oracle_pkls_present(cls, tdc_home: Path, oracle_names: List[str]) -> None:
        oracle_dir = tdc_home / "oracle"
        try:
            oracle_dir.mkdir(parents=True, exist_ok=True)
        except Exception:
            return

        for name in oracle_names:
            expected = oracle_dir / f"{name}.pkl"
            
            aliases = [f"{name}_current.pkl", f"{name}_latest.pkl"]
            
            all_files = [expected] + [oracle_dir / a for a in aliases]
            for p in all_files:
                if p.exists() and cls._is_html_corrupt(p):
                    corrupt_path = p.with_suffix(".pkl.corrupt")
                    print(f"⚠️ [TDC] Found corrupt HTML file at {p}, renaming to {corrupt_path.name}")
                    try:
                        if corrupt_path.exists():
                            corrupt_path.unlink()
                        p.rename(corrupt_path)
                    except Exception as e:
                        print(f"❌ [TDC] Failed to rename corrupt file: {e}")

            valid_existing = [p for p in all_files if p.is_file() and not cls._is_html_corrupt(p)]
            
            if valid_existing:
                chosen = valid_existing[0]
                for target in all_files:
                    if not target.exists():
                        cls._safe_link_or_copy(chosen, target)
                continue

            candidates = sorted(oracle_dir.glob(f"{name}*.pkl"))
            valid_candidates = [p for p in candidates if p.is_file() and not cls._is_html_corrupt(p)]
            
            if not valid_candidates:
                continue

            chosen = max(valid_candidates, key=lambda p: p.stat().st_mtime)
            if os.environ.get("GRPO_DEBUG_TDC_HOME", "0") == "1":
                print(f"✅ [TDC] Restoring {expected.name} and aliases from {chosen.name}")
            
            for target in all_files:
                if not target.exists():
                    cls._safe_link_or_copy(chosen, target)

    def _maybe_configure_tdc_home(self, tdc_home: Optional[str]) -> None:
        resolved = self._resolve_tdc_home(tdc_home)
        if resolved is None:
            return

        os.environ["TDC_HOME"] = str(resolved)
        
        cwd_oracle = Path("oracle").resolve()
        real_oracle = resolved / "oracle"
        
        if real_oracle.is_dir() and not cwd_oracle.exists():
            try:
                cwd_oracle.symlink_to(real_oracle)
                if os.environ.get("GRPO_DEBUG_TDC_HOME", "0") == "1":
                    print(f"🔗 [TDC] Created symlink in CWD: ./oracle -> {real_oracle}")
            except Exception as e:
                if os.environ.get("GRPO_DEBUG_TDC_HOME", "0") == "1":
                    print(f"⚠️ [TDC] Failed to create CWD symlink: {e}")

        try:
            self._ensure_oracle_pkls_present(resolved, self.oracle_names)
        except Exception:
            pass

        if os.environ.get("GRPO_DEBUG_TDC_HOME", "").strip().lower() in ("1", "true", "yes", "y", "on"):
            oracle_dir = resolved / "oracle"
            expected = [oracle_dir / f"{name}.pkl" for name in self.oracle_names]
            details = []
            for p in expected:
                status = "EXISTS" if p.exists() else "MISSING"
                if p.exists():
                    is_html = self._is_html_corrupt(p)
                    status += " (HTML Corrupt!!)" if is_html else " (Valid pkl)"
                details.append(f"{p.name}: {status}")
            
            print(
                f"🔎 [TDC] TDC_HOME={resolved}\n"
                f"   oracle_dir={oracle_dir}\n"
                f"   details: {details}",
                file=sys.stderr,
            )

    def __call__(self, graphs: List[Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
        if len(graphs) == 0:
            return torch.tensor([], dtype=torch.float32, device=self.device)

        valid_indices: List[int] = []
        valid_smiles: List[str] = []

        for i, (atom_types, edge_types) in enumerate(graphs):
            mol = self._graph_to_mol(atom_types, edge_types)
            if mol is None:
                continue
            try:
                Chem.SanitizeMol(mol)
            except Exception:
                continue
            try:
                frags = Chem.GetMolFrags(mol, asMols=True)
                if frags and len(frags) > 1:
                    mol = max(frags, key=lambda m: int(m.GetNumAtoms()))
            except Exception:
                pass
            try:
                smi = Chem.MolToSmiles(mol)
            except Exception:
                smi = None
            if not smi:
                continue
            valid_indices.append(i)
            valid_smiles.append(smi)

        out = np.full((len(graphs),), self.invalid_score, dtype=np.float32)
        if not valid_smiles:
            return torch.tensor(out, dtype=torch.float32, device=self.device)

        scores = np.zeros((len(valid_smiles), len(self._oracles)), dtype=np.float32)
        for j, oracle in enumerate(self._oracles):
            s = None
            try:
                s = oracle(valid_smiles)
            except Exception:
                s = None

            try:
                s_arr = np.asarray(s)
            except Exception:
                s_arr = np.asarray([])

            needs_fallback = (
                s_arr.ndim == 0
                or (s_arr.ndim >= 1 and int(s_arr.shape[0]) != int(len(valid_smiles)))
            )
            if needs_fallback:
                s_arr = np.asarray([oracle(smi) for smi in valid_smiles], dtype=np.float32)
            else:
                s_arr = s_arr.astype(np.float32, copy=False)

            scores[:, j] = s_arr.reshape(-1)

        if self.minimize:
            scores = -scores

        if self.clip_min is not None or self.clip_max is not None:
            scores = np.clip(
                scores,
                a_min=self.clip_min if self.clip_min is not None else -np.inf,
                a_max=self.clip_max if self.clip_max is not None else np.inf,
            )

        if self.weights is not None:
            w = np.asarray(self.weights, dtype=np.float32)
            if w.shape[0] != scores.shape[1]:
                raise ValueError(
                    f"TDCOracleReward: weights [anonymized]({w.shape[0]})[anonymized] oracles [anonymized]({scores.shape[1]})[anonymized]"
                )
        else:
            w = np.ones((scores.shape[1],), dtype=np.float32)

        agg = self.aggregation
        if agg in ("mean", "avg", "average"):
            agg_scores = (scores * w[None, :]).sum(axis=1) / max(float(w.sum()), 1e-12)
        elif agg in ("sum",):
            agg_scores = (scores * w[None, :]).sum(axis=1)
        elif agg in ("min",):
            agg_scores = scores.min(axis=1)
        elif agg in ("max",):
            agg_scores = scores.max(axis=1)
        elif agg in ("geometric_mean", "gmean", "geo"):
            eps = 1e-12
            safe = np.clip(scores, a_min=0.0, a_max=None) + eps
            agg_scores = np.exp((np.log(safe) * w[None, :]).sum(axis=1) / max(float(w.sum()), eps))
        else:
            raise ValueError(f"TDCOracleReward: [anonymized] aggregation='{self.aggregation}'")

        for i, idx in enumerate(valid_indices):
            out[idx] = float(agg_scores[i])

        return torch.tensor(out, dtype=torch.float32, device=self.device)

    def _graph_to_mol(self, atom_types, edge_types):
        if build_molecule is None:
            return None
        try:
            at = torch.as_tensor(atom_types).long().cpu()
            et = torch.as_tensor(edge_types).long().cpu()
            if at.dim() == 2:
                at = at.argmax(dim=-1)
            if et.dim() == 3:
                et = et.argmax(dim=-1)
            if at.numel() == 0:
                return None
            mol = build_molecule(at, et, self.atom_decoder)
            return mol if (mol and mol.GetNumAtoms() > 0) else None
        except Exception:
            return None


class GDPODockingReward(BaseRewardFunction):

    _DEFAULT_ATOM_DECODER = ["C", "N", "O", "F", "B", "Br", "Cl", "I", "P", "S", "Se", "Si"]
    _FP_RADIUS = 2
    _FP_BITS = 1024

    @staticmethod
    def _project_root() -> Optional[Path]:
        try:
            return Path(__file__).resolve().parents[1]
        except Exception:
            return None

    @property
    def _DEFAULT_TRAIN_PT_PATH(self) -> Path:
        root = self._project_root() or Path.cwd()
        return root / "data" / "zinc" / "full" / "processed" / "train.pt"

    @property
    def _DEFAULT_FPS_CACHE_PATH(self) -> Path:
        root = self._project_root() or Path.cwd()
        return root / "data" / "zinc" / "full" / "processed" / "train.pt.fps.pkl"
    _WARNED_DECODER_MISMATCH = False

    def __init__(
        self,
        target_name: str,
        atom_decoder: Optional[List[str]] = None,
        device: Optional[torch.device] = None,
        sa_threshold: Optional[float] = None,
        sim_threshold: Optional[float] = None,
        dock_exhaustiveness: Optional[int] = None,
        dock_num_modes: Optional[int] = None,
        dock_timeout: Optional[int] = None,
        dataset_name: Optional[str] = None,
        datadir: Optional[str] = None,
        remove_h: Optional[bool] = None,
    ):
        super().__init__("gdpo_docking", device=device)
        if not target_name:
            raise ValueError("GDPODockingReward requires target_name")
        self.target_name = str(target_name)
        self.atom_decoder = atom_decoder or self._DEFAULT_ATOM_DECODER
        self.sa_threshold = float(sa_threshold) if sa_threshold is not None else (10.0 - 5.0) / 9.0
        if sim_threshold is None:
            sim_threshold = gdpo_get_sim_threshold(dataset_name or "")
        self.sim_threshold = float(sim_threshold)
        self.dataset_name = str(dataset_name) if dataset_name else None
        self.datadir = str(datadir) if datadir else None
        self.remove_h = bool(remove_h) if remove_h is not None else True
        self.repo_root = self._project_root() or Path.cwd()
        self._dock_cache: Dict[str, float] = {}
        self.oracle = LeadOptOracle(
            target_name=self.target_name,
            exhaustiveness=dock_exhaustiveness,
            num_modes=dock_num_modes,
            dock_timeout=dock_timeout,
        )
        self.train_fps_cache_path = None
        self._train_fps = self._load_train_fps()

    def __call__(self, graphs: List[Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
        if len(graphs) == 0:
            return torch.tensor([], dtype=torch.float32, device=self.device)

        valid_indices: List[int] = []
        valid_smiles: List[str] = []
        components: List[Tuple[float, float, float, float]] = []

        for i, (atom_types, edge_types) in enumerate(graphs):
            mol = self._graph_to_mol(atom_types, edge_types)
            if mol is None:
                continue
            try:
                if hasattr(mol, "GetMol"):
                    mol = mol.GetMol()
                Chem.SanitizeMol(mol)
            except Exception:
                continue
            try:
                frags = Chem.GetMolFrags(mol, asMols=True)
                if frags and len(frags) > 1:
                    mol = max(frags, key=lambda m: int(m.GetNumAtoms()))
            except Exception:
                pass
            try:
                smi = Chem.MolToSmiles(mol)
            except Exception:
                smi = None
            if not smi:
                continue

            r_qed = 0.0
            try:
                r_qed = 1.0 if float(QED.qed(mol)) > 0.5 else 0.0
            except Exception:
                r_qed = 0.0

            r_sa = 0.0
            if sascorer is None:
                global _SA_FALLBACK_WARNED
                if not _SA_FALLBACK_WARNED:
                    print("⚠️ [GRPO] sascorer [anonymized]，SA [anonymized]。")
                    _SA_FALLBACK_WARNED = True
            else:
                try:
                    sa = float(sascorer.calculateScore(mol))
                    r_sa = (10.0 - sa) / 9.0
                except Exception:
                    r_sa = 0.0

            r_nov = 1.0
            max_sim = 0.0
            try:
                fp = rdMolDescriptors.GetMorganFingerprintAsBitVect(
                    mol, self._FP_RADIUS, nBits=self._FP_BITS
                )
                if self._train_fps:
                    sims = DataStructs.BulkTanimotoSimilarity(fp, self._train_fps)
                    max_sim = float(max(sims)) if sims else 0.0
                else:
                    max_sim = 0.0
                r_nov = 1.0 - max_sim
            except Exception:
                r_nov = 0.0

            valid_indices.append(i)
            valid_smiles.append(smi)
            components.append((r_qed, r_sa, r_nov, max_sim))

        out = np.zeros((len(graphs),), dtype=np.float32)
        if valid_smiles:
            to_dock_smiles = []
            to_dock_map = []
            
            for i, (r_qed, r_sa, r_nov, max_sim) in enumerate(components):
                if r_qed <= 0:
                    continue
                if sascorer is not None and r_sa < self.sa_threshold:
                    continue
                if self.sim_threshold is not None and max_sim >= self.sim_threshold:
                    continue
                smi = valid_smiles[i]
                if smi in self._dock_cache:
                    continue
                to_dock_smiles.append(smi)
                to_dock_map.append(i)
            
            energies = np.zeros(len(valid_smiles), dtype=np.float32)
            for i, smi in enumerate(valid_smiles):
                cached = self._dock_cache.get(smi)
                if cached is not None:
                    energies[i] = float(cached)
            if to_dock_smiles:
                actual_energies = self.oracle.score(to_dock_smiles)
                for i, energy in enumerate(actual_energies):
                    mapped_idx = to_dock_map[i]
                    energies[mapped_idx] = float(energy)
                    self._dock_cache[to_dock_smiles[i]] = float(energy)
            
            for (idx, (r_qed, r_sa, r_nov, _max_sim), energy) in zip(valid_indices, components, energies):
                r_ds = -1.0 * float(np.clip(energy, -20.0, 0.0)) / 20.0
                reward = 0.1 * (r_qed + r_sa) + 0.3 * r_nov + 0.5 * r_ds
                out[idx] = float(reward)

        return torch.tensor(out, dtype=torch.float32, device=self.device)

    def _load_train_fps(self) -> List:
        if self.dataset_name and self.datadir:
            try:
                return gdpo_load_train_fps(
                    dataset_name=self.dataset_name,
                    datadir=self.datadir,
                    remove_h=self.remove_h,
                    repo_root=self.repo_root,
                )
            except Exception as exc:
                print(f"⚠️ [GRPO] GDPO train fps load failed: {exc}. Falling back to train.pt.")

        path = self._DEFAULT_TRAIN_PT_PATH
        cache_path = self._resolve_cache_path(path)
        if cache_path is not None and cache_path.is_file():
            try:
                if cache_path.stat().st_size >= 16:
                    with open(cache_path, "rb") as handle:
                        cached = pickle.load(handle)
                    if isinstance(cached, list) and cached:
                        return cached
            except Exception:
                pass

        if not path.is_file():
            raise FileNotFoundError(f"train.pt not found: {path}")

        print(f"🔧 [GRPO] [anonymized] train.pt [anonymized]: {path}")
        fps = self._load_train_fps_from_processed(path)
        print(f"✅ [GRPO] [anonymized]，[anonymized]: {len(fps)}")

        if cache_path is not None:
            cache_path.parent.mkdir(parents=True, exist_ok=True)
            try:
                with open(cache_path, "wb") as handle:
                    pickle.dump(fps, handle)
            except Exception:
                pass

        return fps

    def _resolve_cache_path(self, path: Path) -> Optional[Path]:
        if self._DEFAULT_FPS_CACHE_PATH.is_file():
            return self._DEFAULT_FPS_CACHE_PATH
        return Path(f"{path}.fps.pkl")

    def _load_train_fps_from_processed(self, path: Path) -> List:
        try:
            from torch_geometric.data import InMemoryDataset
            from torch_geometric.data.data import Data as PyGData
        except Exception as exc:
            raise ImportError("torch_geometric is required to load processed train.pt") from exc

        data_obj = None
        slices = None
        try:
            from torch.serialization import safe_globals

            with safe_globals([PyGData]):
                data_obj, slices = torch.load(path, map_location="cpu")
        except Exception:
            try:
                data_obj, slices = torch.load(path, map_location="cpu", weights_only=False)
            except TypeError:
                data_obj, slices = torch.load(path, map_location="cpu")

        if data_obj is None or slices is None:
            return []
        if not isinstance(slices, dict) or not slices:
            return []

        dataset = InMemoryDataset.__new__(InMemoryDataset)
        dataset.data = data_obj
        dataset.slices = slices

        first_key = next(iter(slices))
        total = int(slices[first_key].numel() - 1)
        print(f"🔍 [GRPO] train.pt slices={list(slices.keys())} total={total}")
        fps = []
        skipped = 0
        failed = 0
        fp_error_logged = False
        debug_samples = 3
        for idx in range(total):
            try:
                data = InMemoryDataset.get(dataset, idx)
            except Exception:
                failed += 1
                continue
            mol = self._data_to_mol(data)
            if mol is None:
                skipped += 1
                if debug_samples > 0:
                    try:
                        x_shape = tuple(getattr(data, "x", torch.empty(0)).shape)
                        edge_attr = getattr(data, "edge_attr", None)
                        edge_shape = tuple(edge_attr.shape) if edge_attr is not None else None
                        edge_index = getattr(data, "edge_index", None)
                        edge_index_shape = tuple(edge_index.shape) if edge_index is not None else None
                        print(
                            f"⚠️ [GRPO] data->mol[anonymized] idx={idx} x={x_shape} "
                            f"edge_attr={edge_shape} edge_index={edge_index_shape}"
                        )
                    except Exception:
                        pass
                    debug_samples -= 1
                continue
            try:
                fp = rdMolDescriptors.GetMorganFingerprintAsBitVect(
                    mol, self._FP_RADIUS, nBits=self._FP_BITS
                )
                fps.append(fp)
            except Exception as exc:
                failed += 1
                if not fp_error_logged:
                    try:
                        smi = Chem.MolToSmiles(mol) if mol is not None else None
                    except Exception:
                        smi = None
                    print(f"⚠️ [GRPO] [anonymized]: {exc} smiles={smi}")
                    fp_error_logged = True
                continue
        print(
            f"🔍 [GRPO] train.pt [anonymized]: total={total} fps={len(fps)} "
            f"mol_none={skipped} fp_fail={failed}"
        )
        return fps

    def _data_to_mol(self, data) -> Optional[Chem.Mol]:
        if data is None or not hasattr(data, "x"):
            return None
        try:
            atom_types = torch.argmax(data.x, dim=-1).long().cpu()
        except Exception:
            return None

        n_nodes = int(atom_types.numel())
        if n_nodes == 0:
            return None

        atom_dim = int(getattr(data.x, "size", lambda *_: 0)(-1))
        atom_decoder = self.atom_decoder
        if atom_dim and len(atom_decoder) != atom_dim:
            if atom_dim == 9:
                atom_decoder = ["C", "N", "O", "F", "P", "S", "Cl", "Br", "I"]
            elif atom_dim == 12:
                atom_decoder = self._DEFAULT_ATOM_DECODER
            else:
                if not self._WARNED_DECODER_MISMATCH:
                    print(
                        f"⚠️ [GRPO] atom_decoder size mismatch: decoder={len(self.atom_decoder)} "
                        f"data.x={atom_dim}. Skipping molecules."
                    )
                    self._WARNED_DECODER_MISMATCH = True
                return None
            if not self._WARNED_DECODER_MISMATCH:
                print(
                    f"⚠️ [GRPO] atom_decoder size mismatch: decoder={len(self.atom_decoder)} "
                    f"data.x={atom_dim}. Using fallback decoder."
                )
                self._WARNED_DECODER_MISMATCH = True

        edge_types = torch.zeros((n_nodes, n_nodes), dtype=torch.long)
        if hasattr(data, "edge_index") and data.edge_index is not None:
            edge_index = data.edge_index
            edge_attr = getattr(data, "edge_attr", None)
            if edge_attr is not None:
                if edge_attr.dim() > 1:
                    edge_vals = torch.argmax(edge_attr, dim=-1)
                else:
                    edge_vals = edge_attr
            else:
                edge_vals = None
            for k in range(edge_index.size(1)):
                i = int(edge_index[0, k].item())
                j = int(edge_index[1, k].item())
                if edge_vals is not None:
                    bond = int(edge_vals[k].item())
                else:
                    bond = 1
                edge_types[i, j] = bond

        try:
            mol = build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder)
            if hasattr(mol, "GetMol"):
                mol = mol.GetMol()
            try:
                Chem.SanitizeMol(mol)
            except Exception:
                return None
        except Exception as exc:
            if not self._WARNED_DECODER_MISMATCH:
                print(f"⚠️ [GRPO] build_molecule_with_partial_charges failed: {exc}")
                self._WARNED_DECODER_MISMATCH = True
            return None
        return mol if (mol and mol.GetNumAtoms() > 0) else None

    def _graph_to_mol(self, atom_types, edge_types):
        if build_molecule is None:
            return None
        try:
            at = torch.as_tensor(atom_types).long().cpu()
            et = torch.as_tensor(edge_types).long().cpu()
            if at.dim() == 2:
                at = at.argmax(dim=-1)
            if et.dim() == 3:
                et = et.argmax(dim=-1)
            if at.numel() == 0:
                return None
            mol = build_molecule(at, et, self.atom_decoder)
            return mol if (mol and mol.GetNumAtoms() > 0) else None
        except Exception:
            return None


class ValsartanSmartsReward(BaseRewardFunction):
    _DEFAULT_ATOM_DECODER = ["C", "N", "O", "F", "B", "Br", "Cl", "I", "P", "S", "Se", "Si"]

    def __init__(self, mode: str = "full", atom_decoder: Optional[List[str]] = None, device: Optional[torch.device] = None):
        super().__init__("valsartan_smarts", device=device)
        self.mode = mode                    
        self.atom_decoder = atom_decoder or self._DEFAULT_ATOM_DECODER

        self._init_resources()

    def _init_resources(self):
        self.valsartan_smarts = "CN(C=O)Cc1ccc(c2ccccc2)cc1"
        self.valsartan_mol = Chem.MolFromSmiles(self.valsartan_smarts)
        if self.valsartan_mol is None:
             self.valsartan_mol = Chem.MolFromSmarts(self.valsartan_smarts)
        self.valsartan_query = Chem.MolFromSmarts(self.valsartan_smarts)

        self.valsartan_num_atoms = self.valsartan_mol.GetNumHeavyAtoms()

        sitagliptin_smiles = "NC(CC(=O)N1CCn2c(nnc2C(F)(F)F)C1)Cc1cc(F)c(F)cc1F"
        sitagliptin_mol = Chem.MolFromSmiles(sitagliptin_smiles)

        target_logp = Descriptors.MolLogP(sitagliptin_mol)
        target_tpsa = Descriptors.TPSA(sitagliptin_mol)
        target_bertz = Descriptors.BertzCT(sitagliptin_mol)

        self.valsartan_logp_modifier = GaussianModifier(mu=target_logp, sigma=0.2)
        self.valsartan_tpsa_modifier = GaussianModifier(mu=target_tpsa, sigma=5)
        self.valsartan_bertz_modifier = GaussianModifier(mu=target_bertz, sigma=30)

    def __call__(self, graphs: List[Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
        rewards = []
        for atom_types, edge_types in graphs:
            mol = self._graph_to_mol(atom_types, edge_types)
            if mol is None:
                rewards.append(0.0)
                continue

            try:
                if hasattr(mol, "GetMol"):
                    mol = mol.GetMol()
                Chem.SanitizeMol(mol)
                
                frags = Chem.GetMolFrags(mol, asMols=True)
                if frags and len(frags) > 1:
                    mol = max(frags, key=lambda m: int(m.GetNumAtoms()))
            except Exception as e:
                smi = None
                try:
                    smi = Chem.MolToSmiles(mol)
                except Exception:
                    smi = None
                print(f"⚠️ [ValsartanSmartsReward] SanitizeMol failed: {type(e).__name__}: {e}; smiles={smi}")
                rewards.append(0.0)
                continue
            
            try:
                smarts_match = False
                try:
                    if self.valsartan_query is not None:
                        smarts_match = len(mol.GetSubstructMatches(self.valsartan_query)) > 0
                    elif self.valsartan_mol is not None:
                        smarts_match = len(mol.GetSubstructMatches(self.valsartan_mol)) > 0
                except Exception:
                    smarts_match = False

                try:
                    mcs_res = rdFMCS.FindMCS(
                        [mol, self.valsartan_mol],
                        bondCompare=rdFMCS.BondCompare.CompareOrder,
                        atomCompare=rdFMCS.AtomCompare.CompareElements,
                        matchValences=True,
                        ringMatchesRingOnly=True,
                        completeRingsOnly=True,
                        timeout=1,
                    )
                    mcs_atoms = int(getattr(mcs_res, "numAtoms", 0) or 0)
                    if bool(getattr(mcs_res, "canceled", False)):
                        print(
                            f"⚠️ [ValsartanSmartsReward] FindMCS canceled/timeout: "
                            f"mcs_atoms={mcs_atoms}, template_atoms={self.valsartan_num_atoms}"
                        )
                    
                    ratio = mcs_atoms / max(1, self.valsartan_num_atoms)
                    
                    structural_score = ratio ** 2.0 
                    
                except Exception as e:
                    smi = None
                    try:
                        smi = Chem.MolToSmiles(mol)
                    except Exception:
                        smi = None
                    print(f"⚠️ [ValsartanSmartsReward] FindMCS failed: {type(e).__name__}: {e}; smiles={smi}")
                    structural_score = 0.0

                logp_val = Descriptors.MolLogP(mol)
                tpsa_val = Descriptors.TPSA(mol)
                bertz_val = Descriptors.BertzCT(mol)

                logp_score = self.valsartan_logp_modifier(logp_val)
                tpsa_score = self.valsartan_tpsa_modifier(tpsa_val)
                bertz_score = self.valsartan_bertz_modifier(bertz_val)
                
                props_avg = (logp_score + tpsa_score + bertz_score) / 3.0

                if smarts_match:
                    reward = 2.0
                else:
                    reward = structural_score + 0.1 * props_avg
                rewards.append(float(reward))

            except Exception as e:
                smi = None
                try:
                    smi = Chem.MolToSmiles(mol)
                except Exception:
                    smi = None
                print(f"⚠️ [ValsartanSmartsReward] Reward computation failed: {type(e).__name__}: {e}; smiles={smi}")
                rewards.append(0.0)

        return torch.tensor(rewards, dtype=torch.float32, device=self.device)

    def _graph_to_mol(self, atom_types, edge_types):
        if build_molecule is None: return None
        try:
            at = torch.as_tensor(atom_types).long().cpu()
            et = torch.as_tensor(edge_types).long().cpu()
            if at.dim() == 2: at = at.argmax(dim=-1)
            if et.dim() == 3: et = et.argmax(dim=-1)
            if at.numel() == 0: return None
            mol = build_molecule(at, et, self.atom_decoder)
            return mol if (mol and mol.GetNumAtoms() > 0) else None
        except: return None


def create_reward_function(
    reward_type: str,
    cfg=None,
    device=None,
    **kwargs
) -> BaseRewardFunction:
    reward_type = reward_type.lower()
    
    datamodule = kwargs.get('datamodule')
    model = kwargs.get('model')
    ref_metrics = kwargs.get('ref_metrics')
    name = kwargs.get('name')
    atom_decoder = kwargs.get('atom_decoder')
    target_node_dist = kwargs.get('target_node_dist')
    target_edge_dist = kwargs.get('target_edge_dist')
    dist_coef = kwargs.get('dist_coef', None)
    dist_scale = kwargs.get('scale_factor') or kwargs.get('dist_scale_factor')
    dist_clip = kwargs.get('clip_range') or kwargs.get('dist_clip_range')
    edge_dist_factor = kwargs.get('edge_dist_factor')
    precomputed_node_weights = kwargs.get('precomputed_node_weights')
    precomputed_edge_weights = kwargs.get('precomputed_edge_weights')
    sa_threshold = kwargs.get("sa_threshold")
    sim_threshold = kwargs.get("sim_threshold")
    dock_exhaustiveness = kwargs.get("dock_exhaustiveness")
    dock_num_modes = kwargs.get("dock_num_modes")
    dock_timeout = kwargs.get("dock_timeout")
    dataset_name = kwargs.get("dataset_name")
    datadir = kwargs.get("datadir")
    remove_h = kwargs.get("remove_h")

    tdc_oracle = kwargs.get("tdc_oracle")
    tdc_oracles = kwargs.get("tdc_oracles")
    tdc_aggregation = kwargs.get("tdc_aggregation")
    tdc_weights = kwargs.get("tdc_weights")
    tdc_minimize = kwargs.get("tdc_minimize")
    tdc_invalid_score = kwargs.get("tdc_invalid_score")
    tdc_clip_min = kwargs.get("tdc_clip_min")
    tdc_clip_max = kwargs.get("tdc_clip_max")
    tdc_home = kwargs.get("tdc_home")

    if ref_metrics is not None and isinstance(ref_metrics, dict):
        if "ref_degree_dist" not in kwargs:
             kwargs["ref_degree_dist"] = ref_metrics.get("ref_degree_dist")
        if "ref_clustering_hist" not in kwargs:
             kwargs["ref_clustering_hist"] = ref_metrics.get("ref_clustering_hist")
        if "ref_orbit_mean" not in kwargs:
             kwargs["ref_orbit_mean"] = ref_metrics.get("ref_orbit_mean")
    
    target_task = kwargs.get("target_task")
    if target_task is None:
        target_task = resolve_target_task(cfg)

    lead_target_name = kwargs.get("target_name")

    if dist_coef is None and cfg is not None and hasattr(cfg, "grpo"):
        try:
            dist_coef = cfg.grpo.get("dist_coef", None)
        except AttributeError:
            dist_coef = getattr(cfg.grpo, "dist_coef", None)
        if dist_coef is None:
            try:
                dist_coef = cfg.grpo.get("reward_dist_coef", None)
            except AttributeError:
                dist_coef = getattr(cfg.grpo, "reward_dist_coef", None)

    if target_node_dist is None and model is not None and hasattr(model, 'dataset_info'):
        target_node_dist = getattr(model.dataset_info, 'node_types', None)
    if target_edge_dist is None and model is not None and hasattr(model, 'dataset_info'):
        target_edge_dist = getattr(model.dataset_info, 'edge_types', None)

    if atom_decoder is None and model is not None and hasattr(model, 'dataset_info'):
        atom_decoder = getattr(model.dataset_info, 'atom_decoder', None)

    if reward_type == "base":
        print("📊 [anonymized] ([anonymized]/[anonymized])")
        return BaseRewardFunction(device=device)

    if reward_type == "default":
        print("📊 [anonymized]")
        return DefaultRewardFunction(device=device)

    elif reward_type in ("planar_graph", "planar"):
        return PlanarGraphReward(
            device=device,
            datamodule=datamodule,
            ref_degree_dist=kwargs.get("ref_degree_dist"),
            ref_clustering_hist=kwargs.get("ref_clustering_hist"),
            ref_orbit_mean=kwargs.get("ref_orbit_mean"),
        )

    elif reward_type in ("sbm", "sbm_graph"):
        return SBMGraphReward(
            device=device,
            datamodule=datamodule,
            ref_degree_dist=kwargs.get("ref_degree_dist"),
            ref_clustering_hist=kwargs.get("ref_clustering_hist"),
            ref_orbit_mean=kwargs.get("ref_orbit_mean"),
        )

    elif reward_type in ("tree", "tree_graph"):
        return TreeGraphReward(
            device=device,
            datamodule=datamodule,
            ref_degree_dist=kwargs.get("ref_degree_dist"),
            ref_clustering_hist=kwargs.get("ref_clustering_hist"),
            ref_orbit_mean=kwargs.get("ref_orbit_mean"),
        )
    
    elif reward_type in ("guacamol_mpo", "guacamol_goal", "target_mpo", "target_goal"):
        return TargetMPOReward(
            target_task=target_task if target_task else "penalized_logp",
            atom_decoder=atom_decoder,
            device=device
        )
    elif reward_type in ("tdc_oracle", "tdc_pmo", "pmo"):
        if cfg is not None and hasattr(cfg, "grpo"):
            try:
                tdc_oracle = tdc_oracle or cfg.grpo.get("tdc_oracle", None)
                tdc_oracles = tdc_oracles or cfg.grpo.get("tdc_oracles", None)
                tdc_aggregation = tdc_aggregation or cfg.grpo.get("tdc_aggregation", None)
                tdc_weights = tdc_weights or cfg.grpo.get("tdc_weights", None)
                if tdc_minimize is None:
                    tdc_minimize = cfg.grpo.get("tdc_minimize", None)
                if tdc_invalid_score is None:
                    tdc_invalid_score = cfg.grpo.get("tdc_invalid_score", None)
                if tdc_clip_min is None:
                    tdc_clip_min = cfg.grpo.get("tdc_clip_min", None)
                if tdc_clip_max is None:
                    tdc_clip_max = cfg.grpo.get("tdc_clip_max", None)
                if tdc_home is None:
                    tdc_home = cfg.grpo.get("tdc_home", None)
            except AttributeError:
                pass

        oracle_names = tdc_oracles or tdc_oracle
        if oracle_names is None:
            raise ValueError("[anonymized] TDC reward [anonymized] grpo.tdc_oracle [anonymized] grpo.tdc_oracles")

        return TDCOracleReward(
            oracle_names=oracle_names,
            atom_decoder=atom_decoder,
            aggregation=tdc_aggregation or "mean",
            weights=tdc_weights,
            minimize=bool(tdc_minimize) if tdc_minimize is not None else False,
            invalid_score=float(tdc_invalid_score) if tdc_invalid_score is not None else 0.0,
            clip_min=tdc_clip_min,
            clip_max=tdc_clip_max,
            tdc_home=tdc_home,
            device=device,
        )
    elif reward_type in ("gdpo_docking", "gdpo"):
        if cfg is not None and hasattr(cfg, "grpo"):
            try:
                lead_target_name = lead_target_name or cfg.grpo.get("target_name", None)
            except AttributeError:
                try:
                    lead_target_name = lead_target_name or getattr(cfg.grpo, "target_name", None)
                except Exception:
                    pass
            if sim_threshold is None:
                sim_override = None
                try:
                    sim_override = cfg.grpo.get("gdpo_sim_threshold", None)
                except Exception:
                    sim_override = getattr(cfg.grpo, "gdpo_sim_threshold", None)
                if sim_override is None:
                    try:
                        sim_override = cfg.grpo.get("gdpo_eval_sim_threshold", None)
                    except Exception:
                        sim_override = getattr(cfg.grpo, "gdpo_eval_sim_threshold", None)
                dataset_name = dataset_name or getattr(getattr(cfg, "dataset", None), "name", None)
                sim_threshold = gdpo_get_sim_threshold(dataset_name or "", override=sim_override)

        if dataset_name is None and cfg is not None:
            dataset_name = getattr(getattr(cfg, "dataset", None), "name", None)
        if datadir is None and cfg is not None:
            datadir = getattr(getattr(cfg, "dataset", None), "datadir", None)
        if remove_h is None and cfg is not None:
            remove_h = getattr(getattr(cfg, "dataset", None), "remove_h", None)

        if lead_target_name is None:
            raise ValueError("GDPODockingReward requires grpo.target_name or target_name in kwargs.")

        return GDPODockingReward(
            target_name=str(lead_target_name),
            atom_decoder=atom_decoder,
            device=device,
            sa_threshold=sa_threshold,
            sim_threshold=sim_threshold,
            dock_exhaustiveness=dock_exhaustiveness,
            dock_num_modes=dock_num_modes,
            dock_timeout=dock_timeout,
            dataset_name=dataset_name,
            datadir=datadir,
            remove_h=remove_h,
        )
    elif reward_type in ("molecular_validity", "guacamol_reward", "gracamol_reward", "gracamol"):
        return MolecularValidityReward(
            atom_decoder=atom_decoder,
            device=device,
            target_node_dist=target_node_dist,
            target_edge_dist=target_edge_dist,
            dist_coef=dist_coef if dist_coef is not None else 0.1,
            scale_factor=dist_scale if dist_scale is not None else 10.0,
            clip_range=dist_clip if dist_clip is not None else 2.0,
            edge_dist_factor=edge_dist_factor if edge_dist_factor is not None else 1.0,
            precomputed_node_weights=precomputed_node_weights,
            precomputed_edge_weights=precomputed_edge_weights,
        )

    elif reward_type in ("valsartan_smarts", "valsartan_smarts_easy", "valsartan"):
        return ValsartanSmartsReward(
            mode="easy",                               
            atom_decoder=atom_decoder,
            device=device
        )
            
    else:
        print(f"⚠️  [anonymized]: {reward_type}，[anonymized]")
        return DefaultRewardFunction(device=device)
    if edge_dist_factor is None and cfg is not None and hasattr(cfg, "grpo"):
        try:
            edge_dist_factor = cfg.grpo.get("edge_dist_factor", None)
        except AttributeError:
            edge_dist_factor = getattr(cfg.grpo, "edge_dist_factor", None)
