{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import random\n",
    "import pandas as pd\n",
    "device = \"cpu\"\n",
    "from copy import deepcopy\n",
    "\n",
    "def set_seed():\n",
    "    random.seed(42)\n",
    "    np.random.seed(42)\n",
    "    torch.manual_seed(42)\n",
    "    torch.cuda.manual_seed_all(42)  # If using CUDA"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset Properties"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "Gauss = np.random.normal\n",
    "var = 5 # High Variance\n",
    "trn_size = 1000 # High Training Size\n",
    "ndim = 6 # Number of dimensions. The base dimensions are always 3.\n",
    "depth = 10 # Depth of the causal graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lin_y(X: np.array, w: np.array, var: float):\n",
    "    \"\"\"\n",
    "    Defines y as a linear function of X and w with Gaussian noise of variance var and mean 0\n",
    "    \"\"\"\n",
    "    if len(w.shape) == 1:\n",
    "        w = w.reshape(-1, 1)\n",
    "    assert X.shape[1] == len(w), \"Input array must be of same shape as the weight vector\"\n",
    "    result = X @ w / X.shape[1]\n",
    "    exog = np.random.randn(len(X)) * np.sqrt(var)\n",
    "    return [result.reshape(-1, 1), exog.reshape(-1, 1)]\n",
    "\n",
    "def lin_X(X: np.ndarray, mat: np.ndarray, var: float):\n",
    "    \"\"\"\n",
    "    Defines X' as a linear function of X and mat with Gaussian noise of variance var and mean 0\n",
    "    \"\"\"\n",
    "    assert X.shape[1] == mat.shape[0], \"Input array must be of same shape as the weight matrix\"\n",
    "    result: np.ndarray = X @ mat / X.shape[1]\n",
    "    exog = Gauss(0, 1, result.shape) * np.sqrt(var)\n",
    "    return [result, exog]\n",
    "\n",
    "def generate_roots(num_samples, ndim, corr):\n",
    "    \"\"\"Generates correlated features. \n",
    "    Base X is assumed to be of 3 dimensions\n",
    "    Remaining dimensions are correlated with the base X with correlation corr\n",
    "    \"\"\"\n",
    "    base_X = Gauss(0, 1, (num_samples, 3))\n",
    "    X = [base_X]\n",
    "    num_rpt = ndim // 3\n",
    "    for i in range(num_rpt - 1):\n",
    "        X.append(base_X + Gauss(0, corr, (num_samples, 3)))\n",
    "    X = np.concatenate(X, axis=1)\n",
    "    assert X.shape == (num_samples, ndim), f\"Shape of X is {X.shape}\"\n",
    "    return X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define the Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearYModel(nn.Module):\n",
    "    def __init__(self, num_params):\n",
    "        super(LinearYModel, self).__init__()\n",
    "        self.num_params = num_params\n",
    "        self.w = np.random.randn(num_params, 1)\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        self.w = np.linalg.inv(X.T @ X) @ X.T @ y\n",
    "\n",
    "    def predict(self, X):\n",
    "        yhat = X @ self.w\n",
    "        yhat = yhat.reshape(-1, 1)\n",
    "        return yhat\n",
    "\n",
    "    def forward(self, X):\n",
    "        return self.predict(X)\n",
    "    \n",
    "class LinearXModel(nn.Module):\n",
    "    def __init__(self, num_params):\n",
    "        super(LinearXModel, self).__init__()\n",
    "        self.num_params = num_params\n",
    "        self.mat = np.random.randn(num_params, num_params)\n",
    "    \n",
    "    def fit(self, XPrev, Xcurr):\n",
    "        self.mat = np.linalg.inv(XPrev.T @ XPrev) @ XPrev.T @ Xcurr\n",
    "        assert self.mat.shape == (self.num_params, self.num_params), f\"Shape of matrix is {self.mat.shape}\"\n",
    "    \n",
    "    def predict(self, X):\n",
    "        yhat = X @ self.mat\n",
    "        return yhat\n",
    "    \n",
    "    def forward(self, X):\n",
    "        return self.predict(X)\n",
    "\n",
    "# %% Define the Gold DGP Parameters\n",
    "GoldMatrices = []\n",
    "for i in range(depth-1):\n",
    "    GoldMatrices.append(Gauss(0, 1, (ndim, ndim)))\n",
    "w = Gauss(0, 1, (ndim, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate Training and Test Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_datasets(rho):\n",
    "    \"\"\"\n",
    "    Generates the training, validation, test and fixed datasets for the given DGP\n",
    "    \"\"\"\n",
    "    set_seed()\n",
    "    Xval_root, XTst_root = generate_roots(100, ndim, rho), generate_roots(100, ndim, rho)\n",
    "    XFix_root = np.copy(XTst_root)\n",
    "    XTrn_root = generate_roots(trn_size, ndim, rho)\n",
    "    for i in range(len(XTst_root)):\n",
    "        XTst_root[i, np.random.randint(0, 3)] = np.random.uniform(3, 10)\n",
    "        \n",
    "    def gen_inter_data(X_root, matrices, w, var):\n",
    "        int_data = []\n",
    "        int_data.append(lin_X(X_root, matrices[0], var)) # The root nodes have only noise\n",
    "        for mat in matrices[1:]:\n",
    "            X_noise = int_data[-1][0] + int_data[-1][1]\n",
    "            int_data.append(lin_X(X_noise, mat, var))\n",
    "        X_noise = int_data[-1][0] + int_data[-1][1]\n",
    "        int_data.append(lin_y(X_noise, w, var)) \n",
    "        return int_data\n",
    "\n",
    "    data_dicts = {\n",
    "        \"trn_root\": XTrn_root,\n",
    "        \"val_root\": Xval_root,\n",
    "        \"tst_root\": XTst_root,\n",
    "        \"fix_root\": XFix_root,\n",
    "        \"trn\": gen_inter_data(XTrn_root, GoldMatrices, w, var),\n",
    "        \"val\": gen_inter_data(Xval_root, GoldMatrices, w, var),\n",
    "        \"tst\": gen_inter_data(XTst_root, GoldMatrices, w, var),\n",
    "        \"fix\": gen_inter_data(XFix_root, GoldMatrices, w, var)\n",
    "    }\n",
    "    return data_dicts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_models(data_dicts):\n",
    "    \"\"\"\n",
    "    Trains the models for the given data dictionaries\n",
    "    The first model is from the root to the first intermediate node\n",
    "    The rest of the models are from one intermediate node to the next\n",
    "    \"\"\"\n",
    "    XTrn_root = data_dicts[\"trn_root\"]\n",
    "    Xmodels = []\n",
    "    for i in range(depth-1):\n",
    "        Xmodels.append(LinearXModel(ndim))\n",
    "    Ymodel = LinearYModel(ndim)\n",
    "\n",
    "    # First train the 0th model from root to the first intermediate node\n",
    "    trn_Y = data_dicts[\"trn\"][0][0] + data_dicts[\"trn\"][0][1]\n",
    "    bparams = deepcopy(Xmodels[0].mat)\n",
    "    Xmodels[0].fit(XTrn_root, trn_Y)\n",
    "    aparams = deepcopy(Xmodels[0].mat)\n",
    "    print(\"Trained 0th model from Root -> X0\", np.linalg.norm(bparams - aparams))\n",
    "    for i, model in enumerate(Xmodels[1:]):\n",
    "        trn_X = data_dicts[\"trn\"][i][0] + data_dicts[\"trn\"][i][1]\n",
    "        trn_Y = data_dicts[\"trn\"][i+1][0] + data_dicts[\"trn\"][i+1][1]\n",
    "        bparams = deepcopy(model.mat)\n",
    "        model.fit(trn_X, trn_Y)\n",
    "        aparams = deepcopy(model.mat)\n",
    "        print(f\"Trained {i+1}th model from X{i} -> X{i+1}\", np.linalg.norm(bparams - aparams))\n",
    "    trn_X = data_dicts[\"trn\"][i+1][0] + data_dicts[\"trn\"][i+1][1]\n",
    "    trn_Y = data_dicts[\"trn\"][i+2][0] + data_dicts[\"trn\"][i+2][1]\n",
    "    bparams = deepcopy(Ymodel.w)\n",
    "    Ymodel.fit(trn_X, trn_Y)\n",
    "    aparams = deepcopy(Ymodel.w)\n",
    "    print(f\"Trained Y model from X{i+1} -> X{i+2}\", np.linalg.norm(bparams - aparams))\n",
    "    assert len(Xmodels)+1 == depth, \"Number of models trained is not equal to depth\"\n",
    "    return Xmodels, Ymodel\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluate the models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_preds(Xmodels, Ymodel, X_root, data, val_preds=None, test_preds=None, int=False, CF=False):\n",
    "    preds = []\n",
    "    for i, model in enumerate(Xmodels):\n",
    "        if i == 0:\n",
    "            p = model(X_root)\n",
    "        else:\n",
    "            p = model(p)\n",
    "        if CF == True:\n",
    "            residual = test_preds[i][1]\n",
    "        elif int == True:\n",
    "            # residual = val_preds[i][1]\n",
    "            # np.random.shuffle(residual)\n",
    "            residual = np.zeros_like(val_preds[i][1])\n",
    "        else:\n",
    "            residual = p - data[i][0]\n",
    "        preds.append([p, residual])\n",
    "    p = Ymodel(p)\n",
    "    if CF == True:\n",
    "        residual = test_preds[-1][1]\n",
    "    elif int == True:\n",
    "        # residual = val_preds[-1][1]\n",
    "        # np.random.shuffle(residual)\n",
    "        residual = np.zeros_like(val_preds[-1][1])\n",
    "    else:\n",
    "        residual = p - data[-1][0]\n",
    "    preds.append([p, residual])\n",
    "    return preds\n",
    "\n",
    "# %% Assess the L2 errors\n",
    "def mse_error(preds, data):\n",
    "    errors = []\n",
    "    for i, pred in enumerate(preds):\n",
    "        errors.append(round(np.mean((pred[0] + preds[1] - data[i][0] + data[i][1])**2), 2))\n",
    "    return errors\n",
    "\n",
    "def eval_models(Xmodels, Ymodel, data_dicts):\n",
    "    \"\"\"\n",
    "    Evaluates the models on the given data dictionaries\n",
    "    \"\"\"\n",
    "    Xval_root = data_dicts[\"val_root\"]\n",
    "    XTst_root = data_dicts[\"tst_root\"]\n",
    "    XFix_root = data_dicts[\"fix_root\"]\n",
    "    \n",
    "    # %% Get the predictions for test data and validation data for residuals\n",
    "    val_preds = get_preds(Xmodels, Ymodel, Xval_root, data_dicts[\"val\"])\n",
    "    test_preds = get_preds(Xmodels, Ymodel, XTst_root, data_dicts[\"tst\"])    \n",
    "\n",
    "\n",
    "    # %% Interventional Preds\n",
    "    int_preds = get_preds(Xmodels, Ymodel, XFix_root, data_dicts[\"fix\"], val_preds=val_preds, int=True)\n",
    "    int_errors = mse_error(int_preds, data_dicts[\"fix\"])\n",
    "\n",
    "    # %% Counterfactual Preds\n",
    "    cf_preds = get_preds(Xmodels, Ymodel, XFix_root, data_dicts[\"fix\"], test_preds=test_preds, CF=True)\n",
    "    cf_errors = mse_error(cf_preds, data_dicts[\"fix\"])\n",
    "\n",
    "    df_dict = {\n",
    "        \"Depth\": np.arange(1, depth+1),\n",
    "        \"int_errors\": int_errors,\n",
    "        \"cf_errors\": cf_errors,\n",
    "        \"Reduction_PC\": [round(100 * ((cf - int) / cf), 2) for cf, int in zip(cf_errors, int_errors)]\n",
    "    }\n",
    "    df = pd.DataFrame(df_dict)\n",
    "    df.set_index(\"Depth\", inplace=True)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trained 0th model from Root -> X0 42.52195383148337\n",
      "Trained 1th model from X0 -> X1 7.720633332688531\n",
      "Trained 2th model from X1 -> X2 5.057511244483703\n",
      "Trained 3th model from X2 -> X3 7.497846949499599\n",
      "Trained 4th model from X3 -> X4 4.310451996116022\n",
      "Trained 5th model from X4 -> X5 6.977017609909292\n",
      "Trained 6th model from X5 -> X6 5.796138377138417\n",
      "Trained 7th model from X6 -> X7 5.713012190967711\n",
      "Trained 8th model from X7 -> X8 6.680180541075484\n",
      "Trained Y model from X8 -> X9 1.889158144886001\n",
      "Trained 0th model from Root -> X0 5.521850147519549\n",
      "Trained 1th model from X0 -> X1 7.725684955067254\n",
      "Trained 2th model from X1 -> X2 5.056907725455107\n",
      "Trained 3th model from X2 -> X3 7.497698520491602\n",
      "Trained 4th model from X3 -> X4 4.310497821537817\n",
      "Trained 5th model from X4 -> X5 6.977029353646472\n",
      "Trained 6th model from X5 -> X6 5.7961713591012245\n",
      "Trained 7th model from X6 -> X7 5.713035967101211\n",
      "Trained 8th model from X7 -> X8 6.680176096171011\n",
      "Trained Y model from X8 -> X9 1.8891571821222932\n",
      "|   Depth |   (0.01, 'int_errors') |   (0.01, 'cf_errors') |   (0.01, 'Reduction_PC') |   (1.0, 'int_errors') |   (1.0, 'cf_errors') |   (1.0, 'Reduction_PC') |\n",
      "|--------:|-----------------------:|----------------------:|-------------------------:|----------------------:|---------------------:|------------------------:|\n",
      "|       1 |                   5.36 |                222.64 |                  97.5925 |                  5.39 |                 5.92 |                 8.9527  |\n",
      "|       2 |                   6.23 |                222.69 |                  97.2024 |                  6.22 |                 6.97 |                10.7604  |\n",
      "|       3 |                   6.22 |                225.01 |                  97.2357 |                  6.23 |                 6.66 |                 6.45646 |\n",
      "|       4 |                   5.88 |                223.66 |                  97.371  |                  5.88 |                 6.42 |                 8.41121 |\n",
      "|       5 |                   5.76 |                229.02 |                  97.4849 |                  5.78 |                 6.53 |                11.4855  |\n",
      "|       6 |                   6.24 |                223.33 |                  97.2059 |                  6.25 |                 6.8  |                 8.08824 |\n",
      "|       7 |                   6.66 |                227.53 |                  97.0729 |                  6.63 |                 7.13 |                 7.01262 |\n",
      "|       8 |                   6.07 |                229.74 |                  97.3579 |                  6.08 |                 6.52 |                 6.74847 |\n",
      "|       9 |                   6.58 |                224.61 |                  97.0705 |                  6.56 |                 6.94 |                 5.4755  |\n",
      "|      10 |                   5.09 |                224.29 |                  97.7306 |                  5.09 |                 5.6  |                 9.10714 |\n"
     ]
    }
   ],
   "source": [
    "result_dfs = []\n",
    "for corr in [0.01, 1]:\n",
    "    data_dicts = generate_datasets(corr)\n",
    "    Xmodels, Ymodel = train_models(data_dicts)\n",
    "    df = eval_models(Xmodels, Ymodel, data_dicts)\n",
    "    result_dfs.append(df)\n",
    "\n",
    "# Concatenate all the dfs based on Depth\n",
    "final_df = pd.concat(result_dfs, axis=1, keys=[0.01, 1])\n",
    "print(final_df.to_markdown())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "petshop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
