{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28f00531",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import utils.args_parser  as argtools\n",
    "import pytorch_lightning as pl\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6076a2c",
   "metadata": {},
   "source": [
    "# LOAD CONFIG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ce904cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "use_custom_dataset = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "babcb0b2",
   "metadata": {},
   "source": [
    "### Option 1: Datasets from the paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89eee561",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not use_custom_dataset:\n",
    "    print('Using dataset from the paper')\n",
    "    dataset_file =  os.path.join('_params', 'dataset_adult.yaml')\n",
    "    model_file =   os.path.join('_params', 'model_vaca.yaml')\n",
    "    trainer_file =   os.path.join('_params', 'trainer.yaml')\n",
    "\n",
    "    yaml_file = ''\n",
    "    \n",
    "    if yaml_file == '':\n",
    "        cfg = argtools.parse_args(dataset_file)\n",
    "        cfg.update(argtools.parse_args(model_file))\n",
    "        cfg.update(argtools.parse_args(trainer_file))\n",
    "    else:\n",
    "        cfg = argtools.parse_args(yaml_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f11780b1",
   "metadata": {},
   "source": [
    "### Option 2: New dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bae3aa70",
   "metadata": {},
   "outputs": [],
   "source": [
    "if use_custom_dataset:\n",
    "    print('Using custom dataset')\n",
    "    model_file =   os.path.join('_params', 'model_vaca.yaml')\n",
    "    trainer_file =   os.path.join('_params', 'trainer.yaml')\n",
    "\n",
    "    yaml_file = ''\n",
    "    if yaml_file == '':\n",
    "        cfg = argtools.parse_args(model_file)\n",
    "        cfg.update(argtools.parse_args(trainer_file))\n",
    "    else:\n",
    "        cfg = argtools.parse_args(yaml_file)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "125e8421",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Config for new dataset\n",
    "\n",
    "cfg['dataset'] = {\n",
    "    'name': '2nodes',\n",
    "    'params1': {},\n",
    "    'params2': {}\n",
    "}\n",
    "\n",
    "cfg['dataset']['params1'] = {\n",
    "    'data_dir': '../Data',\n",
    "    'batch_size': 1000,\n",
    "    'num_workers': 0\n",
    "}\n",
    "\n",
    "cfg['dataset']['params2'] = {\n",
    "    'num_samples_tr': 5000,\n",
    "    'equations_type': 'linear',\n",
    "    'normalize': 'lik',\n",
    "    'likelihood_names': 'd',\n",
    "    'lambda_': 0.05,\n",
    "    'normalize_A': None,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "194f6abc",
   "metadata": {},
   "source": [
    "### You can also update any parameter manually"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09a87058",
   "metadata": {},
   "outputs": [],
   "source": [
    "    \n",
    "cfg['root_dir'] = 'results'\n",
    "cfg['seed'] = 0\n",
    "pl.seed_everything(cfg['seed'])\n",
    "\n",
    "cfg['dataset']['params'] = cfg['dataset']['params1'].copy()\n",
    "cfg['dataset']['params'].update(cfg['dataset']['params2'])\n",
    "\n",
    "cfg['dataset']['params']['data_dir'] = ''\n",
    "\n",
    "cfg['trainer']['limit_train_batches'] = 1.0\n",
    "cfg['trainer']['limit_val_batches'] = 1.0\n",
    "cfg['trainer']['limit_test_batches'] = 1.0\n",
    "cfg['trainer']['check_val_every_n_epoch'] = 10\n",
    "\n",
    "\n",
    "def print_if_not_dict(key, value, extra=''):\n",
    "    if not isinstance(value, dict):\n",
    "        print(f\"{extra}{key}: {value}\")\n",
    "        return True\n",
    "    else:\n",
    "        print(f\"{extra}{key}:\")\n",
    "        False\n",
    "        \n",
    "for key, value in cfg.items():\n",
    "    if not print_if_not_dict(key, value):\n",
    "        for key2, value2 in value.items():\n",
    "            if not print_if_not_dict(key2, value2, extra='\\t'):\n",
    "                for key3, value3 in value2.items():\n",
    "                    print_if_not_dict(key3, value3, extra='\\t\\t')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb1b2408",
   "metadata": {},
   "source": [
    "# LOAD DATASET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e528e48a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.constants import Cte\n",
    "\n",
    "\n",
    "print('These are datasets from the paper:')\n",
    "for data_name in Cte.DATASET_LIST:\n",
    "    print(f\"\\t{data_name}\")\n",
    "    \n",
    "\n",
    "\n",
    "print(f\"\\nUsing dataset: {cfg['dataset']['name']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e63f7e0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if cfg['dataset']['name'] in Cte.DATASET_LIST:\n",
    "    from data_modules.het_scm import HeterogeneousSCMDataModule\n",
    "\n",
    "    dataset_params = cfg['dataset']['params'].copy()\n",
    "    dataset_params['dataset_name'] = cfg['dataset']['name']\n",
    "\n",
    "    data_module = HeterogeneousSCMDataModule(**dataset_params)\n",
    "    data_module.prepare_data()\n",
    "\n",
    "elif cfg['dataset']['name']  == '2nodes':\n",
    "    from data_modules.my_toy_scm import MyToySCMDataModule\n",
    "    from utils.distributions import *\n",
    "    \n",
    "    dataset_params = cfg['dataset']['params'].copy()\n",
    "    dataset_params['dataset_name'] = cfg['dataset']['name']\n",
    "    \n",
    "    dataset_params['nodes_to_intervene'] = ['x1']\n",
    "    dataset_params['structural_eq'] = {'x1': lambda u1: u1,\n",
    "                                            'x2': lambda u2, x1: u2 + x1}\n",
    "    dataset_params['noises_distr'] = {'x1': Normal(0,1),\n",
    "                                           'x2': Normal(0,1)}\n",
    "    dataset_params['adj_edges'] = {'x1': ['x2'],\n",
    "                                        'x2': []}\n",
    "    \n",
    "    data_module = MyToySCMDataModule(**dataset_params)\n",
    "    data_module.prepare_data()\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93320b26",
   "metadata": {},
   "source": [
    "# LOAD MODEL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14e9ca3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = None\n",
    "model_params = cfg['model']['params'].copy()\n",
    "\n",
    "print(f\"\\nUsing model: {cfg['model']['name']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcb87ef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# VACA\n",
    "if cfg['model']['name'] == Cte.VACA:\n",
    "    from models.vaca.vaca import VACA\n",
    "\n",
    "    model_params['is_heterogeneous'] = data_module.is_heterogeneous\n",
    "    model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "    model_params['deg'] = data_module.get_deg(indegree=True)\n",
    "    model_params['num_nodes'] = data_module.num_nodes\n",
    "    model_params['edge_dim'] = data_module.edge_dimension\n",
    "    model_params['scaler'] = data_module.scaler\n",
    "\n",
    "    model = VACA(**model_params)\n",
    "    model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "# VACA with PIWAE\n",
    "elif cfg['model']['name'] == Cte.VACA_PIWAE:\n",
    "    from models.vaca.vaca_piwae import VACA_PIWAE\n",
    "\n",
    "    model_params['is_heterogeneous'] = data_module.is_heterogeneous\n",
    "\n",
    "    model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "    model_params['deg'] = data_module.get_deg(indegree=True)\n",
    "    model_params['num_nodes'] = data_module.num_nodes\n",
    "    model_params['edge_dim'] = data_module.edge_dimension\n",
    "    model_params['scaler'] = data_module.scaler\n",
    "\n",
    "    model = VACA_PIWAE(**model_params)\n",
    "    model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "\n",
    "\n",
    "\n",
    "# MultiCVAE\n",
    "elif cfg['model']['name'] == Cte.MCVAE:\n",
    "    from models.multicvae.multicvae import MCVAE\n",
    "\n",
    "    model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "    model_params['topological_node_dims'] = data_module.train_dataset.get_node_columns_in_X()\n",
    "    model_params['topological_parents'] = data_module.topological_parents\n",
    "    model_params['scaler'] = data_module.scaler\n",
    "    model_params['num_epochs_per_nodes'] = int(\n",
    "        np.floor((cfg['trainer']['max_epochs'] / len(data_module.topological_nodes))))\n",
    "    model = MCVAE(**model_params)\n",
    "    model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "    cfg['early_stopping'] = False\n",
    "\n",
    "# CAREFL\n",
    "elif cfg['model']['name'] == Cte.CARELF:\n",
    "    from models.carefl.carefl import CAREFL\n",
    "\n",
    "    model_params['node_per_dimension_list'] = data_module.train_dataset.node_per_dimension_list\n",
    "    model_params['scaler'] = data_module.scaler\n",
    "    model = CAREFL(**model_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75a391bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "model.summarize()\n",
    "model.set_optim_params(optim_params=cfg['optimizer'],\n",
    "                       sched_params=cfg['scheduler'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c97e0086",
   "metadata": {},
   "source": [
    "# LOAD EVALUATOR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1aa7677",
   "metadata": {},
   "outputs": [],
   "source": [
    "from models._evaluator import MyEvaluator\n",
    "\n",
    "evaluator = MyEvaluator(model=model,\n",
    "                        intervention_list=data_module.train_dataset.get_intervention_list(),\n",
    "                        scaler=data_module.scaler\n",
    "                        )\n",
    "model.set_my_evaluator(evaluator=evaluator)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "781f3176",
   "metadata": {},
   "outputs": [],
   "source": [
    "for intervention in data_module.train_dataset.get_intervention_list():\n",
    "    inter_dict, name = intervention\n",
    "    print(f'Interventiona name: {name}')\n",
    "    for node_name, value in inter_dict.items():\n",
    "        print(f\"\\t{node_name}: {value}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e512cc3e",
   "metadata": {},
   "source": [
    "# PREPARE TRAINING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c0a45fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "from pytorch_lightning.loggers.tensorboard import TensorBoardLogger\n",
    "\n",
    "\n",
    "is_training = False\n",
    "load = True\n",
    "\n",
    "print(f'Is training activated? {is_training}')\n",
    "print(f'Is loading activated? {load}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2290222",
   "metadata": {},
   "outputs": [],
   "source": [
    "if yaml_file == '':\n",
    "    if (cfg['dataset']['name'] in [Cte.GERMAN]) and (cfg['dataset']['params3']['train_kfold'] == True):\n",
    "        save_dir = argtools.mkdir(os.path.join(cfg['root_dir'],\n",
    "                                               argtools.get_experiment_folder(cfg),\n",
    "                                               str(cfg['seed']), str(cfg['dataset']['params3']['kfold_idx'])))\n",
    "    else:\n",
    "        save_dir = argtools.mkdir(os.path.join(cfg['root_dir'],\n",
    "                                               argtools.get_experiment_folder(cfg),\n",
    "                                               str(cfg['seed'])))\n",
    "else:\n",
    "    save_dir = os.path.join(*yaml_file.split('/')[:-1])\n",
    "print(f'Save dir: {save_dir}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f4b953",
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = TensorBoardLogger(save_dir=save_dir, name='logs', default_hp_metric=False)\n",
    "\n",
    "out = logger.log_hyperparams(argtools.flatten_cfg(cfg))\n",
    "\n",
    "save_dir_ckpt = argtools.mkdir(os.path.join(save_dir, 'ckpt'))\n",
    "if load:\n",
    "    ckpt_file = argtools.newest(save_dir_ckpt)\n",
    "else:\n",
    "    ckpt_file = None\n",
    "callbacks = []\n",
    "\n",
    "print(f\"ckpt_file: {ckpt_file}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b50da25b",
   "metadata": {},
   "outputs": [],
   "source": [
    "if is_training:\n",
    "    checkpoint = ModelCheckpoint(period=1,\n",
    "                                 monitor=model.monitor(),\n",
    "                                 mode=model.monitor_mode(),\n",
    "                                 save_top_k=1,\n",
    "                                 save_last=True,\n",
    "                                 filename='checkpoint-{epoch:02d}',\n",
    "                                 dirpath=save_dir_ckpt)\n",
    "    callbacks = [checkpoint]\n",
    "\n",
    "    \n",
    "    if cfg['early_stopping']:\n",
    "        early_stopping = EarlyStopping(model.monitor(), mode=model.monitor_mode(), min_delta=0.0, patience=50)\n",
    "        callbacks.append(early_stopping)\n",
    "    trainer = pl.Trainer(logger=logger, callbacks=callbacks, **cfg['trainer'])\n",
    "    \n",
    "if load:\n",
    "    if ckpt_file is None:\n",
    "        print(f'No ckpt files in {save_dir_ckpt}')\n",
    "    else:\n",
    "        print(f'\\nLoading from: {ckpt_file}')\n",
    "        if is_training:\n",
    "            trainer = pl.Trainer(logger=logger, callbacks=callbacks, resume_from_checkpoint=ckpt_file,\n",
    "                             **cfg['trainer'])\n",
    "        else:\n",
    "\n",
    "            model = model.load_from_checkpoint(ckpt_file, **model_params)\n",
    "            evaluator.set_model(model)\n",
    "            model.set_my_evaluator(evaluator=evaluator)\n",
    "\n",
    "            if cfg['model']['name'] in [Cte.VACA_PIWAE, Cte.VACA, Cte.MCVAE]:\n",
    "                model.set_random_train_sampler(data_module.get_random_train_sampler())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb106f1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "if is_training:\n",
    "    trainer.fit(model, data_module)\n",
    "    # save_yaml(model.get_arguments(), file_path=os.path.join(save_dir, 'hparams_model.yaml'))\n",
    "    argtools.save_yaml(cfg, file_path=os.path.join(save_dir, 'hparams_full.yaml'))\n",
    "    # %% Testing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37cc92e9",
   "metadata": {},
   "source": [
    "# TESTING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f3d74fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n",
    "params = int(sum([p.numel() for p in model_parameters]))\n",
    "\n",
    "model.eval()\n",
    "model.freeze()  # IMPORTANT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3d04804",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_valid = model.evaluate(dataloader=data_module.val_dataloader(),\n",
    "                              name='valid',\n",
    "                              save_dir=save_dir,\n",
    "                              plots=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "221c3eac",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_test = model.evaluate(dataloader=data_module.test_dataloader(),\n",
    "                             name='test',\n",
    "                             save_dir=save_dir,\n",
    "                             plots=True)\n",
    "output_valid.update(output_test)\n",
    "\n",
    "output_valid.update(argtools.flatten_cfg(cfg))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e11558e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "output_valid.update({'ckpt_file': ckpt_file,\n",
    "                     'num_parameters': params})\n",
    "\n",
    "with open(os.path.join(save_dir, 'output.json'), 'w') as f:\n",
    "    json.dump(output_valid, f)\n",
    "print(f'Experiment folder: {save_dir}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47fe8a36",
   "metadata": {},
   "source": [
    "# Custom interventions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95583513",
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = data_module.batch_size\n",
    "data_module.batch_size = 1\n",
    "x_I = {'x1': 2.4721} # Intervention before normalizing\n",
    "x_I = {'x1': 0.0}  # Intervention before normalizing\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_loader.dataset.set_intervention(x_I)\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_module.batch_size = bs\n",
    "\n",
    "batch = next(iter(data_loader))\n",
    "\n",
    "\n",
    "\n",
    "print(batch)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dd4ffc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_hat, z = model.get_intervention(batch,\n",
    "                         x_I=data_loader.dataset.x_I,\n",
    "                         nodes_list=data_loader.dataset.nodes_list,\n",
    "                         return_type = 'sample', # mean or sample\n",
    "                         use_aggregated_posterior = False,\n",
    "                         normalize = True)\n",
    "\n",
    "print(f\"Original: {batch.x.flatten()}\")\n",
    "print(f\"Intervened: {batch.x_i.flatten()}\")\n",
    "print(f\"Reconstructed: {x_hat.flatten()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "459b0cfc",
   "metadata": {},
   "source": [
    "# Custom counterfactuals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9e24fb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = data_module.batch_size\n",
    "data_module.batch_size = 1\n",
    "x_I = {'x1': 2.4721} # Intervention before normalizing\n",
    "x_I = {'x1': 0.0}  # Intervention before normalizing\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_loader.dataset.set_intervention(x_I)\n",
    "data_loader = data_module.test_dataloader()~\n",
    "data_module.batch_size = bs\n",
    "\n",
    "batch = next(iter(data_loader))\n",
    "\n",
    "\n",
    "\n",
    "print(batch)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "453f81ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_CF, z_factual, z_cf_I, z_dec = model.compute_counterfactual(batch=batch,\n",
    "                                        x_I=data_loader.dataset.x_I,\n",
    "                                        nodes_list=data_loader.dataset.nodes_list,\n",
    "                                        normalize=True,\n",
    "                                        return_type='sample')\n",
    "\n",
    "print(f\"Original: {batch.x.flatten()}\")\n",
    "print(f\"Counterfactual: {batch.x_i.flatten()}\")\n",
    "print(f\"Reconstructed: {x_CF.flatten()}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
