{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import logging\n",
    "import sys, os\n",
    "import time\n",
    "import warnings\n",
    "import torch\n",
    "import argparse\n",
    "import torch.nn as nn\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "from datetime import datetime\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from torch.nn import MSELoss, L1Loss\n",
    "import torch.distributed as dist\n",
    "import torch.multiprocessing as mp\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP\n",
    "from torch.serialization import save\n",
    "from gnn.model.metric import EarlyStopping\n",
    "from gnn.model.gated_solv_network import GatedGCNSolvationNetwork, InteractionMap, SelfInteractionMap\n",
    "from gnn.data.dataset import SolvationDataset, train_validation_test_split, solvent_split, element_split, substructure_split, stratified_solvent_split, stratified_split\n",
    "from gnn.data.dataloader import DataLoaderSolvation\n",
    "from gnn.data.grapher import HeteroMoleculeGraph\n",
    "from gnn.data.featurizer import (\n",
    "    SolventAtomFeaturizer,\n",
    "    BondAsNodeFeaturizerFull,\n",
    "    SolventGlobalFeaturizer,\n",
    ")\n",
    "from gnn.data.solvent_graph import HeteroMoleculeGraph2\n",
    "from gnn.data.dataset import load_mols_labels\n",
    "from gnn.utils import (\n",
    "    load_checkpoints,\n",
    "    pickle_load,\n",
    "    save_checkpoints,\n",
    "    seed_torch,\n",
    "    pickle_dump,\n",
    "    yaml_dump,\n",
    ")\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "ls=[]\n",
    "p=[]\n",
    "def accuracy(pred, label):\n",
    "    with torch.no_grad():\n",
    "        # 将未经过 Sigmoid 的输出进行 Sigmoid 激活\n",
    "        pred_probs = torch.sigmoid(pred)\n",
    "        # 将概率值转换为二分类的预测值，即大于等于0.5的为1，小于0.5的为0\n",
    "        pred_labels = (pred_probs >= 0.5).float()\n",
    "        # 计算正确预测的数量\n",
    "        correct = (pred_labels == label).sum().item()\n",
    "        # 计算总样本数量\n",
    "        total = label.size(0)\n",
    "        # 计算准确率\n",
    "        acc = correct\n",
    "        #print(acc)\n",
    "    return acc\n",
    "def parse_args():\n",
    "    parser = argparse.ArgumentParser(description=\"GatedSolvationNetwork\")\n",
    "\n",
    "    # input files and global variables\n",
    "    # parser.add_argument('--dataset-file', type=str, default=\"data/Deepddi.csv\")\n",
    "    # parser.add_argument('--dataset-pickle', type=str, default=\"data/Deepddi.csv\")\n",
    "    \n",
    "    parser.add_argument('--dataset-file', type=str, default=\"data/CHCHMiner.csv\")\n",
    "    parser.add_argument('--dataset-pickle', type=str, default=\"data/CHCHMiner.csv\")\n",
    "    \n",
    "    parser.add_argument('--dielectric-constants', type=str, default=None)\n",
    "    parser.add_argument('--molecular-refractivity', type=bool, default=False)\n",
    "    parser.add_argument('--molecular-volume', type=bool, default=False)\n",
    "\n",
    "    # output dir\n",
    "    parser.add_argument('--save-dir', type=str, default=\"result_model/train_file\")\n",
    "\n",
    "    # training params\n",
    "    parser.add_argument('--random-seed', type=int, default=50)\n",
    "    parser.add_argument('--feature-scaling', type=bool, default=True)\n",
    "    parser.add_argument('--solvent-split', type=str, default=None)\n",
    "    parser.add_argument('--solvent-stratified-split', type=str, default=None)\n",
    "    parser.add_argument('--solvent-stratified-frac', type=float, default=0.1)\n",
    "    parser.add_argument('--stratified-split', type=bool, default=False)\n",
    "    parser.add_argument('--element-split', type=str, default=None)\n",
    "    parser.add_argument('--scaffold-split', type=bool, default=False)\n",
    "    parser.add_argument('--attention-map', type=str, default=None)\n",
    "    parser.add_argument('--partial-charges', type=str, default=None)\n",
    "\n",
    "\n",
    "    # embedding layer\n",
    "    parser.add_argument(\"--embedding-size\", type=int, default=48)\n",
    "\n",
    "    # gated layer\n",
    "    parser.add_argument(\"--gated-num-layers\", type=int, default=3)\n",
    "    parser.add_argument(\"--gated-hidden-size\", type=int, nargs=\"+\", default=[400])\n",
    "    parser.add_argument(\"--gated-num-fc-layers\", type=int, default=3)\n",
    "    parser.add_argument(\"--gated-graph-norm\", type=int, default=0)\n",
    "    parser.add_argument(\"--gated-batch-norm\", type=int, default=0)\n",
    "    parser.add_argument(\"--gated-activation\", type=str, default=\"LeakyReLU\")\n",
    "    parser.add_argument(\"--gated-residual\", type=int, default=1)\n",
    "    parser.add_argument(\"--gated-dropout\", type=float, default=0.0)\n",
    "\n",
    "    # readout layer\n",
    "    parser.add_argument(\n",
    "        \"--num-lstm-iters\",\n",
    "        type=int,\n",
    "        default=6,\n",
    "        help=\"number of iterations for the LSTM in set2set readout layer\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--num-lstm-layers\",\n",
    "        type=int,\n",
    "        default=3,\n",
    "        help=\"number of layers for the LSTM in set2set readout layer\",\n",
    "    )\n",
    "\n",
    "    # fc layer\n",
    "    parser.add_argument(\"--fc-num-layers\", type=int, default=4)\n",
    "    parser.add_argument(\"--fc-hidden-size\", type=int, nargs=\"+\", default=[300])\n",
    "    parser.add_argument(\"--fc-batch-norm\", type=int, default=0)\n",
    "    parser.add_argument(\"--fc-activation\", type=str, default=\"LeakyReLU\")\n",
    "    parser.add_argument(\"--fc-dropout\", type=float, default=0.5)\n",
    "\n",
    "    # training\n",
    "    parser.add_argument(\"--start-epoch\", type=int, default=1)\n",
    "    parser.add_argument(\"--epochs\", type=int, default=100, help=\"number of epochs\")\n",
    "    parser.add_argument(\"--batch-size\", type=int, default=16, help=\"batch size\")\n",
    "    parser.add_argument(\"--lr\", type=float, default=0.0001, help=\"learning rate\")\n",
    "    parser.add_argument(\"--weight-decay\", type=float, default=0.0, help=\"weight decay\")\n",
    "    parser.add_argument(\"--restore\", type=int, default=0, help=\"read checkpoints\")\n",
    "    parser.add_argument(\"--load-dataset\", type=int, default=0, help=\"read dataset\")\n",
    "    parser.add_argument(\n",
    "        \"--dataset-state-dict-filename\", type=str, default=\"dataset_state_dict.pkl\"\n",
    "    )\n",
    "    # gpu\n",
    "    parser.add_argument(\n",
    "        \"--gpu\", type=int, default=2, help=\"GPU index. None to use CPU.\"\n",
    "    )\n",
    "\n",
    "    parser.add_argument(\n",
    "        \"--distributed\",\n",
    "        type=int,\n",
    "        default=0,\n",
    "        help=\"DDP training, --gpu is ignored if this is True\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--num-gpu\",\n",
    "        type=int,\n",
    "        default=1,\n",
    "        help=\"Number of GPU to use in distributed mode; ignored otherwise.\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--dist-url\",\n",
    "        default=\"tcp://localhost:13456\",\n",
    "        type=str,\n",
    "        help=\"url used to set up distributed training\",\n",
    "    )\n",
    "    \n",
    "    parser.add_argument(\"--dist-backend\", type=str, default=\"nccl\")\n",
    "\n",
    "    # output file (needed by hypertunity)\n",
    "    parser.add_argument(\"--output_file\", type=str, default=\"results.pkl\")\n",
    "\n",
    "    args = parser.parse_args(args=[])\n",
    "    if len(args.gated_hidden_size) == 1:\n",
    "        args.gated_hidden_size = args.gated_hidden_size * args.gated_num_layers\n",
    "    else:\n",
    "        assert len(args.gated_hidden_size) == args.gated_num_layers, (\n",
    "            \"length of `gat-hidden-size` should be equal to `num-gat-layers`, but got \"\n",
    "            \"{} and {}.\".format(args.gated_hidden_size, args.gated_num_layers)\n",
    "        )\n",
    "\n",
    "    if len(args.fc_hidden_size) == 1:\n",
    "        val = 2 * args.gated_hidden_size[-1]\n",
    "        args.fc_hidden_size = [max(val // 2 ** i, 8) for i in range(args.fc_num_layers)]\n",
    "    else:\n",
    "        assert len(args.fc_hidden_size) == args.fc_num_layers, (\n",
    "            \"length of `fc-hidden-size` should be equal to `num-fc-layers`, but got \"\n",
    "            \"{} and {}.\".format(args.fc_hidden_size, args.fc_num_layers)\n",
    "        )\n",
    "    return args\n",
    "\n",
    "def train(optimizer, model, nodes,nodes1, data_loader, loss_fn, accuracy_fn, device=None):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        accuracy_fn (function): the function should be using a `sum` reduction method.\n",
    "    \"\"\"\n",
    "\n",
    "    model.train()\n",
    "\n",
    "    epoch_loss = 0.0\n",
    "    accuracy = 0.0\n",
    "    count = 0.0\n",
    "\n",
    "    for it, (solute_batched_graph, solvent_batched_graph, label) in enumerate(data_loader):\n",
    "        solute_feats = {nt: solute_batched_graph.nodes[nt].data[\"feat\"] for nt in nodes1}\n",
    "        solvent_feats = {nt: solvent_batched_graph.nodes[nt].data[\"feat\"] for nt in nodes}\n",
    "        target = torch.squeeze(label[\"value\"])\n",
    "        #print(target)\n",
    "        solute_norm_atom = label[\"solute_norm_atom\"]\n",
    "        solute_norm_bond = label[\"solute_norm_bond\"]\n",
    "        solvent_norm_atom = label[\"solvent_norm_atom\"]\n",
    "        solvent_norm_bond = label[\"solvent_norm_bond\"]\n",
    "        #stdev = label[\"scaler_stdev\"]\n",
    "\n",
    "        if device is not None:\n",
    "            solute_feats = {k: v.to(device) for k, v in solute_feats.items()}\n",
    "            solvent_feats = {k: v.to(device) for k, v in solvent_feats.items()}\n",
    "            target = target.to(device)\n",
    "            solute_norm_atom = solute_norm_atom.to(device)\n",
    "            solute_norm_bond = solute_norm_bond.to(device)\n",
    "            solvent_norm_atom = solvent_norm_atom.to(device)\n",
    "            solvent_norm_bond = solvent_norm_bond.to(device)\n",
    "            #stdev = stdev.to(device)\n",
    "        #print(solute_feats)\n",
    "        pred,feats_list,loss_vq = model(solute_batched_graph, solvent_batched_graph, solute_feats, \n",
    "                     solvent_feats, solute_norm_atom, solute_norm_bond, \n",
    "                     solvent_norm_atom, solvent_norm_bond)\n",
    "        pred = pred.view(-1)\n",
    "        target = target.view(-1)\n",
    "        #print(pred)\n",
    "        #print(\"********\")\n",
    "        #print(target)\n",
    "        #print(loss_1)\n",
    "        #loss = loss_fn(pred, target)+loss_1\n",
    "        pred_total_loss=0\n",
    "        pred_total_loss_squared = 0\n",
    "        for _outputs in feats_list:\n",
    "            _outputs = _outputs.view(-1)\n",
    "            _loss = loss_fn(_outputs, target)\n",
    "            #_loss = loss_function_BCE(_outputs, samples[2].reshape(-1, 1).to(self.device).float()).mean()\n",
    "            pred_total_loss += _loss\n",
    "            pred_total_loss_squared += (_loss)**2\n",
    "        # 计算所有损失的平均值\n",
    "        pred_average_loss = pred_total_loss / len(feats_list)\n",
    "        pred_loss_variance = (pred_total_loss_squared / len(feats_list)) - (pred_average_loss ** 2)\n",
    "        \n",
    "        loss = pred_average_loss+pred_loss_variance+loss_vq\n",
    "        #loss = loss_vq\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        epoch_loss += loss.detach().item()\n",
    "        accuracy += accuracy_fn(pred, target)\n",
    "        count += len(target)\n",
    "    \n",
    "    epoch_loss /= it + 1\n",
    "    accuracy /= count\n",
    "\n",
    "    return epoch_loss, accuracy\n",
    "\n",
    "def evaluate(model, nodes,nodes1, data_loader, accuracy_fn, scaler = None, device=None, return_preds=False):\n",
    "    \"\"\"\n",
    "    Evaluate the accuracy of a validation set of test set.\n",
    "    Args:\n",
    "        accuracy_fn (function): the function should be using a `sum` reduction method.\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        accuracy = 0.0\n",
    "        count = 0.0\n",
    "\n",
    "        preds = []\n",
    "        y_true = []\n",
    "\n",
    "        for solute_batched_graph, solvent_batched_graph, label in data_loader:\n",
    "            solute_feats = {nt: solute_batched_graph.nodes[nt].data[\"feat\"] for nt in nodes1}\n",
    "            solvent_feats = {nt: solvent_batched_graph.nodes[nt].data[\"feat\"] for nt in nodes}\n",
    "            target = torch.squeeze(label[\"value\"])\n",
    "            #stdev = label[\"scaler_stdev\"]\n",
    "            solvent_norm_atom = label[\"solvent_norm_atom\"]\n",
    "            solvent_norm_bond = label[\"solvent_norm_bond\"]\n",
    "            solute_norm_atom = label[\"solute_norm_atom\"]\n",
    "            solute_norm_bond = label[\"solute_norm_bond\"]\n",
    "\n",
    "            if device is not None:\n",
    "                solute_feats = {k: v.to(device) for k, v in solute_feats.items()}\n",
    "                solvent_feats = {k: v.to(device) for k, v in solvent_feats.items()}\n",
    "                target = target.to(device)\n",
    "                solute_norm_atom = solute_norm_atom.to(device)\n",
    "                solute_norm_bond = solute_norm_bond.to(device)\n",
    "                solvent_norm_atom = solvent_norm_atom.to(device)\n",
    "                solvent_norm_bond = solvent_norm_bond.to(device)\n",
    "\n",
    "            pred,feats_list,loss_vq  = model(solute_batched_graph, solvent_batched_graph, solute_feats, \n",
    "                     solvent_feats, solute_norm_atom, solute_norm_bond, \n",
    "                     solvent_norm_atom, solvent_norm_bond)\n",
    "            pred = pred.view(-1)\n",
    "            target = target.view(-1)\n",
    "\n",
    "            # Inverse scaler\n",
    "            if scaler is not None:\n",
    "                pred = scaler.inverse_transform(pred.cpu())\n",
    "                pred = pred.to(device)\n",
    "\n",
    "            accuracy += accuracy_fn(pred, target)\n",
    "            count += len(target)\n",
    "            #print(\"----------------------\")\n",
    "            #print(pred)\n",
    "            #print(\"===========\")\n",
    "            #print(target)\n",
    "            #print(\"----------------------\")\n",
    "            batch_pred = pred.tolist()\n",
    "            batch_target = target.tolist()\n",
    "            preds.extend(batch_pred)\n",
    "            y_true.extend(batch_target)\n",
    "\n",
    "    if return_preds:\n",
    "        return y_true, preds\n",
    "\n",
    "    else:\n",
    "        return accuracy / count\n",
    "#from gnn.data.grapher import HeteroMoleculeGraph\n",
    "#from gnn.data.featurizer import (\n",
    "#    SolventAtomFeaturizer,\n",
    "#    BondAsNodeFeaturizerFull,\n",
    "#    SolventGlobalFeaturizer,\n",
    "#)\n",
    "def grapher(dielectric_constant=None, mol_volume=False, mol_refract=False, partial_charges=None,lable=False):\n",
    "    atom_featurizer = SolventAtomFeaturizer(partial_charges=partial_charges)\n",
    "    bond_featurizer = BondAsNodeFeaturizerFull(length_featurizer=None, dative=False)\n",
    "    global_featurizer = SolventGlobalFeaturizer(dielectric_constant=dielectric_constant, mol_volume=mol_volume, mol_refract=mol_refract)\n",
    "    if lable:\n",
    "        grapher = HeteroMoleculeGraph(atom_featurizer, bond_featurizer, global_featurizer, self_loop=True)\n",
    "    else:\n",
    "        grapher = HeteroMoleculeGraph2(atom_featurizer, bond_featurizer, global_featurizer, self_loop=True)\n",
    "\n",
    "    return grapher\n",
    "\n",
    "def main_worker(gpu, world_size, args):\n",
    "    \n",
    "    # Explicitly setting seed to ensure the same dataset split and models created in\n",
    "    # two processes (when distributed) start from the same random weights and biases\n",
    "    random_seed = args.random_seed\n",
    "    seed_torch(random_seed)\n",
    "\n",
    "    args.gpu = gpu\n",
    "\n",
    "    if not args.distributed or (args.distributed and args.gpu == 0):\n",
    "        print(\"\\n\\nStart training at: \", datetime.now())\n",
    "\n",
    "    if args.save_dir is None:\n",
    "        args.save_dir = os.getcwd()\n",
    "\n",
    "    if args.distributed:\n",
    "        dist.init_process_group(\n",
    "            args.dist_backend,\n",
    "            init_method = args.dist_url,\n",
    "            world_size = world_size,\n",
    "            rank = args.gpu\n",
    "        )\n",
    "    \n",
    "    if args.restore:\n",
    "        dataset_state_dict_filename = args.dataset_state_dict_filename\n",
    "\n",
    "        if dataset_state_dict_filename is None:\n",
    "            warnings.warn(\"Restore with `args.dataset_state_dict_filename` set to None.\")\n",
    "        elif not Path(dataset_state_dict_filename).exists():\n",
    "            warnings.warn(\n",
    "                f\"`{dataset_state_dict_filename} not found; set \"\n",
    "                f\"args.dataset_state_dict_filename` to None\"\n",
    "            )\n",
    "            dataset_state_dict_filename = None\n",
    "    else:\n",
    "        dataset_state_dict_filename = None\n",
    "\n",
    "    # Load molecules and labels from file\n",
    "    mols, labels = load_mols_labels(args.dataset_file)\n",
    "\n",
    "    if args.load_dataset:\n",
    "        data_dict = args.dataset_pickle\n",
    "        dataset = pickle_load(data_dict)\n",
    "    \n",
    "    else:\n",
    "        if args.dielectric_constants is not None:\n",
    "            dc_file = Path(args.dielectric_constants)        \n",
    "            dataset = SolvationDataset(\n",
    "                solute_grapher = grapher(mol_volume = args.molecular_volume,\n",
    "                                        mol_refract = args.molecular_refractivity,\n",
    "                                        partial_charges=args.partial_charges),\n",
    "                solvent_grapher = grapher(dielectric_constant=True,\n",
    "                                        mol_volume = args.molecular_volume,\n",
    "                                        mol_refract = args.molecular_refractivity,\n",
    "                                        partial_charges=args.partial_charges),\n",
    "                molecules = mols,\n",
    "                labels = labels,\n",
    "                solute_extra_features = None,\n",
    "                solvent_extra_features=dc_file,\n",
    "                feature_transformer = False,\n",
    "                label_transformer= False,\n",
    "                state_dict_filename=dataset_state_dict_filename)\n",
    "\n",
    "        else:\n",
    "            dataset = SolvationDataset(\n",
    "                solute_grapher = grapher(mol_volume=args.molecular_volume, mol_refract = args.molecular_refractivity, partial_charges=args.partial_charges),\n",
    "                solvent_grapher = grapher(mol_volume=args.molecular_volume, mol_refract = args.molecular_refractivity, partial_charges=args.partial_charges,lable=True),\n",
    "                molecules = mols,\n",
    "                labels = labels,\n",
    "                solute_extra_features = None,\n",
    "                solvent_extra_features = None,\n",
    "                feature_transformer = False,\n",
    "                label_transformer= False,\n",
    "                state_dict_filename=dataset_state_dict_filename\n",
    "                )\n",
    "\n",
    "    # Save the solute and solvent graphers for loading datasets later\n",
    "    pickle_dump([dataset.solute_grapher, dataset.solvent_grapher], os.path.join(args.save_dir,\"graphers.pkl\"))\n",
    "\n",
    "    best = np.finfo(np.float32).max\n",
    "    os.makedirs(args.save_dir, exist_ok=True)\n",
    "\n",
    "    # Split data: random, solvent-based split, element-based, or scaffold-based split\n",
    "\n",
    "    possible_solvents = ['hexane', 'water', 'acetone', 'ethanol', 'benzene', 'ethylacetate',\n",
    "               'dichloromethane', 'acetonitrile', 'thf', 'dmso', 'dmf', 'octanol', 'hexadecane', 'cyclohexane']\n",
    "\n",
    "    if (args.solvent_split is None) and (args.element_split is None) and (args.solvent_stratified_split is None) and (args.stratified_split is False) and (args.scaffold_split is False):\n",
    "        print(f'Splitting data using random seed {random_seed}')\n",
    "        trainset, valset, testset = train_validation_test_split(\n",
    "            dataset, validation=0.1, test=0.1, random_seed=args.random_seed)\n",
    "    \n",
    "    elif args.solvent_split is not None:\n",
    "        assert args.solvent_split in possible_solvents, \"Solvent unavailable! Choose from: hexane, cyclohexane, water, acetone, ethanol, benzene, ethylacetate, dichloromethane, acetonitrile, thf, dmso\"\n",
    "        print(f'Using compounds with {args.solvent_split} solvent as test data.')\n",
    "        trainset, valset, testset = solvent_split(\n",
    "            dataset, args.solvent_split, random_seed=args.random_seed)\n",
    "    elif args.solvent_stratified_split is not None:\n",
    "        assert args.solvent_stratified_split in possible_solvents, \"Solvent unavailable! Choose from: hexane, cyclohexane, water, acetone, ethanol, benzene, ethylacetate, dichloromethane, acetonitrile, thf, dmso\"\n",
    "        print(f'Using {1-args.solvent_stratified_frac}% of {args.solvent_stratified_split} solvent as test data.')\n",
    "        trainset, valset, testset = stratified_solvent_split(\n",
    "            dataset, args.solvent_stratified_split, frac=args.solvent_stratified_frac, random_seed=args.random_seed)\n",
    "    \n",
    "    elif args.scaffold_split is True:\n",
    "        trainset, valset, testset = substructure_split(\n",
    "            dataset, random_seed=args.random_seed)\n",
    "    \n",
    "    elif args.stratified_split is True:\n",
    "        trainset, valset, testset = stratified_split(\n",
    "            dataset, random_seed=args.random_seed)\n",
    "    \n",
    "    elif args.element_split is not None: # element split\n",
    "        possible_elems = ['Br', 'Cl', 'F', 'I', 'N', 'O', 'S']\n",
    "        elem = args.element_split\n",
    "        assert elem in possible_elems, \"Element unavailable! Choose from: 'Br', 'Cl', 'F', 'I', 'N', 'O', 'S'\"\n",
    "        print(f'Placing all solutes with {elem} atoms into the test dataset.')\n",
    "        trainset, valset, testset = element_split(dataset, elem, random_seed=args.random_seed)\n",
    "\n",
    "    # Scale training dataset features\n",
    "    if args.feature_scaling:\n",
    "        solute_features_scaler,solvent_features_scaler= trainset.normalize_features()\n",
    "        #solute_features_scaler, solvent_features_scaler = trainset.normalize_features()\n",
    "        valset.normalize_features(solute_features_scaler, solvent_features_scaler)\n",
    "        testset.normalize_features(solute_features_scaler, solvent_features_scaler)\n",
    "        #testset.normalize_features(solute_features_scaler)\n",
    "    else:\n",
    "        solute_features_scaler, solvent_features_scaler = None, None\n",
    "    \n",
    "    #label_scaler = trainset.normalize_labels()\n",
    "    label_scaler = None\n",
    "    if not args.distributed or (args.distributed and args.gpu == 0):\n",
    "        torch.save(dataset.state_dict(), os.path.join(args.save_dir, args.dataset_state_dict_filename))\n",
    "        print(\n",
    "            \"Trainset size: {}, valset size: {}: testset size: {}.\".format(\n",
    "                len(trainset), len(valset), len(testset)\n",
    "            )\n",
    "        )\n",
    "    if args.distributed:\n",
    "        train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)\n",
    "    else:\n",
    "        train_sampler = None\n",
    "    \n",
    "    train_loader = DataLoaderSolvation(\n",
    "        trainset,\n",
    "        batch_size = args.batch_size,\n",
    "        shuffle = (train_sampler is None),\n",
    "        sampler = train_sampler\n",
    "    )\n",
    "    # larger val and test set batch_size is faster but needs more memory\n",
    "    # adjust the batch size of val and test set to fit memory\n",
    "    bs = max(len(valset) // 10, 1)\n",
    "    val_loader = DataLoaderSolvation(valset, batch_size=bs, shuffle=False)\n",
    "    bs = max(len(testset) // 10, 1)\n",
    "    test_loader = DataLoaderSolvation(testset, batch_size=bs, shuffle=False)\n",
    "    ### model\n",
    "    feature_names = [\"atom\", \"bond\", \"global\"]\n",
    "    solute_feature_names = [\"atom\", \"bond\",\"atom2\", \"bond2\", \"global\"]\n",
    "    set2set_ntypes_direct = [\"global\"]\n",
    "    solute_feature_size = dataset.feature_sizes[0]\n",
    "    solute_feature_size ={'bond': 11,'atom': 41,'atom2': 41, 'bond2': 11, 'global': 3}\n",
    "    solvent_feature_size = dataset.feature_sizes[1]\n",
    "    args.solute_feature_size = solute_feature_size\n",
    "    args.solvent_feature_size = solvent_feature_size\n",
    "    args.set2set_ntypes_direct = set2set_ntypes_direct\n",
    "    # save args\n",
    "    if not args.distributed or (args.distributed and args.gpu == 0):\n",
    "        yaml_dump(args, os.path.join(args.save_dir, \"train_args.yaml\"))\n",
    "\n",
    "    if args.attention_map == 'cross':\n",
    "        model = InteractionMap(\n",
    "            solute_in_feats=args.solute_feature_size,\n",
    "            solvent_in_feats=args.solvent_feature_size,\n",
    "            embedding_size=args.embedding_size,\n",
    "            gated_num_layers=args.gated_num_layers,\n",
    "            gated_hidden_size=args.gated_hidden_size,\n",
    "            gated_num_fc_layers=args.gated_num_fc_layers,\n",
    "            gated_graph_norm=args.gated_graph_norm,\n",
    "            gated_batch_norm=args.gated_batch_norm,\n",
    "            gated_activation=args.gated_activation,\n",
    "            gated_residual=args.gated_residual,\n",
    "            gated_dropout=args.gated_dropout,\n",
    "            num_lstm_iters=args.num_lstm_iters,\n",
    "            num_lstm_layers=args.num_lstm_layers,\n",
    "            set2set_ntypes_direct=args.set2set_ntypes_direct,\n",
    "            fc_num_layers=args.fc_num_layers,\n",
    "            fc_hidden_size=args.fc_hidden_size,\n",
    "            fc_batch_norm=args.fc_batch_norm,\n",
    "            fc_activation=args.fc_activation,\n",
    "            fc_dropout=args.fc_dropout,\n",
    "            outdim=1,\n",
    "            conv=\"GatedGCNConv\",\n",
    "        )\n",
    "        \n",
    "    elif args.attention_map == 'self':\n",
    "        model = SelfInteractionMap(\n",
    "            solute_in_feats=args.solute_feature_size,\n",
    "            solvent_in_feats=args.solvent_feature_size,\n",
    "            embedding_size=args.embedding_size,\n",
    "            gated_num_layers=args.gated_num_layers,\n",
    "            gated_hidden_size=args.gated_hidden_size,\n",
    "            gated_num_fc_layers=args.gated_num_fc_layers,\n",
    "            gated_graph_norm=args.gated_graph_norm,\n",
    "            gated_batch_norm=args.gated_batch_norm,\n",
    "            gated_activation=args.gated_activation,\n",
    "            gated_residual=args.gated_residual,\n",
    "            gated_dropout=args.gated_dropout,\n",
    "            num_lstm_iters=args.num_lstm_iters,\n",
    "            num_lstm_layers=args.num_lstm_layers,\n",
    "            set2set_ntypes_direct=args.set2set_ntypes_direct,\n",
    "            fc_num_layers=args.fc_num_layers,\n",
    "            fc_hidden_size=args.fc_hidden_size,\n",
    "            fc_batch_norm=args.fc_batch_norm,\n",
    "            fc_activation=args.fc_activation,\n",
    "            fc_dropout=args.fc_dropout,\n",
    "            outdim=1,\n",
    "            conv=\"GatedGCNConv\",\n",
    "        )\n",
    "    else:\n",
    "        model = GatedGCNSolvationNetwork(\n",
    "            solute_in_feats=args.solute_feature_size,\n",
    "            solvent_in_feats=args.solvent_feature_size,\n",
    "            embedding_size=args.embedding_size,\n",
    "            gated_num_layers=args.gated_num_layers,\n",
    "            gated_hidden_size=args.gated_hidden_size,\n",
    "            gated_num_fc_layers=args.gated_num_fc_layers,\n",
    "            gated_graph_norm=args.gated_graph_norm,\n",
    "            gated_batch_norm=args.gated_batch_norm,\n",
    "            gated_activation=args.gated_activation,\n",
    "            gated_residual=args.gated_residual,\n",
    "            gated_dropout=args.gated_dropout,\n",
    "            num_lstm_iters=args.num_lstm_iters,\n",
    "            num_lstm_layers=args.num_lstm_layers,\n",
    "            set2set_ntypes_direct=args.set2set_ntypes_direct,\n",
    "            fc_num_layers=args.fc_num_layers,\n",
    "            fc_hidden_size=args.fc_hidden_size,\n",
    "            fc_batch_norm=args.fc_batch_norm,\n",
    "            fc_activation=args.fc_activation,\n",
    "            fc_dropout=args.fc_dropout,\n",
    "            outdim=1,\n",
    "            conv=\"GatedGCNConv\",\n",
    "        )\n",
    "    # if not args.distributed or (args.distributed and args.gpu == 0):\n",
    "    #     print(model)\n",
    "\n",
    "    print(f'Model type: {type(model)}')\n",
    "\n",
    "    if args.gpu is not None:\n",
    "        model.to(args.gpu)\n",
    "    if args.distributed:\n",
    "        ddp_model = DDP(model, device_ids=[args.gpu])\n",
    "        ddp_model.feature_before_fc = model.feature_before_fc\n",
    "        model = ddp_model\n",
    "    ### optimizer, loss, and accuracy\n",
    "    optimizer = torch.optim.Adam(\n",
    "        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n",
    "    )\n",
    "    loss_func = nn.BCEWithLogitsLoss()\n",
    "    ### learning rate scheduler and stopper\n",
    "    scheduler = ReduceLROnPlateau(\n",
    "        optimizer, mode=\"min\", factor=0.4, patience=50, verbose=True\n",
    "    )\n",
    "    stopper = EarlyStopping(patience=150)\n",
    "    # load checkpoint\n",
    "    state_dict_objs = {\"model\": model, \"optimizer\": optimizer, \"scheduler\": scheduler}\n",
    "    if args.restore:\n",
    "        try:\n",
    "            if args.gpu is None:\n",
    "                checkpoint = load_checkpoints(state_dict_objs, save_dir=args.save_dir, filename=\"checkpoint.pkl\")\n",
    "            else:\n",
    "                # Map model to be loaded to specified single gpu.\n",
    "                loc = \"cuda:{}\".format(args.gpu)\n",
    "                checkpoint = load_checkpoints(\n",
    "                    state_dict_objs, map_location=loc, save_dir=args.save_dir, filename=\"checkpoint.pkl\"\n",
    "                )\n",
    "            args.start_epoch = checkpoint[\"epoch\"]\n",
    "            best = checkpoint[\"best\"]\n",
    "            print(f\"Successfully load checkpoints, best {best}, epoch {args.start_epoch}\")\n",
    "        except FileNotFoundError as e:\n",
    "            warnings.warn(str(e) + \" Continue without loading checkpoints.\")\n",
    "            pass\n",
    "    # start training\n",
    "    if not args.distributed or (args.distributed and args.gpu == 0):\n",
    "        print(\"\\n\\n# Epoch     Loss         TrainAcc        ValAcc     Time (s)\")\n",
    "        sys.stdout.flush()\n",
    "    for epoch in range(args.start_epoch, args.epochs):\n",
    "        ti = time.time()\n",
    "        # In distributed mode, calling the set_epoch method is needed to make shuffling\n",
    "        # work; each process will use the same random seed otherwise.\n",
    "        if args.distributed:\n",
    "            train_sampler.set_epoch(epoch)\n",
    "        # train\n",
    "        loss, train_acc = train(\n",
    "            optimizer, model, feature_names, solute_feature_names,train_loader, loss_func, accuracy, args.gpu)\n",
    "        # bad, we get nan\n",
    "        if np.isnan(loss):\n",
    "            print(\"\\n\\nBad, we get nan for loss. Exiting\")\n",
    "            sys.stdout.flush()\n",
    "            sys.exit(1)\n",
    "        # evaluate\n",
    "        val_acc = evaluate(model, feature_names, solute_feature_names,val_loader, accuracy, label_scaler, args.gpu)\n",
    "        if stopper.step(val_acc):\n",
    "            pickle_dump(best, os.path.join(args.save_dir, args.output_file))  # save results for hyperparam tune\n",
    "            break\n",
    "        scheduler.step(val_acc)\n",
    "        is_best = val_acc < best\n",
    "        if is_best:\n",
    "            best = val_acc\n",
    "        # save checkpoint\n",
    "        if not args.distributed or (args.distributed and args.gpu == 0):\n",
    "            misc_objs = {\"best\": best, \"epoch\": epoch}\n",
    "            scaler_objs = {'label_scaler': {\n",
    "                            'means': label_scaler.mean,\n",
    "                            'stds': label_scaler.std\n",
    "                            } if label_scaler is not None else None,\n",
    "                            'solute_features_scaler': {\n",
    "                            'means': solute_features_scaler.mean,\n",
    "                            'stds': solute_features_scaler.std\n",
    "                            } if solute_features_scaler is not None else None,\n",
    "                            'solvent_features_scaler': {\n",
    "                            'means': solvent_features_scaler.mean,\n",
    "                            'stds': solvent_features_scaler.std\n",
    "                            } if solvent_features_scaler is not None else None}\n",
    "            save_checkpoints(\n",
    "                state_dict_objs,\n",
    "                misc_objs,\n",
    "                scaler_objs,\n",
    "                is_best,\n",
    "                msg=f\"epoch: {epoch}, score {val_acc}\",\n",
    "                save_dir=args.save_dir)\n",
    "            tt = time.time() - ti\n",
    "            print(\n",
    "                \"{:5d}   {:12.6e}   {:12.6e}   {:12.6e}   {:.2f}\".format(\n",
    "                    epoch, loss, train_acc, val_acc, tt\n",
    "                )\n",
    "            )\n",
    "            ls.append( val_acc)\n",
    "            if epoch % 10 == 0:\n",
    "                sys.stdout.flush()\n",
    "    # load best to calculate test accuracy\n",
    "    if args.gpu is None:\n",
    "        checkpoint = load_checkpoints(state_dict_objs, args.save_dir, filename=\"best_checkpoint.pkl\")\n",
    "    else:\n",
    "        # Map model to be loaded to specified single  gpu.\n",
    "        loc = \"cuda:{}\".format(args.gpu)\n",
    "        checkpoint = load_checkpoints(\n",
    "            state_dict_objs, map_location=loc, save_dir=args.save_dir, filename=\"best_checkpoint.pkl\"\n",
    "        )\n",
    "    \n",
    "    if not args.distributed or (args.distributed and args.gpu == 0):\n",
    "        test_acc = evaluate(model, feature_names,solute_feature_names, test_loader, accuracy, label_scaler, args.gpu)\n",
    "        y_true, y_pred = evaluate(model, feature_names,solute_feature_names, test_loader, accuracy, \n",
    "                                    label_scaler, args.gpu, return_preds=True)\n",
    "        \n",
    "        print(len(y_true))\n",
    "        print(len(y_pred))\n",
    "        print(\"\\n#Test MAE: {:12.6e} \\n\".format(test_acc))\n",
    "        print(\"\\n#Test RMSE: {:12.6e} \\n\".format(mean_squared_error(y_true, y_pred, squared=False)))\n",
    "        print(\"\\nFinish training at:\", datetime.now())\n",
    "        p.append(mean_squared_error(y_true, y_pred, squared=False))\n",
    "        results_dict = {'y_true': y_true, 'y_pred': y_pred}\n",
    "        pickle_dump(results_dict, os.path.join(args.save_dir, f'seed_{random_seed}_test_results.pkl'))\n",
    "\n",
    "\n",
    "def main():\n",
    "    args = parse_args()\n",
    "    print(args)\n",
    "\n",
    "    if args.save_dir is not None:\n",
    "        os.makedirs(args.save_dir, exist_ok=True)\n",
    "\n",
    "    logging.basicConfig(\n",
    "    filename=os.path.join(args.save_dir, '{}.log'.format(\n",
    "        datetime.now().strftime(\"gnn_%Y_%m_%d-%I_%M_%p\"))),\n",
    "    format=\"%(asctime)s:%(name)s:%(levelname)s: %(message)s\",\n",
    "    level=logging.INFO,\n",
    "    )\n",
    "\n",
    "    if args.distributed:\n",
    "        # DDP\n",
    "        world_size = torch.cuda.device_count() if args.num_gpu is None else args.num_gpu\n",
    "        mp.spawn(main_worker, nprocs=world_size, args=(world_size, args))\n",
    "\n",
    "    else:\n",
    "        # train on CPU or a single GPU\n",
    "        main_worker(args.gpu, None, args)\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    #torch.autograd.set_detect_anomaly(True)\n",
    "    main()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "zs",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
