import warnings
warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*")
warnings.filterwarnings("ignore", message=".*Boto3 will no longer support Python 3.9.*")
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", message=".*PossibleUserWarning.*")

import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from datetime import datetime
import time
import torch
import hydra
from hydra.core.hydra_config import HydraConfig
import numpy as np
import scipy

if not hasattr(scipy, 'errstate'):
    scipy.errstate = np.errstate
from utils import ensure_legacy_aliases

ensure_legacy_aliases()
if not hasattr(np, "float_"):
    np.float_ = np.float64
if not hasattr(np, "complex_"):
    np.complex_ = np.complex128
import random
from omegaconf import DictConfig, OmegaConf

os.environ.setdefault("PL_DISABLE_WANDB", "1")
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, Callback
from pytorch_lightning.utilities.warnings import PossibleUserWarning

try:
    import swanlab
except ImportError:
    swanlab = None

from grpo_lightning_module import (
    GRPOLightningModule,
    FlowGRPODataModule,
    create_grpo_lightning_module
)

from graph_discrete_flow_model import GraphDiscreteFlowModel
from grpo_rewards import resolve_target_task

warnings.filterwarnings("ignore", category=PossibleUserWarning)

torch.set_float32_matmul_precision('medium')


def _patch_torch_load_weights_only_default():
    import torch as _torch
    import functools as _functools

    if getattr(_torch.load, "_flow_grpo_patched", False):
        return

    original_load = _torch.load

    @_functools.wraps(original_load)
    def patched_load(*args, **kwargs):
        if "weights_only" not in kwargs:
            kwargs["weights_only"] = False
        return original_load(*args, **kwargs)

    patched_load._flow_grpo_patched = True
    _torch.load = patched_load

def _should_strict_resume(cfg: DictConfig) -> bool:
    try:
        return bool(cfg.grpo.get("strict_resume", True))
    except Exception:
        return True

def _validate_resume_checkpoint(flow_grpo_module: pl.LightningModule, ckpt_path: str, cfg: DictConfig) -> None:
    try:
        checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    except TypeError:
        checkpoint = torch.load(ckpt_path, map_location="cpu")

    state_dict = checkpoint.get("state_dict", checkpoint) if isinstance(checkpoint, dict) else checkpoint
    if isinstance(state_dict, dict):
        normalized_state_dict = dict(state_dict)
        for key in list(state_dict.keys()):
            if key.startswith("model."):
                continue
            prefixed = f"model.{key}"
            if prefixed not in normalized_state_dict:
                normalized_state_dict[prefixed] = state_dict[key]
        state_dict = normalized_state_dict
    model_state = flow_grpo_module.state_dict()

    model_keys = set(model_state.keys())
    ckpt_keys = set(state_dict.keys())
    missing = sorted(model_keys - ckpt_keys)
    unexpected = sorted(ckpt_keys - model_keys)

    compat_missing = {
        "model.p0_node_dist",
        "model.p0_edge_dist",
        "model.node_count_prob",
        "model.node_count_buffer_rewards",
        "model.node_count_buffer_nodes",
        "model.node_count_buffer_filled",
    }

    def _is_compat_key(key: str) -> bool:
        if key in compat_missing:
            return True
        if key.startswith("model.sampling_metrics."):
            return True
        return False

    missing = [key for key in missing if not _is_compat_key(key)]

    shape_mismatch = []
    for key in sorted(model_keys & ckpt_keys):
        v_ckpt = state_dict.get(key)
        v_model = model_state.get(key)
        if hasattr(v_ckpt, "shape") and hasattr(v_model, "shape"):
            if tuple(v_ckpt.shape) != tuple(v_model.shape):
                shape_mismatch.append(key)

    shape_mismatch = [key for key in shape_mismatch if not _is_compat_key(key)]

    if missing or unexpected or shape_mismatch:
        raise ValueError(
            "Resume checkpoint is incompatible with current model. "
            f"missing={len(missing)}, unexpected={len(unexpected)}, shape_mismatch={len(shape_mismatch)}. "
            f"Example missing={missing[:5]}, unexpected={unexpected[:5]}, shape_mismatch={shape_mismatch[:5]}"
        )

    if "grpo_trainer_state" not in checkpoint:
        print(
            "Resume checkpoint missing 'grpo_trainer_state'; GRPO trainer buffers will be reinitialized."
        )

def create_datamodule_and_model_components(cfg: DictConfig):
    dataset_config = cfg["dataset"]
    
    if dataset_config["name"] in [
        "sbm",
        "comm20", 
        "planar",
        "tree",
    ]:
        from analysis.visualization import NonMolecularVisualization
        from datasets.spectre_dataset import (
            SpectreGraphDataModule,
            SpectreDatasetInfos,
        )
        from analysis.spectre_utils import (
            PlanarSamplingMetrics,
            SBMSamplingMetrics,
            Comm20SamplingMetrics,
            TreeSamplingMetrics,
        )
        from metrics.abstract_metrics import TrainAbstractMetricsDiscrete
        from models.extra_features import DummyExtraFeatures, ExtraFeatures

        datamodule = SpectreGraphDataModule(cfg)
        if dataset_config["name"] == "sbm":
            sampling_metrics = SBMSamplingMetrics(datamodule)
        elif dataset_config["name"] == "comm20":
            sampling_metrics = Comm20SamplingMetrics(datamodule)
        elif dataset_config["name"] == "planar":
            sampling_metrics = PlanarSamplingMetrics(datamodule)
        elif dataset_config["name"] == "tree":
            sampling_metrics = TreeSamplingMetrics(datamodule)
        else:
            raise NotImplementedError(
                f"Dataset {dataset_config['name']} not implemented"
            )

        dataset_infos = SpectreDatasetInfos(datamodule, dataset_config)
        train_metrics = TrainAbstractMetricsDiscrete()
        visualization_tools = NonMolecularVisualization(dataset_name=cfg.dataset.name)

        extra_features = ExtraFeatures(
            cfg.model.extra_features,
            cfg.model.rrwp_steps,
            dataset_info=dataset_infos,
        )
        domain_features = DummyExtraFeatures()

        dataset_infos.compute_input_output_dims(
            datamodule=datamodule,
            extra_features=extra_features,
            domain_features=domain_features,
        )
    elif dataset_config["name"] in ["qm9", "guacamol", "guacamol_mpo", "moses"] or "zinc" in dataset_config["name"]:
        from metrics.molecular_metrics import (
            TrainMolecularMetrics,
            SamplingMolecularMetrics,
        )
        from metrics.molecular_metrics_discrete import TrainMolecularMetricsDiscrete
        from models.extra_features import ExtraFeatures
        from models.extra_features_molecular import ExtraMolecularFeatures
        from analysis.visualization import MolecularVisualization

        if "qm9" in dataset_config["name"]:
            from datasets import qm9_dataset

            datamodule = qm9_dataset.QM9DataModule(cfg)
            dataset_infos = qm9_dataset.QM9infos(datamodule=datamodule, cfg=cfg)
            dataset_smiles = qm9_dataset.get_smiles(
                cfg=cfg,
                datamodule=datamodule,
                dataset_infos=dataset_infos,
                evaluate_datasets=False,
            )
           
        elif dataset_config["name"] in ["guacamol", "guacamol_mpo"]:
            split = getattr(cfg.dataset, "split", "ood")
            if split == "ood":
                from src.datasets import guacamol_dataset_ood as guacamol_dataset
            else:
                from src.datasets import guacamol_dataset as guacamol_dataset

            datamodule = guacamol_dataset.GuacamolDataModule(cfg)
            dataset_infos = guacamol_dataset.Guacamolinfos(datamodule, cfg)
            
            if dataset_config.get("empty", False) or dataset_config["name"] == "guacamol_mpo":
                 dataset_smiles = {"train": [], "val": [], "test": []}
            else:
                dataset_smiles = guacamol_dataset.get_smiles(
                    raw_dir=datamodule.train_dataset.raw_dir,
                    filter_dataset=cfg.dataset.filter,
                )
        elif dataset_config.name == "moses":
            from datasets import moses_dataset

            datamodule = moses_dataset.MosesDataModule(cfg)
            dataset_infos = moses_dataset.MOSESinfos(datamodule, cfg)
            dataset_smiles = moses_dataset.get_smiles(
                raw_dir=datamodule.train_dataset.raw_dir,
                filter_dataset=cfg.dataset.filter,
            )
        elif "zinc" in dataset_config["name"]:
            from datasets import zinc_dataset

            use_empty = bool(getattr(cfg.dataset, "empty", False))
            if use_empty:
                from torch_geometric.data import Data
                from torch_geometric.data import InMemoryDataset

                class _MockZINCDataset(InMemoryDataset):
                    def __init__(self, stage: str, root: str):
                        self.stage = stage
                        super().__init__(root)
                        self.data, self.slices = None, None

                    def len(self):
                        return 1

                    def get(self, idx):
                        x = torch.zeros((2, 9), dtype=torch.float)
                        x[:, 0] = 1.0
                        edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
                        edge_attr = torch.zeros((2, 4), dtype=torch.float)
                        edge_attr[:, 1] = 1.0
                        y = torch.zeros((1, 0), dtype=torch.float)
                        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=0)

                    @property
                    def raw_file_names(self):
                        return []

                    @property
                    def processed_file_names(self):
                        return []

                    def download(self):
                        pass

                    def process(self):
                        pass

                from datasets.abstract_dataset import MolecularDataModule
                import pathlib
                import os

                base_path = pathlib.Path(os.path.realpath(__file__)).parents[1]
                root_path = os.path.join(base_path, getattr(cfg.dataset, "datadir", "data/zinc/"))
                datasets = {
                    "train": _MockZINCDataset("train", root=root_path),
                    "val": _MockZINCDataset("val", root=root_path),
                    "test": _MockZINCDataset("test", root=root_path),
                }
                datamodule = MolecularDataModule(cfg, datasets)
                dataset_smiles = {"train": [], "val": [], "test": []}
            else:
                datamodule = zinc_dataset.ZINCDataModule(cfg)
                dataset_infos = zinc_dataset.ZINCinfos(datamodule=datamodule, cfg=cfg)
                dataset_smiles = zinc_dataset.get_smiles(
                    cfg=cfg,
                    datamodule=datamodule,
                    dataset_infos=dataset_infos,
                    evaluate_datasets=False,
                )

            if use_empty:
                dataset_infos = zinc_dataset.ZINCinfos(datamodule=datamodule, cfg=cfg)
        else:
            raise ValueError("Dataset not implemented")

        extra_features = ExtraFeatures(
            cfg.model.extra_features,
            cfg.model.rrwp_steps,
            dataset_info=dataset_infos,
        )
        domain_features = ExtraMolecularFeatures(dataset_infos=dataset_infos)

        dataset_infos.compute_input_output_dims(
            datamodule=datamodule,
            extra_features=extra_features,
            domain_features=domain_features,
        )

        train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
        add_virtual_states = "absorbing" == cfg.model.transition
        sampling_metrics = SamplingMolecularMetrics(
            dataset_infos, dataset_smiles, cfg, add_virtual_states=add_virtual_states
        )
        visualization_tools = MolecularVisualization(
            cfg.dataset.remove_h, dataset_infos=dataset_infos
        )
    elif dataset_config["name"] == "tls":
        from datasets import tls_dataset
        from metrics.tls_metrics import TLSSamplingMetrics
        from analysis.visualization import NonMolecularVisualization
        from metrics.abstract_metrics import TrainAbstractMetricsDiscrete
        from models.extra_features import DummyExtraFeatures, ExtraFeatures

        datamodule = tls_dataset.TLSDataModule(cfg)
        dataset_infos = tls_dataset.TLSInfos(datamodule=datamodule)

        train_metrics = TrainAbstractMetricsDiscrete()
        extra_features = (
            ExtraFeatures(
                cfg.model.extra_features,
                cfg.model.rrwp_steps,
                dataset_info=dataset_infos,
            )
            if cfg.model.extra_features is not None
            else DummyExtraFeatures()
        )
        domain_features = DummyExtraFeatures()

        sampling_metrics = TLSSamplingMetrics(datamodule)
        visualization_tools = NonMolecularVisualization(dataset_name=cfg.dataset.name)

        dataset_infos.compute_input_output_dims(
            datamodule=datamodule,
            extra_features=extra_features,
            domain_features=domain_features,
        )
    elif dataset_config["name"] == "my_tree":
        from datasets.my_tree_dataset import (
            MyTreeGraphDataModule,
            MyTreeDatasetInfos,
        )
        from analysis.visualization import NonMolecularVisualization
        from analysis.spectre_utils import TreeSamplingMetrics
        from metrics.abstract_metrics import TrainAbstractMetricsDiscrete
        from models.extra_features import DummyExtraFeatures, ExtraFeatures

        datamodule = MyTreeGraphDataModule(cfg)
        sampling_metrics = TreeSamplingMetrics(datamodule)

        dataset_infos = MyTreeDatasetInfos(datamodule, dataset_config)
        train_metrics = TrainAbstractMetricsDiscrete()
        visualization_tools = NonMolecularVisualization(dataset_name=cfg.dataset.name)

        extra_features = ExtraFeatures(
            cfg.model.extra_features,
            cfg.model.rrwp_steps,
            dataset_info=dataset_infos,
        )
        domain_features = DummyExtraFeatures()

        dataset_infos.compute_input_output_dims(
            datamodule=datamodule,
            extra_features=extra_features,
            domain_features=domain_features,
        )
    elif dataset_config["name"] == "my_planar":
        from datasets.my_planar_dataset import (
            MyPlanarGraphDataModule,
            MyPlanarDatasetInfos,
        )
        from analysis.visualization import NonMolecularVisualization
        from analysis.spectre_utils import PlanarSamplingMetrics
        from metrics.abstract_metrics import TrainAbstractMetricsDiscrete
        from models.extra_features import DummyExtraFeatures, ExtraFeatures

        datamodule = MyPlanarGraphDataModule(cfg)
        sampling_metrics = PlanarSamplingMetrics(datamodule)

        dataset_infos = MyPlanarDatasetInfos(datamodule, dataset_config)
        train_metrics = TrainAbstractMetricsDiscrete()
        visualization_tools = NonMolecularVisualization(dataset_name=cfg.dataset.name)

        extra_features = ExtraFeatures(
            cfg.model.extra_features,
            cfg.model.rrwp_steps,
            dataset_info=dataset_infos,
        )
        domain_features = DummyExtraFeatures()

        dataset_infos.compute_input_output_dims(
            datamodule=datamodule,
            extra_features=extra_features,
            domain_features=domain_features,
        )
    else:
        raise NotImplementedError("Unknown dataset {}".format(cfg["dataset"]))

    if not bool(getattr(cfg.dataset, "empty", False)):
        dataset_infos.compute_reference_metrics(
            datamodule=datamodule,
            sampling_metrics=sampling_metrics,
        )
    else:
        dataset_infos.ref_metrics = {"val": None, "test": None}

    model_kwargs = {
        "dataset_infos": dataset_infos,
        "train_metrics": train_metrics,
        "sampling_metrics": sampling_metrics,
        "visualization_tools": visualization_tools,
        "extra_features": extra_features,
        "domain_features": domain_features,
        "test_labels": (
            datamodule.test_labels
            if ("qm9" in cfg.dataset.name and cfg.general.conditional)
            else None
        ),
    }
    
    return datamodule, model_kwargs

def _initialize_distributed_environment_for_checkpoint_loading():
    if not torch.distributed.is_initialized():
        if "WORLD_SIZE" in os.environ:
            rank = int(os.environ["RANK"])
            world_size = int(os.environ["WORLD_SIZE"])
            try:
                torch.distributed.init_process_group(backend="nccl")
            except Exception as e:
                torch.distributed.init_process_group(backend="gloo")
        else:
            os.environ["MASTER_ADDR"] = "localhost"
            os.environ["MASTER_PORT"] = "12355"
            os.environ["RANK"] = "0"
            os.environ["WORLD_SIZE"] = "1"
            try:
                torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1)
            except Exception as e:
                pass
    else:
        pass

def _load_pretrained_weights_into_grpo_module(grpo_module, pretrained_path: str):
    checkpoint = torch.load(pretrained_path, map_location='cpu', weights_only=False)

    raw_state = checkpoint.get("state_dict", checkpoint)
    if not isinstance(raw_state, dict):
        raise ValueError("Invalid checkpoint format: missing state_dict dict")

    raw_keys = list(raw_state.keys())
    already_grpo_prefixed = any(
        k.startswith("model.model.")
        or k.startswith("model.p0_")
        or k.startswith("model.node_count_")
        for k in raw_keys
    )

    if already_grpo_prefixed:
        remapped_state_dict = dict(raw_state)
        print("✅ [anonymized]GRPO/[anonymized]checkpoint[anonymized]（[anonymized]'model.'[anonymized]），[anonymized]。")
    else:
        remapped_state_dict = {}
        for k, v in raw_state.items():
            new_key = f"model.{k}"
            remapped_state_dict[new_key] = v

    keys_to_drop = []
    for k in remapped_state_dict.keys():
        if k.startswith("model.sampling_metrics."):
            keys_to_drop.append(k)
    if keys_to_drop:
        print("⚠️ [anonymized]（[anonymized]）:")
        for k in keys_to_drop:
            print(f"   - {k}")
            remapped_state_dict.pop(k, None)

    grpo_module.load_state_dict(remapped_state_dict, strict=False)
    
    for param in grpo_module.parameters():
        param.requires_grad = True
    print("✅ [anonymized]requires_grad=True")


def _run_flow_grpo_test_only(cfg: DictConfig):
    random_seed = int(time.time()) % (2**31)
    pl.seed_everything(random_seed)
    print(f"test_only [anonymized]: {random_seed}")

    ckpt_path = cfg.general.get("test_only")
    if not ckpt_path:
        raise ValueError("general.test_only [anonymized] ckpt [anonymized]，[anonymized] test_only。")
    ckpt_path = os.path.expanduser(ckpt_path)
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"[anonymized] checkpoint: {ckpt_path}")

    datamodule, model_kwargs = create_datamodule_and_model_components(cfg)
    try:
        datamodule.setup(stage="fit")
    except Exception:
        pass
    flow_grpo_module = create_grpo_lightning_module(
        cfg=cfg,
        model_kwargs=model_kwargs,
        datamodule=datamodule,
        total_steps=cfg.grpo.total_steps,
    )

    checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    state_dict = checkpoint.get("state_dict", checkpoint)
    has_double_model_prefix = any(k.startswith("model.model.") for k in state_dict.keys())
    if has_double_model_prefix:
        flow_grpo_module.load_state_dict(state_dict, strict=False)
        print(f"📥 [anonymized] Flow-GRPO/[anonymized] checkpoint（[anonymized]'model.model.'[anonymized]）: {ckpt_path}")
    else:
        print(f"📥 [anonymized]GraphDiscreteFlowModel[anonymized]checkpoint（[anonymized]'model.model.'[anonymized]），[anonymized]flow[anonymized]: {ckpt_path}")
        _load_pretrained_weights_into_grpo_module(flow_grpo_module, ckpt_path)

    use_gpu = cfg.general.gpus > 0 and torch.cuda.is_available()
    device = torch.device("cuda") if use_gpu else torch.device("cpu")
    flow_grpo_module.to(device)
    flow_grpo_module.eval()

    flow_grpo_module.on_fit_start()

    total_to_generate = (
        cfg.general.final_model_samples_to_generate * cfg.general.num_sample_fold
    )
    samples_left = total_to_generate
    batch_size = 2 * cfg.train.batch_size
    total_steps = flow_grpo_module._get_forward_steps()
    print(f"🧪 test_only [anonymized]: {total_steps}")

    all_graphs = []

    while samples_left > 0:
        cur_bs = min(samples_left, batch_size)
        graphs, node_mask, *_ = flow_grpo_module.grpo_trainer.sample_graphs_with_trajectory_tracking(
            batch_size=cur_bs,
            seed=None,                
            total_inference_steps=total_steps,
            force_same_start=False,
            group_size_for_same_start=None,
        )
        batch_graphs = flow_grpo_module.grpo_trainer._convert_placeholder_to_graph_list_cpu(
            graphs, node_mask, as_tensor=True
        )
        all_graphs.extend(batch_graphs)
        samples_left -= cur_bs

    graph_list = all_graphs[:total_to_generate]

    reward_type_str = str(cfg.grpo.reward_type).lower()
    is_goal_directed_reward = any(k in reward_type_str for k in ("guacamol", "tdc", "pmo"))
    reward_stats = {}
    if getattr(flow_grpo_module, "reward_function", None) is not None:
        try:
            rewards = flow_grpo_module.grpo_trainer._compute_rewards_multiprocess_sync(
                graph_list,
                timeout=getattr(flow_grpo_module.grpo_trainer, "eval_timeout_seconds", 1800),
                context="test_only",
            )
            if rewards.numel() > 0:
                rewards_cpu = rewards.detach().cpu()
                reward_stats = {
                    "grpo/reward_mean": float(rewards_cpu.mean().item()),
                    "grpo/reward_std": float(rewards_cpu.std().item()) if rewards_cpu.numel() > 1 else 0.0,
                    "grpo/reward_min": float(rewards_cpu.min().item()),
                    "grpo/reward_max": float(rewards_cpu.max().item()),
                }
                print("Test-only GRPO[anonymized]:")
                for k, v in reward_stats.items():
                    print(f"   {k}: {v}")

                if is_goal_directed_reward:
                    print("🏆 Performing goal-directed Top-K Analysis...")
                    from analysis.rdkit_functions import mol2smiles, build_molecule
                    
                    smiles_list = []
                    atom_decoder = datamodule.dataset_infos.atom_decoder
                    for G in graph_list:
                        atom_types, edge_types = G
                        if isinstance(atom_types, torch.Tensor): atom_types = atom_types.cpu()
                        if isinstance(edge_types, torch.Tensor): edge_types = edge_types.cpu()
                        
                        mol = build_molecule(atom_types, edge_types, atom_decoder)
                        smi = mol2smiles(mol)
                        smiles_list.append(smi)
                    
                    scored_mols = []
                    for s, r in zip(smiles_list, rewards_cpu.tolist()):
                        if s:                                 
                            scored_mols.append((s, r))
                            
                    scored_mols.sort(key=lambda x: x[1], reverse=True)
                    
                    top_k_stats = {}
                    if len(scored_mols) > 0:
                        top_k_stats["grpo/top1_score"] = scored_mols[0][1]
                        top_k_stats["grpo/top10_mean"] = np.mean([x[1] for x in scored_mols[:10]])
                        top_k_stats["grpo/top100_mean"] = np.mean([x[1] for x in scored_mols[:100]])
                        
                        print("   Top-1 Score:", top_k_stats["grpo/top1_score"])
                        print("   Top-10 Mean:", top_k_stats["grpo/top10_mean"])
                        print("   Top-100 Mean:", top_k_stats["grpo/top100_mean"])
                        
                        reward_stats.update(top_k_stats)
                        
                        task_for_filename = resolve_target_task(cfg, default=None)
                        if not task_for_filename:
                            tdc_tag = OmegaConf.select(cfg, "grpo.tdc_oracle") or OmegaConf.select(cfg, "grpo.tdc_oracles")
                            if tdc_tag is not None:
                                if isinstance(tdc_tag, (list, tuple)):
                                    task_for_filename = "_".join(str(x) for x in tdc_tag)
                                else:
                                    try:
                                        task_for_filename = "_".join(str(x) for x in list(tdc_tag))
                                    except Exception:
                                        task_for_filename = str(tdc_tag)
                        if not task_for_filename:
                            task_for_filename = "goal_directed"
                        for ch in [" ", "/", "\\", ":", ";", ","]:
                            task_for_filename = task_for_filename.replace(ch, "_")
                        task_for_filename = task_for_filename[:80]
                        best_mols_file = os.path.join(os.getcwd(), f"best_molecules_{task_for_filename}.txt")
                        with open(best_mols_file, "w") as f:
                            f.write(f"Rank\tScore\tSMILES\n")
                            for i, (s, r) in enumerate(scored_mols[:100]):
                                f.write(f"{i+1}\t{r:.4f}\t{s}\n")
                        print(f"💾 Top 100 molecules saved to: {best_mols_file}")
                    
            else:
                print("⚠️ Test-only GRPO[anonymized]")
        except Exception as e:
            print(f"⚠️ Test-only GRPO[anonymized]: {e}")
    else:
        print("⚠️ [anonymized]GRPO[anonymized]，[anonymized]reward[anonymized]")

    to_log = {}
    is_guacamol_mpo = is_goal_directed_reward

    if is_guacamol_mpo:
        print("🏆 Goal-directed Mode: Skipping full distribution metrics, computing VUN + Score only.")
        
        from analysis.rdkit_functions import mol2smiles, build_molecule
        
        valid_smiles = []
        atom_decoder = datamodule.dataset_infos.atom_decoder
        
        for G in graph_list:
            atom_types, edge_types = G
            if isinstance(atom_types, torch.Tensor): atom_types = atom_types.cpu()
            if isinstance(edge_types, torch.Tensor): edge_types = edge_types.cpu()
            
            mol = build_molecule(atom_types, edge_types, atom_decoder)
            smi = mol2smiles(mol)
            if smi:
                valid_smiles.append(smi)
                
        validity = len(valid_smiles) / len(graph_list) if len(graph_list) > 0 else 0
        unique_smiles = set(valid_smiles)
        uniqueness = len(unique_smiles) / len(valid_smiles) if len(valid_smiles) > 0 else 0
        novelty = 1.0                                                                  
        
        to_log["Validity"] = validity
        to_log["Uniqueness"] = uniqueness
        to_log["Novelty"] = novelty
        
        print(f"   Validity: {validity:.4f}")
        print(f"   Uniqueness: {uniqueness:.4f}")
        print(f"   Novelty: {novelty:.4f}")

    else:
        model = flow_grpo_module.model
        model.sampling_metrics.reset()
        to_log = model.evaluate_samples(
            samples=graph_list,
            labels=None,
            is_test=True,
        )

    if reward_stats:
        to_log.update(reward_stats)

    filename = os.path.join(
        os.getcwd(),
        f"test_epoch{flow_grpo_module.current_epoch}_res_{cfg.sample.eta}_{cfg.sample.rdb}.txt",
    )
    with open(filename, "w") as file:
        for key, value in to_log.items():
            file.write(f"{key}: {value}\n")

    print(f"✅ Test-only [anonymized]. Mode: {'Goal-Directed' if is_guacamol_mpo else 'Distribution'}")
    for k, v in to_log.items():
        print(f"   {k}: {v}")

    return to_log

def run_flow_grpo_training_pipeline(cfg: DictConfig):
    print(f" [anonymized]: {cfg.dataset.name}")
    print(f" [anonymized]: {cfg.grpo.reward_type}")
    _patch_torch_load_weights_only_default()

    swanlab_mode = cfg.general.get('swanlab', 'disabled')
    if swanlab is not None and swanlab_mode != 'disabled':
        try:
            config_dict = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
            swanlab.init(
                project=f"Flow-GRPO-{cfg.dataset.name}",
                experiment_name=cfg.general.name,
                config=config_dict,
                mode=swanlab_mode,                                  
            )
        except Exception as e:
            print(f"⚠️ SwanLab[anonymized]: {e}")
    
    pl.seed_everything(cfg.train.seed)
    
    use_gpu = cfg.general.gpus > 0 and torch.cuda.is_available()
    if use_gpu:
        available_gpus = min(cfg.general.gpus, torch.cuda.device_count())
        for i in range(available_gpus):
            print(f"   GPU {i}: {torch.cuda.get_device_name(i)}")
    else:
        print("🖥️ [anonymized]CPU")
        available_gpus = 0
    
    try:
        datamodule, model_kwargs = create_datamodule_and_model_components(cfg)
    except Exception as e:
        print(f"❌ [anonymized]: {e}")
        raise e

    if cfg.general.get("test_only"):
        print("🧪 [anonymized] general.test_only，[anonymized] test-only [anonymized]（GRPO [anonymized] + main [anonymized]）")
        return _run_flow_grpo_test_only(cfg)
    
    try:
        flow_grpo_module = create_grpo_lightning_module(
            cfg=cfg,
            model_kwargs=model_kwargs,
            datamodule=datamodule,
            total_steps=cfg.grpo.total_steps,
        )
    except Exception as e:
        print(f"❌ Flow-GRPO Lightning[anonymized]: {e}")
        raise e
    
    flow_grpo_datamodule = FlowGRPODataModule(
        num_epochs=cfg.grpo.total_steps,
        batch_size=1
    )
    
    callbacks = []
    
    trainer_kwargs = {
        'max_steps': cfg.grpo.total_steps,
        'max_epochs': -1,                                     
        'accumulate_grad_batches': 1,                                                
        
        'check_val_every_n_epoch': None,
        'val_check_interval': None,
        'num_sanity_val_steps': 0,
        
        'log_every_n_steps': 10,
        'enable_progress_bar': True,
        'enable_model_summary': True,
        
        'callbacks': callbacks,
        
        'fast_dev_run': cfg.general.name == "debug",
        
        'deterministic': False,                    
        'benchmark': True,
        'logger': False,                                                                 
    }
    
    if use_gpu:
        trainer_kwargs.update({
            'accelerator': "gpu",
            'devices': available_gpus,
            'precision': cfg.get('mixed_precision', '32'),
            'strategy': 'ddp' if available_gpus > 1 else 'auto'
        })
    else:
        trainer_kwargs.update({
            'accelerator': "cpu",
            'devices': 1,
        })
    
    trainer = Trainer(**trainer_kwargs)
    
    try:
        
        ckpt_path = cfg.grpo.get('resume_from_checkpoint')                                
        if ckpt_path and os.path.exists(ckpt_path):
            if _should_strict_resume(cfg):
                _validate_resume_checkpoint(flow_grpo_module, ckpt_path, cfg)
            print(f"📥 [anonymized]checkpoint[anonymized]: {ckpt_path}")
            print(f"   [anonymized]（[anonymized]、[anonymized]、[anonymized]）")
            trainer.fit(
                model=flow_grpo_module,
                datamodule=flow_grpo_datamodule,
                ckpt_path=ckpt_path
            )
        else:
            _initialize_distributed_environment_for_checkpoint_loading()
            
            pretrained_path = cfg.grpo.get('pretrained_checkpoint')
            if pretrained_path and os.path.exists(pretrained_path):
                print(f"📥 [anonymized]: {pretrained_path}")
                _load_pretrained_weights_into_grpo_module(flow_grpo_module, pretrained_path)
            
            trainer.fit(
                model=flow_grpo_module,
                datamodule=flow_grpo_datamodule
            )
        
        print("\n✅ Flow-GRPO[anonymized]!")
        
        save_dir = HydraConfig.get().runtime.output_dir
        current_time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
        final_checkpoint_path = os.path.join(save_dir, f"final_model_{current_time_str}.ckpt")
        trainer.save_checkpoint(final_checkpoint_path)
        print(f"💾 [anonymized]: {final_checkpoint_path}")
        
    except Exception as e:
        print(f"❌ Flow-GRPO[anonymized]: {e}")
        import traceback
        traceback.print_exc()
        raise e
    
    finally:
        if use_gpu:
            torch.cuda.empty_cache()
        if swanlab is not None and swanlab.run is not None:
            swanlab.finish()
    
    return flow_grpo_module


@hydra.main(version_base="1.3", config_path="../configs", config_name="config")
def main(cfg: DictConfig):
    try:
        model = run_flow_grpo_training_pipeline(cfg)
        
    except Exception as e:
        print(f"❌ Flow-GRPO[anonymized]: {e}")
        raise e


if __name__ == "__main__":
    main()
