{
 "cells": [
  {
   "cell_type": "code",
   "id": "44ca21de",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:22:21.179533Z",
     "start_time": "2024-10-20T15:22:19.844341Z"
    }
   },
   "source": [
    "import networkx as nx\n",
    "import numpy as np\n",
    "import torch\n",
    "import utils\n",
    "from model import CFCN\n",
    "from utils.inference import CausalInference\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from datasets import reorder_dag, get_full_ordering"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "code",
   "id": "fb2fd00d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:22:49.788671Z",
     "start_time": "2024-10-20T15:22:49.704957Z"
    }
   },
   "source": [
    "shuffling = 0\n",
    "seed = 1\n",
    "standardize = 0\n",
    "sample_size = 10000\n",
    "batch_size = 50\n",
    "max_iters =  6000\n",
    "eval_interval = 100\n",
    "eval_iters = 100\n",
    "validation_fraction = 0.1\n",
    "np.random.seed(seed=seed)\n",
    "torch.manual_seed(seed)\n",
    "device = 'cuda'\n",
    "dropout_rate = 0.0\n",
    "learning_rate = 1e-3\n",
    "\n",
    "neurons_per_layer = [3,6,3]\n",
    "\n",
    "def generate_data_mediation(N):\n",
    "    DAGnx = nx.DiGraph()\n",
    "    \n",
    "    Ux = np.random.randn(N)\n",
    "    X =  Ux\n",
    "\n",
    "    Um = np.random.randn(N)\n",
    "    M =  0.5 * X + Um\n",
    "\n",
    "    Uy = np.random.randn(N)\n",
    "    Y =  0.8 * M + Uy\n",
    "\n",
    "    M0 = 0.5 * 0 + Um \n",
    "    M1 = 0.5 * 1 + Um\n",
    "\n",
    "    Y0 = 0.8 * M0 + Uy \n",
    "    Y1 = 0.8 * M1 +  Uy \n",
    "\n",
    "    all_data_dict = {'X': X, 'M': M, 'Y': Y}\n",
    "\n",
    "    # types can be 'cat' (categorical) 'cont' (continuous) or 'bin' (binary)\n",
    "    var_types = {'X': 'cont', 'M': 'cont', 'Y': 'cont'}\n",
    "\n",
    "    DAGnx.add_edges_from([('X', 'M'), ('M', 'Y')])\n",
    "    DAGnx = reorder_dag(dag=DAGnx)  # topologically sorted dag\n",
    "    var_names = list(DAGnx.nodes())  # topologically ordered list of variables\n",
    "    all_data = np.stack([all_data_dict[key] for key in var_names], axis=1)\n",
    "    causal_ordering = get_full_ordering(DAGnx)\n",
    "    ordered_var_types = dict(sorted(var_types.items(), key=lambda item: causal_ordering[item[0]]))\n",
    "\n",
    "    return all_data, DAGnx, var_names, causal_ordering, ordered_var_types, Y0, Y1\n",
    "\n",
    "_, _, _, _, _, Y0, Y1 = generate_data_mediation(N=1000000)\n",
    "ATE = (Y1 - Y0).mean()  # ATE based off a large sample\n",
    "all_data, DAG, var_names, causal_ordering, var_types, Y0, Y1 = generate_data_mediation(N=sample_size)\n",
    "print(var_names, ATE)\n",
    "\n",
    "input_dim = all_data.shape[1]\n",
    "\n",
    "# prepend the input size to neurons_per_layer if not included in neurons_per_layer\n",
    "# append the intput size to neurons_per_layer (output) if not included in neurons_per_layer\n",
    "neurons_per_layer = [6,12, 12, 6, 6]\n",
    "\n",
    "\n",
    "indices = np.arange(0, len(all_data))\n",
    "np.random.shuffle(indices)\n",
    "\n",
    "val_inds = indices[:int(validation_fraction*len(indices))]\n",
    "train_inds = indices[int(validation_fraction*len(indices)):]\n",
    "train_data = all_data[train_inds]\n",
    "val_data = all_data[val_inds]\n",
    "\n",
    "train_data, val_data = torch.from_numpy(train_data).float(),  torch.from_numpy(val_data).float()\n",
    "\n",
    "model = CFCN(neurons_per_layer=neurons_per_layer, dag=DAG, causal_ordering=causal_ordering, var_types=var_types, dropout_rate=dropout_rate).to(device)\n",
    "\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['X', 'M', 'Y'] 0.3999999999999992\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "id": "bed0fbc4",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2024-10-20T15:22:58.442941Z",
     "start_time": "2024-10-20T15:22:58.373605Z"
    }
   },
   "source": [
    "def get_batch(train_data, val_data, split, device, batch_size):\n",
    "    data = train_data if split == 'train' else val_data\n",
    "    ix = torch.randint(0, len(data), (batch_size,))\n",
    "    x = data[ix]\n",
    "    return x.to(device)\n",
    "\n",
    "all_var_losses = {}\n",
    "for iter_ in range(0, max_iters):\n",
    "    # train and update the model\n",
    "    model.train()\n",
    "\n",
    "    xb = get_batch(train_data=train_data, val_data=val_data, split='train', device=device, batch_size=batch_size)\n",
    "    xb_mod = torch.clone(xb.detach())\n",
    "    X, loss, loss_dict = model(X=xb, targets=xb_mod, shuffling=shuffling)\n",
    "\n",
    "    optimizer.zero_grad(set_to_none=True)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "\n",
    "    if iter_ % eval_interval == 0:  # evaluate the loss (no gradients)\n",
    "        for key in loss_dict.keys():\n",
    "            if key not in all_var_losses.keys():\n",
    "                all_var_losses[key] = []\n",
    "            all_var_losses[key].append(loss_dict[key])\n",
    "\n",
    "        model.eval()\n",
    "        eval_loss = {}\n",
    "        for split in ['train', 'val']:\n",
    "            losses = torch.zeros(eval_iters)\n",
    "            for k in range(eval_iters):\n",
    "\n",
    "                xb = get_batch(train_data=train_data, val_data=val_data, split=split, device=device,\n",
    "                               batch_size=batch_size)\n",
    "                xb_mod = torch.clone(xb.detach())\n",
    "                X, loss, loss_dict = model(X=xb, targets=xb_mod, shuffling=False)\n",
    "                losses[k] = loss.item()\n",
    "            eval_loss[split] = losses.mean()\n",
    "        print(f\"step {iter_} of {max_iters}: train_loss {eval_loss['train']:.4f}, val loss {eval_loss['val']:.4f}\")"
   ],
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "mat1 and mat2 shapes cannot be multiplied (50x3 and 6x12)",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mRuntimeError\u001B[0m                              Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[5], line 14\u001B[0m\n\u001B[1;32m     12\u001B[0m xb \u001B[38;5;241m=\u001B[39m get_batch(train_data\u001B[38;5;241m=\u001B[39mtrain_data, val_data\u001B[38;5;241m=\u001B[39mval_data, split\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtrain\u001B[39m\u001B[38;5;124m'\u001B[39m, device\u001B[38;5;241m=\u001B[39mdevice, batch_size\u001B[38;5;241m=\u001B[39mbatch_size)\n\u001B[1;32m     13\u001B[0m xb_mod \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mclone(xb\u001B[38;5;241m.\u001B[39mdetach())\n\u001B[0;32m---> 14\u001B[0m X, loss, loss_dict \u001B[38;5;241m=\u001B[39m \u001B[43mmodel\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mxb\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtargets\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mxb_mod\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mshuffling\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mshuffling\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     16\u001B[0m optimizer\u001B[38;5;241m.\u001B[39mzero_grad(set_to_none\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m)\n\u001B[1;32m     17\u001B[0m loss\u001B[38;5;241m.\u001B[39mbackward()\n",
      "File \u001B[0;32m~/anaconda3/envs/cat_env/lib/python3.9/site-packages/torch/nn/modules/module.py:1532\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m   1530\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)  \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[1;32m   1531\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m-> 1532\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/anaconda3/envs/cat_env/lib/python3.9/site-packages/torch/nn/modules/module.py:1541\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m   1536\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[1;32m   1537\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[1;32m   1538\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[1;32m   1539\u001B[0m         \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[1;32m   1540\u001B[0m         \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[0;32m-> 1541\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m   1543\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m   1544\u001B[0m     result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n",
      "File \u001B[0;32m~/GitHub/Causal_Transformer/CFCN/model.py:179\u001B[0m, in \u001B[0;36mCFCN.forward\u001B[0;34m(self, X, targets, shuffling, verbose)\u001B[0m\n\u001B[1;32m    175\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mwas_shuffled \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[1;32m    177\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i, (linear, activation, dropout) \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28menumerate\u001B[39m(\u001B[38;5;28mzip\u001B[39m(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mlayers, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mactivations, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdropout_layers)):\n\u001B[0;32m--> 179\u001B[0m     X \u001B[38;5;241m=\u001B[39m \u001B[43mlinear\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    180\u001B[0m     X \u001B[38;5;241m=\u001B[39m activation(X)\n\u001B[1;32m    181\u001B[0m     X \u001B[38;5;241m=\u001B[39m dropout(X)\n",
      "File \u001B[0;32m~/anaconda3/envs/cat_env/lib/python3.9/site-packages/torch/nn/modules/module.py:1532\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m   1530\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)  \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[1;32m   1531\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m-> 1532\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/anaconda3/envs/cat_env/lib/python3.9/site-packages/torch/nn/modules/module.py:1541\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m   1536\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[1;32m   1537\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[1;32m   1538\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[1;32m   1539\u001B[0m         \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[1;32m   1540\u001B[0m         \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[0;32m-> 1541\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m   1543\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m   1544\u001B[0m     result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n",
      "File \u001B[0;32m~/GitHub/Causal_Transformer/CFCN/model.py:86\u001B[0m, in \u001B[0;36mMaskedLinear.forward\u001B[0;34m(self, input)\u001B[0m\n\u001B[1;32m     84\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;28minput\u001B[39m):\n\u001B[1;32m     85\u001B[0m     masked_weight \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mweight \u001B[38;5;241m*\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmask\n\u001B[0;32m---> 86\u001B[0m     masked_linear \u001B[38;5;241m=\u001B[39m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mmatmul\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43minput\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmasked_weight\u001B[49m\u001B[43m)\u001B[49m \u001B[38;5;241m+\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbias\n\u001B[1;32m     87\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m masked_linear\n",
      "\u001B[0;31mRuntimeError\u001B[0m: mat1 and mat2 shapes cannot be multiplied (50x3 and 6x12)"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "id": "539b528e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:23:35.195031Z",
     "start_time": "2024-10-20T15:23:35.127247Z"
    }
   },
   "source": [
    "\n",
    "df = pd.DataFrame(all_data, columns=var_names)\n",
    "data_dict = df.to_dict(orient='list')\n",
    "cause_var = 'X'\n",
    "effect_var = 'M'\n",
    "effect_index = var_names.index(effect_var)\n",
    "\n",
    "\n",
    "model.eval()\n",
    "intervention_nodes_vals_0 = {'X': 0}\n",
    "intervention_nodes_vals_1 = {'X': 1}\n",
    "ci = CausalInference(dag=DAG)\n",
    "D0 = ci.forward(data=all_data, model=model, intervention_nodes_vals=intervention_nodes_vals_0)\n",
    "D1 = ci.forward(data=all_data, model=model, intervention_nodes_vals=intervention_nodes_vals_1)\n",
    "\n",
    "\n",
    "\n",
    "est_ATE = (D1[:,effect_index] - D0[:,effect_index]).mean()\n",
    "print('True ATE X->M:', 0.5, 'Estimated:', est_ATE)"
   ],
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matthewvowels/GitHub/Causal_Transformer/utils/inference.py:18: UserWarning: No mask has been specified. If padding has been used, the absence of a mask may lead to incorrect results.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "mat1 and mat2 shapes cannot be multiplied (10000x3 and 6x12)",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mRuntimeError\u001B[0m                              Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[6], line 12\u001B[0m\n\u001B[1;32m     10\u001B[0m intervention_nodes_vals_1 \u001B[38;5;241m=\u001B[39m {\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mX\u001B[39m\u001B[38;5;124m'\u001B[39m: \u001B[38;5;241m1\u001B[39m}\n\u001B[1;32m     11\u001B[0m ci \u001B[38;5;241m=\u001B[39m CausalInference(dag\u001B[38;5;241m=\u001B[39mDAG)\n\u001B[0;32m---> 12\u001B[0m D0 \u001B[38;5;241m=\u001B[39m \u001B[43mci\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mforward\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdata\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mall_data\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmodel\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mmodel\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mintervention_nodes_vals\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mintervention_nodes_vals_0\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     13\u001B[0m D1 \u001B[38;5;241m=\u001B[39m ci\u001B[38;5;241m.\u001B[39mforward(data\u001B[38;5;241m=\u001B[39mall_data, model\u001B[38;5;241m=\u001B[39mmodel, intervention_nodes_vals\u001B[38;5;241m=\u001B[39mintervention_nodes_vals_1)\n\u001B[1;32m     17\u001B[0m est_ATE \u001B[38;5;241m=\u001B[39m (D1[:,effect_index] \u001B[38;5;241m-\u001B[39m D0[:,effect_index])\u001B[38;5;241m.\u001B[39mmean()\n",
      "File \u001B[0;32m~/GitHub/Causal_Transformer/utils/inference.py:73\u001B[0m, in \u001B[0;36mCausalInference.forward\u001B[0;34m(self, data, model, intervention_nodes_vals)\u001B[0m\n\u001B[1;32m     71\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i, var \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28menumerate\u001B[39m(\u001B[38;5;28mlist\u001B[39m(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdag\u001B[38;5;241m.\u001B[39mnodes())):\n\u001B[1;32m     72\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m var \u001B[38;5;129;01min\u001B[39;00m vars_to_update:\n\u001B[0;32m---> 73\u001B[0m         preds \u001B[38;5;241m=\u001B[39m \u001B[43mmodel\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mforward\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfrom_numpy\u001B[49m\u001B[43m(\u001B[49m\u001B[43mDprime\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfloat\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mto\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmodel\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdevice\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     74\u001B[0m         Dprime[:, i] \u001B[38;5;241m=\u001B[39m preds[:, i]\u001B[38;5;241m.\u001B[39mdetach()\u001B[38;5;241m.\u001B[39mcpu()\u001B[38;5;241m.\u001B[39mnumpy()\n\u001B[1;32m     75\u001B[0m         \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmask \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:  \u001B[38;5;66;03m# apply the mask to the predictions\u001B[39;00m\n",
      "File \u001B[0;32m~/GitHub/Causal_Transformer/CFCN/model.py:179\u001B[0m, in \u001B[0;36mCFCN.forward\u001B[0;34m(self, X, targets, shuffling, verbose)\u001B[0m\n\u001B[1;32m    175\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mwas_shuffled \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[1;32m    177\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i, (linear, activation, dropout) \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28menumerate\u001B[39m(\u001B[38;5;28mzip\u001B[39m(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mlayers, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mactivations, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdropout_layers)):\n\u001B[0;32m--> 179\u001B[0m     X \u001B[38;5;241m=\u001B[39m \u001B[43mlinear\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    180\u001B[0m     X \u001B[38;5;241m=\u001B[39m activation(X)\n\u001B[1;32m    181\u001B[0m     X \u001B[38;5;241m=\u001B[39m dropout(X)\n",
      "File \u001B[0;32m~/anaconda3/envs/cat_env/lib/python3.9/site-packages/torch/nn/modules/module.py:1532\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m   1530\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)  \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[1;32m   1531\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m-> 1532\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/anaconda3/envs/cat_env/lib/python3.9/site-packages/torch/nn/modules/module.py:1541\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m   1536\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[1;32m   1537\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[1;32m   1538\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[1;32m   1539\u001B[0m         \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[1;32m   1540\u001B[0m         \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[0;32m-> 1541\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m   1543\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m   1544\u001B[0m     result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n",
      "File \u001B[0;32m~/GitHub/Causal_Transformer/CFCN/model.py:86\u001B[0m, in \u001B[0;36mMaskedLinear.forward\u001B[0;34m(self, input)\u001B[0m\n\u001B[1;32m     84\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;28minput\u001B[39m):\n\u001B[1;32m     85\u001B[0m     masked_weight \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mweight \u001B[38;5;241m*\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmask\n\u001B[0;32m---> 86\u001B[0m     masked_linear \u001B[38;5;241m=\u001B[39m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mmatmul\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43minput\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmasked_weight\u001B[49m\u001B[43m)\u001B[49m \u001B[38;5;241m+\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbias\n\u001B[1;32m     87\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m masked_linear\n",
      "\u001B[0;31mRuntimeError\u001B[0m: mat1 and mat2 shapes cannot be multiplied (10000x3 and 6x12)"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "code",
   "id": "30f62f5e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:23:36.379988Z",
     "start_time": "2024-10-20T15:23:36.357503Z"
    }
   },
   "source": [
    "\n",
    "df = pd.DataFrame(all_data, columns=var_names)\n",
    "data_dict = df.to_dict(orient='list')\n",
    "cause_var = 'X'\n",
    "effect_var = 'Y'\n",
    "effect_index = var_names.index(effect_var)\n",
    "ci = CausalInference(model=model, device=device)\n",
    "\n",
    "model.eval()\n",
    "ci = CausalInference(dag=DAG)\n",
    "D0 = ci.forward(data=all_data, model=model, intervention_nodes_vals=intervention_nodes_vals_0)\n",
    "D1 = ci.forward(data=all_data, model=model, intervention_nodes_vals=intervention_nodes_vals_1)\n",
    "\n",
    "est_ATE = (D1[:,effect_index] - D0[:,effect_index]).mean()\n",
    "print('True ATE X->M:', 0.5, 'Estimated:', est_ATE)"
   ],
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "__init__() got an unexpected keyword argument 'model'",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mTypeError\u001B[0m                                 Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[7], line 6\u001B[0m\n\u001B[1;32m      4\u001B[0m effect_var \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mY\u001B[39m\u001B[38;5;124m'\u001B[39m\n\u001B[1;32m      5\u001B[0m effect_index \u001B[38;5;241m=\u001B[39m var_names\u001B[38;5;241m.\u001B[39mindex(effect_var)\n\u001B[0;32m----> 6\u001B[0m ci \u001B[38;5;241m=\u001B[39m \u001B[43mCausalInference\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmodel\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mmodel\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdevice\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mdevice\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m      8\u001B[0m model\u001B[38;5;241m.\u001B[39meval()\n\u001B[1;32m      9\u001B[0m ci \u001B[38;5;241m=\u001B[39m CausalInference(dag\u001B[38;5;241m=\u001B[39mDAG)\n",
      "\u001B[0;31mTypeError\u001B[0m: __init__() got an unexpected keyword argument 'model'"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3abd2ebf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
