{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "203faf65",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "\n",
    "\n",
    "def get_results_multiple_file(base_file_name, result_path, num_files):\n",
    "    all_nlpd_values = np.zeros(num_files)\n",
    "    for i in range(num_files):\n",
    "        filename = f\"{base_file_name}{i}.csv\"\n",
    "        file_path = result_path / filename\n",
    "\n",
    "        if file_path.is_file():\n",
    "            df = pd.read_csv(file_path)\n",
    "            all_nlpd_values[i] = df[\"NLL\"].values[-1]\n",
    "        else:\n",
    "            print(f\"Warning: File {result_path / filename} not found. Skipping.\")\n",
    "            raise FileNotFoundError\n",
    "    return all_nlpd_values\n",
    "\n",
    "\n",
    "def get_all_possible_results_multiple_file(base_file_name, result_path, num_files):\n",
    "    all_nlpd_values = []\n",
    "    for i in range(num_files):\n",
    "        filename = f\"{base_file_name}{i}.csv\"\n",
    "        file_path = result_path / filename\n",
    "\n",
    "        if file_path.is_file():\n",
    "            df = pd.read_csv(file_path)\n",
    "            all_nlpd_values.append(df[\"NLL\"].values[-1])\n",
    "        else:\n",
    "            print(f\"Warning: File {result_path / filename} not found. Skipping.\")\n",
    "            continue\n",
    "    return all_nlpd_values"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3554319",
   "metadata": {},
   "source": [
    "# ARCO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cac1848",
   "metadata": {},
   "outputs": [],
   "source": [
    "result_path = Path(\n",
    "   \"CausalInferenceNeuralProcess/baselines/arco-dibs-gp/arco_results/20var_ER4_neuralgplvm_1000\"\n",
    ")\n",
    "base_filename_prefix = \"20var_ER4_neuralgplvm_1000_arco_nlpd_data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "659ffeb0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "56416855",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_nlpd_values = get_results_multiple_file(\n",
    "    base_file_name=base_filename_prefix,\n",
    "    result_path=result_path,\n",
    "    num_files=100\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "926c1e5c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([900.4196167 , 708.49072266, 641.77844238, 701.66888428,\n",
       "       690.51257324, 702.22192383, 718.3203125 , 657.3885498 ,\n",
       "       715.38934326, 646.6463623 , 672.58563232, 669.3503418 ,\n",
       "       707.809021  , 712.92895508, 628.16101074, 687.25421143,\n",
       "       665.19744873, 820.10180664, 702.2479248 , 725.4440918 ,\n",
       "       676.45153809, 666.15240479, 679.71905518, 668.16687012,\n",
       "       708.86254883, 705.7769165 , 706.54119873, 696.47753906,\n",
       "       663.91680908, 714.14160156, 749.46392822, 707.21014404,\n",
       "       696.84710693, 680.33099365, 725.87640381, 707.45916748,\n",
       "       711.12963867, 690.36743164, 696.22143555, 638.76751709,\n",
       "       745.19250488, 710.04833984, 693.51281738, 654.46032715,\n",
       "       716.7713623 , 721.94219971, 720.19665527, 711.25866699,\n",
       "       721.09179688, 710.42932129, 723.61773682, 683.16137695,\n",
       "       663.57269287, 685.62091064, 737.72296143, 708.70330811,\n",
       "       721.77026367, 715.60516357, 613.77905273, 717.16247559,\n",
       "       638.14294434, 685.51965332, 625.63201904, 669.5536499 ,\n",
       "       692.90545654, 659.07745361, 738.21478271, 710.87255859,\n",
       "       726.4354248 , 745.49822998, 710.79418945, 707.65252686,\n",
       "       664.09594727, 696.58612061, 704.07098389, 730.82299805,\n",
       "       723.93444824, 694.99658203, 675.66625977, 656.63238525,\n",
       "       714.50915527, 697.65771484, 666.81433105, 756.07354736,\n",
       "       648.31115723, 754.47186279, 715.22607422, 639.19750977,\n",
       "       704.43731689, 653.94873047, 715.45373535, 722.97399902,\n",
       "       757.71105957, 774.4621582 , 714.18725586, 705.49530029,\n",
       "       746.40698242, 698.17803955, 730.07226562, 767.07788086])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_nlpd_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7b3aa6db",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean: 701.7918804931641\n",
      "Std: 3.983464928557969\n"
     ]
    }
   ],
   "source": [
    "mean = all_nlpd_values.mean(axis=0)\n",
    "std = all_nlpd_values.std(axis=0)\n",
    "print(f\"Mean: {mean}\")\n",
    "print(f\"Std: {std / (len(all_nlpd_values) ** 0.5)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4aa1878b",
   "metadata": {},
   "source": [
    "# DIBS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a99e886",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "result_path = Path(\n",
    "   \"CausalInferenceNeuralProcess/baselines/arco-dibs-gp/dibs_results/20var_ER4_neuralgplvm_1000\"\n",
    ")\n",
    "base_filename_prefix = \"20var_ER4_neuralgplvm_1000_dibs_nlpd_data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fe894c75",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_nlpd_values = get_results_multiple_file(\n",
    "    base_file_name=base_filename_prefix,\n",
    "    result_path=result_path,\n",
    "    num_files=100\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a3caba01",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean: 701.8740936279297\n",
      "Std: 3.99437536236343\n"
     ]
    }
   ],
   "source": [
    "mean = all_nlpd_values.mean(axis=0)\n",
    "std = all_nlpd_values.std(axis=0)\n",
    "print(f\"Mean: {mean}\")\n",
    "print(f\"Std: {std / (len(all_nlpd_values) ** 0.5)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2205eef2",
   "metadata": {},
   "source": [
    "# DECI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b106b88",
   "metadata": {},
   "outputs": [],
   "source": [
    "result_path = Path(\n",
    "   \"CausalInferenceNeuralProcess/baselines/deci/results/20var_ER4_neuralgplvm_1000\"\n",
    ")\n",
    "base_filename_prefix = \"20var_ER4_neuralgplvm_1000_deci_nlpd_data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "19b18b1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_nlpd_values = get_all_possible_results_multiple_file(\n",
    "    base_file_name=base_filename_prefix,\n",
    "    result_path=result_path,\n",
    "    num_files=100\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ee872992",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[np.float64(815.888284642452), np.float64(675.390988599457), np.float64(804.3301091746721), np.float64(715.8373967239331), np.float64(709.4041696838099), np.float64(729.1030544715059), np.float64(739.5035627703975), np.float64(608.6341641685368), np.float64(742.486322519575), np.float64(624.0593801930588), np.float64(683.1152136448533), np.float64(675.2560106789913), np.float64(677.800251236721), np.float64(748.0394605329881), np.float64(628.822768739862), np.float64(656.3074235911799), np.float64(605.7606044244154), np.float64(683.7067998302689), np.float64(722.8354688151767), np.float64(721.2066180767297), np.float64(702.6709007537985), np.float64(695.5151479238808), np.float64(724.6249751078548), np.float64(655.4753830209537), np.float64(704.9540321464217), np.float64(757.209334779983), np.float64(644.7029977971802), np.float64(720.7598055414468), np.float64(596.8726171297924), np.float64(745.8987815189209), np.float64(810.2505235951162), np.float64(631.9052348462262), np.float64(709.9489253609505), np.float64(625.9930564129793), np.float64(734.4331055847604), np.float64(751.2265954606131), np.float64(759.9268887236319), np.float64(709.2385230670499), np.float64(707.2832580009957), np.float64(614.3951696112714), np.float64(665.7035229110886), np.float64(739.4605702410502), np.float64(563.1898758709281), np.float64(595.5381067581363), np.float64(740.5276366318856), np.float64(666.9552830328992), np.float64(911.4900530206406), np.float64(730.7622857184626), np.float64(742.4514825352283), np.float64(631.5453753567903), np.float64(623.2174049430803), np.float64(724.9268781980578), np.float64(595.0354494840313), np.float64(569.0896284274422), np.float64(650.5420746759821), np.float64(738.6440705683818), np.float64(725.7142484512187), np.float64(664.2687083324914), np.float64(558.5600963312456), np.float64(739.3402686186322), np.float64(551.6865178229622), np.float64(669.7969061182444), np.float64(578.4414961397218), np.float64(708.8082112961317), np.float64(710.069517439982), np.float64(602.5216761790869), np.float64(754.705439975155), np.float64(685.7795071417328), np.float64(622.5789117227374), np.float64(659.9634914566724), np.float64(705.4585961461448), np.float64(749.2921847264957), np.float64(688.2775559756913), np.float64(733.3406189656105), np.float64(655.5364928872086), np.float64(736.4852706017646), np.float64(768.9310680267834), np.float64(701.113837887622), np.float64(700.5060132074204), np.float64(653.7796755387052), np.float64(686.5406528131709), np.float64(620.6569138287433), np.float64(559.5583830626387), np.float64(739.1381845034806), np.float64(574.1107771020676), np.float64(802.4984025151952), np.float64(710.8712348852997), np.float64(595.7304916996386), np.float64(737.7921911248122), np.float64(587.0532822565059), np.float64(763.2719942732161), np.float64(739.3146513328136), np.float64(608.5794540070573), np.float64(709.9615007583154), np.float64(749.8956332115836), np.float64(743.5079829412181), np.float64(644.1549537501263), np.float64(564.3075354850198), np.float64(664.9191473692373), np.float64(643.8102392666497)]\n",
      "100\n"
     ]
    }
   ],
   "source": [
    "print(all_nlpd_values)\n",
    "print(len(all_nlpd_values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ebf7bbf1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean: 686.3048292445274\n",
      "Std: 6.708186822719017\n"
     ]
    }
   ],
   "source": [
    "mean = np.mean(all_nlpd_values)\n",
    "std = np.std(all_nlpd_values)\n",
    "print(f\"Mean: {mean}\")\n",
    "print(f\"Std: {std / (len(all_nlpd_values) ** 0.5)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "4fd577a1",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'list' object has no attribute 'mean'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[18], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m mean \u001b[38;5;241m=\u001b[39m \u001b[43mall_nlpd_values\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m(axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m      2\u001b[0m std \u001b[38;5;241m=\u001b[39m all_nlpd_values\u001b[38;5;241m.\u001b[39mstd(axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMean: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmean\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'mean'"
     ]
    }
   ],
   "source": [
    "mean = all_nlpd_values.mean(axis=0)\n",
    "std = all_nlpd_values.std(axis=0)\n",
    "print(f\"Mean: {mean}\")\n",
    "print(f\"Std: {std / (len(all_nlpd_values) ** 0.5)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95a232b4",
   "metadata": {},
   "source": [
    "# NOGAM+GP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fe20a2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "result_path = Path(\n",
    "    \"CausalInferenceNeuralProcess/baselines/score_gp/results/20var_ER4_neuralgplvm_1000\"\n",
    ")\n",
    "base_filename_prefix = \"20var_ER4_neuralgplvm_1000_nogamgp_nlpd_data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8e559628",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_nlpd_values = get_results_multiple_file(\n",
    "    base_file_name=base_filename_prefix,\n",
    "    result_path=result_path,\n",
    "    num_files=100\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bc5142a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean: 942.7254290771484\n",
      "Std: 23.787232614580596\n"
     ]
    }
   ],
   "source": [
    "mean = all_nlpd_values.mean(axis=0)\n",
    "std = all_nlpd_values.std(axis=0)\n",
    "print(f\"Mean: {mean}\")\n",
    "print(f\"Std: {std / (len(all_nlpd_values) ** 0.5)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6842cd1d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "meta_causal_inf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
