{
 "cells": [
  {
   "cell_type": "code",
   "id": "44ca21de",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:23:58.348597Z",
     "start_time": "2024-10-20T15:23:57.036194Z"
    }
   },
   "source": [
    "import networkx as nx\n",
    "import numpy as np\n",
    "import torch\n",
    "import utils\n",
    "from model import CFCN\n",
    "from utils.inference import CausalInference\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from datasets import reorder_dag, get_full_ordering\n"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "code",
   "id": "02898e95-c037-4289-97f1-2172b04046d6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:24:08.372593Z",
     "start_time": "2024-10-20T15:24:08.232626Z"
    }
   },
   "source": [
    "shuffling = 0\n",
    "seed = 1\n",
    "standardize = 0\n",
    "sample_size = 1000000\n",
    "batch_size = 50\n",
    "max_iters =  6000\n",
    "eval_interval = 100\n",
    "eval_iters = 100\n",
    "validation_fraction = 0.1\n",
    "np.random.seed(seed=seed)\n",
    "torch.manual_seed(seed)\n",
    "device = 'cuda'\n",
    "dropout_rate = 0.0\n",
    "learning_rate = 1e-3\n",
    "\n",
    "neurons_per_layer = [3,6, 6, 6, 3]\n",
    "\n",
    "def generate_data_mediation(N):\n",
    "    DAGnx = nx.DiGraph()\n",
    "    \n",
    "    Ux = np.random.randn(N)\n",
    "    X =  Ux\n",
    "\n",
    "    Um = np.random.randn(N)\n",
    "    M =  0.9 * X + Um    \n",
    "\n",
    "    Uy = np.random.randn(N)\n",
    "    Y =  0.6 * M + 1.2 * X + Uy\n",
    "\n",
    "    M0 = 0.9 * 0 + Um \n",
    "    M1 = 0.9 * 1 + Um\n",
    "\n",
    "    Y0 = 0.6 * M0 +  1.2 * 0 + Uy \n",
    "    Y1 = 0.6 * M1 +  1.2 * 1 + Uy \n",
    "\n",
    "    # X-> M = 0.9\n",
    "    # X-> Y = 1.2 \n",
    "    # M -> Y = 0.6 \n",
    "    # partial effect = 0.9*0.6 = .54\n",
    "    # total effect = .54 + 1.2 = 1.74\n",
    "\n",
    "    all_data_dict = {'X': X, 'M': M, 'Y': Y}\n",
    "\n",
    "    # types can be 'cat' (categorical) 'cont' (continuous) or 'bin' (binary)\n",
    "    var_types = {'X': 'cont', 'M': 'cont', 'Y': 'cont'}\n",
    "\n",
    "    DAGnx.add_edges_from([('X', 'M'), ('M', 'Y')])\n",
    "    DAGnx = reorder_dag(dag=DAGnx)  # topologically sorted dag\n",
    "    var_names = list(DAGnx.nodes())  # topologically ordered list of variables\n",
    "    all_data = np.stack([all_data_dict[key] for key in var_names], axis=1)\n",
    "    causal_ordering = get_full_ordering(DAGnx)\n",
    "    ordered_var_types = dict(sorted(var_types.items(), key=lambda item: causal_ordering[item[0]]))\n",
    "\n",
    "    return all_data, DAGnx, var_names, causal_ordering, ordered_var_types, Y0, Y1\n",
    "\n",
    "_, _, _, _, _, Y0, Y1 = generate_data_mediation(N=1000000)\n",
    "ATE = (Y1 - Y0).mean()  # ATE based off a large sample\n",
    "all_data, DAG, var_names, causal_ordering, var_types, Y0, Y1 = generate_data_mediation(N=sample_size)\n",
    "print(var_names, ATE)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['X', 'M', 'Y'] 1.740000000000002\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "id": "b79b49ae-b268-4914-b8bb-095ecf3bf607",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-20T15:24:09.528780Z",
     "start_time": "2024-10-20T15:24:09.525683Z"
    }
   },
   "source": [
    "\n",
    "def get_batch(train_data, val_data, split, device, batch_size):\n",
    "    data = train_data if split == 'train' else val_data\n",
    "    ix = torch.randint(0, len(data), (batch_size,))\n",
    "    x = data[ix]\n",
    "    return x.to(device)"
   ],
   "outputs": [],
   "execution_count": 4
  },
  {
   "cell_type": "code",
   "id": "bed0fbc4",
   "metadata": {
    "scrolled": true,
    "jupyter": {
     "is_executing": true
    },
    "ExecuteTime": {
     "start_time": "2024-10-20T15:24:37.207118Z"
    }
   },
   "source": [
    "for i in range(10):\n",
    "    all_data, DAG, var_names, causal_ordering, var_types, Y0, Y1 = generate_data_mediation(N=sample_size)\n",
    "    \n",
    "    input_dim = all_data.shape[1]\n",
    "    \n",
    "    indices = np.arange(0, len(all_data))\n",
    "    np.random.shuffle(indices)\n",
    "    \n",
    "    val_inds = indices[:int(validation_fraction*len(indices))]\n",
    "    train_inds = indices[int(validation_fraction*len(indices)):]\n",
    "    train_data = all_data[train_inds]\n",
    "    val_data = all_data[val_inds]\n",
    "    \n",
    "    train_data, val_data = torch.from_numpy(train_data).float(),  torch.from_numpy(val_data).float()\n",
    "    \n",
    "    model = CFCN(neurons_per_layer=neurons_per_layer, dag=DAG, causal_ordering=causal_ordering, var_types=var_types, dropout_rate=dropout_rate).to(device)\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
    "    \n",
    "    \n",
    "    all_var_losses = {}\n",
    "    for iter_ in range(0, max_iters):\n",
    "        # train and update the model\n",
    "        model.train()\n",
    "    \n",
    "        xb = get_batch(train_data=train_data, val_data=val_data, split='train', device=device, batch_size=batch_size)\n",
    "        xb_mod = torch.clone(xb.detach())\n",
    "        X, loss, loss_dict = model(X=xb, targets=xb_mod, shuffling=shuffling)\n",
    "    \n",
    "        optimizer.zero_grad(set_to_none=True)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    \n",
    "        if iter_ % eval_interval == 0:  # evaluate the loss (no gradients)\n",
    "            for key in loss_dict.keys():\n",
    "                if key not in all_var_losses.keys():\n",
    "                    all_var_losses[key] = []\n",
    "                all_var_losses[key].append(loss_dict[key])\n",
    "    \n",
    "            model.eval()\n",
    "            eval_loss = {}\n",
    "            for split in ['train', 'val']:\n",
    "                losses = torch.zeros(eval_iters)\n",
    "                for k in range(eval_iters):\n",
    "    \n",
    "                    xb = get_batch(train_data=train_data, val_data=val_data, split=split, device=device,\n",
    "                                   batch_size=batch_size)\n",
    "                    xb_mod = torch.clone(xb.detach())\n",
    "                    X, loss, loss_dict = model(X=xb, targets=xb_mod, shuffling=False)\n",
    "                    losses[k] = loss.item()\n",
    "                eval_loss[split] = losses.mean()\n",
    "            print(f\"step {iter_} of {max_iters}: train_loss {eval_loss['train']:.4f}, val loss {eval_loss['val']:.4f}\")\n",
    "    \n",
    "    \n",
    "    df = pd.DataFrame(all_data, columns=var_names)\n",
    "    data_dict = df.to_dict(orient='list')\n",
    "    cause_var = 'X'\n",
    "    effect_var = 'Y'\n",
    "    effect_index = var_names.index(effect_var)\n",
    "    ci = CausalInference(dag=DAG)\n",
    "    \n",
    "    model.eval()\n",
    "    intervention_nodes_vals_0 = {'X': 0}\n",
    "    intervention_nodes_vals_1 = {'X': 1}\n",
    "    D0 = ci.forward(data=all_data, model=model, intervention_nodes_vals=intervention_nodes_vals_0)\n",
    "    D1 = ci.forward(data=all_data, model=model, intervention_nodes_vals=intervention_nodes_vals_1)\n",
    "    \n",
    "    \n",
    "    est_ATE = (D1[:,effect_index] - D0[:,effect_index]).mean()\n",
    "    print(est_ATE)"
   ],
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matthewvowels/GitHub/Causal_Transformer/utils/inference.py:18: UserWarning: No mask has been specified. If padding has been used, the absence of a mask may lead to incorrect results.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8558830618858337\n"
     ]
    }
   ],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "539b528e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
