{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b662aaf7-dedc-4838-9191-734661adfc91",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "module 'tqdm' has no attribute 'auto'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mos\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01margs_parser\u001b[39;00m\u001b[38;5;250m  \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01margtools\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpytorch_lightning\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpl\u001b[39;00m\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconstants\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Cte\n",
      "File \u001b[0;32m/opt/anaconda3/envs/vaca/lib/python3.10/site-packages/pytorch_lightning/__init__.py:20\u001b[0m\n\u001b[1;32m     17\u001b[0m _PACKAGE_ROOT \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mdirname(\u001b[38;5;18m__file__\u001b[39m)\n\u001b[1;32m     18\u001b[0m _PROJECT_ROOT \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mdirname(_PACKAGE_ROOT)\n\u001b[0;32m---> 20\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpytorch_lightning\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m metrics  \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n\u001b[1;32m     21\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpytorch_lightning\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcallbacks\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Callback  \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n\u001b[1;32m     22\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpytorch_lightning\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcore\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m LightningDataModule, LightningModule  \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n",
      "File \u001b[0;32m/opt/anaconda3/envs/vaca/lib/python3.10/site-packages/pytorch_lightning/metrics/__init__.py:15\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# Copyright The PyTorch Lightning team.\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;66;03m# Licensed under the Apache License, Version 2.0 (the \"License\");\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[38;5;66;03m# See the License for the specific language governing permissions and\u001b[39;00m\n\u001b[1;32m     13\u001b[0m \u001b[38;5;66;03m# limitations under the License.\u001b[39;00m\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpytorch_lightning\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mclassification\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m (  \u001b[38;5;66;03m# noqa: F401\u001b[39;00m\n\u001b[1;32m     16\u001b[0m     Accuracy,\n\u001b[1;32m     17\u001b[0m     AUC,\n\u001b[1;32m     18\u001b[0m     AUROC,\n\u001b[1;32m     19\u001b[0m     AveragePrecision,\n\u001b[1;32m     20\u001b[0m     ConfusionMatrix,\n\u001b[1;32m     21\u001b[0m     F1,\n\u001b[1;32m     22\u001b[0m     FBeta,\n\u001b[1;32m     23\u001b[0m     HammingDistance,\n\u001b[1;32m     24\u001b[0m     IoU,\n\u001b[1;32m     25\u001b[0m     Precision,\n\u001b[1;32m     26\u001b[0m     PrecisionRecallCurve,\n\u001b[1;32m     27\u001b[0m     Recall,\n\u001b[1;32m     28\u001b[0m     ROC,\n\u001b[1;32m     29\u001b[0m     StatScores,\n\u001b[1;32m     30\u001b[0m )\n\u001b[1;32m     31\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpytorch_lightning\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetric\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Metric, MetricCollection  \u001b[38;5;66;03m# noqa: F401\u001b[39;00m\n\u001b[1;32m     32\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpytorch_lightning\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mregression\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m (  \u001b[38;5;66;03m# noqa: F401\u001b[39;00m\n\u001b[1;32m     33\u001b[0m     ExplainedVariance,\n\u001b[1;32m     34\u001b[0m     MeanAbsoluteError,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     39\u001b[0m     SSIM,\n\u001b[1;32m     40\u001b[0m )\n",
      "File \u001b[0;32m/opt/anaconda3/envs/vaca/lib/python3.10/site-packages/pytorch_lightning/metrics/classification/__init__.py:14\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# Copyright The PyTorch Lightning team.\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;66;03m# Licensed under the Apache License, Version 2.0 (the \"License\");\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[38;5;66;03m# See the License for the specific language governing permissions and\u001b[39;00m\n\u001b[1;32m     13\u001b[0m \u001b[38;5;66;03m# limitations under the License.\u001b[39;00m\n\u001b[0;32m---> 14\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpytorch_lightning\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mclassification\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01maccuracy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Accuracy  \u001b[38;5;66;03m# noqa: F401\u001b[39;00m\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpytorch_lightning\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mclassification\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mauc\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m AUC  \u001b[38;5;66;03m# noqa: F401\u001b[39;00m\n\u001b[1;32m     16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpytorch_lightning\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mclassification\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mauroc\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m AUROC  \u001b[38;5;66;03m# noqa: F401\u001b[39;00m\n",
      "File \u001b[0;32m/opt/anaconda3/envs/vaca/lib/python3.10/site-packages/pytorch_lightning/metrics/classification/accuracy.py:16\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# Copyright The PyTorch Lightning team.\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;66;03m# Licensed under the Apache License, Version 2.0 (the \"License\");\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[38;5;66;03m# See the License for the specific language governing permissions and\u001b[39;00m\n\u001b[1;32m     13\u001b[0m \u001b[38;5;66;03m# limitations under the License.\u001b[39;00m\n\u001b[1;32m     14\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtyping\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Any, Callable, Optional\n\u001b[0;32m---> 16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorchmetrics\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Accuracy \u001b[38;5;28;01mas\u001b[39;00m _Accuracy\n\u001b[1;32m     18\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpytorch_lightning\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m deprecated_metrics, void\n\u001b[1;32m     21\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mAccuracy\u001b[39;00m(_Accuracy):\n",
      "File \u001b[0;32m/opt/anaconda3/envs/vaca/lib/python3.10/site-packages/torchmetrics/__init__.py:14\u001b[0m\n\u001b[1;32m     11\u001b[0m _PACKAGE_ROOT \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mdirname(\u001b[38;5;18m__file__\u001b[39m)\n\u001b[1;32m     12\u001b[0m _PROJECT_ROOT \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mdirname(_PACKAGE_ROOT)\n\u001b[0;32m---> 14\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorchmetrics\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m functional  \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorchmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01maudio\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m PIT, SI_SDR, SI_SNR, SNR  \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n\u001b[1;32m     16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorchmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01maverage\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m AverageMeter  \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n",
      "File \u001b[0;32m/opt/anaconda3/envs/vaca/lib/python3.10/site-packages/torchmetrics/functional/__init__.py:60\u001b[0m\n\u001b[1;32m     58\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorchmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctional\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mretrieval\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mreciprocal_rank\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m retrieval_reciprocal_rank\n\u001b[1;32m     59\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorchmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctional\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mself_supervised\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m embedding_similarity\n\u001b[0;32m---> 60\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorchmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctional\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtext\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbert\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m bert_score\n\u001b[1;32m     61\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorchmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctional\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtext\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbleu\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m bleu_score\n\u001b[1;32m     62\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorchmetrics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctional\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtext\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrouge\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m rouge_score\n",
      "File \u001b[0;32m/opt/anaconda3/envs/vaca/lib/python3.10/site-packages/torchmetrics/functional/text/bert.py:247\u001b[0m\n\u001b[1;32m    243\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m idf:\n\u001b[1;32m    244\u001b[0m             \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtokens_idf \u001b[38;5;241m=\u001b[39m tokens_idf \u001b[38;5;28;01mif\u001b[39;00m tokens_idf \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_tokens_idf()\n\u001b[0;32m--> 247\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_get_progress_bar\u001b[39m(dataloader: DataLoader, verbose: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[DataLoader, \u001b[43mtqdm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mauto\u001b[49m\u001b[38;5;241m.\u001b[39mtqdm]:\n\u001b[1;32m    248\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"Helper function returning either the dataloader itself when `verbose = False`, or it wraps the dataloader with\u001b[39;00m\n\u001b[1;32m    249\u001b[0m \u001b[38;5;124;03m    `tqdm.auto.tqdm`, when `verbose = True` to display a progress bar during the embbeddings calculation.\"\"\"\u001b[39;00m\n\u001b[1;32m    250\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m tqdm\u001b[38;5;241m.\u001b[39mauto\u001b[38;5;241m.\u001b[39mtqdm(dataloader) \u001b[38;5;28;01mif\u001b[39;00m verbose \u001b[38;5;28;01melse\u001b[39;00m dataloader\n",
      "\u001b[0;31mAttributeError\u001b[0m: module 'tqdm' has no attribute 'auto'"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import utils.args_parser  as argtools\n",
    "import pytorch_lightning as pl\n",
    "import numpy as np\n",
    "from utils.constants import Cte\n",
    "from data_modules.my_toy_scm import MyToySCMDataModule\n",
    "from utils.distributions import *\n",
    "from data_modules.het_scm import HeterogeneousSCMDataModule\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "from pytorch_lightning.loggers.tensorboard import TensorBoardLogger\n",
    "from models._evaluator import MyEvaluator\n",
    "\n",
    "\n",
    "\n",
    "def create_data(n,\n",
    "                seed,\n",
    "                structural_equations,\n",
    "                noise_distributions,\n",
    "                graph,\n",
    "                name,\n",
    "                equations_type,\n",
    "                cfg = None,\n",
    "                new_data = False):\n",
    "    if cfg == None:\n",
    "        model_file =   os.path.join('_params', 'model_vaca.yaml')\n",
    "        trainer_file =   os.path.join('_params', 'trainer.yaml')\n",
    "\n",
    "\n",
    "        cfg = argtools.parse_args(model_file)\n",
    "        cfg.update(argtools.parse_args(trainer_file))\n",
    "        # Config for new dataset\n",
    "        cfg['dataset'] = {\n",
    "            'params1': {},\n",
    "            'params2': {}\n",
    "        }\n",
    "\n",
    "        cfg['dataset']['params1'] = {\n",
    "            'data_dir': '../Data',\n",
    "            'batch_size': 1000,\n",
    "            'num_workers': 0\n",
    "        }\n",
    "\n",
    "        cfg['dataset']['params2'] = {\n",
    "            'num_samples_tr': None,\n",
    "            'equations_type': 'linear',\n",
    "            'normalize': 'lik',\n",
    "            'likelihood_names': 'd',\n",
    "            'lambda_': 0.05,\n",
    "            'normalize_A': None,\n",
    "        }\n",
    "\n",
    "        cfg['root_dir'] = 'results'\n",
    "        cfg['seed'] = None\n",
    "        pl.seed_everything(cfg['seed'])\n",
    "\n",
    "        cfg['dataset']['params'] = cfg['dataset']['params1'].copy()\n",
    "        cfg['dataset']['params'].update(cfg['dataset']['params2'])\n",
    "\n",
    "        cfg['dataset']['params']['data_dir'] = ''\n",
    "\n",
    "        cfg['trainer']['limit_train_batches'] = 1.0\n",
    "        cfg['trainer']['limit_val_batches'] = 1.0\n",
    "        cfg['trainer']['limit_test_batches'] = 1.0\n",
    "        cfg['trainer']['check_val_every_n_epoch'] = 10\n",
    "        cfg['dataset']['params'] = equations_type\n",
    "        cfg['dataset']['name'] = name\n",
    "        cfg['dataset']['params']['num_samples_tr'] = n\n",
    "        cfg['seed'] = seed\n",
    "\n",
    "    intervene_nodes = []\n",
    "    adj_edges = {}\n",
    "\n",
    "    for node in g.nodes:\n",
    "        if g.out_degree[node] > 0:\n",
    "            intervene_nodes.append(node)\n",
    "        adj_edges[node] = (list(g.neighbors(node)))\n",
    "\n",
    "    dataset_params = cfg['dataset']['params'].copy()\n",
    "    dataset_params['dataset_name'] = cfg['dataset']['name']\n",
    "\n",
    "    dataset_params['nodes_to_intervene'] = intervene_nodes\n",
    "    dataset_params['structural_eq'] = structural_equations\n",
    "    dataset_params['noises_distr'] = noise_distributions\n",
    "\n",
    "    dataset_params['adj_edges'] = adj_edges\n",
    "\n",
    "    data_module = MyToySCMDataModule(**dataset_params)\n",
    "    data_module.prepare_data(new_data = new_data)\n",
    "    return data_module,cfg\n",
    "\n",
    "def get_train_data(data_module):\n",
    "    orig_bs = data_module.batch_size\n",
    "    data_module.batch_size = 1\n",
    "    train = data_module.train_dataloader()\n",
    "    n_points = len(data_module.train_dataloader())\n",
    "    data_module.batch_size = n_points\n",
    "    train = data_module.train_dataloader()\n",
    "    batch = next(iter(train))\n",
    "    data_module.batch_size = orig_bs\n",
    "    return batch.x.view(n,-1)\n",
    "\n",
    "\n",
    "def create_vaca_model(cfg):\n",
    "\n",
    "    model = None\n",
    "    model_params = cfg['model']['params'].copy()\n",
    "\n",
    "    #print(f\"\\nUsing model: {cfg['model']['name']}\")\n",
    "\n",
    "\n",
    "    # VACA\n",
    "    if cfg['model']['name'] == Cte.VACA:\n",
    "        from models.vaca.vaca import VACA\n",
    "\n",
    "        model_params['is_heterogeneous'] = data_module.is_heterogeneous\n",
    "        model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "        model_params['deg'] = data_module.get_deg(indegree=True)\n",
    "        model_params['num_nodes'] = data_module.num_nodes\n",
    "        model_params['edge_dim'] = data_module.edge_dimension\n",
    "        model_params['scaler'] = data_module.scaler\n",
    "\n",
    "        model = VACA(**model_params)\n",
    "        model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "    # VACA with PIWAE\n",
    "    elif cfg['model']['name'] == Cte.VACA_PIWAE:\n",
    "        from models.vaca.vaca_piwae import VACA_PIWAE\n",
    "\n",
    "        model_params['is_heterogeneous'] = data_module.is_heterogeneous\n",
    "\n",
    "        model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "        model_params['deg'] = data_module.get_deg(indegree=True)\n",
    "        model_params['num_nodes'] = data_module.num_nodes\n",
    "        model_params['edge_dim'] = data_module.edge_dimension\n",
    "        model_params['scaler'] = data_module.scaler\n",
    "\n",
    "        model = VACA_PIWAE(**model_params)\n",
    "        model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "\n",
    "\n",
    "    # MultiCVAE\n",
    "    elif cfg['model']['name'] == Cte.MCVAE:\n",
    "        from models.multicvae.multicvae import MCVAE\n",
    "\n",
    "        model_params['likelihood_x'] = data_module.likelihood_list\n",
    "\n",
    "        model_params['topological_node_dims'] = data_module.train_dataset.get_node_columns_in_X()\n",
    "        model_params['topological_parents'] = data_module.topological_parents\n",
    "        model_params['scaler'] = data_module.scaler\n",
    "        model_params['num_epochs_per_nodes'] = int(\n",
    "            np.floor((cfg['trainer']['max_epochs'] / len(data_module.topological_nodes))))\n",
    "        model = MCVAE(**model_params)\n",
    "        model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "        cfg['early_stopping'] = False\n",
    "\n",
    "    # CAREFL\n",
    "    elif cfg['model']['name'] == Cte.CAREFL:\n",
    "        from models.carefl.carefl import CAREFL\n",
    "\n",
    "        model_params['node_per_dimension_list'] = data_module.train_dataset.node_per_dimension_list\n",
    "        model_params['scaler'] = data_module.scaler\n",
    "        model = CAREFL(**model_params)\n",
    "\n",
    "    model.set_optim_params(optim_params=cfg['optimizer'],\n",
    "                           sched_params=cfg['scheduler'])\n",
    "    return model\n",
    "\n",
    "def fit_vaca(model,cfg, data_module):\n",
    "\n",
    "\n",
    "    is_training = True\n",
    "    load = True\n",
    "\n",
    "    print(f'Is training activated? {is_training}')\n",
    "    print(f'Is loading activated? {load}')\n",
    "    if yaml_file == '':\n",
    "        if (cfg['dataset']['name'] in [Cte.GERMAN]) and (cfg['dataset']['params3']['train_kfold'] == True):\n",
    "            save_dir = argtools.mkdir(os.path.join(cfg['root_dir'],\n",
    "                                                   argtools.get_experiment_folder(cfg),\n",
    "                                                   str(cfg['seed']), str(cfg['dataset']['params3']['kfold_idx'])))\n",
    "        else:\n",
    "            save_dir = argtools.mkdir(os.path.join(cfg['root_dir'],\n",
    "                                                   argtools.get_experiment_folder(cfg),\n",
    "                                                   str(cfg['seed'])))\n",
    "    else:\n",
    "        save_dir = os.path.join(*yaml_file.split('/')[:-1])\n",
    "\n",
    "    logger = TensorBoardLogger(save_dir=save_dir, name='logs', default_hp_metric=False)\n",
    "\n",
    "    out = logger.log_hyperparams(argtools.flatten_cfg(cfg))\n",
    "\n",
    "    save_dir_ckpt = argtools.mkdir(os.path.join(save_dir, 'ckpt'))\n",
    "    if load:\n",
    "        ckpt_file = argtools.newest(save_dir_ckpt)\n",
    "    else:\n",
    "        ckpt_file = None\n",
    "    callbacks = []\n",
    "\n",
    "\n",
    "\n",
    "    evaluator = MyEvaluator(model=model,\n",
    "                            intervention_list=data_module.train_dataset.get_intervention_list(),\n",
    "                            scaler=data_module.scaler\n",
    "                            )\n",
    "    model.set_my_evaluator(evaluator=evaluator)\n",
    "\n",
    "    if is_training:\n",
    "        checkpoint = ModelCheckpoint(period=1,\n",
    "                                     monitor=model.monitor(),\n",
    "                                     mode=model.monitor_mode(),\n",
    "                                     save_top_k=1,\n",
    "                                     save_last=True,\n",
    "                                     filename='checkpoint-{epoch:02d}',\n",
    "                                     dirpath=save_dir_ckpt)\n",
    "        callbacks = [checkpoint]\n",
    "\n",
    "\n",
    "        if cfg['early_stopping']:\n",
    "            early_stopping = EarlyStopping(model.monitor(), mode=model.monitor_mode(), min_delta=0.0, patience=50)\n",
    "            callbacks.append(early_stopping)\n",
    "        trainer = pl.Trainer(logger=logger, callbacks=callbacks, **cfg['trainer'])\n",
    "\n",
    "    if load:\n",
    "        if ckpt_file is None:\n",
    "            print(f'No ckpt files in {save_dir_ckpt}')\n",
    "        else:\n",
    "            print(f'\\nLoading from: {ckpt_file}')\n",
    "            if is_training:\n",
    "                trainer = pl.Trainer(logger=logger, callbacks=callbacks, resume_from_checkpoint=ckpt_file,\n",
    "                                 **cfg['trainer'])\n",
    "            else:\n",
    "\n",
    "                model = model.load_from_checkpoint(ckpt_file, **model_params)\n",
    "                evaluator.set_model(model)\n",
    "                model.set_my_evaluator(evaluator=evaluator)\n",
    "\n",
    "                if cfg['model']['name'] in [Cte.VACA_PIWAE, Cte.VACA, Cte.MCVAE]:\n",
    "                    model.set_random_train_sampler(data_module.get_random_train_sampler())\n",
    "    path = os.path.join(save_dir,\"logs\")\n",
    "\n",
    "    if not os.path.exists(path):\n",
    "        os.makedirs(path)\n",
    "\n",
    "    if is_training:\n",
    "        trainer.fit(model, data_module)\n",
    "        # save_yaml(model.get_arguments(), file_path=os.path.join(save_dir, 'hparams_model.yaml'))\n",
    "        argtools.save_yaml(cfg, file_path=os.path.join(save_dir, 'hparams_full.yaml'))\n",
    "        # %% Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ffd2e10-4d64-4fd6-98da-74c855c580fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def vaca_cf(model, cfg, intervention, n_samples = 100):\n",
    "    n = cfg['dataset']['params']['num_samples_tr']\n",
    "    cfg['dataset']['params']['num_samples_tr'] = n_samples\n",
    "    data_module = create_data(cfg,new_data=True)\n",
    "\n",
    "    data_module.batch_size = 1\n",
    "    x_I = intervention  # Intervention before normalizing\n",
    "    data_loader = data_module.test_dataloader()\n",
    "    data_loader.dataset.set_intervention(x_I,is_noise=False)\n",
    "    data_loader = data_module.test_dataloader()\n",
    "\n",
    "    batch = next(iter(data_loader))\n",
    "    cfg['dataset']['params']['num_samples_tr'] = n\n",
    "    vaca_pred, gt_cf, factual = model.get_counterfactual_distr(data_loader,\n",
    "                                            x_I=x_I,\n",
    "                                            is_noise = False,\n",
    "                                            num_batches= 1,\n",
    "                                            normalize=False,\n",
    "                                            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c7601066-1abb-4ada-987e-3ffe8cc659e4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'intervened': tensor([[10.],\n",
       "         [10.]]),\n",
       " 'children': tensor([[9.2288],\n",
       "         [8.3016]]),\n",
       " 'all': tensor([[10.0000,  9.2288],\n",
       "         [10.0000,  8.3016]])}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cfg['dataset']['params']['num_samples_tr']=10\n",
    "data_module = create_data(cfg,new_data=True)\n",
    "bs = data_module.batch_size\n",
    "data_module.batch_size = 1\n",
    "x_I = {'x1': 10.0}  # Intervention before normalizing\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_loader.dataset.set_intervention(x_I,is_noise=False)\n",
    "data_loader = data_module.test_dataloader()\n",
    "data_module.batch_size = bs\n",
    "\n",
    "batch = next(iter(data_loader))\n",
    "\n",
    "vaca_pred, gt_cf, factual = model.get_counterfactual_distr(data_loader,\n",
    "                                        x_I=x_I,\n",
    "                                        is_noise = False,\n",
    "                                        num_batches= 1,\n",
    "                                        normalize=False,\n",
    "                                        )\n",
    "\n",
    "gt_cf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cec9deb1-a1f3-4226-9056-fb6970c8ca00",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'intervened': tensor([[10.],\n",
       "         [10.]]),\n",
       " 'children': tensor([[10.5870],\n",
       "         [10.4335]]),\n",
       " 'all': tensor([[10.0000, 10.5870],\n",
       "         [10.0000, 10.4335]])}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gt_cf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "504bd977-e54d-4138-bf8d-d27eb0fc953c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'all': tensor([[ 2.4700,  3.0570],\n",
       "         [-0.3974,  0.0362]])}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "factual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "ec34046b-d379-4eda-97a5-f1680ecaf9e0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "6cb9c3d7-e6c1-4f2b-809e-80f042810fb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "a52efaa2-926b-49ed-ba0c-4e68e05daed5",
   "metadata": {},
   "outputs": [],
   "source": [
    "temp = batch.x.view(-1,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "6c5ce012-070b-4ad4-8ddd-d0b02f4e2530",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.5491],\n",
       "        [-1.5926],\n",
       "        [-0.6784],\n",
       "        ...,\n",
       "        [-0.0165],\n",
       "        [ 0.7259],\n",
       "        [ 0.4806]])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch.x"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vaca",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
