{
 "cells": [
  {
   "cell_type": "code",
   "id": "b0949a58-1135-4b21-9068-ee725ae55a96",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:06:01.861235Z",
     "start_time": "2024-10-20T15:06:00.467283Z"
    }
   },
   "source": [
    "import networkx as nx\n",
    "import numpy as np\n",
    "import torch\n",
    "from model import CaT\n",
    "from CaT.datasets import reorder_dag, get_full_ordering\n",
    "from utils.inference import CausalInference\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "seed = 1\n",
    "standardize = 0\n",
    "sample_size = 50000\n",
    "batch_size = 100\n",
    "max_iters = 6000\n",
    "eval_interval = 200\n",
    "eval_iters = 100\n",
    "validation_fraction = 0.3\n",
    "np.random.seed(seed=seed)\n",
    "torch.manual_seed(seed)\n",
    "device = 'cuda'\n",
    "dropout_rate = 0.0\n",
    "learning_rate = 5e-3\n",
    "ff_n_embed = 6\n",
    "num_heads = 2\n",
    "n_layers = 2\n",
    "embed_dim = 5\n",
    "head_size = 6\n",
    "d = 5\n",
    "\n",
    "\n",
    "def generate_data_mediation(N, d=5):\n",
    "    DAGnx = nx.DiGraph()\n",
    "    \n",
    "    Ux = np.random.randn(N,d)\n",
    "    X =  Ux\n",
    "\n",
    "    Um = np.random.randn(N,d)\n",
    "    M =  0.2 * X + Um\n",
    "\n",
    "    Uy = np.random.randn(N,d)\n",
    "    Y =  0.7 * M + 0.1 * Uy\n",
    "\n",
    "    M0 = 0.2 * 0 + Um \n",
    "    M1 = 0.2 * 1 + Um\n",
    "\n",
    "    Y0 = 0.7 * M0 + 0.1 * Uy \n",
    "    Y1 = 0.7 * M1 + 0.1 * Uy   # total effect   = 0.7 * 0.2 = 0.14\n",
    "\n",
    "    all_data_dict = {'X': X, 'M': M, 'Y': Y}\n",
    "\n",
    "    # types can be 'cat' (categorical) 'cont' (continuous) or 'bin' (binary)\n",
    "    var_types = {'X': 'cont', 'M': 'cont', 'Y': 'cont'}\n",
    "\n",
    "    DAGnx.add_edges_from([('X', 'M'), ('M', 'Y')])\n",
    "    DAGnx = reorder_dag(dag=DAGnx)  # topologically sorted dag\n",
    "    var_names = list(DAGnx.nodes())  # topologically ordered list of variables\n",
    "    all_data = np.stack([all_data_dict[key] for key in var_names], axis=1)\n",
    "    causal_ordering = get_full_ordering(DAGnx)\n",
    "\n",
    "    return all_data, DAGnx, var_names, causal_ordering, var_types, Y0, Y1"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "code",
   "id": "5e1c09c9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:06:02.229801Z",
     "start_time": "2024-10-20T15:06:01.955246Z"
    }
   },
   "source": [
    "d=3\n",
    "_, _, _, _, _, Y0, Y1 = generate_data_mediation(N=1000000, d=d)\n",
    "ATE = (Y1 - Y0).mean(0)  # multi-dim ATE based off a large sample\n",
    "all_data, DAGnx, var_names, causal_ordering, var_types, Y0, Y1 = generate_data_mediation(N=sample_size, d=d)\n",
    "print(var_names, ATE)\n",
    "print(all_data.shape)\n",
    "\n",
    "indices = np.arange(0, len(all_data))\n",
    "np.random.shuffle(indices)\n",
    "\n",
    "val_inds = indices[:int(validation_fraction*len(indices))]\n",
    "train_inds = indices[int(validation_fraction*len(indices)):]\n",
    "train_data = all_data[train_inds]\n",
    "val_data = all_data[val_inds]\n",
    "train_data, val_data = torch.from_numpy(train_data).float(),  torch.from_numpy(val_data).float()"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['X', 'M', 'Y'] [0.14 0.14 0.14]\n",
      "(50000, 3, 3)\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "id": "63c1b6ba",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:06:04.260639Z",
     "start_time": "2024-10-20T15:06:03.546256Z"
    }
   },
   "source": [
    "input_dim = all_data.shape[2]\n",
    "\n",
    "model = CaT(input_dim=input_dim,\n",
    "                    dropout_rate=dropout_rate,\n",
    "                    head_size=head_size,\n",
    "                    num_heads=num_heads,\n",
    "                    ff_n_embed=ff_n_embed,\n",
    "                    embed_dim= embed_dim,\n",
    "                    dag=DAGnx,\n",
    "                    causal_ordering=causal_ordering,\n",
    "                    n_layers=n_layers,\n",
    "                    device=device,\n",
    "                    var_types=var_types, activation_function='Swish'\n",
    "                    ).to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
   ],
   "outputs": [],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "id": "d32adc12",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:07:15.919883Z",
     "start_time": "2024-10-20T15:06:09.186520Z"
    }
   },
   "source": [
    "def get_batch(train_data, val_data, split, device, batch_size):\n",
    "    data = train_data if split == 'train' else val_data\n",
    "    ix = torch.randint(0, len(data), (batch_size,))\n",
    "    x = data[ix]\n",
    "    return x.to(device)\n",
    "\n",
    "all_var_losses = {}\n",
    "for iter_ in range(0, max_iters):\n",
    "    # train and update the model\n",
    "    model.train()\n",
    "\n",
    "    xb = get_batch(train_data=train_data, val_data=val_data, split='train', device=device, batch_size=batch_size)\n",
    "    xb_mod = torch.clone(xb.detach())\n",
    "    X, loss, loss_dict = model(X=xb, targets=xb_mod)\n",
    "\n",
    "    optimizer.zero_grad(set_to_none=True)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "\n",
    "    if iter_ % eval_interval == 0:  # evaluate the loss (no gradients)\n",
    "        for key in loss_dict.keys():\n",
    "            if key not in all_var_losses.keys():\n",
    "                all_var_losses[key] = []\n",
    "            all_var_losses[key].append(loss_dict[key])\n",
    "\n",
    "        model.eval()\n",
    "        eval_loss = {}\n",
    "        for split in ['train', 'val']:\n",
    "            losses = torch.zeros(eval_iters)\n",
    "            for k in range(eval_iters):\n",
    "\n",
    "                xb = get_batch(train_data=train_data, val_data=val_data, split=split, device=device,\n",
    "                               batch_size=batch_size)\n",
    "                xb_mod = torch.clone(xb.detach())\n",
    "                X, loss, loss_dict = model(X=xb, targets=xb_mod)\n",
    "                losses[k] = loss.item()\n",
    "            eval_loss[split] = losses.mean()\n",
    "        model.train()\n",
    "        print(f\"step {iter_} of {max_iters}: train_loss {eval_loss['train']:.4f}, val loss {eval_loss['val']:.4f}\")\n",
    "    "
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0 of 6000: train_loss 4.1337, val loss 4.1061\n",
      "step 200 of 6000: train_loss 1.0236, val loss 0.9932\n",
      "step 400 of 6000: train_loss 1.0141, val loss 0.9928\n",
      "step 600 of 6000: train_loss 1.0146, val loss 0.9888\n",
      "step 800 of 6000: train_loss 1.0202, val loss 0.9930\n",
      "step 1000 of 6000: train_loss 1.0089, val loss 0.9970\n",
      "step 1200 of 6000: train_loss 1.0301, val loss 0.9861\n",
      "step 1400 of 6000: train_loss 1.0110, val loss 0.9848\n",
      "step 1600 of 6000: train_loss 1.0164, val loss 0.9980\n",
      "step 1800 of 6000: train_loss 1.0131, val loss 0.9892\n",
      "step 2000 of 6000: train_loss 1.0106, val loss 0.9999\n",
      "step 2200 of 6000: train_loss 1.0086, val loss 1.0064\n",
      "step 2400 of 6000: train_loss 1.0161, val loss 0.9968\n",
      "step 2600 of 6000: train_loss 1.0180, val loss 0.9868\n",
      "step 2800 of 6000: train_loss 1.0250, val loss 1.0077\n",
      "step 3000 of 6000: train_loss 1.0170, val loss 0.9947\n",
      "step 3200 of 6000: train_loss 1.0192, val loss 1.0000\n",
      "step 3400 of 6000: train_loss 1.0028, val loss 0.9935\n",
      "step 3600 of 6000: train_loss 1.0101, val loss 1.0053\n",
      "step 3800 of 6000: train_loss 1.0121, val loss 0.9924\n",
      "step 4000 of 6000: train_loss 1.0138, val loss 0.9978\n",
      "step 4200 of 6000: train_loss 1.0141, val loss 0.9894\n",
      "step 4400 of 6000: train_loss 1.0069, val loss 0.9957\n",
      "step 4600 of 6000: train_loss 1.0136, val loss 0.9962\n",
      "step 4800 of 6000: train_loss 1.0065, val loss 0.9896\n",
      "step 5000 of 6000: train_loss 1.0087, val loss 0.9968\n",
      "step 5200 of 6000: train_loss 1.0231, val loss 0.9890\n",
      "step 5400 of 6000: train_loss 1.0100, val loss 0.9924\n",
      "step 5600 of 6000: train_loss 1.0179, val loss 0.9991\n",
      "step 5800 of 6000: train_loss 1.0115, val loss 1.0019\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "code",
   "id": "62f2ec5e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:07:16.166604Z",
     "start_time": "2024-10-20T15:07:15.920971Z"
    }
   },
   "source": [
    "model.eval()\n",
    "inf = CausalInference(dag=DAGnx)\n",
    "\n",
    "int_nodes_vals0 = {'X':np.array([0.0,])}\n",
    "int_nodes_vals1 = {'X':np.array([1.0,])}\n",
    "effect_var = 'Y'\n",
    "effect_index = var_names.index(effect_var)\n",
    "\n",
    "preds0 = inf.forward(all_data, model=model, intervention_nodes_vals=int_nodes_vals0)\n",
    "preds1 = inf.forward(all_data, model=model, intervention_nodes_vals=int_nodes_vals1)\n",
    "ATE_pred = (preds1[:,effect_index,:] - preds0[:,effect_index,:]).mean(0)\n",
    "eATE = np.abs(ATE_pred - ATE)\n",
    "print('ATE:', ATE, 'est ATE:', ATE_pred, 'error:', eATE)\n",
    "\n",
    "preds = model(train_data.to(device)).detach().cpu().numpy()\n",
    "\n",
    "plt.scatter(train_data[:,effect_index,-1].detach().cpu().numpy(), preds[:, effect_index, -1])"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ATE: [0.14 0.14 0.14] est ATE: [0.13401228 0.15703369 0.13112091] error: [0.00598772 0.01703369 0.00887909]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matthewvowels/GitHub/Causal_Transformer/utils/inference.py:18: UserWarning: No mask has been specified. If padding has been used, the absence of a mask may lead to incorrect results.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7fe88924d9d0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ],
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAGdCAYAAAAvwBgXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAymklEQVR4nO3dfXSU9Z3//9ckJBMIyUC4SygBQsDaiGChRgPVcme9W9S2enZrVbCuqwju19LjEbq1Mac/N9p6Tv2uUkq/urD7Q2p7VERajCIqfmvBVCKLMWoLBkEyMUBkJgYzwZnr+wdOTMh9Mtd8rpnr+ThnzmEmV3K9GcF58bl5fzyWZVkCAAAwIMV0AQAAwL0IIgAAwBiCCAAAMIYgAgAAjCGIAAAAYwgiAADAGIIIAAAwhiACAACMGWK6gJ5EIhHV1dUpKytLHo/HdDkAAKAPLMtSU1OTxo8fr5SUnsc8HB1E6urqlJ+fb7oMAAAwAIcPH9aECRN6vMbRQSQrK0vS6d9Idna24WoAAEBfBINB5efnt32O98TRQSQ6HZOdnU0QAQAgwfRlWQWLVQEAgDEEEQAAYAxBBAAAGEMQAQAAxhBEAACAMQQRAABgDEEEAAAYQxABAADGOLqhGQAAsEc4YqmytlENTS0am5Wh4oIcpabE/1w3gggAAC5TUe1X2dYa+QMtba/l+TJUurhIl03Pi2stTM0AAOAiFdV+LdtY1SGESFJ9oEXLNlapotof13oIIgAAuEQ4Yqlsa42sLr4Wfa1sa43Cka6usAdBBAAAl6isbew0EtKeJckfaFFlbWPcaiKIAADgEg1N3YeQgVwXC7YGkbVr12rGjBnKzs5Wdna2SkpK9Pzzz9t5SwAA0I2xWRkxvS4WbA0iEyZM0AMPPKA9e/bozTff1IIFC3T11VfrnXfesfO2AACgC8UFOcrzZai7Tboend49U1yQE7eabA0iixcv1hVXXKFp06bprLPO0v3336/hw4dr9+7ddt4WAAB0ITXFo9LFRZLUKYxEn5cuLoprP5G4rREJh8N68skn1dzcrJKSki6vCYVCCgaDHR4AACB2Lpuep7U3zFKur+P0S64vQ2tvmBX3PiK2NzR7++23VVJSopaWFg0fPlybN29WUVFRl9eWl5errKzM7pIAAHC1y6bn6ZKiXEd0VvVYlmXrZuHW1lYdOnRIgUBATz31lB577DHt3LmzyzASCoUUCoXangeDQeXn5ysQCCg7O9vOMgEAQIwEg0H5fL4+fX7bHkTOtGjRIhUWFmrdunW9Xtuf3wgAAHCG/nx+x72PSCQS6TDqAQAA3MvWNSKrV6/W5ZdfrokTJ6qpqUmbNm3Sq6++qhdeeMHO2wIAgARhaxBpaGjQTTfdJL/fL5/PpxkzZuiFF17QJZdcYudtAQBAgrA1iDz++ON2/ngAAJDgbN++CwCAm4UjliO2yToVQQQAAJtUVPtVtrWmw4m3eb4MlS4uinvjMKfi9F0AAGxQUe3Xso1VHUKIJNUHWrRsY5Uqqv2GKnMWgggAADEWjlgq21qjrhp1RV8r21qjcCSurbwciakZAABirLK2sdNISHuWJH+gRbs/OK4Uj8fV60cIIgAAxFhDU/chpL3lT1TpxGen2p67cf0IUzMAAMTY2KyM3i+SOoQQqe/rR8IRS7sOHNeWvUe068DxhJ7iYUQEAIAYKy7IUZ4vQ/WBli7XiXTHkuTR6fUjlxTldjlNk2w7cRgRAQAgxlJTPCpdXNSvEBIVXT9SWdvY6WvJuBOHIAIAgA0uKcrViGFpA/7++sBnHaZfWj+PJOVOHKZmAACwQWVto06cPNX7hd34+Z/eVWNza9vznMz0Ds/P1H4kpaRw1IDvG2+MiAAAYIO+7pzpzpmho6cQEsv7xhtBBAAAG/R150yy3HegCCIAANggunOmP+3JMr2pA76fR6d3zxQX5Az4Z5hAEAEAwAapKR5dNTOvXztn0lIG9rEcDTuli4sSrjMrQQQAABts21enda/V9ut7zmxw1p2czI67cXJ9GVp7w6yE7CPCrhkAAPohHLFUWdvY4/kw2/b5teJ3bw3o548YmqbAZ6e6HEnx6HTo2Hn3fO358JOkOKOGIAIAQB/1patpRbVfd2yqGvA9bp5boIdf+ps8Uocw0n76JX1ISkJt0e0JUzMAAHyhpzNc+tLVNByxVLa1ZkD3ji42XbFgqtbeMEu5vo67XxJ5+qUnjIgAAJJWX6ZRonoa7bikKLfHrqbR82E+PN7cKaj0xZmLTS+bnqdLinL7XHsi81iW5dhesMFgUD6fT4FAQNnZ2abLAQAkkP4cDhcd7TjzAzH6sX/Xomn61Ut/t63WRD60riv9+fwmiAAAkk5vwaL9FEc4YumbD77c7UiGR5JvaFqfd7T01Y8WTdPk0ZlJOdrRn89vpmYAAEkluk6jt2mUBWeP054PP9Hr+4/2OJ1iqe/bavsixSM9+v2v64oZ42P2MxMZQQQAkFQqaxt7DRb+QIsuLN/R5/NbYunR78/SFTOSYwomFggiAICk0tdD3+IdQkYMTdMD3zs3adaBxApBBACQVJx66NuaH8zS3KmjTZfhOPQRAQAkldmTRnZqgW5ani9DF05JjgZksUYQAQAkjYpqv771y1fU2BzbHS6DlYiH0cULUzMAgIQXjlh69OW/29rrYyBGDktT+XdZF9ITgggAIKFVVPt133M1qg/2v6OpXYZ7h2jtD2ZpztTRjIT0giACAEhY3TUuM+3T0OcakppCCOkD1ogAABJST43LnKCv24jdjiACAEhIvTUuM82p24idhqkZAEBCcuqIg0dSru/0+THoHSMiAICEVPG233QJnURXhLBdt+8IIgCAhHPrf/9Vz7/zsdEaRgxLU252x+mXXF9Gh5N90TumZgAACSMcsfS/X/qbttc0mC5FJ06e0hO3zFJKikcNTS0am3V6OoaRkP4hiAAAEsLpfiHvqD4YMl1Km2PNIV193ldMl5HQCCIAAMdzar+Q0cO9pktIeKwRAQA4mqP7hTiyqMTCiAgAwHHCEUu7PziuXQeO66NPTjq2X8ixZudMEyUqgggAwFEqqv1a9czbOnHSWSfodoWmZYNn69RMeXm5zj//fGVlZWns2LG65ppr9P7779t5SwBAAquo9uv2jVWODyEeSXk0LYsJW4PIzp07tXz5cu3evVvbt2/XqVOn9O1vf1vNzc123hYAkIDCEUv3PVdjuoxe0bQstmydmqmoqOjwfMOGDRo7dqz27Nmjiy++2M5bAwASQDhiqbK2UQ1NLWoIhlQfdOZakPZyfRkqXVxE07IYiesakUAgIEnKyel6KCsUCikU+nLhTzAYjEtdAID4O90XpMbR4cM7JEV7f/Zt7T18gqZlNolbEIlEIrrrrrs0d+5cTZ8+vctrysvLVVZWFq+SAACGRNeCON0d8wo1ND1VJYWjTJeStDyWZcVlF/SyZcv0/PPP689//rMmTJjQ5TVdjYjk5+crEAgoOzs7HmUCAGzQfjuuJUvrX6/VydaI6bJ6NHJYmt786SWMfgxAMBiUz+fr0+d3XEZEVqxYoT/+8Y967bXXug0hkuT1euX10qUOAJJJIm3HjfJIKv/uuYSQOLA1iFiWpTvvvFObN2/Wq6++qoKCAjtvBwBwkHDE0qMv79evXvqb6VL6JY/FqHFlaxBZvny5Nm3apC1btigrK0v19fWSJJ/Pp6FDh9p5awCAQU48oK4nK+YXatq4LBajGmDrGhGPp+v/kOvXr9fSpUt7/f7+zDEBAMyKbsXdXlOv/3z9oOly+uV3t17IgtQYcswakTitgwUAGLZtX53+7dlqfZJA60Ck02tBcumQahRnzQAABqV8W43WvVZruox+o0OqMxBEAAADtm2fPyFDiESHVKcgiAAA+qR9O/axWRmaPWmkfrql2nRZfeaRdNeiszR59DAWpToIQQQA0KuKar/KttbIH/iyHXtWxhA1tXxusKq+GzE0TQ9871xGPxyIIAIA6FFFtV/LNlbpzO0HiRJCJGnND2Zp7tTRpstAF1JMFwAAcK5wxFLZ1ppOISSR5PkydOEUtuY6FSMiAIBOoutBXt9/rMN0TCJiV4yzEUQAAB10tR4kEaV4pEe/P4t1IQ5HEAEASErcs2G6c1PJJF0xgxDidAQRAEDCnQ3TF5eeQwhJBAQRAHCp6DqQl2rq9XiCnQ3TmzzaticMgggAuFCyrAM5E23bEw9BBABcpru+IMmAtu2JhyACAC6SDH1BznTvlV/T6CwvbdsTFEEEAFyksrYxqaZj8nwZWjq3gPCRwOisCgAu4j/xmekSYsYj1oIkA0ZEACCJnHlCbnTnSGVto7bX1GvTGx8arnBgPB7JajeflMdakKRBEAGAJNHVThjf0CH6PGKpORQ2WNng3Hvl13RjyWTt+fCTDgGLkZDkQBABgCTQ3U6YwGeJc0JuV9qvASkp5OC6ZEQQAYAE1H4KZvRwr0q3VCfVTpgo1oAkP4IIACSYZG1GdqYfLZrGGhAXIIgAQAJJ5mZk7eX5MrRiwTTTZSAOCCIAkCCSsRnZmWjR7j4EEQBIEMnWjKwrtGh3H4IIACSIhqbkDCEjh6XpZ4vPUW4223LdiCACAA7WfnfMsaaQ6XJiKho3yr97LiMgLkYQAQCH6mp3TIpHiiTJIhGmYSARRADAkbrbHZPoISRnWJru/Yci5fqGMg0DSQQRAHCcZN4d03jylHJ9Q+mSijacvgsADrP7g+NJvTsmWRfdYmAIIgDgIBXVft3xRJXpMmw1NivDdAlwEKZmAMAhKqr9un1j8oYQj04vUC0uyDFdChyEEREAcIBwxNKqZ942XYZt6JiK7jAiAgAGhSOWdh84rodfel8nTp4yXY5t2KqL7hBEAMCQbfvqdPfT+9QcCpsuJeY8knIy0/XTK7/GVl30iCACAAaUb6vRutdqTZdhi2jcuP870xkBQa8IIgBgs/Zt2sdmZagh2JK0IURiGgb9QxABgBhrHzwOHjup31UeUn0weXtnjBg6RGt+MFvHPg1pbBYH16F/CCIAEENdnQ+T7B743gzNnTradBlIUAQRAIiR7s6HSVYjhqXpAU7OxSARRABgEKLTMPXBFv38j++4IoSMGJamm+cUaMWCqUzBYNAIIgAwQG6ahpnxlWwtnVOgvBFsxUVs2dpZ9bXXXtPixYs1fvx4eTwePfvss3beDgDiJjoN44YQkufL0Obl39R3Z09QSeEoQghiytYg0tzcrJkzZ2rNmjV23gYA4iocsVS2tSbpp2E8Xzxoyw472To1c/nll+vyyy+38xYAEHeVtY2uGAmhHwjiwVFrREKhkEKhUNvzYDBosBoA6FpDU/KHEEl66NqZmjuNbbmwl6NO3y0vL5fP52t75Ofnmy4JADoZm5VhuoS4ONYc6v0iYJAcFURWr16tQCDQ9jh8+LDpkgCgk+KCHPkyHDWgbAu3BC6Y5ai/SV6vV16v13QZANClaM+Q7TX1ag0n71JVj06vDykuyDFdClzAUUEEAJyqotqv+56rSeozY6QvT85lpwzixdYg8umnn2r//v1tz2tra7V3717l5ORo4sSJdt4aAGJm2z6/7thUZboMW6R4pEi7wR12yiDebA0ib775pubPn9/2fOXKlZKkJUuWaMOGDXbeGgBiYtu+Oi3f9JbpMmzz6PdnaWRmuhqaWjg5F0bYGkTmzZsny0reeVQAya2i2q87kjSEeCStuX6WrpjByAfMYo0IALQTXZBa98lJ/duWatPl2MaSNDIz3XQZAEEEAKTTAeTRl/frP//8gQItn5suJy7c0pgNzkYQAeA60VGP6LqIT5pD+smz1Tpx8pTp0uKKPiFwAoIIAFepqParbGuNK86K6Q59QuAkBBEArlFR7deyjVVJf2puT+gTAqchiABwhXDEUtnWGleHEIk+IXAegggAV6isbXT1dMwtcydrUVEufULgOAQRAK6wvabedAlG5DECAocjiABIeuGIpWf31pkuI65WzJ+quVNHMwICxyOIAEhq4Yile576HzU2t5ouJS6iO2J+dMlZBBAkBIIIgKTQvjfI6OFeyZJ2vPex/vDmR/o05I4GZeyIQSIiiABIeG7sDXJhwUgdPH5S9cFQ22vsiEEiIogASGhu7A0yYlianri1RJI6dIhlPQgSEUEEQMJya2+QB757blvgKCkcZbgaYHBSTBcAAAPltt4gOZlp+s0Ns5h6QVJhRARAwnLT6bGjMtO1a/VCpQ/h349ILgQRAI515im5Z66B+N87/mawuviI/m7v/850QgiSEkEEgCN1tROmfZfQ56qO6IOjJw1WGB/shEGyI4gAcJzudsL4Ay26fWOV/nV+oda8esBIbXbLTE/Rb286X8c+DbETBq5AEAHgKH3ZCfMfryRnCJGkX147U3OnjjZdBhA3TDgCcBS37YRp77aLC3TFjPGmywDiihERAI7ipp0wUcO9Q/SL783QFTNYBwL3IYgAcJSxWRmmS4ibFI/0XzcXa87U0awDgWsxNQPAMcIRS0/sPmi6jLiJWNKQ1BRCCFyNEREAjlBR7dfKP/yPTraGTZcSV26cigLaI4gAMK6i2q/bN1aZLsMIN01FAV0hiAAwKhyxtOqZt02XEXcenW5WVlyQY7oUwCiCCIC4a9+6fef7DTpx8pTpkuIquiKkdHER60PgegQRADHV2/kwXbVudxvatgNfIogAiJnezofprnW7m9x75de0dG4BIyHAF9i+CyAmoiHjzJGO+kCLlm2s0tb/qdNPNle7OoRI0ugsLyEEaIcREQCD1tP5MNHX/vV3b7k+hEjskgHOxIgIgEHry/kwhJDT01TskgE6IogAGDSacvXOI3bJAF1hagbAoDHd0LM8dskA3SKIABi04oIc5fkyVB9oce0UzLisdJUuPkcjM72qD7ao8dOQcjLTlesb2mkLM4AvEUQADFpqikeli4tc1ab919fP0sjM9G77pQDoG4IIgEELRyy9528yXUZcpA9J0X/803lMswAxQhAB0KueuqVWVPt133PvqD4YMlyl/S49Z6x+/YNvMPIBxBBBBECPeuqWKsk1nVIXfW2M1t14vukygKRDEAHQre5asvsDLbp9Y5WGe4ckfQjxSPrniwr0b1cWmS4FSEoEEQBdCkcsrXrm7R6Dxqehz+NWjwnXnDdev7h2ptKH0HIJsEtc/natWbNGkydPVkZGhi644AJVVlbG47YABigcsXTPU/+jEydPmS7FqH88fyIhBLCZ7X/Dfv/732vlypUqLS1VVVWVZs6cqUsvvVQNDQ123xrAGcIRS7sOHNeWvUe068BxhSOdxzsqqv2a+8DLeqrqiIEKnYN27EB8eCzLsnWK94ILLtD555+vRx99VJIUiUSUn5+vO++8U6tWrerxe4PBoHw+nwKBgLKzs+0sE0h6PS06jW5F7W5NiNt4JK29YRZbdIEB6s/nt60jIq2trdqzZ48WLVr05Q1TUrRo0SLt2rWr0/WhUEjBYLDDA8DgRQPGmQfT1QdatGxjlSqq/T2eoOsmeb4MQggQR7YuVj127JjC4bDGjRvX4fVx48bpvffe63R9eXm5ysrK7CwJcJ2eAoal0//6v++5d3S48bNeT9BNZgvOHqNbLyqkQyoQZ45ahbV69WoFAoG2x+HDh02XBCS8ytrGHgOGJak+GNL9296NX1EOMtw7RL++fpb+c2mxSgpHEUKAOLN1RGT06NFKTU3Vxx9/3OH1jz/+WLm5uZ2u93q98nq9dpYEuE5Dk3tHOXqSmZ6qf7m4UCsWTCV8AAbZOiKSnp6u2bNna8eOHW2vRSIR7dixQyUlJXbeGsAXxmZlmC7Bcf51fqH23Xep/teiaYQQwDDbG5qtXLlSS5Ys0Te+8Q0VFxfr4YcfVnNzs26++Wa7bw1AUnFBjvJ8Ga5e/9HerRdN1spLzzZdBoAv2B5E/vEf/1FHjx7Vz372M9XX1+u8885TRUVFpwWsAOyRmuLRvVcW6Y5NVaZLMe4fzs3Tv115jukyALQTlxbvK1as0IoVK+JxKwBdGJmZbroER7jkHP4BBDiNo3bNALAHC1ZPY70M4Dwcege4wMFjJ02XYJRHUi4t2wFHYkQESHLhiKXfVR4yXYYx0T0xpYuL2CEDOBAjIkASCkcsVdY2qqGpRceaQqoPumdqJiczTY3NX54anHvGeToAnIUgAiSZrg63c4Po9MvOu+drz4efqKGpRWOzMmjZDjgcQQRIIm49Pbf99Ev6kBSVFI4yWg+AviOIAAmo/dRL9F/9klx7ei7TL0DiIogACaarqZeczDRdUDDKddMxC88eo3/mxFwgoRFEgATS3dRLY/MpPV9db6QmUxaePUaPLy02XQaAQWL7LpAgwhHLtVMvXfnniwpNlwAgBhgRARJEZW2j66ZeukJzMiC5MCICJAg39QLpDs3JgOTDiAjgcOGIpUdf/rt++9oHpksxjt0xQPIhiAAOVlHt16pn3taJk6d6vzhJrZhfqGnjsmhOBiQpggjgUNv21emOTW+ZLsO4uVPH0KAMSGIEEcCBtu3zuz6EsCgVcAeCCOAQ4Yil3R8c18bdH7quJ8iZWJQKuAdBBHAAt64FmT1phG6ZO0U//1PHTrEsSgXcgyACGFZR7dftG6tMlxF3N144UT+/5lxJ0qXTczudncNICOAOBBHAoHDE0n3P1Zguw4jJozLbfp2a4mFBKuBSBBEgDro6LTc1xaNHX/67axuV5Qz3mi4BgAMQRACbdXVabp4vQ1dOz9Vjrx80V5hhudkZpksA4AAEEcBG3Z2W6w+0uDqE5LEtF8AXOGsGsAmn5XbNI7blAvgSIyKATTgtt7M8tuUCOANBBLBJQ5N7Q8i/LpiqksLRqg98psbmVuUM9yo3m225ADojiAA2GZvl3sWYhWOHsx0XQJ+wRgSIsXDE0q4Dx1Uf+Ew5memmyzHCzSEMQP8wIgLEUFdbdd2Eg+oA9BdBBBiArhqUba+p73KrrltwUB2AgSCIAP3U1ahHbnaGWj4PuzaESBxUB2BgCCJAP2zbV6c7Nr3V6XU3tWk/L3+Ebp47WaOHeyVLOtYc4qA6AANGEAH6aNs+v1b8rnMIcZt7LjubHTEAYoYgAvRBRbVfd2yqMl2GcbRmBxBrBBGgF9FW7W7GQlQAdiGIAL2gVTsLUQHYhyAC9MKtrdpzMtN07z+cQ2t2ALYiiAC9OHis2XQJcRWNG//+nXMZAQFgO4II0IOKar9+9dLfTZcRV0zDAIgnggjQDTctUr1r4TQVjMmkHwiAuCOIAF0IRyxteL3WFYtUb7u4QHddcpbpMgC4FEEEOENFtV/3PVfjim6p/2vhNP2IEALAINuCyP33368//elP2rt3r9LT03XixAm7bgUMSvsD7A4eO6lfvfQ30yXFzZQxmaZLAOBytgWR1tZWXXfddSopKdHjjz9u122AQenqADs3GZuVYboEAC5nWxApKyuTJG3YsMGuWwCDUlHt17KNVa48Mdej07tjaNcOwDRHrREJhUIKhUJtz4PBoMFqkMxaP4/oJ5vfdm0IkWjXDsAZUkwX0F55ebl8Pl/bIz8/33RJSELb9vk1+//brsbmU6ZLMSLXl6G1N8yiTwgAR+jXiMiqVav04IMP9njNu+++q7PPPntAxaxevVorV65sex4MBgkjiKnybTVa91qt6TLiZuSwNN1/zbkamZmuhqYW+oQAcJx+BZEf//jHWrp0aY/XTJkyZcDFeL1eeb3eAX8/0JM/7j3imhDiHeLRHfOmasWCaYQOAI7WryAyZswYjRkzxq5agJiLbs19saZeG14/aLoc22UMSdHt3yrUnQsJIAASg22LVQ8dOqTGxkYdOnRI4XBYe/fulSRNnTpVw4cPt+u2QJtt+/z66ZZqNTa3mi4lLq6d9RU9eO1MAgiAhGJbEPnZz36m//qv/2p7/vWvf12S9Morr2jevHl23RaQ5L61ILnZXkIIgIRk266ZDRs2yLKsTg9CCOy2bV+da0KI54vHfVedQwgBkJActX0XGKxwxNJPt1SbLsM2vqEdBzHZigsg0TmqoRkwWJW1jUnZHyQtRXrk+lm6pCi37VwctuICSAYEESSVhqbkPDPmP5cW66KzTu9YKykcZbgaAIgdpmaQVJLxELcRw9I0Z+po02UAgC0YEUHCi/YKaWhq0QvV9abLibkHvnsu0y8AkhZBBAmtotqvsq018geSb0omN9ur+646h4WoAJIaQQQJq6Lar2Ubq5LmBN1LzxmnpXMKWIgKwFUIIkhI4Yilsq01SRNCJOmmCyezEBWA67BYFQmpsrYxqaZjRgxL04WEEAAuRBBBQkq2bbosSAXgVgQRJKRk2aabm+3Vb+iMCsDFWCOChBHdpus/8ZmeqTpsupwBuWvhNF0wZRQLUgHgCwQRJIRk2KY73DtEdy6cRvAAgHYIInCc9g3KxmZl6JPmVi3flPjbdH/xvRmEEAA4A0EEjlJR7dd9z72j+mDIdCkxddvFBbpiButAAOBMBBE4RkW1X7dvrDJdRkwN96bqF9+boStmjDddCgA4EkEERkWnYeqDLVr99D7T5cTENTPHa0LOMJUUjtKFU0YxHQMAPSCIIO6i4WN7Tb2e3VunxuZW0yXF1PyvjdXV533FdBkAkBAIIoirZNj90ptk6XECAPFAEEHcJNshdWfySMr1ne4NAgDoG4IIbBeOWNp94LhWPf12UocQSSpdXMSaEADoB4IIbOWGqRjp9EhI6eIiWrUDQD8RRGCbZJ6K8UiyJP1w7mRdUpRLq3YAGCCCCAbszA6o7T+MwxFLZVtrkjKESIyAAECsEEQwIF1NueS1+3CurG1MyumYEUPTtOYHs+gPAgAxQhBBv3U35eIPtOj2jVW6a+FUvXU4YKQ2O3kkPfC9czV36mjTpQBA0iCIoF/6MuXy8I79casnXvKYigEAWxBE0C/JOuXSlVGZ6br6vPEsRgUAGxFE0C8NTckZQr6eP0IrLzlLKR6PjjWHOi2+BQDYgyCCfknG9uUXFozQk7fNNV0GALhSiukCkFiKC3KU50ueMOLxSP99S4npMgDAtQgi6Ld/Oj/fdAkx8y8XFSh9CH8NAMAUpmbQZ8nUrj3FI916UYFWX1FkuhQAcDWCCPokmdq1XzvrK/r3785gJAQAHIAggl4lS7t2eoEAgPMQRNCrRO8dwsF0AOBcBBH0KpF7h/z6+q/rihnjTZcBAOgGQQS9StTeIb++fpaumME0DAA4Gav10KtE6x2S58vQb24ghABAImBEBD0KRyxV1jbqium5evz1g6bL6WBcVrr+5eJCXX/BJO09fEINTS20ZgeABEMQQbec3jfk1bsXaGh6qiSppHCU4WoAAAPB1Ay6FO0b4tQQIkl7D58wXQIAYJBsCyIHDx7ULbfcooKCAg0dOlSFhYUqLS1Va2urXbfEAIQjlnYdOK4te49o14HjCkcsfdYa1t1P7XN835BE3s0DADjNtqmZ9957T5FIROvWrdPUqVNVXV2tW2+9Vc3NzXrooYfsui26EV3r0X4dxfaa+k5TL95Uj0Jhp0eQ0xJ1Nw8A4Esey7Li9qnzy1/+UmvXrtUHH3zQp+uDwaB8Pp8CgYCys7Ntri55dbXWY8SwNJ04ecpgVQPnkZTry9Cf71nAolQAcKD+fH7HdbFqIBBQTk5Ot18PhUIKhUJtz4PBYDzKSmrdnRGTyCFEkkoXFxFCACAJxG2x6v79+/XII4/otttu6/aa8vJy+Xy+tkd+fvIcN29CspwR016uL0Nrb5jFeTEAkCT6PTWzatUqPfjggz1e8+677+rss89ue37kyBF961vf0rx58/TYY491+31djYjk5+czNTNAuw4c1/f/z27TZQxaTmaa7v2Hc5SbTY8QAEgEtk7N/PjHP9bSpUt7vGbKlCltv66rq9P8+fM1Z84c/fa3v+3x+7xer7xeb39LQjcSfVdJNG78+3fOZQQEAJJUv4PImDFjNGbMmD5de+TIEc2fP1+zZ8/W+vXrlZJC25J4SpRdJeOy0vX94kna8JeDOvHZl2tXcn0ZKl1cRAgBgCRm22LVI0eOaN68eZo0aZIeeughHT16tO1rubm5dt0W7UTPiKkPtDh2nYhHUtnV03XZ9DzduXBapy3GTMMAQHKzLYhs375d+/fv1/79+zVhwoQOX4vjjmFXS03xqHRxkW7fWNXtNWmpHp0y1DckN9ur+646p23EIzXFQ6t2AHAZ2+ZKli5dKsuyunzAOUyFkMun5+r1VQuZdgEAl+PQuyTTvoPq6Eyv7nuuxnRJXbrhwklMuwAACCLJpKLar/ueq1F9MAF2yzAwBgAQQSRpVFT7e1wL4jTHmkO9XwQASHrsp00C4YilVc+8bbqMfkmUrcUAAHsxIpIEdn9wPGHOjokeWFdc0P2ZQwAA9yCIJJj2i1GjvTY27v7QdFnyDR2iwGef93gNB9YBAM5EEEkgFdV+lW2tkT/w5WLU3OwMHf3U7OLUW+ZO1k+uLOoQkD5pDunnf3q3Y610SgUAnIEgkiAqqv1atrGq02YTJ+yQWVSU22Uzskun59EpFQDQI4JIAghHLJVtrXHcjtfe1nvQKRUA0Bt2zSSAytrGDlMcTsJ6DwDAYDAikgAampwXQvJY7wEAiAGCSAJwWs+NHy2aphULpjESAgAYNIJIAiguyNGIYWnGe4UwCgIAiDWCSII42Ro2ct9rZ03QRWeNZtcLAMAWBJEE8B87/qbWzyMx/7nDvUP0aaj7JmSZ3lQ9eO0MwgcAwDYEEYdr/TyidTs/iPnPzRmWpkgv+4HTUtlUBQCwF580DlZR7deF5TvUYsNoyAVTcnTis57XnJw4eUqVtY0xvzcAAFGMiNikqzNh+jPF0V0n1djpWy1O3DoMAEgeBBEbdHUmzJk7TqJBpT7wmRqbW5Uz3Kvc7C+7lPa3k+pw7xDNO2uU/vj2x326/o0+jnQ4beswACC5EERirNszYQItWraxSmtvmCVJnYJKVJ4vQ/90/sR+d1L9xfdm6NLpuXrl/RfU3IcdNo3NrcrJTNMnzae6DDy9tW8HACAWWCMSQz2dCRN9bdUzb2vZxqpug4Y/0KJfvfS3ft331osKNDIzXX/cV6d/ubiwz9/3nfO+IqnzJE30Oe3bAQB2Y0Qkhno7E8aSYtqULCsjVf90fr7+uM+v//N/a9tez/SmqjnU+6jIoqJcnV+Q02l0JpfGZQCAOCGIxFA8F3aOykxX2eIi3fnk3k4jML2FkPbTLqkpHl1SlDuohbUAAAwUQSSG4rWw0yPp51dP18//1PuCVo/U4Zqupl1SUzwqKRwV+0IBAOgFa0RiqLggR3m+jD5ujB2YFI+05vqva2Rmep8WtI7MTO/wPNeXobU3zGLaBQDgCIyIxFBqikeli4u0bGNVp5GIWIlY0shMb5+nge698mvK9Q1l2gUA4EgEkRi7bHqe1t4wq9vtubEQDRV9kesbyrQLAMCxCCI2uGx6XtsC0Oer/frvXR/G9OdHRzbyfBmqD7TQBwQAkLBYI2KT6ALQy/u4FiMzPVX//w+LlZvd/RoTj043PItOr5QuLmp7/czrJPqAAACcjyBis+jIRW/SUlPU3Pq57ruq7+EiOg2Ue8bPZ0EqACBReCzLsu9ctUEKBoPy+XwKBALKzs42Xc6A9eUAu2jQ6K4F/Jln1bQ32AP2AACIpf58fhNE4qSi2q/7nntH9cFQj9fl+TL053sWSBLhAgCQkPrz+c1i1Ti5bHqesjLS9IPH3ujxOn+gRZW1jSopHMVuFwBA0mONSBwd+7Tn0ZCo7TX1NlcCAIAzEETiqK+9P7bsrVM44tgZMwAAYoYgEkfFBTnKyUzr9brjza2qrG2MQ0UAAJhFELFBOGJp14Hj2rL3iHYdON42upGa4tF3zvtKn35GPE/yBQDAFBarxlhFtb/HrbeLinL1+OsHe/058TrJFwAAkxgRiaFt++p0+8aqTmfM1AdatGxjlSqq/b2e0Nu+eyoAAMmOIBIj2/b5teJ3b3X5teiy07KtNZJEa3YAAL5AEImBimq/7thUpZ42ulj6skcIrdkBADiNNSKDFI5YbSMdfRFdhNr+hF66pwIA3MrWEZGrrrpKEydOVEZGhvLy8nTjjTeqrq7OzlvGXWVtY6c1IT1pvwg1ekLv1ed9RSWFowghAADXsTWIzJ8/X3/4wx/0/vvv6+mnn9aBAwd07bXX2nnLuOvPNlsWoQIA0JGtUzM/+tGP2n49adIkrVq1Stdcc41OnTqltLTeG3slgv5ss2URKgAAHcVtsWpjY6OeeOIJzZkzJ2lCiKRet+NKUopH+vX1X2cRKgAAZ7A9iNxzzz3KzMzUqFGjdOjQIW3ZsqXba0OhkILBYIeH06WmeLrdjhv16Pdn6YoZ4+NXFAAACaLfQWTVqlXyeDw9Pt5777226++++2699dZbevHFF5WamqqbbrpJltX1Ptfy8nL5fL62R35+/sB/Z3HU3XbcPF+GfnPDLF0xg5EQAAC64rG6SwXdOHr0qI4fP97jNVOmTFF6enqn1z/66CPl5+frL3/5i0pKSjp9PRQKKRQKtT0PBoPKz89XIBBQdnZ2f8o0Ihyx2I4LAHC9YDAon8/Xp8/vfi9WHTNmjMaMGTOgwiKRiCR1CBvteb1eeb3eAf1sJ4huxwUAAH1j266ZN954Q3/961/1zW9+UyNHjtSBAwd07733qrCwsMvREAAA4D62LVYdNmyYnnnmGS1cuFBf/epXdcstt2jGjBnauXNnQo96AACA2LFtROTcc8/Vyy+/bNePBwAASYBD7wAAgDGuPPSO3S0AADiD64JIRbVfZVtrOhxUl+fLUOniIjqfAgAQZ66amqmo9mvZxqpOp+XWB1q0bGOVKqr9hioDAMCdXBNEwhFLZVtr1FX3tuhrZVtrFI70q78bAAAYBNcEkcraxk4jIe1ZkvyBFlXWNsavKAAAXM41QaShqfsQMpDrAADA4LkmiIzNyuj9on5cBwAABs81QaS4IEd5vgx1t0nXo9O7Z4oLcuJZFgAAruaaIJKa4lHp4iJJ6hRGos9LFxfRTwQAgDhyTRCRpMum52ntDbOU6+s4/ZLry9DaG2bRRwQAgDhzXUOzy6bn6ZKiXDqrAgDgAK4LItLpaZqSwlGmywAAwPVcNTUDAACchSACAACMIYgAAABjCCIAAMAYgggAADCGIAIAAIwhiAAAAGMIIgAAwBiCCAAAMMbRnVUty5IkBYNBw5UAAIC+in5uRz/He+LoINLU1CRJys/PN1wJAADor6amJvl8vh6v8Vh9iSuGRCIR1dXVKSsrSx6PmUPpgsGg8vPzdfjwYWVnZxupwel4j3rG+9M73qOe8f70jveoZ/F+fyzLUlNTk8aPH6+UlJ5XgTh6RCQlJUUTJkwwXYYkKTs7mz/cveA96hnvT+94j3rG+9M73qOexfP96W0kJIrFqgAAwBiCCAAAMIYg0guv16vS0lJ5vV7TpTgW71HPeH96x3vUM96f3vEe9czJ74+jF6sCAIDkxogIAAAwhiACAACMIYgAAABjCCIAAMAYgkg/XHXVVZo4caIyMjKUl5enG2+8UXV1dabLcoyDBw/qlltuUUFBgYYOHarCwkKVlpaqtbXVdGmOcf/992vOnDkaNmyYRowYYbocR1izZo0mT56sjIwMXXDBBaqsrDRdkmO89tprWrx4scaPHy+Px6Nnn33WdEmOUl5ervPPP19ZWVkaO3asrrnmGr3//vumy3KUtWvXasaMGW2NzEpKSvT888+bLqsDgkg/zJ8/X3/4wx/0/vvv6+mnn9aBAwd07bXXmi7LMd577z1FIhGtW7dO77zzjn71q1/pN7/5jX7yk5+YLs0xWltbdd1112nZsmWmS3GE3//+91q5cqVKS0tVVVWlmTNn6tJLL1VDQ4Pp0hyhublZM2fO1Jo1a0yX4kg7d+7U8uXLtXv3bm3fvl2nTp3St7/9bTU3N5suzTEmTJigBx54QHv27NGbb76pBQsW6Oqrr9Y777xjurQvWRiwLVu2WB6Px2ptbTVdimP94he/sAoKCkyX4Tjr16+3fD6f6TKMKy4utpYvX972PBwOW+PHj7fKy8sNVuVMkqzNmzebLsPRGhoaLEnWzp07TZfiaCNHjrQee+wx02W0YURkgBobG/XEE09ozpw5SktLM12OYwUCAeXk5JguAw7U2tqqPXv2aNGiRW2vpaSkaNGiRdq1a5fBypCoAoGAJPH/nG6Ew2E9+eSTam5uVklJiely2hBE+umee+5RZmamRo0apUOHDmnLli2mS3Ks/fv365FHHtFtt91muhQ40LFjxxQOhzVu3LgOr48bN0719fWGqkKiikQiuuuuuzR37lxNnz7ddDmO8vbbb2v48OHyer26/fbbtXnzZhUVFZkuq43rg8iqVavk8Xh6fLz33ntt1999991666239OKLLyo1NVU33XSTrCRvTtvf90iSjhw5ossuu0zXXXedbr31VkOVx8dA3h8AsbV8+XJVV1frySefNF2K43z1q1/V3r179cYbb2jZsmVasmSJampqTJfVxvUt3o8eParjx4/3eM2UKVOUnp7e6fWPPvpI+fn5+stf/uKoYa5Y6+97VFdXp3nz5unCCy/Uhg0blJKS3Hl3IH+GNmzYoLvuuksnTpywuTrnam1t1bBhw/TUU0/pmmuuaXt9yZIlOnHiBKONZ/B4PNq8eXOH9wqnrVixQlu2bNFrr72mgoIC0+U43qJFi1RYWKh169aZLkWSNMR0AaaNGTNGY8aMGdD3RiIRSVIoFIplSY7Tn/foyJEjmj9/vmbPnq3169cnfQiRBvdnyM3S09M1e/Zs7dixo+3DNRKJaMeOHVqxYoXZ4pAQLMvSnXfeqc2bN+vVV18lhPRRJBJx1OeW64NIX73xxhv661//qm9+85saOXKkDhw4oHvvvVeFhYVJPRrSH0eOHNG8efM0adIkPfTQQzp69Gjb13Jzcw1W5hyHDh1SY2OjDh06pHA4rL1790qSpk6dquHDh5stzoCVK1dqyZIl+sY3vqHi4mI9/PDDam5u1s0332y6NEf49NNPtX///rbntbW12rt3r3JycjRx4kSDlTnD8uXLtWnTJm3ZskVZWVlta4t8Pp+GDh1quDpnWL16tS6//HJNnDhRTU1N2rRpk1599VW98MILpkv7ktlNO4lj37591vz5862cnBzL6/VakydPtm6//Xbro48+Ml2aY6xfv96S1OUDpy1ZsqTL9+eVV14xXZoxjzzyiDVx4kQrPT3dKi4utnbv3m26JMd45ZVXuvzzsmTJEtOlOUJ3/79Zv3696dIc44c//KE1adIkKz093RozZoy1cOFC68UXXzRdVgeuXyMCAADMSf4JfAAA4FgEEQAAYAxBBAAAGEMQAQAAxhBEAACAMQQRAABgDEEEAAAYQxABAADGEEQAAIAxBBEAAGAMQQQAABhDEAEAAMb8PxoYuYQR9IC0AAAAAElFTkSuQmCC"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "id": "c381d7b5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:07:24.298336Z",
     "start_time": "2024-10-20T15:07:24.296480Z"
    }
   },
   "source": "",
   "outputs": [],
   "execution_count": 6
  },
  {
   "cell_type": "code",
   "id": "ed695e3f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:07:25.335363Z",
     "start_time": "2024-10-20T15:07:25.291733Z"
    }
   },
   "source": [
    "model.eval()\n",
    "inf = CausalInference(dag=DAGnx)\n",
    "\n",
    "int_nodes_vals0 = {'X':np.array([0.0,])}\n",
    "int_nodes_vals1 = {'X':np.array([1.0,])}\n",
    "effect_var = 'M'\n",
    "effect_index = var_names.index(effect_var)\n",
    "\n",
    "preds0 = inf.forward(all_data, model=model, intervention_nodes_vals=int_nodes_vals0)\n",
    "preds1 = inf.forward(all_data, model=model, intervention_nodes_vals=int_nodes_vals1)\n",
    "ATE = np.array([0.5, 0.5, 0.5])\n",
    "ATE_pred = (preds1[:,effect_index,:] - preds0[:,effect_index,:]).mean(0)\n",
    "eATE = np.abs(ATE_pred - ATE)\n",
    "print('est ATE:', ATE_pred, 'error:', eATE)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "est ATE: [0.19732073 0.20942023 0.18575056] error: [0.30267927 0.29057977 0.31424944]\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "code",
   "id": "7ed31067",
   "metadata": {},
   "source": [],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
