{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "b662aaf7-dedc-4838-9191-734661adfc91",
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "invalid syntax (2034102531.py, line 24)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  Input \u001b[0;32mIn [42]\u001b[0;36m\u001b[0m\n\u001b[0;31m    if cfg = None:\u001b[0m\n\u001b[0m           ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import utils.args_parser  as argtools\n",
    "import pytorch_lightning as pl\n",
    "import numpy as np\n",
    "from utils.constants import Cte\n",
    "from data_modules.my_toy_scm import MyToySCMDataModule\n",
    "from utils.distributions import *\n",
    "from data_modules.het_scm import HeterogeneousSCMDataModule\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "from pytorch_lightning.loggers.tensorboard import TensorBoardLogger\n",
    "from models._evaluator import MyEvaluator\n",
    "\n",
    "\n",
    "\n",
    "def create_data(n,\n",
    "                seed,\n",
    "                structural_equations,\n",
    "                noise_distributions,\n",
    "                graph,\n",
    "                name,\n",
    "                equations_type,\n",
    "                cfg = None,\n",
    "                new_data = False):\n",
    "    if cfg = None:\n",
    "        model_file =   os.path.join('_params', 'model_vaca.yaml')\n",
    "        trainer_file =   os.path.join('_params', 'trainer.yaml')\n",
    "\n",
    "\n",
    "        cfg = argtools.parse_args(model_file)\n",
    "        cfg.update(argtools.parse_args(trainer_file))\n",
    "        # Config for new dataset\n",
    "        cfg['dataset'] = {\n",
    "            'params1': {},\n",
    "            'params2': {}\n",
    "        }\n",
    "\n",
    "        cfg['dataset']['params1'] = {\n",
    "            'data_dir': '../Data',\n",
    "            'batch_size': 1000,\n",
    "            'num_workers': 0\n",
    "        }\n",
    "\n",
    "        cfg['dataset']['params2'] = {\n",
    "            'num_samples_tr': None,\n",
    "            'equations_type': 'linear',\n",
    "            'normalize': 'lik',\n",
    "            'likelihood_names': 'd',\n",
    "            'lambda_': 0.05,\n",
    "            'normalize_A': None,\n",
    "        }\n",
    "\n",
    "        cfg['root_dir'] = 'results'\n",
    "        cfg['seed'] = None\n",
    "        pl.seed_everything(cfg['seed'])\n",
    "\n",
    "        cfg['dataset']['params'] = cfg['dataset']['params1'].copy()\n",
    "        cfg['dataset']['params'].update(cfg['dataset']['params2'])\n",
    "\n",
    "        cfg['dataset']['params']['data_dir'] = ''\n",
    "\n",
    "        cfg['trainer']['limit_train_batches'] = 1.0\n",
    "        cfg['trainer']['limit_val_batches'] = 1.0\n",
    "        cfg['trainer']['limit_test_batches'] = 1.0\n",
    "        cfg['trainer']['check_val_every_n_epoch'] = 10\n",
    "        cfg['dataset']['params'] = equations_type \n",
    "        cfg['dataset']['name'] = name\n",
    "        cfg['dataset']['params']['num_samples_tr'] = n\n",
    "        cfg['seed'] = seed\n",
    "\n",
    "    intervene_nodes = []\n",
    "    adj_edges = {}\n",
    "    \n",
    "    for node in g.nodes:\n",
    "        if g.out_degree[node] > 0:\n",
    "            intervene_nodes.append(node)\n",
    "        adj_edges[node] = (list(g.neighbors(node)))\n",
    "    \n",
    "    dataset_params = cfg['dataset']['params'].copy()\n",
    "    dataset_params['dataset_name'] = cfg['dataset']['name']\n",
    "\n",
    "    dataset_params['nodes_to_intervene'] = intervene_nodes\n",
    "    dataset_params['structural_eq'] = structural_equations\n",
    "    dataset_params['noises_distr'] = noise_distributions\n",
    "    \n",
    "    dataset_params['adj_edges'] = adj_edges\n",
    "\n",
    "    data_module = MyToySCMDataModule(**dataset_params)\n",
    "    data_module.prepare_data(new_data = new_data)\n",
    "    return data_module,cfg\n",
    "\n",
    "def get_train_data(data_module):\n",
    "    orig_bs = data_module.batch_size\n",
    "    data_module.batch_size = 1\n",
    "    train = data_module.train_dataloader()\n",
    "    n_points = len(data_module.train_dataloader())\n",
    "    data_module.batch_size = n_points\n",
    "    train = data_module.train_dataloader()\n",
    "    batch = next(iter(train))\n",
    "    data_module.batch_size = orig_bs\n",
    "    return batch.x.view(n,-1)\n",
    "\n",
    "\n",
    "def create_vaca_model(cfg):\n",
    " \n",
    "    model = None\n",
    "    model_params = cfg['model']['params'].copy()\n",
    "\n",
    "    #print(f\"\\nUsing model: {cfg['model']['name']}\")\n",
    "\n",
    "\n",
    "    # VACA\n",
    "    if cfg['model']['name'] == Cte.VACA:\n",
    "        from models.vaca.vaca import VACA\n",
    "\n",
    "        model_params['is_heterogeneous'] = data_module.is_heterogeneous\n",
    "        model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "        model_params['deg'] = data_module.get_deg(indegree=True)\n",
    "        model_params['num_nodes'] = data_module.num_nodes\n",
    "        model_params['edge_dim'] = data_module.edge_dimension\n",
    "        model_params['scaler'] = data_module.scaler\n",
    "\n",
    "        model = VACA(**model_params)\n",
    "        model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "    # VACA with PIWAE\n",
    "    elif cfg['model']['name'] == Cte.VACA_PIWAE:\n",
    "        from models.vaca.vaca_piwae import VACA_PIWAE\n",
    "\n",
    "        model_params['is_heterogeneous'] = data_module.is_heterogeneous\n",
    "\n",
    "        model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "        model_params['deg'] = data_module.get_deg(indegree=True)\n",
    "        model_params['num_nodes'] = data_module.num_nodes\n",
    "        model_params['edge_dim'] = data_module.edge_dimension\n",
    "        model_params['scaler'] = data_module.scaler\n",
    "\n",
    "        model = VACA_PIWAE(**model_params)\n",
    "        model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "\n",
    "\n",
    "    # MultiCVAE\n",
    "    elif cfg['model']['name'] == Cte.MCVAE:\n",
    "        from models.multicvae.multicvae import MCVAE\n",
    "\n",
    "        model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "        model_params['topological_node_dims'] = data_module.train_dataset.get_node_columns_in_X()\n",
    "        model_params['topological_parents'] = data_module.topological_parents\n",
    "        model_params['scaler'] = data_module.scaler\n",
    "        model_params['num_epochs_per_nodes'] = int(\n",
    "            np.floor((cfg['trainer']['max_epochs'] / len(data_module.topological_nodes))))\n",
    "        model = MCVAE(**model_params)\n",
    "        model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "        cfg['early_stopping'] = False\n",
    "\n",
    "    # CAREFL\n",
    "    elif cfg['model']['name'] == Cte.CAREFL:\n",
    "        from models.carefl.carefl import CAREFL\n",
    "\n",
    "        model_params['node_per_dimension_list'] = data_module.train_dataset.node_per_dimension_list\n",
    "        model_params['scaler'] = data_module.scaler\n",
    "        model = CAREFL(**model_params)\n",
    "\n",
    "    model.set_optim_params(optim_params=cfg['optimizer'],\n",
    "                           sched_params=cfg['scheduler'])\n",
    "    return model\n",
    "\n",
    "def fit_vaca(model,cfg, data_module):\n",
    "    \n",
    "\n",
    "    is_training = True\n",
    "    load = True\n",
    "\n",
    "    print(f'Is training activated? {is_training}')\n",
    "    print(f'Is loading activated? {load}')\n",
    "    if yaml_file == '':\n",
    "        if (cfg['dataset']['name'] in [Cte.GERMAN]) and (cfg['dataset']['params3']['train_kfold'] == True):\n",
    "            save_dir = argtools.mkdir(os.path.join(cfg['root_dir'],\n",
    "                                                   argtools.get_experiment_folder(cfg),\n",
    "                                                   str(cfg['seed']), str(cfg['dataset']['params3']['kfold_idx'])))\n",
    "        else:\n",
    "            save_dir = argtools.mkdir(os.path.join(cfg['root_dir'],\n",
    "                                                   argtools.get_experiment_folder(cfg),\n",
    "                                                   str(cfg['seed'])))\n",
    "    else:\n",
    "        save_dir = os.path.join(*yaml_file.split('/')[:-1])\n",
    "\n",
    "    logger = TensorBoardLogger(save_dir=save_dir, name='logs', default_hp_metric=False)\n",
    "\n",
    "    out = logger.log_hyperparams(argtools.flatten_cfg(cfg))\n",
    "\n",
    "    save_dir_ckpt = argtools.mkdir(os.path.join(save_dir, 'ckpt'))\n",
    "    if load:\n",
    "        ckpt_file = argtools.newest(save_dir_ckpt)\n",
    "    else:\n",
    "        ckpt_file = None\n",
    "    callbacks = []\n",
    "\n",
    "\n",
    "\n",
    "    evaluator = MyEvaluator(model=model,\n",
    "                            intervention_list=data_module.train_dataset.get_intervention_list(),\n",
    "                            scaler=data_module.scaler\n",
    "                            )\n",
    "    model.set_my_evaluator(evaluator=evaluator)\n",
    "\n",
    "    if is_training:\n",
    "        checkpoint = ModelCheckpoint(period=1,\n",
    "                                     monitor=model.monitor(),\n",
    "                                     mode=model.monitor_mode(),\n",
    "                                     save_top_k=1,\n",
    "                                     save_last=True,\n",
    "                                     filename='checkpoint-{epoch:02d}',\n",
    "                                     dirpath=save_dir_ckpt)\n",
    "        callbacks = [checkpoint]\n",
    "\n",
    "\n",
    "        if cfg['early_stopping']:\n",
    "            early_stopping = EarlyStopping(model.monitor(), mode=model.monitor_mode(), min_delta=0.0, patience=50)\n",
    "            callbacks.append(early_stopping)\n",
    "        trainer = pl.Trainer(logger=logger, callbacks=callbacks, **cfg['trainer'])\n",
    "\n",
    "    if load:\n",
    "        if ckpt_file is None:\n",
    "            print(f'No ckpt files in {save_dir_ckpt}')\n",
    "        else:\n",
    "            print(f'\\nLoading from: {ckpt_file}')\n",
    "            if is_training:\n",
    "                trainer = pl.Trainer(logger=logger, callbacks=callbacks, resume_from_checkpoint=ckpt_file,\n",
    "                                 **cfg['trainer'])\n",
    "            else:\n",
    "\n",
    "                model = model.load_from_checkpoint(ckpt_file, **model_params)\n",
    "                evaluator.set_model(model)\n",
    "                model.set_my_evaluator(evaluator=evaluator)\n",
    "\n",
    "                if cfg['model']['name'] in [Cte.VACA_PIWAE, Cte.VACA, Cte.MCVAE]:\n",
    "                    model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "    path = os.path.join(save_dir,\"logs\")\n",
    "\n",
    "    if not os.path.exists(path):\n",
    "        os.makedirs(path)\n",
    "\n",
    "    if is_training:\n",
    "        trainer.fit(model, data_module)\n",
    "        # save_yaml(model.get_arguments(), file_path=os.path.join(save_dir, 'hparams_model.yaml'))\n",
    "        argtools.save_yaml(cfg, file_path=os.path.join(save_dir, 'hparams_full.yaml'))\n",
    "        # %% Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ffd2e10-4d64-4fd6-98da-74c855c580fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def vaca_cf(model, cfg, intervention, n_samples = 100):\n",
    "    n = cfg['dataset']['params']['num_samples_tr']\n",
    "    cfg['dataset']['params']['num_samples_tr'] = n_samples\n",
    "    data_module = create_data(cfg,new_data=True)\n",
    "\n",
    "    data_module.batch_size = 1\n",
    "    x_I = intervention  # Intervention before normalizing\n",
    "    data_loader = data_module.test_dataloader()\n",
    "    data_loader.dataset.set_intervention(x_I,is_noise=False)\n",
    "    data_loader = data_module.test_dataloader()\n",
    "\n",
    "    batch = next(iter(data_loader))\n",
    "    cfg['dataset']['params']['num_samples_tr'] = n\n",
    "    vaca_pred, gt_cf, factual = model.get_counterfactual_distr(data_loader,\n",
    "                                            x_I=x_I,\n",
    "                                            is_noise = False,\n",
    "                                            num_batches= 1,\n",
    "                                            normalize=False,\n",
    "                                            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c7601066-1abb-4ada-987e-3ffe8cc659e4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'intervened': tensor([[10.],\n",
       "         [10.]]),\n",
       " 'children': tensor([[9.2288],\n",
       "         [8.3016]]),\n",
       " 'all': tensor([[10.0000,  9.2288],\n",
       "         [10.0000,  8.3016]])}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cfg['dataset']['params']['num_samples_tr']=10\n",
    "data_module = create_data(cfg,new_data=True)\n",
    "bs = data_module.batch_size\n",
    "data_module.batch_size = 1\n",
    "x_I = {'x1': 10.0}  # Intervention before normalizing\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_loader.dataset.set_intervention(x_I,is_noise=False)\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_module.batch_size = bs\n",
    "\n",
    "batch = next(iter(data_loader))\n",
    "\n",
    "vaca_pred, gt_cf, factual = model.get_counterfactual_distr(data_loader,\n",
    "                                        x_I=x_I,\n",
    "                                        is_noise = False,\n",
    "                                        num_batches= 1,\n",
    "                                        normalize=False,\n",
    "                                        )\n",
    "\n",
    "gt_cf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cec9deb1-a1f3-4226-9056-fb6970c8ca00",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'intervened': tensor([[10.],\n",
       "         [10.]]),\n",
       " 'children': tensor([[10.5870],\n",
       "         [10.4335]]),\n",
       " 'all': tensor([[10.0000, 10.5870],\n",
       "         [10.0000, 10.4335]])}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gt_cf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "504bd977-e54d-4138-bf8d-d27eb0fc953c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'all': tensor([[ 2.4700,  3.0570],\n",
       "         [-0.3974,  0.0362]])}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "factual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "ec34046b-d379-4eda-97a5-f1680ecaf9e0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "6cb9c3d7-e6c1-4f2b-809e-80f042810fb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "a52efaa2-926b-49ed-ba0c-4e68e05daed5",
   "metadata": {},
   "outputs": [],
   "source": [
    "temp = batch.x.view(-1,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "6c5ce012-070b-4ad4-8ddd-d0b02f4e2530",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.5491],\n",
       "        [-1.5926],\n",
       "        [-0.6784],\n",
       "        ...,\n",
       "        [-0.0165],\n",
       "        [ 0.7259],\n",
       "        [ 0.4806]])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch.x"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
