{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "28f00531",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import utils.args_parser  as argtools\n",
    "import pytorch_lightning as pl\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6076a2c",
   "metadata": {},
   "source": [
    "# LOAD CONFIG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5ce904cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "use_custom_dataset = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "babcb0b2",
   "metadata": {},
   "source": [
    "### Option 1: Datasets from the paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "89eee561",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not use_custom_dataset:\n",
    "    print('Using dataset from the paper')\n",
    "    dataset_file =  os.path.join('_params', 'dataset_adult.yaml')\n",
    "    model_file =   os.path.join('_params', 'model_vaca.yaml')\n",
    "    trainer_file =   os.path.join('_params', 'trainer.yaml')\n",
    "\n",
    "    yaml_file = ''\n",
    "    \n",
    "    if yaml_file == '':\n",
    "        cfg = argtools.parse_args(dataset_file)\n",
    "        cfg.update(argtools.parse_args(model_file))\n",
    "        cfg.update(argtools.parse_args(trainer_file))\n",
    "    else:\n",
    "        cfg = argtools.parse_args(yaml_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f11780b1",
   "metadata": {},
   "source": [
    "### Option 2: New dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bae3aa70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using custom dataset\n"
     ]
    }
   ],
   "source": [
    "if use_custom_dataset:\n",
    "    print('Using custom dataset')\n",
    "    model_file =   os.path.join('_params', 'model_vaca.yaml')\n",
    "    trainer_file =   os.path.join('_params', 'trainer.yaml')\n",
    "\n",
    "    yaml_file = ''\n",
    "    if yaml_file == '':\n",
    "        cfg = argtools.parse_args(model_file)\n",
    "        cfg.update(argtools.parse_args(trainer_file))\n",
    "    else:\n",
    "        cfg = argtools.parse_args(yaml_file)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "125e8421",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Config for new dataset\n",
    "\n",
    "cfg['dataset'] = {\n",
    "    'name': '2nodes',\n",
    "    'params1': {},\n",
    "    'params2': {}\n",
    "}\n",
    "\n",
    "cfg['dataset']['params1'] = {\n",
    "    'data_dir': '../Data',\n",
    "    'batch_size': 1000,\n",
    "    'num_workers': 0\n",
    "}\n",
    "\n",
    "cfg['dataset']['params2'] = {\n",
    "    'num_samples_tr': 5000,\n",
    "    'equations_type': 'linear',\n",
    "    'normalize': 'lik',\n",
    "    'likelihood_names': 'd',\n",
    "    'lambda_': 0.05,\n",
    "    'normalize_A': None,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "194f6abc",
   "metadata": {},
   "source": [
    "### You can also update any parameter manually"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "09a87058",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 10\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimizer:\n",
      "\tname: adam\n",
      "\tparams:\n",
      "\t\tlr: 0.005\n",
      "\t\tbetas: [0.9, 0.999]\n",
      "\t\tweight_decay: 1.2e-06\n",
      "scheduler:\n",
      "\tname: exp_lr\n",
      "\tparams:\n",
      "\t\tgamma: 0.99\n",
      "model:\n",
      "\tname: vaca\n",
      "\tparams:\n",
      "\t\tarchitecture: dgnn\n",
      "\t\testimator: elbo\n",
      "\t\th_dim_list_dec: [8, 8]\n",
      "\t\th_dim_list_enc: [16]\n",
      "\t\tz_dim: 4\n",
      "\t\tdistr_z: normal\n",
      "\t\tdropout_adj_rate: 0.0\n",
      "\t\tdropout_adj_pa_rate: 0.2\n",
      "\t\tdropout_adj_pa_prob_keep_self: 0.0\n",
      "\t\tresidual: 0.0\n",
      "\t\tnorm_categorical: 0\n",
      "seed: 10\n",
      "root_dir: results\n",
      "early_stopping: True\n",
      "trainer:\n",
      "\tmax_epochs: 20\n",
      "\tmin_epochs: 10\n",
      "\tlimit_train_batches: 1.0\n",
      "\tlimit_val_batches: 1.0\n",
      "\tlimit_test_batches: 1.0\n",
      "\tcheck_val_every_n_epoch: 10\n",
      "\tprogress_bar_refresh_rate: 1\n",
      "\tflush_logs_every_n_steps: 100\n",
      "\tlog_every_n_steps: 2\n",
      "\tprecision: 32\n",
      "\tterminate_on_nan: True\n",
      "\tauto_select_gpus: True\n",
      "\tdeterministic: True\n",
      "\tweights_summary: None\n",
      "\tgpus: None\n",
      "\tnum_sanity_val_steps: 2\n",
      "\ttrack_grad_norm: -1\n",
      "\tgradient_clip_val: 0.0\n",
      "dataset:\n",
      "\tname: 2nodes\n",
      "\tparams1:\n",
      "\t\tdata_dir: ../Data\n",
      "\t\tbatch_size: 1000\n",
      "\t\tnum_workers: 0\n",
      "\tparams2:\n",
      "\t\tnum_samples_tr: 5000\n",
      "\t\tequations_type: linear\n",
      "\t\tnormalize: lik\n",
      "\t\tlikelihood_names: d\n",
      "\t\tlambda_: 0.05\n",
      "\t\tnormalize_A: None\n",
      "\tparams:\n",
      "\t\tdata_dir: \n",
      "\t\tbatch_size: 1000\n",
      "\t\tnum_workers: 0\n",
      "\t\tnum_samples_tr: 5000\n",
      "\t\tequations_type: linear\n",
      "\t\tnormalize: lik\n",
      "\t\tlikelihood_names: d\n",
      "\t\tlambda_: 0.05\n",
      "\t\tnormalize_A: None\n"
     ]
    }
   ],
   "source": [
    "    \n",
    "cfg['root_dir'] = 'results'\n",
    "cfg['seed'] = 10\n",
    "pl.seed_everything(cfg['seed'])\n",
    "\n",
    "cfg['dataset']['params'] = cfg['dataset']['params1'].copy()\n",
    "cfg['dataset']['params'].update(cfg['dataset']['params2'])\n",
    "\n",
    "cfg['dataset']['params']['data_dir'] = ''\n",
    "\n",
    "cfg['trainer']['limit_train_batches'] = 1.0\n",
    "cfg['trainer']['limit_val_batches'] = 1.0\n",
    "cfg['trainer']['limit_test_batches'] = 1.0\n",
    "cfg['trainer']['check_val_every_n_epoch'] = 10\n",
    "\n",
    "\n",
    "def print_if_not_dict(key, value, extra=''):\n",
    "    if not isinstance(value, dict):\n",
    "        print(f\"{extra}{key}: {value}\")\n",
    "        return True\n",
    "    else:\n",
    "        print(f\"{extra}{key}:\")\n",
    "        False\n",
    "        \n",
    "for key, value in cfg.items():\n",
    "    if not print_if_not_dict(key, value):\n",
    "        for key2, value2 in value.items():\n",
    "            if not print_if_not_dict(key2, value2, extra='\\t'):\n",
    "                for key3, value3 in value2.items():\n",
    "                    print_if_not_dict(key3, value3, extra='\\t\\t')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb1b2408",
   "metadata": {},
   "source": [
    "# LOAD DATASET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e528e48a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "These are datasets from the paper:\n",
      "\tcollider\n",
      "\ttriangle\n",
      "\tloan\n",
      "\tmgraph\n",
      "\tchain\n",
      "\tadult\n",
      "\tgerman\n",
      "\n",
      "Using dataset: 2nodes\n"
     ]
    }
   ],
   "source": [
    "from utils.constants import Cte\n",
    "\n",
    "\n",
    "print('These are datasets from the paper:')\n",
    "for data_name in Cte.DATASET_LIST:\n",
    "    print(f\"\\t{data_name}\")\n",
    "    \n",
    "\n",
    "\n",
    "print(f\"\\nUsing dataset: {cfg['dataset']['name']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8ce78164-21e0-4d4f-8513-3d737ffd81c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import os\n",
    "\n",
    "# import pytorch_lightning as pl\n",
    "# import torch\n",
    "# from sklearn import preprocessing\n",
    "# import torch_geometric\n",
    "# #from torch_geometric.data import DataLoader\n",
    "# # from torch_geometric.utils import degree\n",
    "# # from torchvision import transforms as transform_lib\n",
    "\n",
    "# # from data_modules._scalers import MaskedTensorLikelihoodScaler\n",
    "# # from data_modules._scalers import MaskedTensorStandardScaler\n",
    "# # from datasets.transforms import ToTensor\n",
    "# # from utils.constants import Cte\n",
    "\n",
    "\n",
    "\n",
    "# # from datasets.toy import create_toy_dataset\n",
    "# # from utils.distributions import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e63f7e0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if cfg['dataset']['name'] in Cte.DATASET_LIST:\n",
    "    from data_modules.het_scm import HeterogeneousSCMDataModule\n",
    "\n",
    "    dataset_params = cfg['dataset']['params'].copy()\n",
    "    dataset_params['dataset_name'] = cfg['dataset']['name']\n",
    "\n",
    "    data_module = HeterogeneousSCMDataModule(**dataset_params)\n",
    "    data_module.prepare_data()\n",
    "\n",
    "elif cfg['dataset']['name']  == '2nodes':\n",
    "    from data_modules.my_toy_scm import MyToySCMDataModule\n",
    "    from utils.distributions import *\n",
    "    \n",
    "    dataset_params = cfg['dataset']['params'].copy()\n",
    "    dataset_params['dataset_name'] = cfg['dataset']['name']\n",
    "    \n",
    "    dataset_params['nodes_to_intervene'] = ['x1']\n",
    "    dataset_params['structural_eq'] = {'x1': lambda u1: u1,\n",
    "                                            'x2': lambda u2, x1: u2 + x1}\n",
    "    dataset_params['noises_distr'] = {'x1': Normal(0,1),\n",
    "                                           'x2': Normal(0,1)}\n",
    "    dataset_params['adj_edges'] = {'x1': ['x2'],\n",
    "                                        'x2': []}\n",
    "    \n",
    "    data_module = MyToySCMDataModule(**dataset_params)\n",
    "    data_module.prepare_data()\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93320b26",
   "metadata": {},
   "source": [
    "# LOAD MODEL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "cf2dc585-3059-4823-8cb6-622a0401516f",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loader = data_module.train_dataloader()\n",
    "\n",
    "#data_module.batch_size = bs\n",
    "\n",
    "batch = next(iter(data_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fd39c7d5-f374-41cf-8467-9be19e5ebd17",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataBatch(x=[2000, 1], edge_index=[2, 3000], edge_attr=[3000, 3], u=[1000, 2], mask=[2000, 1], node_ids=[2000, 2], num_nodes=2000, batch=[2000], ptr=[1001])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14e9ca3b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Using model: vaca\n"
     ]
    }
   ],
   "source": [
    "model = None\n",
    "model_params = cfg['model']['params'].copy()\n",
    "\n",
    "print(f\"\\nUsing model: {cfg['model']['name']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcb87ef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# VACA\n",
    "if cfg['model']['name'] == Cte.VACA:\n",
    "    from models.vaca.vaca import VACA\n",
    "\n",
    "    model_params['is_heterogeneous'] = data_module.is_heterogeneous\n",
    "    model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "    model_params['deg'] = data_module.get_deg(indegree=True)\n",
    "    model_params['num_nodes'] = data_module.num_nodes\n",
    "    model_params['edge_dim'] = data_module.edge_dimension\n",
    "    model_params['scaler'] = data_module.scaler\n",
    "\n",
    "    model = VACA(**model_params)\n",
    "    model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "# VACA with PIWAE\n",
    "elif cfg['model']['name'] == Cte.VACA_PIWAE:\n",
    "    from models.vaca.vaca_piwae import VACA_PIWAE\n",
    "\n",
    "    model_params['is_heterogeneous'] = data_module.is_heterogeneous\n",
    "\n",
    "    model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "    model_params['deg'] = data_module.get_deg(indegree=True)\n",
    "    model_params['num_nodes'] = data_module.num_nodes\n",
    "    model_params['edge_dim'] = data_module.edge_dimension\n",
    "    model_params['scaler'] = data_module.scaler\n",
    "\n",
    "    model = VACA_PIWAE(**model_params)\n",
    "    model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "\n",
    "\n",
    "\n",
    "# MultiCVAE\n",
    "elif cfg['model']['name'] == Cte.MCVAE:\n",
    "    from models.multicvae.multicvae import MCVAE\n",
    "\n",
    "    model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "    model_params['topological_node_dims'] = data_module.train_dataset.get_node_columns_in_X()\n",
    "    model_params['topological_parents'] = data_module.topological_parents\n",
    "    model_params['scaler'] = data_module.scaler\n",
    "    model_params['num_epochs_per_nodes'] = int(\n",
    "        np.floor((cfg['trainer']['max_epochs'] / len(data_module.topological_nodes))))\n",
    "    model = MCVAE(**model_params)\n",
    "    model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "    cfg['early_stopping'] = False\n",
    "\n",
    "# CAREFL\n",
    "elif cfg['model']['name'] == Cte.CAREFL:\n",
    "    from models.carefl.carefl import CAREFL\n",
    "\n",
    "    model_params['node_per_dimension_list'] = data_module.train_dataset.node_per_dimension_list\n",
    "    model_params['scaler'] = data_module.scaler\n",
    "    model = CAREFL(**model_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75a391bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/prchao/miniconda3/envs/vaca/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py:1726: LightningDeprecationWarning: Argument `mode` in `LightningModule.summarize` is deprecated in v1.4 and will be removed in v1.6. Use `max_depth=1` to replicate `mode=top` behavior.\n",
      "  rank_zero_deprecation(\n",
      "\n",
      "  | Name  | Type       | Params\n",
      "-------------------------------------\n",
      "0 | model | VACAModule | 1.2 K \n",
      "-------------------------------------\n",
      "1.2 K     Trainable params\n",
      "0         Non-trainable params\n",
      "1.2 K     Total params\n",
      "0.005     Total estimated model params size (MB)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "model.summarize()\n",
    "model.set_optim_params(optim_params=cfg['optimizer'],\n",
    "                       sched_params=cfg['scheduler'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c97e0086",
   "metadata": {},
   "source": [
    "# LOAD EVALUATOR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1aa7677",
   "metadata": {},
   "outputs": [],
   "source": [
    "from models._evaluator import MyEvaluator\n",
    "\n",
    "evaluator = MyEvaluator(model=model,\n",
    "                        intervention_list=data_module.train_dataset.get_intervention_list(),\n",
    "                        scaler=data_module.scaler\n",
    "                        )\n",
    "model.set_my_evaluator(evaluator=evaluator)\n",
    "es\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f165b692-1069-453a-aa04-5a1266ae4df5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[({'x1': -1.02}, '-1_sigma'),\n",
       " ({'x1': 0.47}, '0.5_sigma'),\n",
       " ({'x1': 0.08}, '0.1_sigma'),\n",
       " ({'x1': 0.08}, '0.1_sigma'),\n",
       " ({'x1': 0.47}, '0.5_sigma'),\n",
       " ({'x1': 0.97}, '1_sigma')]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_module.train_dataset.get_intervention_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "781f3176",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Intervention name: -1_sigma\n",
      "\tx1: -1.02\n",
      "Intervention name: 0.5_sigma\n",
      "\tx1: 0.47\n",
      "Intervention name: 0.1_sigma\n",
      "\tx1: 0.08\n",
      "Intervention name: 0.1_sigma\n",
      "\tx1: 0.08\n",
      "Intervention name: 0.5_sigma\n",
      "\tx1: 0.47\n",
      "Intervention name: 1_sigma\n",
      "\tx1: 0.97\n"
     ]
    }
   ],
   "source": [
    "for intervention in data_module.train_dataset.get_intervention_list():\n",
    "    inter_dict, name = intervention\n",
    "    print(f'Intervention name: {name}')\n",
    "    for node_name, value in inter_dict.items():\n",
    "        print(f\"\\t{node_name}: {value}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e512cc3e",
   "metadata": {},
   "source": [
    "# PREPARE TRAINING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c0a45fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Is training activated? False\n",
      "Is loading activated? True\n"
     ]
    }
   ],
   "source": [
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "from pytorch_lightning.loggers.tensorboard import TensorBoardLogger\n",
    "\n",
    "\n",
    "is_training = False\n",
    "load = True\n",
    "\n",
    "print(f'Is training activated? {is_training}')\n",
    "print(f'Is loading activated? {load}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2290222",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Save dir: results/2nodes_5000_linear_lik_d_0.05_None/vaca/dgnn_elbo_8_8_16_4_normal_0.0_0.2_0.0_0.0_0/adam/0.005_0.9_0.999_1.2e-06_exp_lr_0.99/10\n"
     ]
    }
   ],
   "source": [
    "if yaml_file == '':\n",
    "    if (cfg['dataset']['name'] in [Cte.GERMAN]) and (cfg['dataset']['params3']['train_kfold'] == True):\n",
    "        save_dir = argtools.mkdir(os.path.join(cfg['root_dir'],\n",
    "                                               argtools.get_experiment_folder(cfg),\n",
    "                                               str(cfg['seed']), str(cfg['dataset']['params3']['kfold_idx'])))\n",
    "    else:\n",
    "        save_dir = argtools.mkdir(os.path.join(cfg['root_dir'],\n",
    "                                               argtools.get_experiment_folder(cfg),\n",
    "                                               str(cfg['seed'])))\n",
    "else:\n",
    "    save_dir = os.path.join(*yaml_file.split('/')[:-1])\n",
    "print(f'Save dir: {save_dir}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f4b953",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ckpt_file: results/2nodes_5000_linear_lik_d_0.05_None/vaca/dgnn_elbo_8_8_16_4_normal_0.0_0.2_0.0_0.0_0/adam/0.005_0.9_0.999_1.2e-06_exp_lr_0.99/10/ckpt/checkpoint-epoch=19.ckpt\n"
     ]
    }
   ],
   "source": [
    "logger = TensorBoardLogger(save_dir=save_dir, name='logs', default_hp_metric=False)\n",
    "\n",
    "out = logger.log_hyperparams(argtools.flatten_cfg(cfg))\n",
    "\n",
    "save_dir_ckpt = argtools.mkdir(os.path.join(save_dir, 'ckpt'))\n",
    "if load:\n",
    "    ckpt_file = argtools.newest(save_dir_ckpt)\n",
    "else:\n",
    "    ckpt_file = None\n",
    "callbacks = []\n",
    "\n",
    "print(f\"ckpt_file: {ckpt_file}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b50da25b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loading from: results/2nodes_5000_linear_lik_d_0.05_None/vaca/dgnn_elbo_8_8_16_4_normal_0.0_0.2_0.0_0.0_0/adam/0.005_0.9_0.999_1.2e-06_exp_lr_0.99/10/ckpt/checkpoint-epoch=19.ckpt\n"
     ]
    }
   ],
   "source": [
    "if is_training:\n",
    "    checkpoint = ModelCheckpoint(period=1,\n",
    "                                 monitor=model.monitor(),\n",
    "                                 mode=model.monitor_mode(),\n",
    "                                 save_top_k=1,\n",
    "                                 save_last=True,\n",
    "                                 filename='checkpoint-{epoch:02d}',\n",
    "                                 dirpath=save_dir_ckpt)\n",
    "    callbacks = [checkpoint]\n",
    "\n",
    "    \n",
    "    if cfg['early_stopping']:\n",
    "        early_stopping = EarlyStopping(model.monitor(), mode=model.monitor_mode(), min_delta=0.0, patience=50)\n",
    "        callbacks.append(early_stopping)\n",
    "    trainer = pl.Trainer(logger=logger, callbacks=callbacks, **cfg['trainer'])\n",
    "    \n",
    "if load:\n",
    "    if ckpt_file is None:\n",
    "        print(f'No ckpt files in {save_dir_ckpt}')\n",
    "    else:\n",
    "        print(f'\\nLoading from: {ckpt_file}')\n",
    "        if is_training:\n",
    "            trainer = pl.Trainer(logger=logger, callbacks=callbacks, resume_from_checkpoint=ckpt_file,\n",
    "                             **cfg['trainer'])\n",
    "        else:\n",
    "\n",
    "            model = model.load_from_checkpoint(ckpt_file, **model_params)\n",
    "            evaluator.set_model(model)\n",
    "            model.set_my_evaluator(evaluator=evaluator)\n",
    "\n",
    "            if cfg['model']['name'] in [Cte.VACA_PIWAE, Cte.VACA, Cte.MCVAE]:\n",
    "                model.set_random_train_sampler(data_module.get_random_train_sampler())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3ba5327-cbe8-401b-a1db-108cf3c6f5d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.path.join(save_dir,\"logs\")\n",
    "\n",
    "if not os.path.exists(path):\n",
    "    os.makedirs(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb106f1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "if is_training:\n",
    "    trainer.fit(model, data_module)\n",
    "    # save_yaml(model.get_arguments(), file_path=os.path.join(save_dir, 'hparams_model.yaml'))\n",
    "    argtools.save_yaml(cfg, file_path=os.path.join(save_dir, 'hparams_full.yaml'))\n",
    "    # %% Testing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37cc92e9",
   "metadata": {},
   "source": [
    "# TESTING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f3d74fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n",
    "# params = int(sum([p.numel() for p in model_parameters]))\n",
    "\n",
    "# model.eval()\n",
    "# model.freeze()  # IMPORTANT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3d04804",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# output_valid = model.evaluate(dataloader=data_module.val_dataloader(),\n",
    "#                               name='valid',\n",
    "#                               save_dir=save_dir,\n",
    "#                               plots=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb3ed4fe-6f78-4889-835f-4393e6ea7e76",
   "metadata": {},
   "outputs": [],
   "source": [
    "# path = os.path.join(save_dir,\"images\")\n",
    "\n",
    "# if not os.path.exists(path):\n",
    "#     os.makedirs(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "221c3eac",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# output_test = model.evaluate(dataloader=data_module.test_dataloader(),\n",
    "#                              name='test',\n",
    "#                              save_dir=save_dir,\n",
    "#                              plots=True)\n",
    "# output_valid.update(output_test)\n",
    "\n",
    "# output_valid.update(argtools.flatten_cfg(cfg))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e11558e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import json\n",
    "# output_valid.update({'ckpt_file': ckpt_file,\n",
    "#                      'num_parameters': params})\n",
    "\n",
    "# with open(os.path.join(save_dir, 'output.json'), 'w') as f:\n",
    "#     json.dump(output_valid, f)\n",
    "# print(f'Experiment folder: {save_dir}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47fe8a36",
   "metadata": {},
   "source": [
    "# Custom interventions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "95583513",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataBatch(x=[6, 1], edge_index=[2, 9], edge_attr=[9, 3], u=[3, 2], mask=[6, 1], node_ids=[6, 2], x_i=[6, 1], edge_index_i=[2, 9], edge_attr_i=[9, 3], num_nodes=6, batch=[6], ptr=[4])\n"
     ]
    }
   ],
   "source": [
    "bs = data_module.batch_size\n",
    "data_module.batch_size = 3\n",
    "x_I = {'x1': 10}  # Intervention before normalizing\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_loader.dataset.set_intervention(x_I)\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_module.batch_size = bs\n",
    "\n",
    "batch = next(iter(data_loader))\n",
    "\n",
    "\n",
    "\n",
    "print(batch)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "9ae239fd-effc-4da0-aaac-deecd677b06d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[10.1029],\n",
       "        [ 2.2218],\n",
       "        [10.1029],\n",
       "        [ 0.0525],\n",
       "        [10.1029],\n",
       "        [ 1.2740]])"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch.x_i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "6dd4ffc5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original: tensor([ 2.5131,  2.2218, -0.3770,  0.0525,  0.5278,  1.2740])\n",
      "Intervened: tensor([10.1029,  2.2218, 10.1029,  0.0525, 10.1029,  1.2740])\n",
      "Reconstructed: tensor([10.0055,  1.2454, 10.0055,  0.8707, 10.0055,  0.5242])\n"
     ]
    }
   ],
   "source": [
    "x_hat, z = model.get_intervention(batch,\n",
    "                         x_I=data_loader.dataset.x_I,\n",
    "                         nodes_list=data_loader.dataset.nodes_list,\n",
    "                         return_type = 'sample', # mean or sample\n",
    "                         use_aggregated_posterior = False,\n",
    "                         normalize = True)\n",
    "\n",
    "print(f\"Original: {batch.x.flatten()}\")\n",
    "print(f\"Intervened: {batch.x_i.flatten()}\")\n",
    "print(f\"Reconstructed: {x_hat.flatten()}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "459b0cfc",
   "metadata": {},
   "source": [
    "# Custom counterfactuals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a43a1a99-6c4f-487b-9021-ddbead5550a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "f9e24fb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/prchao/miniconda3/envs/vaca/lib/python3.9/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
      "  warnings.warn(out)\n"
     ]
    }
   ],
   "source": [
    "bs = data_module.batch_size\n",
    "data_module.batch_size = 1\n",
    "x_I = {'x1': 10.0}  # Intervention before normalizing\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_loader.dataset.set_intervention(x_I,is_noise=False)\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_module.batch_size = bs\n",
    "\n",
    "batch = next(iter(data_loader))\n",
    "\n",
    "\n",
    "\n",
    "# print(batch)\n",
    "# x1-> x2 \n",
    "# x2 = x1 + N(0,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "3b04e229-34d5-4c0c-9102-97a7f211b67a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataBatch(x=[2, 1], edge_index=[2, 3], edge_attr=[3, 3], u=[1, 2], mask=[2, 1], node_ids=[2, 2], x_i=[2, 1], edge_index_i=[2, 3], edge_attr_i=[3, 3], num_nodes=2, batch=[2], ptr=[2])"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "453f81ba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[9.9033, 3.1029],\n",
       "         [9.9033, 0.7779]]),\n",
       " {'all': tensor([[ 2.4700,  3.0570],\n",
       "          [-0.3974,  0.0362]])})"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_I = {'x1': 10.0}\n",
    "vaca_pred, gt_cf, factual = model.get_counterfactual_distr(data_loader,\n",
    "                                        x_I=x_I,\n",
    "                                        is_noise = False,\n",
    "                                        num_batches= 1,\n",
    "                                        normalize=False,\n",
    "                                        )\n",
    "\n",
    "vaca_pred['all'], factual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "c22079af-ef5e-4840-8c89-451fbc377a29",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'intervened': tensor([[10.],\n",
       "         [10.]]),\n",
       " 'children': tensor([[10.5870],\n",
       "         [10.4335]]),\n",
       " 'all': tensor([[10.0000, 10.5870],\n",
       "         [10.0000, 10.4335]])}"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gt_cf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68a60dc6-a9c9-440a-8556-79f8a71f582d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
