{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "28f00531",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import utils.args_parser  as argtools\n",
    "import pytorch_lightning as pl\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6076a2c",
   "metadata": {},
   "source": [
    "# LOAD CONFIG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5ce904cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "use_custom_dataset = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "babcb0b2",
   "metadata": {},
   "source": [
    "### Option 1: Datasets from the paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "89eee561",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not use_custom_dataset:\n",
    "    print('Using dataset from the paper')\n",
    "    dataset_file =  os.path.join('_params', 'dataset_adult.yaml')\n",
    "    model_file =   os.path.join('_params', 'model_carefl.yaml')\n",
    "    trainer_file =   os.path.join('_params', 'trainer.yaml')\n",
    "\n",
    "    yaml_file = ''\n",
    "    \n",
    "    if yaml_file == '':\n",
    "        cfg = argtools.parse_args(dataset_file)\n",
    "        cfg.update(argtools.parse_args(model_file))\n",
    "        cfg.update(argtools.parse_args(trainer_file))\n",
    "    else:\n",
    "        cfg = argtools.parse_args(yaml_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f11780b1",
   "metadata": {},
   "source": [
    "### Option 2: New dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bae3aa70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using custom dataset\n"
     ]
    }
   ],
   "source": [
    "if use_custom_dataset:\n",
    "    print('Using custom dataset')\n",
    "    model_file =   os.path.join('_params', 'model_carefl.yaml')\n",
    "    trainer_file =   os.path.join('_params', 'trainer.yaml')\n",
    "\n",
    "    yaml_file = ''\n",
    "    if yaml_file == '':\n",
    "        cfg = argtools.parse_args(model_file)\n",
    "        cfg.update(argtools.parse_args(trainer_file))\n",
    "    else:\n",
    "        cfg = argtools.parse_args(yaml_file)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "125e8421",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Config for new dataset\n",
    "\n",
    "cfg['dataset'] = {\n",
    "    'name': '2nodes',\n",
    "    'params1': {},\n",
    "    'params2': {}\n",
    "}\n",
    "\n",
    "cfg['dataset']['params1'] = {\n",
    "    'data_dir': '../Data',\n",
    "    'batch_size': 1000,\n",
    "    'num_workers': 0\n",
    "}\n",
    "\n",
    "cfg['dataset']['params2'] = {\n",
    "    'num_samples_tr': 5000,\n",
    "    'equations_type': 'linear',\n",
    "    'normalize': 'lik',\n",
    "    'likelihood_names': 'd',\n",
    "    'lambda_': 0.05,\n",
    "    'normalize_A': None,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "194f6abc",
   "metadata": {},
   "source": [
    "### You can also update any parameter manually"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "09a87058",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 10\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimizer:\n",
      "\tname: adam\n",
      "\tparams:\n",
      "\t\tlr: 0.005\n",
      "\t\tbetas: [0.9, 0.999]\n",
      "\t\tweight_decay: 1.2e-06\n",
      "scheduler:\n",
      "\tname: exp_lr\n",
      "\tparams:\n",
      "\t\tgamma: 0.99\n",
      "model:\n",
      "\tname: carefl\n",
      "\tparams:\n",
      "\t\tflow_architecture: spline\n",
      "\t\tflow_net_class: mlp\n",
      "\t\tdistr_z: normal\n",
      "\t\tn_layers: 4\n",
      "\t\tn_hidden: 10\n",
      "seed: 10\n",
      "root_dir: results\n",
      "early_stopping: True\n",
      "trainer:\n",
      "\tmax_epochs: 200\n",
      "\tmin_epochs: 50\n",
      "\tlimit_train_batches: 1.0\n",
      "\tlimit_val_batches: 1.0\n",
      "\tlimit_test_batches: 1.0\n",
      "\tcheck_val_every_n_epoch: 10\n",
      "\tprogress_bar_refresh_rate: 1\n",
      "\tflush_logs_every_n_steps: 100\n",
      "\tlog_every_n_steps: 2\n",
      "\tprecision: 32\n",
      "\tterminate_on_nan: True\n",
      "\tauto_select_gpus: True\n",
      "\tdeterministic: True\n",
      "\tweights_summary: None\n",
      "\tgpus: None\n",
      "\tnum_sanity_val_steps: 2\n",
      "\ttrack_grad_norm: -1\n",
      "\tgradient_clip_val: 0.0\n",
      "dataset:\n",
      "\tname: 2nodes\n",
      "\tparams1:\n",
      "\t\tdata_dir: ../Data\n",
      "\t\tbatch_size: 1000\n",
      "\t\tnum_workers: 0\n",
      "\tparams2:\n",
      "\t\tnum_samples_tr: 5000\n",
      "\t\tequations_type: linear\n",
      "\t\tnormalize: lik\n",
      "\t\tlikelihood_names: d\n",
      "\t\tlambda_: 0.05\n",
      "\t\tnormalize_A: None\n",
      "\tparams:\n",
      "\t\tdata_dir: \n",
      "\t\tbatch_size: 1000\n",
      "\t\tnum_workers: 0\n",
      "\t\tnum_samples_tr: 5000\n",
      "\t\tequations_type: linear\n",
      "\t\tnormalize: lik\n",
      "\t\tlikelihood_names: d\n",
      "\t\tlambda_: 0.05\n",
      "\t\tnormalize_A: None\n"
     ]
    }
   ],
   "source": [
    "    \n",
    "cfg['root_dir'] = 'results'\n",
    "cfg['seed'] = 10\n",
    "pl.seed_everything(cfg['seed'])\n",
    "\n",
    "cfg['dataset']['params'] = cfg['dataset']['params1'].copy()\n",
    "cfg['dataset']['params'].update(cfg['dataset']['params2'])\n",
    "\n",
    "cfg['dataset']['params']['data_dir'] = ''\n",
    "\n",
    "cfg['trainer']['limit_train_batches'] = 1.0\n",
    "cfg['trainer']['limit_val_batches'] = 1.0\n",
    "cfg['trainer']['limit_test_batches'] = 1.0\n",
    "cfg['trainer']['check_val_every_n_epoch'] = 10\n",
    "\n",
    "\n",
    "def print_if_not_dict(key, value, extra=''):\n",
    "    if not isinstance(value, dict):\n",
    "        print(f\"{extra}{key}: {value}\")\n",
    "        return True\n",
    "    else:\n",
    "        print(f\"{extra}{key}:\")\n",
    "        False\n",
    "        \n",
    "for key, value in cfg.items():\n",
    "    if not print_if_not_dict(key, value):\n",
    "        for key2, value2 in value.items():\n",
    "            if not print_if_not_dict(key2, value2, extra='\\t'):\n",
    "                for key3, value3 in value2.items():\n",
    "                    print_if_not_dict(key3, value3, extra='\\t\\t')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb1b2408",
   "metadata": {},
   "source": [
    "# LOAD DATASET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e528e48a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "These are datasets from the paper:\n",
      "\tcollider\n",
      "\ttriangle\n",
      "\tloan\n",
      "\tmgraph\n",
      "\tchain\n",
      "\tadult\n",
      "\tgerman\n",
      "\n",
      "Using dataset: 2nodes\n"
     ]
    }
   ],
   "source": [
    "from utils.constants import Cte\n",
    "\n",
    "\n",
    "print('These are datasets from the paper:')\n",
    "for data_name in Cte.DATASET_LIST:\n",
    "    print(f\"\\t{data_name}\")\n",
    "    \n",
    "\n",
    "\n",
    "print(f\"\\nUsing dataset: {cfg['dataset']['name']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8ce78164-21e0-4d4f-8513-3d737ffd81c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import os\n",
    "\n",
    "# import pytorch_lightning as pl\n",
    "# import torch\n",
    "# from sklearn import preprocessing\n",
    "# import torch_geometric\n",
    "# #from torch_geometric.data import DataLoader\n",
    "# # from torch_geometric.utils import degree\n",
    "# # from torchvision import transforms as transform_lib\n",
    "\n",
    "# # from data_modules._scalers import MaskedTensorLikelihoodScaler\n",
    "# # from data_modules._scalers import MaskedTensorStandardScaler\n",
    "# # from datasets.transforms import ToTensor\n",
    "# # from utils.constants import Cte\n",
    "\n",
    "\n",
    "\n",
    "# # from datasets.toy import create_toy_dataset\n",
    "# # from utils.distributions import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e63f7e0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if cfg['dataset']['name'] in Cte.DATASET_LIST:\n",
    "    from data_modules.het_scm import HeterogeneousSCMDataModule\n",
    "\n",
    "    dataset_params = cfg['dataset']['params'].copy()\n",
    "    dataset_params['dataset_name'] = cfg['dataset']['name']\n",
    "\n",
    "    data_module = HeterogeneousSCMDataModule(**dataset_params)\n",
    "    data_module.prepare_data()\n",
    "\n",
    "elif cfg['dataset']['name']  == '2nodes':\n",
    "    from data_modules.my_toy_scm import MyToySCMDataModule\n",
    "    from utils.distributions import *\n",
    "    \n",
    "    dataset_params = cfg['dataset']['params'].copy()\n",
    "    dataset_params['dataset_name'] = cfg['dataset']['name']\n",
    "    \n",
    "    dataset_params['nodes_to_intervene'] = ['x1']\n",
    "    dataset_params['structural_eq'] = {'x1': lambda u1: u1,\n",
    "                                            'x2': lambda u2, x1: u2 + x1}\n",
    "    dataset_params['noises_distr'] = {'x1': Normal(0,1),\n",
    "                                           'x2': Normal(0,1)}\n",
    "    dataset_params['adj_edges'] = {'x1': ['x2'],\n",
    "                                        'x2': []}\n",
    "    \n",
    "    data_module = MyToySCMDataModule(**dataset_params)\n",
    "    data_module.prepare_data()\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93320b26",
   "metadata": {},
   "source": [
    "# LOAD MODEL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "cf2dc585-3059-4823-8cb6-622a0401516f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/prchao/miniconda3/envs/vaca/lib/python3.9/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
      "  warnings.warn(out)\n",
      "/workplace/prchao/VACA/datasets/transforms.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  return torch.tensor(sample)\n"
     ]
    }
   ],
   "source": [
    "data_loader = data_module.train_dataloader()\n",
    "\n",
    "#data_module.batch_size = bs\n",
    "\n",
    "batch = next(iter(data_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fd39c7d5-f374-41cf-8467-9be19e5ebd17",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataBatch(x=[2000, 1], edge_index=[2, 3000], edge_attr=[3000, 3], u=[1000, 2], mask=[2000, 1], node_ids=[2000, 2], num_nodes=2000, batch=[2000], ptr=[1001])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "14e9ca3b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Using model: carefl\n"
     ]
    }
   ],
   "source": [
    "model = None\n",
    "model_params = cfg['model']['params'].copy()\n",
    "\n",
    "print(f\"\\nUsing model: {cfg['model']['name']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "dcb87ef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# VACA\n",
    "if cfg['model']['name'] == Cte.VACA:\n",
    "    from models.vaca.vaca import VACA\n",
    "\n",
    "    model_params['is_heterogeneous'] = data_module.is_heterogeneous\n",
    "    model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "    model_params['deg'] = data_module.get_deg(indegree=True)\n",
    "    model_params['num_nodes'] = data_module.num_nodes\n",
    "    model_params['edge_dim'] = data_module.edge_dimension\n",
    "    model_params['scaler'] = data_module.scaler\n",
    "\n",
    "    model = VACA(**model_params)\n",
    "    model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "# VACA with PIWAE\n",
    "elif cfg['model']['name'] == Cte.VACA_PIWAE:\n",
    "    from models.vaca.vaca_piwae import VACA_PIWAE\n",
    "\n",
    "    model_params['is_heterogeneous'] = data_module.is_heterogeneous\n",
    "\n",
    "    model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "    model_params['deg'] = data_module.get_deg(indegree=True)\n",
    "    model_params['num_nodes'] = data_module.num_nodes\n",
    "    model_params['edge_dim'] = data_module.edge_dimension\n",
    "    model_params['scaler'] = data_module.scaler\n",
    "\n",
    "    model = VACA_PIWAE(**model_params)\n",
    "    model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "\n",
    "\n",
    "\n",
    "# MultiCVAE\n",
    "elif cfg['model']['name'] == Cte.MCVAE:\n",
    "    from models.multicvae.multicvae import MCVAE\n",
    "\n",
    "    model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "    model_params['topological_node_dims'] = data_module.train_dataset.get_node_columns_in_X()\n",
    "    model_params['topological_parents'] = data_module.topological_parents\n",
    "    model_params['scaler'] = data_module.scaler\n",
    "    model_params['num_epochs_per_nodes'] = int(\n",
    "        np.floor((cfg['trainer']['max_epochs'] / len(data_module.topological_nodes))))\n",
    "    model = MCVAE(**model_params)\n",
    "    model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "    cfg['early_stopping'] = False\n",
    "\n",
    "# CAREFL\n",
    "elif cfg['model']['name'] == Cte.CARELF:\n",
    "    from models.carefl.carefl import CAREFL\n",
    "\n",
    "    model_params['node_per_dimension_list'] = data_module.train_dataset.node_per_dimension_list\n",
    "    model_params['scaler'] = data_module.scaler\n",
    "    model = CAREFL(**model_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "75a391bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/prchao/miniconda3/envs/vaca/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py:1726: LightningDeprecationWarning: Argument `mode` in `LightningModule.summarize` is deprecated in v1.4 and will be removed in v1.6. Use `max_depth=1` to replicate `mode=top` behavior.\n",
      "  rank_zero_deprecation(\n",
      "\n",
      "  | Name       | Type                 | Params\n",
      "----------------------------------------------------\n",
      "0 | flow_model | NormalizingFlowModel | 752   \n",
      "----------------------------------------------------\n",
      "752       Trainable params\n",
      "0         Non-trainable params\n",
      "752       Total params\n",
      "0.003     Total estimated model params size (MB)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "model.summarize()\n",
    "model.set_optim_params(optim_params=cfg['optimizer'],\n",
    "                       sched_params=cfg['scheduler'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c97e0086",
   "metadata": {},
   "source": [
    "# LOAD EVALUATOR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a1aa7677",
   "metadata": {},
   "outputs": [],
   "source": [
    "from models._evaluator import MyEvaluator\n",
    "\n",
    "evaluator = MyEvaluator(model=model,\n",
    "                        intervention_list=data_module.train_dataset.get_intervention_list(),\n",
    "                        scaler=data_module.scaler\n",
    "                        )\n",
    "model.set_my_evaluator(evaluator=evaluator)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f165b692-1069-453a-aa04-5a1266ae4df5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[({'x1': -1.02}, '-1_sigma'),\n",
       " ({'x1': -0.52}, '-0.5_sigma'),\n",
       " ({'x1': -0.02}, '0_sigma'),\n",
       " ({'x1': 0.47}, '0.5_sigma'),\n",
       " ({'x1': 0.97}, '1_sigma')]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_module.train_dataset.get_intervention_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "781f3176",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Intervention name: -1_sigma\n",
      "\tx1: -1.02\n",
      "Intervention name: -0.5_sigma\n",
      "\tx1: -0.52\n",
      "Intervention name: 0_sigma\n",
      "\tx1: -0.02\n",
      "Intervention name: 0.5_sigma\n",
      "\tx1: 0.47\n",
      "Intervention name: 1_sigma\n",
      "\tx1: 0.97\n"
     ]
    }
   ],
   "source": [
    "for intervention in data_module.train_dataset.get_intervention_list():\n",
    "    inter_dict, name = intervention\n",
    "    print(f'Intervention name: {name}')\n",
    "    for node_name, value in inter_dict.items():\n",
    "        print(f\"\\t{node_name}: {value}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e512cc3e",
   "metadata": {},
   "source": [
    "# PREPARE TRAINING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8c0a45fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Is training activated? True\n",
      "Is loading activated? True\n"
     ]
    }
   ],
   "source": [
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "from pytorch_lightning.loggers.tensorboard import TensorBoardLogger\n",
    "\n",
    "\n",
    "is_training = True\n",
    "load = True\n",
    "\n",
    "print(f'Is training activated? {is_training}')\n",
    "print(f'Is loading activated? {load}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "a2290222",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Save dir: results/2nodes_5000_linear_lik_d_0.05_None/carefl/spline_mlp_normal_4_10/adam/0.005_0.9_0.999_1.2e-06_exp_lr_0.99/10\n"
     ]
    }
   ],
   "source": [
    "if yaml_file == '':\n",
    "    if (cfg['dataset']['name'] in [Cte.GERMAN]) and (cfg['dataset']['params3']['train_kfold'] == True):\n",
    "        save_dir = argtools.mkdir(os.path.join(cfg['root_dir'],\n",
    "                                               argtools.get_experiment_folder(cfg),\n",
    "                                               str(cfg['seed']), str(cfg['dataset']['params3']['kfold_idx'])))\n",
    "    else:\n",
    "        save_dir = argtools.mkdir(os.path.join(cfg['root_dir'],\n",
    "                                               argtools.get_experiment_folder(cfg),\n",
    "                                               str(cfg['seed'])))\n",
    "else:\n",
    "    save_dir = os.path.join(*yaml_file.split('/')[:-1])\n",
    "print(f'Save dir: {save_dir}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "66f4b953",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ckpt_file: None\n"
     ]
    }
   ],
   "source": [
    "logger = TensorBoardLogger(save_dir=save_dir, name='logs', default_hp_metric=False)\n",
    "\n",
    "out = logger.log_hyperparams(argtools.flatten_cfg(cfg))\n",
    "\n",
    "save_dir_ckpt = argtools.mkdir(os.path.join(save_dir, 'ckpt'))\n",
    "if load:\n",
    "    ckpt_file = argtools.newest(save_dir_ckpt)\n",
    "else:\n",
    "    ckpt_file = None\n",
    "callbacks = []\n",
    "\n",
    "print(f\"ckpt_file: {ckpt_file}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "b50da25b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/prchao/miniconda3/envs/vaca/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:487: LightningDeprecationWarning: Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5. Please use `every_n_epochs` instead.\n",
      "  rank_zero_deprecation(\n",
      "GPU available: True, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No ckpt files in results/2nodes_5000_linear_lik_d_0.05_None/carefl/spline_mlp_normal_4_10/adam/0.005_0.9_0.999_1.2e-06_exp_lr_0.99/10/ckpt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/prchao/miniconda3/envs/vaca/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1303: UserWarning: GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.\n",
      "  rank_zero_warn(\n"
     ]
    }
   ],
   "source": [
    "if is_training:\n",
    "    checkpoint = ModelCheckpoint(period=1,\n",
    "                                 monitor=model.monitor(),\n",
    "                                 mode=model.monitor_mode(),\n",
    "                                 save_top_k=1,\n",
    "                                 save_last=True,\n",
    "                                 filename='checkpoint-{epoch:02d}',\n",
    "                                 dirpath=save_dir_ckpt)\n",
    "    callbacks = [checkpoint]\n",
    "\n",
    "    \n",
    "    if cfg['early_stopping']:\n",
    "        early_stopping = EarlyStopping(model.monitor(), mode=model.monitor_mode(), min_delta=0.0, patience=50)\n",
    "        callbacks.append(early_stopping)\n",
    "    trainer = pl.Trainer(logger=logger, callbacks=callbacks, **cfg['trainer'])\n",
    "    \n",
    "if load:\n",
    "    if ckpt_file is None:\n",
    "        print(f'No ckpt files in {save_dir_ckpt}')\n",
    "    else:\n",
    "        print(f'\\nLoading from: {ckpt_file}')\n",
    "        if is_training:\n",
    "            trainer = pl.Trainer(logger=logger, callbacks=callbacks, resume_from_checkpoint=ckpt_file,\n",
    "                             **cfg['trainer'])\n",
    "        else:\n",
    "\n",
    "            model = model.load_from_checkpoint(ckpt_file, **model_params)\n",
    "            evaluator.set_model(model)\n",
    "            model.set_my_evaluator(evaluator=evaluator)\n",
    "\n",
    "            if cfg['model']['name'] in [Cte.VACA_PIWAE, Cte.VACA, Cte.MCVAE]:\n",
    "                model.set_random_train_sampler(data_module.get_random_train_sampler())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "e3ba5327-cbe8-401b-a1db-108cf3c6f5d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.path.join(save_dir,\"logs\")\n",
    "\n",
    "if not os.path.exists(path):\n",
    "    os.makedirs(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "eb106f1e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation sanity check: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/prchao/miniconda3/envs/vaca/lib/python3.9/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
      "  warnings.warn(out)\n",
      "/home/prchao/miniconda3/envs/vaca/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:105: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
      "  rank_zero_warn(\n",
      "/workplace/prchao/VACA/datasets/transforms.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  return torch.tensor(sample)\n",
      "/workplace/prchao/VACA/data_modules/_scalers.py:82: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  x_norm = torch.tensor(x_norm)\n",
      "Global seed set to 10\n",
      "/home/prchao/miniconda3/envs/vaca/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:105: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
      "  rank_zero_warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a2a288a95bfd428ab73bf62685fc42f4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: -1it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/prchao/miniconda3/envs/vaca/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:283: LightningDeprecationWarning: The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3. `outputs` parameter has been deprecated. Support for the old signature will be removed in v1.5\n",
      "  self._warning_cache.deprecation(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving yaml: results/2nodes_5000_linear_lik_d_0.05_None/carefl/spline_mlp_normal_4_10/adam/0.005_0.9_0.999_1.2e-06_exp_lr_0.99/10/hparams_full.yaml\n"
     ]
    }
   ],
   "source": [
    "if is_training:\n",
    "    trainer.fit(model, data_module)\n",
    "    # save_yaml(model.get_arguments(), file_path=os.path.join(save_dir, 'hparams_model.yaml'))\n",
    "    argtools.save_yaml(cfg, file_path=os.path.join(save_dir, 'hparams_full.yaml'))\n",
    "    # %% Testing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37cc92e9",
   "metadata": {},
   "source": [
    "# TESTING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "9f3d74fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n",
    "# params = int(sum([p.numel() for p in model_parameters]))\n",
    "\n",
    "# model.eval()\n",
    "# model.freeze()  # IMPORTANT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "a3d04804",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# output_valid = model.evaluate(dataloader=data_module.val_dataloader(),\n",
    "#                               name='valid',\n",
    "#                               save_dir=save_dir,\n",
    "#                               plots=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "cb3ed4fe-6f78-4889-835f-4393e6ea7e76",
   "metadata": {},
   "outputs": [],
   "source": [
    "# path = os.path.join(save_dir,\"images\")\n",
    "\n",
    "# if not os.path.exists(path):\n",
    "#     os.makedirs(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "221c3eac",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# output_test = model.evaluate(dataloader=data_module.test_dataloader(),\n",
    "#                              name='test',\n",
    "#                              save_dir=save_dir,\n",
    "#                              plots=True)\n",
    "# output_valid.update(output_test)\n",
    "\n",
    "# output_valid.update(argtools.flatten_cfg(cfg))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "e11558e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import json\n",
    "# output_valid.update({'ckpt_file': ckpt_file,\n",
    "#                      'num_parameters': params})\n",
    "\n",
    "# with open(os.path.join(save_dir, 'output.json'), 'w') as f:\n",
    "#     json.dump(output_valid, f)\n",
    "# print(f'Experiment folder: {save_dir}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47fe8a36",
   "metadata": {},
   "source": [
    "# Custom interventions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "95583513",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataBatch(x=[6, 1], edge_index=[2, 9], edge_attr=[9, 3], u=[3, 2], mask=[6, 1], node_ids=[6, 2], x_i=[6, 1], edge_index_i=[2, 9], edge_attr_i=[9, 3], num_nodes=6, batch=[6], ptr=[4])\n"
     ]
    }
   ],
   "source": [
    "bs = data_module.batch_size\n",
    "data_module.batch_size = 3\n",
    "x_I = {'x1': 10}  # Intervention before normalizing\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_loader.dataset.set_intervention(x_I)\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_module.batch_size = bs\n",
    "\n",
    "batch = next(iter(data_loader))\n",
    "\n",
    "\n",
    "\n",
    "print(batch)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "9ae239fd-effc-4da0-aaac-deecd677b06d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[10.1029],\n",
       "        [ 2.2218],\n",
       "        [10.1029],\n",
       "        [ 0.0525],\n",
       "        [10.1029],\n",
       "        [ 1.2740]])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch.x_i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "6dd4ffc5",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'CAREFL' object has no attribute 'get_intervention'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Input \u001b[0;32mIn [31]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m x_hat, z \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_intervention\u001b[49m(batch,\n\u001b[1;32m      2\u001b[0m                          x_I\u001b[38;5;241m=\u001b[39mdata_loader\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39mx_I,\n\u001b[1;32m      3\u001b[0m                          nodes_list\u001b[38;5;241m=\u001b[39mdata_loader\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39mnodes_list,\n\u001b[1;32m      4\u001b[0m                          return_type \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msample\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;66;03m# mean or sample\u001b[39;00m\n\u001b[1;32m      5\u001b[0m                          use_aggregated_posterior \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m      6\u001b[0m                          normalize \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOriginal: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbatch\u001b[38;5;241m.\u001b[39mx\u001b[38;5;241m.\u001b[39mflatten()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIntervened: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbatch\u001b[38;5;241m.\u001b[39mx_i\u001b[38;5;241m.\u001b[39mflatten()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m~/miniconda3/envs/vaca/lib/python3.9/site-packages/torch/nn/modules/module.py:1185\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   1183\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m   1184\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1185\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m   1186\u001b[0m     \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, name))\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'CAREFL' object has no attribute 'get_intervention'"
     ]
    }
   ],
   "source": [
    "x_hat, z = model.get_intervention(batch,\n",
    "                         x_I=data_loader.dataset.x_I,\n",
    "                         nodes_list=data_loader.dataset.nodes_list,\n",
    "                         return_type = 'sample', # mean or sample\n",
    "                         use_aggregated_posterior = False,\n",
    "                         normalize = True)\n",
    "\n",
    "print(f\"Original: {batch.x.flatten()}\")\n",
    "print(f\"Intervened: {batch.x_i.flatten()}\")\n",
    "print(f\"Reconstructed: {x_hat.flatten()}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "459b0cfc",
   "metadata": {},
   "source": [
    "# Custom counterfactuals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a43a1a99-6c4f-487b-9021-ddbead5550a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9e24fb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = data_module.batch_size\n",
    "data_module.batch_size = 1\n",
    "x_I = {'x1': 10.0}  # Intervention before normalizing\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_loader.dataset.set_intervention(x_I,is_noise=False)\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_module.batch_size = bs\n",
    "\n",
    "batch = next(iter(data_loader))\n",
    "\n",
    "\n",
    "\n",
    "# print(batch)\n",
    "# x1-> x2 \n",
    "# x2 = x1 + N(0,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b04e229-34d5-4c0c-9102-97a7f211b67a",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "453f81ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_I = {'x1': 10.0}\n",
    "vaca_pred, gt_cf, factual = model.get_counterfactual_distr(data_loader,\n",
    "                                        x_I=x_I,\n",
    "                                        is_noise = False,\n",
    "                                        num_batches= 1,\n",
    "                                        normalize=False,\n",
    "                                        )\n",
    "\n",
    "vaca_pred['all'], factual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c22079af-ef5e-4840-8c89-451fbc377a29",
   "metadata": {},
   "outputs": [],
   "source": [
    "gt_cf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68a60dc6-a9c9-440a-8556-79f8a71f582d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
