{
 "cells": [
  {
   "cell_type": "code",
   "id": "ecbb2c6e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:12:35.793959Z",
     "start_time": "2024-10-20T15:12:34.261407Z"
    }
   },
   "source": [
    "import networkx as nx\n",
    "import numpy as np\n",
    "import torch\n",
    "import utils\n",
    "from model import CFCN\n",
    "from utils.inference import CausalInference\n",
    "import pandas as pd\n",
    "from datasets import reorder_dag, get_full_ordering"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "code",
   "id": "c8e703ca9c18a267",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:20:12.307385Z",
     "start_time": "2024-10-20T15:20:12.303901Z"
    }
   },
   "source": [
    "shuffling = 1\n",
    "seed = 2\n",
    "standardize = 0\n",
    "sample_size = 100000\n",
    "batch_size = 50\n",
    "max_iters =  3000\n",
    "eval_interval = 1000\n",
    "eval_iters = 100\n",
    "validation_fraction = 0.1\n",
    "np.random.seed(seed=seed)\n",
    "torch.manual_seed(seed)\n",
    "device = 'cuda'\n",
    "dropout_rate = 0.0\n",
    "learning_rate = 1e-3\n",
    "\n",
    "neurons_per_layer = [3,6, 6, 3]\n"
   ],
   "outputs": [],
   "execution_count": 12
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:20:12.841552Z",
     "start_time": "2024-10-20T15:20:12.836106Z"
    }
   },
   "cell_type": "code",
   "source": [
    "\n",
    "\n",
    "def generate_data(N):\n",
    "    DAGnx = nx.DiGraph()\n",
    "    Uc = np.random.randn(N)\n",
    "    C =  Uc    \n",
    "    Ux = np.random.randn(N)\n",
    "    X =  1 * C + Ux\n",
    "    \n",
    "    Uy = np.random.randn(N)\n",
    "    Y = 0.8 * X  + 1.5 * C + Uy\n",
    "\n",
    "    Y0 = 0.8 * 0 + 1.5 * C + Uy\n",
    "    Y1 = 0.8 * 1  + 1.5 * C +  Uy\n",
    "\n",
    "    all_data_dict = {'X': X,  'C': C, 'Y': Y}\n",
    "\n",
    "    # types can be 'cat' (categorical) 'cont' (continuous) or 'bin' (binary)\n",
    "    var_types = {'X': 'cont',  'C': 'cont', 'Y': 'cont'}\n",
    "\n",
    "    DAGnx.add_edges_from([('X', 'Y'), ('C', 'X'), ('C', 'Y')])\n",
    "    DAGnx = reorder_dag(dag=DAGnx)  # topologically sorted dag\n",
    "    var_names = list(DAGnx.nodes())  # topologically ordered list of variables\n",
    "    all_data = np.stack([all_data_dict[key] for key in var_names], axis=1)\n",
    "    causal_ordering = get_full_ordering(DAGnx)\n",
    "    ordered_var_types = dict(sorted(var_types.items(), key=lambda item: causal_ordering[item[0]]))\n",
    "\n",
    "    return all_data, DAGnx, var_names, causal_ordering, ordered_var_types, Y0, Y1"
   ],
   "id": "6e03b9da",
   "outputs": [],
   "execution_count": 13
  },
  {
   "cell_type": "markdown",
   "id": "de91c75d",
   "metadata": {},
   "source": [
    "## Confounding Example"
   ]
  },
  {
   "cell_type": "code",
   "id": "d86361ab",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:20:14.790340Z",
     "start_time": "2024-10-20T15:20:14.718321Z"
    }
   },
   "source": [
    "_, _, _, _, _, Y0, Y1 = generate_data(N=1000000)\n",
    "ATE = (Y1 - Y0).mean()  # ATE based off a large sample\n"
   ],
   "outputs": [],
   "execution_count": 14
  },
  {
   "cell_type": "code",
   "id": "5d8c3af1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:20:15.011155Z",
     "start_time": "2024-10-20T15:20:15.007566Z"
    }
   },
   "source": [
    "def get_batch(train_data, val_data, split, device, batch_size):\n",
    "    data = train_data if split == 'train' else val_data\n",
    "    ix = torch.randint(0, len(data), (batch_size,))\n",
    "    x = data[ix]\n",
    "    return x.to(device)"
   ],
   "outputs": [],
   "execution_count": 15
  },
  {
   "cell_type": "code",
   "id": "a01274ff",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:20:27.831263Z",
     "start_time": "2024-10-20T15:20:15.295904Z"
    }
   },
   "source": [
    "print('True ATE:', ATE)\n",
    "\n",
    "for i in range(1):\n",
    "\n",
    "    all_data, DAG, var_names, causal_ordering, var_types, Y0, Y1 = generate_data(N=sample_size)\n",
    "    \n",
    "    input_dim = all_data.shape[1]\n",
    "    # prepend the input size to neurons_per_layer if not included in neurons_per_layer\n",
    "    # append the intput size to neurons_per_layer (output) if not included in neurons_per_layer\n",
    "    neurons_per_layer.insert(0, input_dim)\n",
    "    neurons_per_layer.append(input_dim)\n",
    "    \n",
    "    indices = np.arange(0, len(all_data))\n",
    "    np.random.shuffle(indices)\n",
    "    \n",
    "    val_inds = indices[:int(validation_fraction*len(indices))]\n",
    "    train_inds = indices[int(validation_fraction*len(indices)):]\n",
    "    train_data = all_data[train_inds]\n",
    "    val_data = all_data[val_inds]\n",
    "    \n",
    "    train_data, val_data = torch.from_numpy(train_data).float(),  torch.from_numpy(val_data).float()\n",
    "    \n",
    "    model = CFCN(neurons_per_layer=neurons_per_layer, dag=DAG, causal_ordering=causal_ordering, var_types=var_types, dropout_rate=dropout_rate).to(device)\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
    "    \n",
    "    all_var_losses = {}\n",
    "    for iter_ in range(0, max_iters):\n",
    "        # train and update the model\n",
    "        model.train()\n",
    "    \n",
    "        xb = get_batch(train_data=train_data, val_data=val_data, split='train', device=device, batch_size=batch_size)\n",
    "        xb_mod = torch.clone(xb.detach())\n",
    "        X, loss, loss_dict = model(X=xb, targets=xb_mod, shuffling=shuffling)\n",
    "    \n",
    "        optimizer.zero_grad(set_to_none=True)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    \n",
    "        if iter_ % eval_interval == 0:  # evaluate the loss (no gradients)\n",
    "            for key in loss_dict.keys():\n",
    "                if key not in all_var_losses.keys():\n",
    "                    all_var_losses[key] = []\n",
    "                all_var_losses[key].append(loss_dict[key])\n",
    "    \n",
    "            model.eval()\n",
    "            eval_loss = {}\n",
    "            for split in ['train', 'val']:\n",
    "                losses = torch.zeros(eval_iters)\n",
    "                for k in range(eval_iters):\n",
    "    \n",
    "                    xb = get_batch(train_data=train_data, val_data=val_data, split=split, device=device,\n",
    "                                   batch_size=batch_size)\n",
    "                    xb_mod = torch.clone(xb.detach())\n",
    "                    X, loss, loss_dict = model(X=xb, targets=xb_mod, shuffling=False)\n",
    "                    losses[k] = loss.item()\n",
    "                eval_loss[split] = losses.mean()\n",
    "            print(f\"step {iter_} of {max_iters}: train_loss {eval_loss['train']:.4f}, val loss {eval_loss['val']:.4f}\")\n",
    "    \n",
    "    \n",
    "    df = pd.DataFrame(all_data, columns=var_names)\n",
    "    data_dict = df.to_dict(orient='list')\n",
    "    cause_var = 'X'\n",
    "    effect_var = 'Y'\n",
    "    effect_index = var_names.index(effect_var)\n",
    "        \n",
    "    model.eval()\n",
    "    intervention_nodes_vals_0 = {'X': 0}\n",
    "    intervention_nodes_vals_1 = {'X': 1}\n",
    "    ci = CausalInference(dag=DAG)\n",
    "    D0 = ci.forward(data=all_data, model=model, intervention_nodes_vals=intervention_nodes_vals_0)\n",
    "    D1 = ci.forward(data=all_data, model=model, intervention_nodes_vals=intervention_nodes_vals_1)\n",
    "    \n",
    "    effect_var = 'Y'\n",
    "    effect_index = var_names.index(effect_var)\n",
    "    \n",
    "    est_ATE = (D1[:,effect_index] - D0[:,effect_index]).mean()\n",
    "    print(est_ATE)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True ATE: 0.7999999999999984\n",
      "step 0 of 3000: train_loss 8.7888, val loss 8.7386\n",
      "step 1000 of 3000: train_loss 5.1421, val loss 5.2046\n",
      "step 2000 of 3000: train_loss 4.4889, val loss 4.5251\n",
      "0.8009571178808809\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matthewvowels/GitHub/Causal_Transformer/utils/inference.py:18: UserWarning: No mask has been specified. If padding has been used, the absence of a mask may lead to incorrect results.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "execution_count": 16
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8a4eb16",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0df5ebcf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
