{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "LmbWsxo_dM8O"
      },
      "outputs": [],
      "source": [
        "from copy import deepcopy\n",
        "\n",
        "import numpy as np\n",
        "np.random.seed(seed=0)\n",
        "import pandas as pd\n",
        "import networkx as nx\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.nn.functional import mse_loss\n",
        "import torch.autograd as autograd\n",
        "from torch.autograd import Variable\n",
        "\n",
        "import torch.cuda as cutorch"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def centering(K):\n",
        "  n = K.shape[0]\n",
        "  unit = torch.ones([n, n]).to(K.device)\n",
        "  I = torch.eye(n).to(K.device)\n",
        "  Q = I - unit/n\n",
        "\n",
        "  return torch.mm(torch.mm(Q, K), Q)\n",
        "\n",
        "def rbf(X, sigma=None):\n",
        "  GX = torch.mm(X, X.T).to(X.device)\n",
        "  KX = torch.diag(GX) - GX + (torch.diag(GX) - GX).T\n",
        "  if sigma is None:\n",
        "    mdist = torch.median(KX[KX != 0])\n",
        "    sigma = torch.sqrt(mdist)\n",
        "    KX *= - 0.5 / sigma / sigma\n",
        "    # torch.exp(KX, KX)\n",
        "    KX = torch.exp(KX)\n",
        "  return KX\n",
        "\n",
        "def HSIC(X, Y):\n",
        "  centered = centering(rbf(X))*centering(rbf(Y))\n",
        "  w = 1. - torch.eye(centered.shape[0]).to(centered.device)\n",
        "  centered = torch.multiply(w, centered)\n",
        "  return (1./np.sqrt(X.shape[0]*Y.shape[0]))*torch.sum(centered)"
      ],
      "metadata": {
        "id": "2sG38KsKIh_w"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zpuFnAGbFxQy"
      },
      "source": [
        "## Notes\n",
        "\n",
        "* phi independence"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_2wX7W6llcbu"
      },
      "source": [
        "# Data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Vi_oLUxzkcdL"
      },
      "source": [
        "## Data Helpers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "6oJahGZBxJCo"
      },
      "outputs": [],
      "source": [
        "# From https://github.com/xunzheng/notears/blob/master/notears/utils.py\n",
        "from scipy.special import expit as sigmoid\n",
        "\n",
        "def simulate_linear_sem(W, n, sem_type, noise_scale, y_noise_scale=0.25, estimand=1):\n",
        "    \"\"\"\n",
        "    Simulate samples from linear SEM with specified type of noise.\n",
        "    For uniform, noise z ~ uniform(-a, a), where a = noise_scale.\n",
        "    Parameters\n",
        "    ----------\n",
        "    W: np.ndarray\n",
        "        [d, d] weighted adj matrix of DAG.\n",
        "    n: int\n",
        "        Number of samples, n=inf mimics population risk.\n",
        "    sem_type: str \n",
        "        gauss, exp, gumbel, uniform, logistic.\n",
        "    noise_scale: float \n",
        "        Scale parameter of noise distribution in linear SEM.\n",
        "    \n",
        "    Return\n",
        "    ------\n",
        "    X: np.ndarray\n",
        "        [n, d] sample matrix, [d, d] if n=inf\n",
        "    \"\"\"\n",
        "    def _simulate_single_equation(X, w, scale):\n",
        "        \"\"\"X: [n, num of parents], w: [num of parents], x: [n]\"\"\"\n",
        "        if sem_type == 'gauss':\n",
        "            z = np.random.normal(scale=scale, size=n)\n",
        "            x = X @ w + z\n",
        "        elif sem_type == 'exp':\n",
        "            z = np.random.exponential(scale=scale, size=n)\n",
        "            x = X @ w + z\n",
        "        elif sem_type == 'gumbel':\n",
        "            z = np.random.gumbel(scale=scale, size=n)\n",
        "            x = X @ w + z\n",
        "        elif sem_type == 'uniform':\n",
        "            z = np.random.uniform(low=-scale, high=scale, size=n)\n",
        "            x = X @ w + z\n",
        "        elif sem_type == 'logistic':\n",
        "            x = np.random.binomial(1, sigmoid(X @ w)) * 1.0\n",
        "        else:\n",
        "            raise ValueError('Unknown sem type. In a linear model, \\\n",
        "                              the options are as follows: gauss, exp, \\\n",
        "                              gumbel, uniform, logistic.')\n",
        "        return x\n",
        "\n",
        "    d = W.shape[0]\n",
        "    if noise_scale is None:\n",
        "        scale_vec = np.ones(d)\n",
        "    elif np.isscalar(noise_scale):\n",
        "        scale_vec = noise_scale * np.ones(d)\n",
        "    else:\n",
        "        if len(noise_scale) != d:\n",
        "            raise ValueError('noise scale must be a scalar or has length d')\n",
        "        scale_vec = noise_scale\n",
        "    G_nx =  nx.from_numpy_matrix(W, create_using=nx.DiGraph)\n",
        "    if not nx.is_directed_acyclic_graph(G_nx):\n",
        "        raise ValueError('W must be a DAG')\n",
        "    if np.isinf(n):  # population risk for linear gauss SEM\n",
        "        if sem_type == 'gauss':\n",
        "            # make 1/d X'X = true cov\n",
        "            X = np.sqrt(d) * np.diag(scale_vec) @ np.linalg.inv(np.eye(d) - W)\n",
        "            return X\n",
        "        else:\n",
        "            raise ValueError('population risk not available')\n",
        "    # empirical risk\n",
        "    ordered_vertices = list(nx.topological_sort(G_nx))\n",
        "    assert len(ordered_vertices) == d\n",
        "    X = np.zeros([n, d])\n",
        "    for j in ordered_vertices:\n",
        "        parents = list(G_nx.predecessors(j))\n",
        "        if j == estimand:\n",
        "            X[:, j] = _simulate_single_equation(X[:, parents], W[parents, j], y_noise_scale)\n",
        "        else:\n",
        "            X[:, j] = _simulate_single_equation(X[:, parents], W[parents, j], scale_vec[j])\n",
        "    return (torch.tensor(np.delete(X,estimand, 1)).float(), torch.tensor(X[:,estimand]).float())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eWqRKVxAlfuk"
      },
      "source": [
        "## Graph Structure"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 319
        },
        "id": "Y5ZAOUuRdRHH",
        "outputId": "6da25b5b-fda9-45db-be0f-7a4851b63c96"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAb4AAAEuCAYAAADx63eqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deVxU9f4/8BcwwICIuCCo4MoqqTEgorigWZhkVhoDitvXJa/a4pKCqKGIyNW6kUbmcskkFbUscc0FzA0FwZ1NzRRNA5SQlGVgfn909ZcLiLJ8Zua8nn8OZw6v7vXBa96fOed89NRqtRpEREQSoS86ABERUX1i8RERkaSw+IiISFJYfEREJCksPiIikhQWHxERSQqLj4iIJIXFR0REksLiIyIiSWHxERGRpLD4iIhIUlh8REQkKSw+IiKSFBYfERFJCouPiIgkhcVHRESSwuIjIiJJYfEREZGksPiIiEhSWHxERCQpLD4iIpIUFh8REUmKTHQAIiLSPnlFJdhyMgcZNwtRWKyCuVwGJ2tzvOtmg6ZmxqLjVUlPrVarRYcgIiLtcPpaAb5MvIiDWbkAgBJVxcOfyWX6UAPwdrTEpD526GJrIShl1Vh8RERULbFJVxC+MwPFqnJU1Rx6eoBcZoCQgU4I9Gxbb/mqi0udRET0TH+XXjrul1U881i1GrhfVo7wnekAoHHlx4mPiIiqdPpaAfxXJeF+Wfkjr5ffv4v8nVEovpIGfRNzNO4zCg1cvB85xsTQAHETPNHZRnOWPXlVJxERVenLxIsoVpU/8frtn7+CnoEhbN6PRbNBM5D/czRKc3975JhiVTmiEy/WV9RqYfEREVGl8opKcDAr94nv9CpKi3Ev8ygsegdC38gEclsXmNp1w1/nEx45Tq0GEjJzkV9UUo+pq8biIyKiSm05mfPU11W3r0NP3wCGTVo9fM2weTuUPTbxAYAegC2pTz+PCCw+IiKqVMbNwkduWXigouw+9IxNHnlN39gUFaX3nzi2WFWBjN/v1lnG58XiIyKiShUWq576ur6hCdQlj5acuuQe9I1Mnnp8YXFZrWd7USw+IiKqlLn86Xe9yZq0grqiHGW3rz98rfSPX2Fo2aaS8xjWSb4XweIjIqJKOVmbw1j2ZFXoG8lh6tgdBYe+Q0VpMYpzLuDexeNo4NL3iWPlMn04tWhYH3GrhcVHRESVGupmU+nPmrw2CWpVKXKWDUfetiVo+tokGD1l4lMDGKqo/Dz1jU9uISKiSjUzM0YfB0vsTb/1xC0NBiYN0XzInCrfr6cH9HW01KgHV3PiIyKiKk32toNcZvBC75XLDDDJ266WE9UMi4+IiKrUxdYCIQOdYGL4fJVhYqiPkIFOGvW4MoDFR0RE1RDo2RbTX2kPqEqg94xj9fT+fkZnyEBnjXtANcDv+IiIqJpuHIxDl7w8tHhlJBIyc6GHv29Of+DBfnx9HS0xydtO4ya9B7g7AxERPVNubi6cnZ1x4sQJtG/fHvlFJdiSmoOM3++isLgM5nJDOLVoiKEK7sBOREQ6YOrUqVCpVFi2bJnoKDXG4iMioir99ttvUCgUuHDhAqysrETHqTEWHxERVWn06NFo3bo1FixYIDpKreDFLUREVKlz585h165dyMrKEh2l1nDiIyKiSr311lvo3bs3pk2bJjpKreHER0RET3X06FGkpaVh48aNoqPUKt7ATkRET1Cr1QgKCkJoaCjkcrnoOLWKxUdERE/YvXs38vLyMGLECNFRah2Lj4iIHlFRUYHg4GCEh4dDJtO9b8RYfERE9IiNGzdCLpfjrbfeEh2lTvCqTiIieqi0tBTOzs5Ys2YNvL29RcepE5z4iIjooVWrVsHe3l5nSw/gxEdERP9TVFQEe3t77Ny5E66urqLj1BlOfEREBACIioqCt7e3TpcewImPiIgA5Ofnw9HREUlJSbCzsxMdp06x+IiICDNmzMC9e/cQHR0tOkqdY/EREUnctWvX8PLLL+PcuXNo0aKF6Dh1jsVHRCRxY8eOhZWVFRYtWiQ6Sr3QvVvyiYio2tLT0xEfH69T2w49C6/qJCKSsDlz5uDjjz+GhYWF6Cj1hkudREQSdfz4cQwdOhRZWVkwMTERHafecOIjIpKgB9sOffLJJ5IqPYDFR0QkSXv37sWNGzcwevRo0VHqHYuPiEhiKioqEBQUpLPbDj0Li4+ISGI2b94MAwMDDBkyRHQUIXhxCxGRhJSVlaFjx45YsWIFXnnlFdFxhODER0QkIWvWrEHbtm0lW3oAJz4iIsm4d+8e7O3t8dNPP8Hd3V10HGE48RERScQXX3wBLy8vSZcewImPiEgSbt++DUdHRxw5cgQODg6i4wjFiY+ISAIiIyPxzjvvSL70AE58REQ67/r16+jcuTPOnDmDVq1aiY4jHIuPiEjHTZgwAY0bN0ZkZKToKBqBxUdEpMMyMzPRs2dPZGVloXHjxqLjaAR+x0dEpMPmzp2L6dOns/T+gRMfEZGOSklJweDBg5GdnQ1TU1PRcTQGJz4iIh0VHByMefPmsfQew+IjItJB+/btw5UrV/B///d/oqNoHBYfEZGOUavVCA4OxsKFC2FoaCg6jsZh8RER6Zjvv/8eFRUVePfdd0VH0Ui8uIWISIeoVCq4uLhg2bJleO2110TH0Uic+IiIdEhMTAxatWqFV199VXQUjcWJj4hIR9y/fx/29vb44Ycf4OHhITqOxuLER0SkI5YvX45u3bqx9J6BEx8RkQ4oKCiAvb09fvnlFzg7O4uOo9E48RER6YB///vfePPNN1l61cCJj4hIy/3+++946aWXcOrUKdja2oqOo/FYfEREWu5f//oXGjRogKVLl4qOohW0pvjyikqw5WQOMm4WorBYBXO5DE7W5njXzQZNzYxFxyMiEiI7Oxvdu3dHZmYmmjZtKjqOVtD44jt9rQBfJl7EwaxcAECJquLhz+QyfagBeDtaYlIfO3SxtRCUkohIjICAALz00ksICQkRHUVraHTxxSZdQfjODBSrylFVSj09QC4zQMhAJwR6tq23fEREIqWmpsLX1xfZ2dkwMzMTHUdraGzx/V166bhfVvHsg//HxFAfIQOdWX5EJAkDBgzAoEGDMHnyZNFRtIpGFt/pawXwX5WE+2XlD19Tq8qQ/3M0iq+cQkVxEWQW1mjcZxRMOrg/8l4TQwPETfBEZxsuexKR7kpISMC4ceOQnp4OIyMj0XG0ikbex/dl4kUUq8ofeU1dUQ5Zw2awHrYYtlPjYNF7BHJ/ioSq4NYjxxWryhGdeLE+4xIR1Su1Wo2goCCEhYWx9F6AxhVfXlEJDmblPvGdnr6RHBa9hkNmYQU9PX2Y2nlA1sgKJTcfLTm1GkjIzEV+UUk9piYiqj8//vgjSkpK4O/vLzqKVtK44ttyMqdax5X/dQdlt6/DyLL1Ez/TA7AltXrnISLSJiqVCiEhIYiIiIC+vsb9CdcKGve/WsbNwkduWXgadbkKeduWwqzTKzBs+uRTCopVFcj4/W5dRSQiEubbb7+FpaUlBgwYIDqK1pKJDvC4wmJVlT9XqyuQt/1TwECGJq9OrOI8ZbUdjYhIqOLiYoSGhiIuLg56enqi42gtjZv4zOWVd7FarUb+zi9Q/lcBLN+eDT2Dyo81lxvWRTwiImGio6OhUCjQvXt30VG0msZNfE7W5jCW3XzqcuftPV+iLP8arPwXQt+w8seUyWX6cGrRsC5jEhHVqz///BOLFy9GQkKC6ChaT+Pu48srKoFX5IEnik/15x+4/tX/AQaG0NM3ePh6kwGTYebS95FjjWX6ODqrH5/hSUQ6Y+7cubh27Rq++eYb0VG0nsYVHwBMWJeCvem3qnxMWWX09ACfjlZYEej+7IOJiLTArVu30LFjR6SmpqJNmzai42g9jfuODwAme9tBLjN49oFPoVaVouDoJsTFxeHChQtQqaq+WIaISNOFhYVh5MiRLL1aopETH/Diz+pUpWzGxV0xkMvlkMlkKC4uxrZt2/D666/XYVoiorpx+fJldO3aFRkZGbC0tBQdRydobPEBL7Y7Q2fTQnh4eKCk5O8nt7Ro0QKZmZlo2JAXuxCR9gkMDISDgwPmzZsnOorO0OjiA4AzOQWITryIhMxc6OHvm9MfeLAfX19HS0zytnv4YGpfX1/s2rUL+vr6aNOmDeLj49GxY0cx/wFERC/o9OnT8PHxQXZ2Nj+81yKNL74H8otKsCU1Bxm/30VhcRnM5YZwatEQQxVP7sB+/vx5dOrUCZ9//jnMzMwwa9YsREZGYsyYMbzpk4i0hq+vL3x8fPDBBx+IjqJTtKb4nte5c+fg4uICPT09nD9/Hn5+fnB1dcVXX33FT05EpPF++eUXjBo1ChkZGTA25q1ZtUkjr+qsDS+99NLD6c7FxQXJycmQy+Vwc3PDqVOnBKcjIqrcg22HFixYwNKrAzpbfI8zNTXF6tWr8cknn+DVV19FdHQ0dHTYJSItFx8fj7t372LYsGGio+gknV3qrEpWVhaUSiU6dOiA1atXw8KCu7UTkWYoLy9Hly5dsHjxYrzxxhui4+gkyUx8/+Tg4IBjx47B2toaCoUCJ06cEB2JiAgAEBsbCwsLC/j6+oqOorMkOfH90w8//ICJEyciKCgIU6dO5VWfRCRMSUkJHB0dERsbi549e4qOo7MkX3wA8Ouvv8Lf3x/NmzfHN998g6ZNm4qOREQSFBUVhX379iE+Pl50FJ0myaXOx7Vr1w6HDh2Ck5MTXF1dcfjwYdGRiEhi7t69i4iICISHh4uOovM48T1mx44dGDt2LN5//30EBQXBwODFHpZNRPQ8QkNDcenSJaxbt050FJ3H4nuKnJwcDBs2DMbGxli3bh2sra1FRyIiHfbHH3/A2dkZKSkpaNeuneg4Oo9LnU9hY2ODAwcOwNPTEwqFAvv27RMdiYh02KJFizB8+HCWXj3hxPcM+/btw8iRIzF27Fh88sknkMlkoiMRkQ65cuUK3NzccOHCBVhZWYmOIwksvmq4efMmRowYgZKSEqxfvx42NjaiIxGRjhg1ahTatm2L+fPni44iGVzqrAZra2vs2bMHAwYMgLu7O3bs2CE6EhHpgHPnzmH37t2YPn266CiSwonvOR06dAjDhw+HUqlEeHg4jIyMREciIi315ptvom/fvpg6daroKJLC4nsBeXl5GD16NPLy8rBx40a0bdtWdCQi0jJHjhzBsGHDkJmZCblcLjqOpHCp8wU0a9YM8fHx8PPzg4eHB77//nvRkYhIizzYdmj+/PksPQE48dXQiRMn4O/vj4EDB2Lp0qX8R0xEz7Rjxw7MnDkTZ86c4UMyBODEV0MeHh5ITU3FzZs30b17d2RlZYmOREQarKKiAsHBwQgPD2fpCcLiqwUWFhbYvHkzxo8fDy8vL6xfv150JCLSUBs2bECDBg0wePBg0VEki0udtezUqVNQKpXo1asXvvjiC5iamoqOREQaorS0FE5OToiJiUGfPn1Ex5EsTny17OWXX0ZKSgqKi4vRtWtXnD9/XnQkItIQK1euhKOjI0tPME58dUStVuObb77BzJkzERkZiTFjxnCTWyIJKyoqgr29PXbt2oWXX35ZdBxJY/HVsQsXLsDPzw9dunTBihUr0LBhQ9GRiEiAsLAwpKen8xoADcClzjrWsWNHnDhxAqampnBzc0NaWproSERUz/Ly8hAVFYUFCxaIjkJg8dULU1NTrFq1CqGhoXjttdfw5ZdfgoM2kXRERERAqVTCzs5OdBQClzrrXXZ2NpRKJdq1a4c1a9bAwsJCdCQiqkNXr16Fq6srzp07hxYtWoiOQ+DEV+/s7e1x7NgxtGrVCq6urjh+/LjoSERUh0JDQzFx4kSWngbhxCfQ1q1b8d5772HmzJmYNm0a9PX5OYRIl1y4cAHe3t7Izs5Go0aNRMeh/2HxCXblyhX4+/ujadOmWLt2LZo1ayY6EhHVkrfffhteXl6YMWOG6Cj0DxwxBGvbti0OHToEFxcXuLq64pdffhEdiYhqQVJSElJSUjB58mTRUegxnPg0yK5duzBmzBhMmTIFwcHBfIAtkZZSq9Xo27cvRowYgbFjx4qOQ49h8WmY69evY9iwYTA0NERsbCysra1FRyKi57R792589NFHOHfuHGQymeg49BgudWqYVq1aYf/+/fDy8oJCocDevXtFRyKi5/DPbYdYepqJxaeBZDIZ5s+fj3Xr1mH06NEICQmBSqUSHYuIqmHTpk0wNDTEO++8IzoKVYJLnRru1q1bGDFiBO7fv4/169fD1tZWdCQiqkRpaSk6duyIlStXol+/fqLjUCU48Wk4Kysr7N69G76+vnB3d8f27dtFRyKiSqxZswbt27dn6Wk4Tnxa5MiRIwgICMC7776LiIgIGBkZiY5ERP/z119/wd7eHvHx8XBzcxMdh6rAiU+LeHl5IS0tDdnZ2ejZsyd+/fVX0ZGI6H+ioqLQq1cvlp4W4MSnhdRqNaKiorBo0SJER0dj6NChoiMRSVp+fj4cHR1x9OhRODg4iI5Dz8Di02LJyclQKpV4/fXX8emnn0Iul4uORCRJM2fORGFhIVasWCE6ClUDi0/L/fnnnxg3bhyys7OxadMmftokqmc5OTno0qULzp49i5YtW4qOQ9XA7/i0XKNGjbBp0yZMnDgRXl5eiI2NFR2JSFLmz5+P8ePHs/S0CCc+HXL69GkolUr06NEDy5YtQ4MGDURHItJpGRkZ6NWrF7KystC4cWPRcaiaOPHpkC5duiAlJQUqlQpdu3bFuXPnREci0mlz5szBjBkzWHpahsWnY8zMzPDtt99i1qxZ6Nu3L1avXg0O9US1Lzk5GceOHcP7778vOgo9Jy516rD09HT4+fmhU6dOWLFiBczNzUVHItIZ/fv3h5+fHyZMmCA6Cj0nTnw6zNnZGSdOnEDDhg3h5uaG1NRU0ZGIdMLevXtx9epVjBkzRnQUegEsPh1nYmKCr7/+GmFhYfDx8cHy5cu59ElUAw+2HVq4cCEMDQ1Fx6EXwOKTCH9/fxw7dgwxMTEYMmQI7ty5IzoSkVb6/vvvAYBPTNJiLD4JsbOzw9GjR2FrawtXV1ckJSWJjkSkVcrKyhASEoKIiAjo6/PPp7bi/3MSY2xsjKioKHz++ecYPHgwlixZgoqKCtGxiLRCTEwMbG1t0b9/f9FRqAZ4VaeE/fbbbwgICICFhQXWrl0LS0tL0ZGINNa9e/dgb2+PrVu3wsPDQ3QcqgFOfBLWpk0bHDx4EJ07d4ZCocDBgwdFRyLSWMuXL0f37t1ZejqAEx8BAHbv3o0xY8Zg0qRJmD17NgwMDERHItIYd+7cgYODAw4dOgQnJyfRcaiGWHz00I0bNzBs2DAYGBggNjYWLVq0EB2JSCMEBwcjNzcXq1evFh2FagGXOumhli1bYv/+/Q93kd67d6/oSETC3bhxAytXrkRoaKjoKFRLOPHRUyUkJGDEiBEYNWoU5s+fD5lMJjoSkRATJ05Ew4YNsWTJEtFRqJaw+KhSf/zxB0aOHImioiJs2LABtra2oiMR1avs7Gx0794dmZmZaNq0qeg4VEu41EmVat68OXbu3IlBgwbB3d0d8fHxoiMR1au5c+di2rRpLD0dw4mPquXo0aMICAjAkCFDsHjxYhgZGYmORFSnTp48iUGDBiE7O5ubOusYTnxULT169EBaWhouXboELy8vXL58WXQkojo1e/ZszJkzh6Wng1h8VG1NmjTBjz/+iMDAQHh6emLz5s2iIxHViQMHDuDSpUsYP3686ChUB7jUSS8kJSUFSqUSPj4++OyzzyCXy0VHIqoVarUanp6e+OijjxAQECA6DtUBTnz0Qtzd3ZGamor8/Hx069YNmZmZoiMR1YqtW7eitLQUSqVSdBSqIyw+emGNGjXCxo0bMWnSJPTs2RPr1q0THYmoRlQqFbcdkgAudVKtOHPmDJRKJTw9PbF8+XJeEEBaac2aNVi3bh0SEhKgp6cnOg7VEX6koVrRuXNnJCcnQ61Ww93dHWfPnhUdiei53L9/H6GhoVi8eDFLT8ex+KjWmJmZ4ZtvvkFQUBD69euHVatWgQsKpC2io6Ph7u4OT09P0VGojnGpk+pERkYG/Pz84OLigq+//hrm5uaiIxFV6s8//4S9vT0SEhLg4uIiOg7VMU58VCecnJxw/PhxNGrUCAqFAidPnhQdiahSS5Ysga+vL0tPIjjxUZ2Li4vDlClTMHfuXLz//vv8/oQ0ys2bN+Hi4oK0tDS0bt1adByqByw+qheXLl2CUqmEra0t1qxZgyZNmoiORAQAmDJlCoyMjPDZZ5+JjkL1hEudVC86dOiAI0eOoG3btlAoFDh27JjoSES4dOkSNm7ciNmzZ4uOQvWIEx/Vu23btmH8+PGYPn06ZsyYwRuFSZjhw4fDyckJc+fOFR2F6hGLj4S4evUqAgICYG5ujrVr16J58+aiI5HEnDp1CgMGDMDFixdhZmYmOg7VI37UJiFat26NxMREvPzyy1AoFEhMTBQdiSRm9uzZCAkJYelJECc+Em7Pnj0YPXo0Jk6ciDlz5sDAwEB0JNJxBw8exOjRo5GRkQFjY2PRcaiesfhII9y4cQOBgYEAgNjYWLRs2VJwItJVarUaXl5emDRp0sN/cyQtXOokjdCyZUvs3bsX3t7ecHNzw549e0RHIh21bds2FBUVca89CePERxonMTERgYGBGDFiBBYsWABDQ0PRkUhHlJeXo3PnzoiMjMQbb7whOg4JwomPNI63tzfS0tJw6tQpeHt74+rVq6IjkY6IjY1FkyZN4OvrKzoKCcTiI41kaWmJHTt2YPDgwejatSu2bdsmOhJpueLiYsybNw8RERF8bJ7EcamTNN6xY8cQEBCAt956C5GRkbwKj17I559/jv379yM+Pl50FBKMxUda4fbt2xg7diyuXbuGuLg4dOjQQXQk0iKFhYWwt7fHvn370KlTJ9FxSDAudZJWaNKkCX744QeMGjUKnp6e2LRpk+hIpEU+/fRT+Pj4sPQIACc+0kInT56EUqlE//798Z///AcmJiaiI5EG++OPP+Ds7IyUlBS0a9dOdBzSAJz4SOu4ubkhNTUVBQUF6NatGzIyMkRHIg0WHh6OwMBAlh49xImPtJZarcbq1asxe/ZsfPrppxg5cqToSKRhfv31V7i7uyM9PZ0PQqeHWHyk9c6ePQulUgkPDw8sX76cDx2mh0aOHIl27dph/vz5oqOQBuFSJ2m9Tp06ITk5Gfr6+ujatSvOnDkjOhJpgLNnz2LPnj2YPn266CikYVh8pBMaNGiA//73v5g9ezZeeeUVfP311+BihrTNnj0bwcHBMDc3Fx2FNAyXOknnZGZmws/PD05OTli5ciUaNWokOhLVs8OHD2P48OHIzMyEXC4XHYc0DCc+0jmOjo5ISkpC06ZNoVAokJKSIjoS1SO1Wo2goCAsWLCApUdPxeIjnWRiYoLo6GgsXrwYAwcORFRUFJc+JWLnzp24c+cO99qjSnGpk3Te5cuXoVQq0bJlS8TExKBJkyaiI1EdKS8vh6urK8LCwjB48GDRcUhDceIjnde+fXscOXIEHTp0gKurK44ePSo6EtWRDRs2wMzMDG+++aboKKTBOPGRpMTHx2PcuHGYNm0aPv74Y+jr87OfrigtLYWjoyPWrl2L3r17i45DGozFR5Jz7do1BAQEwMzMDN9++y2f6KEjli1bhl27dmHnzp2io5CG48ddkhxbW1skJibCzc0NCoUCCQkJoiNRDd29exfh4eGIiIgQHYW0ACc+krSff/4Zo0ePxoQJEzB37lwYGBiIjkQvICwsDBkZGfjuu+9ERyEtwOIjyfv9998RGBiIiooKfPfdd2jZsqXoSPQccnNz4ezsjOPHj3ODYqoWLnWS5LVo0QI///wz+vXrBzc3N+zevVt0JHoOERER8Pf3Z+lRtXHiI/qHgwcPYvjw4QgMDERYWBgMDQ1FR6IqXL16Fa6urjh//jysra1FxyEtweIjekxubi5GjRqFgoICbNiwAW3atBEdiSoxZswYtGrVCgsXLhQdhbQIlzqJHmNpaYnt27fjnXfegYeHB3766SfRkegpzp8/jx07duDjjz8WHYW0DCc+oiokJSUhICAAgwcPRmRkJIyNjUVHov95++230bNnT+63R8+NEx9RFTw9PZGamoqrV6+iR48euHjxouhIBODYsWNISUnBpEmTREchLcTiI3qGxo0b4/vvv8eYMWPQvXt3bNy4UXQkSXuw7VBoaChMTExExyEtxKVOoueQmpoKpVKJfv364fPPP+cfXgF2796NqVOn4uzZs5DJZKLjkBbixEf0HBQKBU6ePInCwkJ4eHggPT1ddCRJqaioQFBQEMLDw1l69MJYfETPydzcHOvXr8eHH36I3r17Y+3ataIjSUZcXByMjY3x9ttvi45CWoxLnUQ1cO7cOSiVSri5uSE6OhpmZmaiI+ms0tJSdOzYEatWrULfvn1FxyEtxomPqAZeeuklnDhxAoaGhnB3d8fp06dFR9JZa9asQYcOHVh6VGOc+IhqSWxsLKZOnYqwsDC899570NPTEx1JZ/z111+wt7fH9u3boVAoRMchLcfiI6pFWVlZ8PPzg4ODA1atWoVGjRqJjqQTFi1ahDNnzvBWEqoVXOokqkUODg5ISkqCpaUlFAoFkpOTRUfSevn5+fjss88QFhYmOgrpCE58RHVky5YtmDRpEoKDg/HRRx9x6fMFffzxx7h79y5WrFghOgrpCBYfUR26fPky/P39YW1tjZiYGDRt2lR0JK2Sk5ODLl264OzZs9wgmGoNlzqJ6lD79u1x+PBh2Nvbw9XVFUeOHBEdSavMnz8fEyZMYOlRreLER1RPtm/fjnHjxuHDDz/ErFmzoK/Pz51VycjIQK9evZCVlYXGjRuLjkM6hMVHVI9ycnIQEBAAU1NTfPvtt7CyshIdSWMNHToUHh4emDlzpugopGP4kZOoHtnY2CAhIQFdu3aFQqHAgQMHREfSSMnJyUhKSsKUKVNERyEdxImPSJC9e/di1KhRGD9+PObNmwcDAwPRkTSCWq1G//79oVQqMWHCBNFxSAex+IgEunnzJgIDA6FSqSO9/QEAAAfzSURBVPDdd9+hVatWoiMJt3fvXkyZMgXnz5/nDgxUJ7jUSSSQtbU19uzZg1dffRVubm7YuXOn6EhCVVRUIDg4GAsXLmTpUZ1h8REJZmBggJCQEGzatAnvvfceZs6cibKyMtGxhNiyZQsAYMiQIYKTkC7jUieRBsnLy8OoUaNw+/ZtbNiwAW3bthUdqd6UlZXBxcUF0dHR6N+/v+g4pMM48RFpkGbNmiE+Pv7hpfxbt24VHanexMTEoHXr1iw9qnOc+Ig01PHjx+Hv749BgwZhyZIlMDY2Fh2pzty7dw/29vb48ccf0bVrV9FxSMdx4iPSUN26dUNaWhquX7+OHj164OLFi6Ij1Zlly5ahR48eLD2qF5z4iDScWq3GV199hU8++QTLli2Dv7+/6Ei16s6dO3BwcMDhw4fh6OgoOg5JAIuPSEukpaVBqVSiT58+iIqKgqmpqehItSI4OBh5eXlYtWqV6CgkESw+Ii1y9+5dTJw4EadPn8amTZvQsWNH0ZFq5MaNG+jUqRNOnz4NGxsb0XFIIvgdH5EWadiwIWJjYzFt2jT06dMHMTEx0ObPrgsWLMDYsWNZelSvOPERaanz58/Dz88PCoUC0dHRaNiwoehIzyU7Oxvdu3dHVlYWmjRpIjoOSQgnPiIt5eLiguTkZBgbG8Pd3R2nTp0SHem5zJkzB9OmTWPpUb3jxEekA9avX48PP/wQCxYswMSJE6Gnpyc6UpVOnjyJQYMGITs7Gw0aNBAdhySGxUekI7Kzs+Hn5wc7OzusWrUKFhYWoiNVysfHB2+99Rb+9a9/iY5CEsSlTiIdYW9vj2PHjsHa2hoKhQLJycmiIz3VgQMHcOnSJYwbN050FJIoFh+RDpHL5Vi2bBmWLl0KX19ffPbZZxp11adarUZQUBAWLlwIQ0ND0XFIolh8RDronXfewfHjxxEXF4c333wT+fn5oiMBALZu3YqysjL4+fmJjkISxuIj0lHt2rXDoUOH4OTkBFdXVxw+fFhoHpVKhdmzZyMiIgL6+vzTQ+LwXx+RDjMyMsKSJUvw1VdfYejQoVi0aBEqKiqEZFm7di1atGgBHx8fIb+f6AFe1UkkETk5ORg2bBjkcjnWrVsHKyurevvd9+/fh4ODAzZv3gxPT896+71ET8OJj0gibGxscODAAXh6ekKhUGD//v319ru//PJLuLu7s/RII3DiI5Kg/fv3Y+TIkRg7dizmzZsHmUxWZ7+roKAADg4OSExM1PqHapNuYPERSdStW7cQGBiI0tJSrF+/Hq1atarV80+ePBkODg64efMmbt26hf/+97+1en6iF8WlTiKJsrKywp49e+Dj4wM3Nzfs3LmzVs+/detWzJo1C5GRkejcubOwi2qIHseJj4hw6NAhDB8+HEqlEuHh4TAyMqrxOS0tLZGXlwcA0NfXxwcffID//Oc/NT4vUU1x4iMi9OrVC6mpqUhPT0fv3r1x5cqVGp+zpKQEAGBqagpvb28EBwfX+JxEtYHFR0QAgGbNmiE+Ph5+fn7w8PDADz/8UKPz3b9/HwYGBvj3v/+Nffv2oXnz5rWUlKhmuNRJRE84ceIE/P394evriyVLlkAulz/1uLyiEmw5mYOMm4UoLFbBXC6Dk7U53nWzwZA3fLB06VK4u7vXc3qiqrH4iOipCgoKMH78eFy6dAlxcXGwt7d/+LPT1wrwZeJFHMzKBQCUqP7/hStymT7UALwdLTGpjx262Gru9kgkTQahoaGhokMQkeaRy+V49913oVarERgYCBsbG3Tq1AmxSVfwYdwpZP1xF6oKNcorHv3s/OC1y3l/4cdTN2BhIkNnG5YfaQ5OfET0TKdOnYJSqcTrHyzCnlumuF9W/VsTTAz1ETLQGYGebesuINFzYPERUbWk/JqLwJgUFP+j9ApPxuOvs/tRmnsFDZz7oNkbU5/6XhNDA8RN8OTkRxqBV3USUbWsPPzbI9/lAYDMrCka9VDCrPOrVb63WFWO6MSLdRmPqNpYfET0THlFJTiYlYvH14dMHXvA1KE79E3Mq3y/Wg0kZOYiv6ikDlMSVQ+Lj4ieacvJnBqfQw/AltSan4eoplh8RPRMGTcLn1jmfF7Fqgpk/H63lhIRvTgWHxE9U2GxqpbOU1Yr5yGqCRYfET2Tubx29uszlxvWynmIaoLFR0TP5GRtDmPZk38u1BXlUKtKgYpyQF0BtaoU6oryp55DLtOHU4uGdR2V6Jl4Hx8RPVNeUQm8Ig888T1fwaHv8OeRDY+81sgrABa9hj9xDmOZPo7O6oemZsZ1mpXoWVh8RFQtE9alYG/6rSduaagOPT3Ap6MVVgTygdUkHpc6iahaJnvbQS4zeKH3ymUGmORtV8uJiF4Mi4+IqqWLrQVCBjrBxPD5/mz8/axOJz6ujDRG7VyqRUSS8OBB0+E7M1CsKq9y2VNP7+9JL2SgEx9QTRqF3/ER0XM7k1OA6MSLSMjMhR7+vjn9gQf78fV1tMQkbztOeqRxWHxE9MLyi0qwJTUHGb/fRWFxGczlhnBq0RBDFTa8epM0FouPiIgkhRe3EBGRpLD4iIhIUlh8REQkKSw+IiKSFBYfERFJCouPiIgkhcVHRESSwuIjIiJJYfEREZGksPiIiEhSWHxERCQpLD4iIpIUFh8REUkKi4+IiCSFxUdERJLC4iMiIklh8RERkaSw+IiISFJYfEREJCksPiIikhQWHxERScr/A4IyS0iNBj/XAAAAAElFTkSuQmCC\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "G = nx.DiGraph()\n",
        "\n",
        "for i in range(3): G.add_node(i)\n",
        "\n",
        "G.add_edge(0,1,weight=1.)\n",
        "G.add_edge(1,2,weight=1.)\n",
        "\n",
        "G_adj = nx.adjacency_matrix(G).toarray()\n",
        "\n",
        "nx.draw(G, with_labels=True)\n",
        "\n",
        "causal_idxs = [1,2]\n",
        "causal_idxs = [2,3]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "a_jdQSGcdfJX",
        "outputId": "3d85db11-327b-4e2f-b889-a36bfdb378d7"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([[0., 1., 0.],\n",
              "       [0., 0., 1.],\n",
              "       [0., 0., 0.]])"
            ]
          },
          "metadata": {},
          "execution_count": 5
        }
      ],
      "source": [
        "G_adj"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uqnxpnLzlpG-"
      },
      "source": [
        "## Generate Data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nr6CNgVzlszd"
      },
      "source": [
        "### Specification"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "aah7qrzOl0ou"
      },
      "outputs": [],
      "source": [
        "N = int(5e3)\n",
        "d = G_adj.shape[0] - 1\n",
        "o = 1\n",
        "\n",
        "num_train_envs = 2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5N4TRR1Glx9G"
      },
      "source": [
        "### Generation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "YOsvt8E4l6MX",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "b4a7a905-00c8-4e9d-d993-9cb155483a4a"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "noise_vecs: [0.1, 1.0]\n"
          ]
        }
      ],
      "source": [
        "train_noise_vecs = [0.1, 1.0]\n",
        "print('noise_vecs:', train_noise_vecs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "yWqvchK8mIUh"
      },
      "outputs": [],
      "source": [
        "data = {}\n",
        "data['train'] = [simulate_linear_sem(G_adj, N, 'exp', noise_scale=train_noise_vec) for train_noise_vec in train_noise_vecs]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Myn3o8-FjY5D"
      },
      "source": [
        "# Models"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "featurizer = nn.Linear(d,o, bias=False)\n",
        "predictor = nn.Parameter(torch.tensor([[1.]]))"
      ],
      "metadata": {
        "id": "5E_BByVd4vTx"
      },
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cQhK6Yrsu28i"
      },
      "source": [
        "### TCRI "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "91oVS0yV-SX-"
      },
      "outputs": [],
      "source": [
        "d, o = 2, 1\n",
        "\n",
        "class TCRI():\n",
        "  def __init__(self):\n",
        "    self.featurizer = deepcopy(featurizer)\n",
        "    self.ds_featurizer = deepcopy(featurizer)\n",
        "    with torch.no_grad():\n",
        "      self.predictor = nn.Linear(o,o, bias=False)\n",
        "      self.predictor.weight = nn.Parameter(torch.ones_like(self.predictor.weight))\n",
        "\n",
        "    self.loss_fn_ = nn.MSELoss()\n",
        "    self.optimizer = torch.optim.Adam(list(self.featurizer.parameters()) + \n",
        "                                       list(self.ds_featurizer.parameters()),\n",
        "                                      lr=0.01)\n",
        "\n",
        "  def _to_gpu(self):\n",
        "    self.featurizer = self.featurizer.cuda()\n",
        "    self.ds_featurizer = self.ds_featurizer.cuda()\n",
        "    self.predictor = self.predictor.cuda()\n",
        "\n",
        "  def _to_cpu(self):\n",
        "    self.featurizer = self.featurizer.cpu()\n",
        "    self.ds_featurizer = self.ds_featurizer.cpu()\n",
        "    self.predictor = self.predictor.cpu()\n",
        "\n",
        "  def _IRM(self, y_hat, y):\n",
        "    scale = torch.tensor(1.).cuda().requires_grad_()\n",
        "\n",
        "    loss_1 = mse_loss(y_hat[::2]*scale, y[::2])\n",
        "    loss_2 = mse_loss(y_hat[1::2]*scale, y[1::2])\n",
        "\n",
        "    grad_1 = autograd.grad(loss_1, [scale], create_graph=True)[0]\n",
        "    grad_2 = autograd.grad(loss_1, [scale], create_graph=True)[0]\n",
        "\n",
        "    return torch.sum(grad_1 * grad_2)\n",
        "\n",
        "  def _TCRI(self, x, y, z):\n",
        "    \"\"\"\n",
        "    X _||_ Y | Z\n",
        "    \"\"\"\n",
        "\n",
        "    res_x = x - z@torch.inverse(z.T@z)@z.T@x\n",
        "    res_y = y - z@torch.inverse(z.T@z)@z.T@y\n",
        "    \n",
        "    return torch.norm(res_x.T@res_y) / res_x.shape[0] # HSIC(res_x, res_y)\n",
        "\n",
        "  def loss_fn(self, y_hat, y, ds_y_hat, x, z, kwargs):\n",
        "    # TCRI\n",
        "    alpha = kwargs.get('alpha', None)\n",
        "    beta = kwargs.get('beta', None)\n",
        "    lambd = kwargs.get('lambd', None)\n",
        "\n",
        "    \n",
        "    if lambd == 0:\n",
        "      irm = 0.\n",
        "    else:\n",
        "      irm = self._IRM(y_hat, y)\n",
        "\n",
        "    if beta == 0:\n",
        "      tcri = 0.\n",
        "    else:\n",
        "      tcri = self._TCRI(x,z,y)\n",
        "\n",
        "    mse_c, mse_ac = self.loss_fn_(y_hat.ravel(), y.ravel()), self.loss_fn_(ds_y_hat.ravel(), y.ravel())\n",
        "\n",
        "    return alpha*mse_c + lambd*irm + beta*tcri + (1.-alpha)*mse_ac, 0.\n",
        "\n",
        "  def train(self, labeled_batches, unlabeled_batches=[], steps=100, update_steps=-1, **kwargs):\n",
        "    log = {}\n",
        "    log['losses'] = []\n",
        "    log['errors'] = []\n",
        "    log['ds_errors'] = []\n",
        "    self._to_gpu()\n",
        "    for step in range(steps):\n",
        "      loss = torch.tensor(0.).cuda()\n",
        "      cis = 0.\n",
        "      for i, (x,y) in enumerate(labeled_batches):\n",
        "        self.optimizer.zero_grad()\n",
        "        x = x.cuda()\n",
        "        y = y.reshape(-1,1).cuda()\n",
        "\n",
        "        phi = self.featurizer(x)\n",
        "        psi = self.ds_featurizer(x)\n",
        "\n",
        "        y_hat = self.predictor(self.featurizer(x))\n",
        "        ds_y_hat = psi @ torch.inverse(psi.T@psi)@psi.T@y\n",
        "\n",
        "\n",
        "        loss_, ci = self.loss_fn(y_hat, y, ds_y_hat=ds_y_hat, x=phi, z=psi, kwargs=kwargs)\n",
        "        loss += loss_\n",
        "        cis += ci\n",
        "\n",
        "      loss /= len(labeled_batches)\n",
        "      log['losses'].append(loss.item())\n",
        "\n",
        "      if update_steps > 0 and (step % update_steps == 0):\n",
        "        print(\"step {} - loss = {:.6f}\".format(step, loss.item()))\n",
        "\n",
        "      loss.backward()\n",
        "      self.optimizer.step()\n",
        "\n",
        "      with torch.no_grad():\n",
        "        log['errors'].append(mse_loss(y, y_hat).item())\n",
        "        log['ds_errors'].append(mse_loss(y, ds_y_hat).item())\n",
        "\n",
        "    self._to_cpu()\n",
        "    return self, log\n",
        "      \n",
        "  def predict(self, x):\n",
        "    return self.predictor((self.featurizer(x)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vjWkqsRdkQwT"
      },
      "source": [
        "## Experiments"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### ERM\n",
        "\n",
        "---\n",
        "\n"
      ],
      "metadata": {
        "id": "TvxGl4Iq74Zq"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_bFmJ9cavBxU",
        "outputId": "72ca4daf-ab56-41ba-8bc8-ee9e711c4e20"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "X:  [[0.8377477 ]\n",
            " [0.18158905]]\n",
            "Xc:  [[1.130098]]\n",
            "Xe:  [[0.56378365]]\n"
          ]
        }
      ],
      "source": [
        "X = torch.concat([x[0] for x in data['train']], 0)\n",
        "Xc = torch.concat([x[0][:,:1] for x in data['train']], 0)\n",
        "Xe = torch.concat([x[0][:,1:] for x in data['train']], 0)\n",
        "y = torch.concat([x[1] for x in data['train']], 0).reshape(-1,1)\n",
        "\n",
        "w = torch.inverse(X.T@X)@X.T@y\n",
        "wc = torch.inverse(Xc.T@Xc)@Xc.T@y\n",
        "we = torch.inverse(Xe.T@Xe)@Xe.T@y\n",
        "\n",
        "print(\"X: \", w.detach().numpy())\n",
        "print(\"Xc: \", wc.detach().numpy())\n",
        "print(\"Xe: \", we.detach().numpy())"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# ERM\n",
        "\n",
        "lambd = 0. # * L_{irmv1'}\n",
        "beta = 0. # * L_{CI}\n",
        "alpha = 1 # * L_{\\Phi} + (1 - \\alpha) * L_{\\Phi \\oplus \\Psi}\n",
        "\n",
        "erm, erm_log = TCRI().train(data['train'], steps=500, alpha=alpha, beta=beta, lambd=lambd)\n",
        "erm_logs = [erm, erm_log]\n",
        "\n",
        "print('ERM Featurizer: {}'.format(erm.featurizer.weight.detach().numpy()))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VymHc7xiaX2v",
        "outputId": "c070fffd-552f-48fb-ef96-57cca53a7c60"
      },
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ERM Featurizer: [[0.8377481 0.1815888]]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### IRM"
      ],
      "metadata": {
        "id": "j2f4RmzRnLeR"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "lambd = 0.1 # * L_{irmv1'}\n",
        "beta = 0 # * L_{CI}\n",
        "alpha = 1 # L_{\\Phi} + (1 - \\alpha) * L_{\\Phi \\oplus \\Psi}\n",
        "\n",
        "irm, irm_log = TCRI().train(data['train'], steps=500, alpha=alpha, beta=beta, lambd=lambd)\n",
        "irm_logs = [irm, irm_log]\n",
        "\n",
        "print('IRM Causal Featurizer: {}'.format(irm.featurizer.weight.detach().numpy()))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ByvHA5UlaXSj",
        "outputId": "a4247856-6781-4497-c080-97ee01442de9"
      },
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "IRM Causal Featurizer: [[0.8338259  0.18038203]]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hkvu7t80qQ9S"
      },
      "source": [
        "### TCRI"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "RlFjmVXS7xYO",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "ba98754d-33ad-46b9-e41f-a6bfe5789113"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "TCRI Causal Featurizer: [[1.1137584  0.01073858]]\n"
          ]
        }
      ],
      "source": [
        "lambd = 0.1 # L_{irmv1'}\n",
        "beta = 10 # L_{CI}\n",
        "alpha = 0.75 # L_{\\Phi} + (1 - \\alpha) * L_{\\Phi \\oplus \\Psi}\n",
        "\n",
        "\n",
        "tcri, tcri_log = TCRI().train(data['train'], steps=500, alpha=alpha, beta=beta, lambd=lambd)\n",
        "tcri_logs = [tcri, tcri_log]\n",
        "\n",
        "print('TCRI Causal Featurizer: {}'.format(tcri.featurizer.weight.detach().numpy()))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}