{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from CITNP.datasets.dataset_generator import InterventionDatasetGenerator\n",
    "from CITNP.models.causalinferencemodel import MoGCausalInferenceModel, LocalLatentCausalInferenceModel\n",
    "import torch\n",
    "from functools import partial\n",
    "from CITNP.utils.datautils import transformer_inference_split_withpadding\n",
    "from CITNP.trainer.causalinferencetrainer import send_to_device\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from CITNP.utils.configs import CausalInfModelConfig\n",
    "\n",
    "\n",
    "model_path = Path(\n",
    "    \"CausalInferenceNeuralProcess/CITNP/experiments/results/DFalldata_neuralgplvm_1000_MTcnp_DM256_DFF1024_NH8_NL4_NC10_BS32_LR0.00050_WR0.02/best_model.pt\"\n",
    ")\n",
    "\n",
    "MODEL_ARGS = {\n",
    "    \"d_model\": 256,\n",
    "    \"emb_depth\": 2,\n",
    "    \"dim_feedforward\": 1024,\n",
    "    \"nhead\": 8,\n",
    "    \"dropout\": 0.0,\n",
    "    \"num_layers_encoder\": 4,\n",
    "    \"num_nodes\": 11,\n",
    "    \"device\": \"cuda\",\n",
    "    \"dtype\": \"float32\",\n",
    "    \"mean_loss_across_samples\": False,\n",
    "    \"num_mixture_components\": 10,\n",
    "    \"sample_attn_mode\": \"MHCA\",\n",
    "    \"linear_attention\": False,\n",
    "}\n",
    "\n",
    "config = CausalInfModelConfig(**MODEL_ARGS)\n",
    "\n",
    "model = MoGCausalInferenceModel(config=config)\n",
    "model.load_state_dict(torch.load(model_path)[\"model_state_dict\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 45%|████▌     | 5/11 [00:01<00:01,  4.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Skipping [5]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11/11 [00:02<00:00,  4.91it/s]\n",
      " 82%|████████▏ | 9/11 [00:01<00:00,  6.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Skipping [7]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11/11 [00:02<00:00,  5.35it/s]\n",
      " 18%|█▊        | 2/11 [00:00<00:01,  4.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Skipping [2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11/11 [00:02<00:00,  5.11it/s]\n",
      "  0%|          | 0/11 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Skipping [0]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11/11 [00:02<00:00,  5.16it/s]\n",
      " 45%|████▌     | 5/11 [00:00<00:00,  6.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Skipping [3]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11/11 [00:02<00:00,  5.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import trange\n",
    "from CITNP.utils.datautils import normalise_variable\n",
    "\n",
    "\n",
    "import h5py\n",
    "\n",
    "\n",
    "all_data_names = [\n",
    "    \"data_idx0.hdf5\",\n",
    "    \"data_idx1.hdf5\",\n",
    "    \"data_idx2.hdf5\",\n",
    "    \"data_idx3.hdf5\",\n",
    "    \"data_idx4.hdf5\"\n",
    "]\n",
    "\n",
    "all_losses = {}\n",
    "\n",
    "all_dataset_losses = np.zeros((5, 10))\n",
    "for data_idx, data_name in enumerate(all_data_names):\n",
    "\n",
    "    dataset_path = Path(f\"CausalInferenceNeuralProcess/CITNP/datasets/synth_training_data/sachs/test/{data_name}\")\n",
    "\n",
    "    with h5py.File(dataset_path, \"r\") as f:\n",
    "        obs_data = f[\"obs_data\"][:]\n",
    "        int_data = f[\"int_data\"][:]\n",
    "        intvn_indices = f[\"intvn_indices\"][:]\n",
    "\n",
    "    all_outcome_losses = []\n",
    "    outcome_counter = 0\n",
    "    for outcome_adder in trange(11):\n",
    "\n",
    "        obs = obs_data[:, :500]\n",
    "        intv = int_data[:, :500]\n",
    "\n",
    "        int_idx = intvn_indices\n",
    "        outcome_idx = np.zeros(1, dtype=int)\n",
    "        outcome_idx += outcome_adder\n",
    "\n",
    "        if outcome_idx.item() == int_idx.item():\n",
    "            print(f\"Skipping {outcome_idx}\")\n",
    "            continue\n",
    "\n",
    "        obs, mean_obs, std_obs = normalise_variable(obs, axis=1, return_stats=True)\n",
    "\n",
    "        obs_outcome_mean, obs_outcome_std = mean_obs[0, 0, outcome_idx[0]], std_obs[0, 0, outcome_idx[0]]\n",
    "        obs_intvn_mean, obs_intvn_std = mean_obs[0, 0, int_idx[0]], std_obs[0, 0, int_idx[0]]\n",
    "\n",
    "        interventionidx = int_idx.squeeze()\n",
    "        # int_node = intv[:, interventionidx]\n",
    "        # intv[: , interventionidx] = (int_node - obs_intvn_mean) / obs_intvn_std\n",
    "        # Set all other variables to 0 apart from the intervention index and outcome index\n",
    "        # for var in range(obs.shape[2]):\n",
    "        #     if var != int_idx[0] and var != outcome_idx[0]:\n",
    "        #         intv[:, :, var] = 0\n",
    "\n",
    "        obs = torch.from_numpy(obs).float().unsqueeze(-1)\n",
    "        intv = torch.from_numpy(intv).float().unsqueeze(-1)\n",
    "        int_idx = torch.from_numpy(int_idx).long()\n",
    "        outcome_idx = torch.from_numpy(outcome_idx).long()\n",
    "\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            output = model(\n",
    "                context_data=obs,\n",
    "                target_data=intv,\n",
    "                intervention_index=int_idx,\n",
    "                outcome_index=outcome_idx,\n",
    "            )\n",
    "            output.pred_mean += obs_outcome_mean\n",
    "            output.pred_std *= obs_outcome_std\n",
    "            loss = model.calculate_loss(model_output=output, target=intv, outcome_index=outcome_idx, test=True)\n",
    "        all_dataset_losses[data_idx, outcome_counter] = loss.item()\n",
    "        outcome_counter += 1\n",
    "\n",
    "print(all_losses)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 554.34637451,  422.87591553,  652.53070068,  963.92877197,\n",
       "         664.10968018,  906.12054443,  579.11120605,  583.60083008,\n",
       "         711.69372559,  842.14434814],\n",
       "       [1760.53344727, 2755.8503418 , 3143.1809082 , 1937.08605957,\n",
       "         933.23120117,  781.7635498 , 1148.20239258, 2486.54638672,\n",
       "        3132.38452148, 2193.52734375],\n",
       "       [ 515.57415771,  394.35931396, 3076.33935547,  740.40026855,\n",
       "         648.7097168 ,  671.28662109,  604.38354492,  606.4442749 ,\n",
       "         483.83950806,  689.72290039],\n",
       "       [ 971.4821167 ,  488.81463623,  700.79888916,  665.78570557,\n",
       "        1383.17785645,  771.81231689, 1206.00268555, 1024.95654297,\n",
       "         441.65441895,  900.81243896],\n",
       "       [ 556.42047119,  445.38308716,  653.65509033,  934.12609863,\n",
       "         588.54229736,  474.74893188,  498.15969849,  572.05792236,\n",
       "         416.59762573,  665.28955078]])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_dataset_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 554.34637451,  422.87591553,  652.53070068,  963.92877197,\n",
       "         664.10968018,  906.12054443,  579.11120605,  583.60083008,\n",
       "         711.69372559,  842.14434814],\n",
       "       [1760.53344727, 2755.8503418 , 3143.1809082 , 1937.08605957,\n",
       "         933.23120117,  781.7635498 , 1148.20239258, 2486.54638672,\n",
       "        3132.38452148, 2193.52734375],\n",
       "       [ 515.57415771,  394.35931396, 3076.33935547,  740.40026855,\n",
       "         648.7097168 ,  671.28662109,  604.38354492,  606.4442749 ,\n",
       "         483.83950806,  689.72290039],\n",
       "       [ 971.4821167 ,  488.81463623,  700.79888916,  665.78570557,\n",
       "        1383.17785645,  771.81231689, 1206.00268555, 1024.95654297,\n",
       "         441.65441895,  900.81243896],\n",
       "       [ 556.42047119,  445.38308716,  653.65509033,  934.12609863,\n",
       "         588.54229736,  474.74893188,  498.15969849,  572.05792236,\n",
       "         416.59762573,  665.28955078]])"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_dataset_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(998.8821258544922)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_dataset_losses.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(104.87282410971045)"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_dataset_losses.std() / (np.sqrt(np.prod(all_dataset_losses.shape[:])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "deci_results_path = Path(\n",
    "    \"CausalInferenceNeuralProcess/baselines/score_gp/results/sachs\"\n",
    ")\n",
    "all_files = os.listdir(deci_results_path)\n",
    "\n",
    "deci_all_data_losses = np.zeros((5, 10))\n",
    "\n",
    "for file_idx, file in enumerate(all_files):\n",
    "    res = pd.read_csv(deci_results_path / f\"{file}\")\n",
    "    deci_all_data_losses[file_idx] = res.mean()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "meta_causal_inf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
