{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "11c6429c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.linear_model import Ridge\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "import matplotlib.pyplot as plt\n",
    "from joblib import Parallel, delayed\n",
    "from tqdm import tqdm\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a950de2",
   "metadata": {},
   "source": [
    "## Main functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "4bbef26d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class KernelRegression():\n",
    "    \"\"\" Kernel regression \"\"\"\n",
    "\n",
    "    def __init__(self, label_type, kernel,\n",
    "                 p=None, standardize=True):\n",
    "        \"\"\" Constructor \"\"\"\n",
    "        \n",
    "        if label_type not in ['cont', 'multi_cont']:\n",
    "            raise ValueError('Label type unknown. Choose from cont or multi_cont')\n",
    "\n",
    "        self.label_type = label_type\n",
    "        self.kernel = kernel\n",
    "        self.p = p\n",
    "        self.standardize = standardize\n",
    "\n",
    "    def fit(self, x_or, y):\n",
    "        \"\"\" Fits with cross validation \"\"\"\n",
    "\n",
    "        #---preprocessing\n",
    "        if self.standardize:\n",
    "            xscaler = StandardScaler().fit(x_or)\n",
    "            self.xscaler = lambda x: xscaler.transform(x)\n",
    "            x = self.xscaler(x_or)\n",
    "        else:\n",
    "            x = x_or.copy()\n",
    "\n",
    "        if self.kernel =='linear':\n",
    "            self.kernel_fit = lambda x:x\n",
    "            x = self.kernel_fit(x)\n",
    "        elif self.kernel == 'poly':\n",
    "            if self.p is None:\n",
    "                raise ValueError('Need polynomial value')\n",
    "            self.kernel_fit = lambda x: np.hstack([x**i for i in range(1,self.p+1)])\n",
    "            x = self.kernel_fit(x)\n",
    "\n",
    "        #---fitting\n",
    "        if self.label_type== 'cont':\n",
    "            parameters = {'alpha':[1e-3, 1e-2, 1e-1, 1, 1e2, 1e3, 1e4, 1e5, 1e6]}\n",
    "            ridge = Ridge(random_state = 23854)\n",
    "            cv_fit = GridSearchCV(ridge, parameters, refit=True).fit(x, y)           \n",
    "            self.model = cv_fit.best_estimator_\n",
    "        elif self.label_type == 'multi_cont':\n",
    "            parameters = {'alpha':[1e-3, 1e-2, 1e-1, 1, 1e2, 1e3, 1e4, 1e5, 1e6]}\n",
    "            models = []\n",
    "            for j in range(y.shape[1]):\n",
    "                ridge = Ridge(random_state = 23854)\n",
    "                cv_fit = GridSearchCV(ridge, parameters, refit=True).fit(x, y[:,j].reshape(-1, 1))\n",
    "                models.append(cv_fit.best_estimator_)\n",
    "            self.model = models\n",
    "\n",
    "\n",
    "    def predict(self, x):\n",
    "\n",
    "        if self.standardize:\n",
    "            x = self.xscaler(x)\n",
    "        x = self.kernel_fit(x)\n",
    "\n",
    "        if self.label_type == 'cont':\n",
    "            yhat = self.model.predict(x)\n",
    "        elif self.label_type == 'multi_cont':\n",
    "            yhat = []\n",
    "            for cmod in self.model:\n",
    "                yhat.append(cmod.predict(x).reshape(-1, 1))\n",
    "            yhat = np.hstack(yhat)\n",
    "\n",
    "        return yhat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "439b83fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MBM():\n",
    "    def __init__(self, kernel_tch='poly',\n",
    "                 kernel_stud='poly', p_tch = None,\n",
    "                 p_stud = None, standardize=True):\n",
    "        \"\"\" Constructor for MBM \"\"\"\n",
    "        self.kernel_tch = kernel_tch\n",
    "        self.kernel_stud = kernel_stud\n",
    "        self.p_tch = p_tch\n",
    "        self.p_stud = p_stud\n",
    "        self.standardize = standardize\n",
    "\n",
    "\n",
    "    def fit(self, x, y, m):\n",
    "        \"\"\" Fits teacher and student \"\"\"\n",
    "\n",
    "        # First teacher fit\n",
    "        self.model_tch = KernelRegression(\n",
    "            label_type='multi_cont' if m.ndim > 1 else 'cont',\n",
    "            kernel=self.kernel_tch, p = self.p_tch,\n",
    "            standardize=self.standardize)\n",
    "\n",
    "        self.model_tch.fit(x, m)\n",
    "        m_hat = self.model_tch.predict(x)\n",
    "\n",
    "        # Fit model to predict y from m_hat\n",
    "        self.model_stud = KernelRegression(\n",
    "            label_type='multi_cont' if y.ndim > 1 else 'cont',\n",
    "            kernel=self.kernel_stud, p = self.p_stud,\n",
    "            standardize=self.standardize)\n",
    "        self.model_stud.fit(m_hat, y)\n",
    "\n",
    "    def predict(self, x):\n",
    "        \"\"\" Predicts y given x by first predicting m and then y \"\"\"\n",
    "\n",
    "        if self.model_tch is None or self.model_stud is None:\n",
    "            raise ValueError(\"Models have not been fitted. Call fit() first.\")\n",
    "\n",
    "        m_hat = self.model_tch.predict(x)\n",
    "        y_hat = self.model_stud.predict(m_hat)\n",
    "        \n",
    "        return y_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "679e7536",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TIPMI():\n",
    "    def __init__(self, kernel_tch='poly',\n",
    "                 kernel_stud='poly', p_tch = None,\n",
    "                 p_stud = None, standardize=True):\n",
    "        \"\"\" Constructor for TIPMI \"\"\"\n",
    "\n",
    "        self.kernel_tch = kernel_tch\n",
    "        self.kernel_stud = kernel_stud\n",
    "        self.p_tch = p_tch\n",
    "        self.p_stud = p_stud\n",
    "        self.standardize = standardize\n",
    "\n",
    "\n",
    "    def fit(self, x, y, m):\n",
    "        \"\"\" Fits teacher and student \"\"\"\n",
    "\n",
    "        #2x cross fitting\n",
    "        n_samp = int(x.shape[0]/2)\n",
    "\n",
    "        # First teacher fit\n",
    "        self.model_tch1 = KernelRegression(\n",
    "            label_type='multi_cont' if y.ndim > 1 else 'cont',\n",
    "            kernel=self.kernel_tch, p = self.p_tch,\n",
    "            standardize=self.standardize)\n",
    "\n",
    "        self.model_tch1.fit(m[:n_samp, :], y[:n_samp, :])\n",
    "        y_hat1 = self.model_tch1.predict(m[n_samp:, :])\n",
    "\n",
    "        # Second teacher fit\n",
    "        self.model_tch2 = KernelRegression(\n",
    "            label_type='multi_cont' if y.ndim > 1 else 'cont',\n",
    "            kernel=self.kernel_tch, p = self.p_tch,\n",
    "            standardize=self.standardize)\n",
    "\n",
    "        self.model_tch2.fit(m[n_samp:, :], y[n_samp:, :])\n",
    "        y_hat2 = self.model_tch2.predict(m[:n_samp, :])\n",
    "\n",
    "        y_hat = np.vstack([y_hat2, y_hat1])\n",
    "\n",
    "        # Fit student model\n",
    "        self.model_stud = KernelRegression(\n",
    "            label_type='multi_cont' if y.ndim > 1 else 'cont',\n",
    "            kernel=self.kernel_stud, p = self.p_stud,\n",
    "            standardize=self.standardize)\n",
    "        self.model_stud.fit(x, y_hat)\n",
    "\n",
    "    def predict(self, x):\n",
    "        \"\"\" Predicts y from the student model\"\"\"\n",
    "\n",
    "        if self.model_stud is None:\n",
    "            raise ValueError(\"Models have not been fitted. Call fit() first.\")\n",
    "\n",
    "        y_hat = self.model_stud.predict(x)\n",
    "\n",
    "        return y_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "b16694d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def main_anti_causal_shifted(n_train, i, m_dim):\n",
    "    n_test_id = 100\n",
    "    n_test_ood = 100\n",
    "    noise_level = 1\n",
    "\n",
    "    rng = np.random.RandomState((i+1)*n_train)\n",
    "\n",
    "    # ------------------------ #\n",
    "    # -----Training data------ #\n",
    "    # ------------------------ #\n",
    "    \n",
    "    # generate v\n",
    "    v = rng.normal(0, 1, (n_train, 1))\n",
    "    \n",
    "    # generate y\n",
    "    vty = rng.normal(0, 1, (1, 1))\n",
    "    y_short = np.dot(v, vty)\n",
    "    y_clean = rng.normal(0, 1, (n_train, 1))\n",
    "    y = y_short + y_clean\n",
    "\n",
    "\n",
    "    \n",
    "    # generate m \n",
    "    ytm = rng.normal(0, 1, (1, 1))\n",
    "    m_1 = np.dot(y_clean, ytm) + rng.normal(0, noise_level, (n_train, 1))\n",
    "    m_rest = rng.normal(0, 1, (n_train, m_dim-1))\n",
    "    if m_dim > 1: \n",
    "        m = np.hstack([m_1, m_rest])\n",
    "    else: \n",
    "        m = m_1.copy()\n",
    "    \n",
    "    # generate x\n",
    "    mtx = rng.normal(0, 1,  (1, 1))\n",
    "    x_clean = np.dot(m_1,mtx)\n",
    "    \n",
    "    vtx = rng.normal(0, 1, (1, 1))\n",
    "    x_short = np.dot(v, vtx)\n",
    "    if m_dim > 1:\n",
    "        x_redundant = np.cbrt(m_rest) + rng.normal(0, noise_level, (n_train, m_dim -1))\n",
    "        x_redundant = x_redundant + x_short\n",
    "    else:\n",
    "        x_redundant = x_short\n",
    "\n",
    "    x = np.hstack([x_clean, x_redundant])\n",
    "    \n",
    "    # ------------------------ #\n",
    "    # -----Testing data------- #\n",
    "    # ------------------------ #\n",
    "    \n",
    "    # ---- in distribution \n",
    "    \n",
    "    # generate v\n",
    "    v_id = rng.normal(0, 1, (n_test_id, 1))\n",
    "    \n",
    "    # generate y\n",
    "    y_short_id = np.dot(v_id, vty)\n",
    "    y_clean_id = rng.normal(0, 1, (n_test_id, 1))\n",
    "    y_id = y_short_id + y_clean_id\n",
    "    \n",
    "    # generate m \n",
    "    m_1_id = np.dot(y_clean_id, ytm) + rng.normal(0, noise_level, (n_test_id, 1))\n",
    "    m_rest_id = rng.normal(0, 1, (n_test_id, m_dim-1))\n",
    "    if m_dim > 1: \n",
    "        m_id = np.hstack([m_1_id, m_rest_id])\n",
    "    else: \n",
    "        m_id = m_1_id.copy()\n",
    "    \n",
    "    # generate x\n",
    "    x_clean_id = np.dot(m_1_id,mtx)\n",
    "    x_short_id = np.dot(v_id, vtx)\n",
    "    if m_dim > 1:\n",
    "        x_redundant_id = np.cbrt(m_rest_id) + rng.normal(0, noise_level, (n_test_id, m_dim -1))\n",
    "        x_redundant_id = x_redundant_id + x_short_id\n",
    "    else:\n",
    "        x_redundant_id = x_short_id\n",
    "    x_id = np.hstack([x_clean_id, x_redundant_id])\n",
    "\n",
    "    # ---- out of distribution \n",
    "    \n",
    "    # generate v\n",
    "    v_ood = rng.normal(0, 1, (n_test_ood, 1))\n",
    "    \n",
    "    # generate y\n",
    "    y_short_ood = np.dot(v_ood, -1.0 * vty)\n",
    "    y_clean_ood = rng.normal(0, 1, (n_test_ood, 1))\n",
    "    y_ood = y_short_ood + y_clean_ood\n",
    "    \n",
    "    # generate m \n",
    "    m_1_ood = np.dot(y_clean_ood, ytm) + rng.normal(0, noise_level, (n_test_ood, 1))\n",
    "    m_rest_ood = rng.normal(0, 1, (n_test_ood, m_dim-1))\n",
    "    if m_dim > 1: \n",
    "        m_ood = np.hstack([m_1_ood, m_rest_ood])\n",
    "    else: \n",
    "        m_ood = m_1_ood.copy()\n",
    "    \n",
    "    # generate x\n",
    "    x_clean_ood = np.dot(m_1_ood,mtx)\n",
    "    x_short_ood = np.dot(v_ood, vtx)\n",
    "    if m_dim > 1:\n",
    "        x_redundant_ood = np.cbrt(m_rest_ood) + rng.normal(0, noise_level, (n_test_ood, m_dim -1))\n",
    "        x_redundant_ood = x_redundant_ood + x_short_ood\n",
    "    else:\n",
    "        x_redundant_ood = x_short_ood\n",
    "    x_ood = np.hstack([x_clean_ood, x_redundant_ood])\n",
    "\n",
    "\n",
    "\n",
    "    # fit and predict mbm\n",
    "    tipmi = TIPMI(kernel_tch='poly', kernel_stud='poly', p_tch=5, p_stud=5)\n",
    "    tipmi.fit(x, y, m)\n",
    "    yh_tipmi_id = tipmi.predict(x_id)\n",
    "    yh_tipmi_ood = tipmi.predict(x_ood)\n",
    "    tipmi_xf_res =  pd.DataFrame({'i': i, 'n': n_train,\n",
    "                               'model': 'TIPMI (Ours)',\n",
    "                               'mse_ood': np.sqrt(np.mean((yh_tipmi_ood - y_ood)**2)),\n",
    "                               'mse_id': np.sqrt(np.mean((yh_tipmi_id - y_id)**2)),\n",
    "                               'm_dim': m_dim}, index = [0])\n",
    "\n",
    "    \n",
    "\n",
    "    # fit and predict mbm\n",
    "    mbm = MBM(kernel_tch='poly', kernel_stud='poly', p_tch=5, p_stud=5)\n",
    "    mbm.fit(x, y, m)\n",
    "    yh_mbm_id = mbm.predict(x_id)\n",
    "    yh_mbm_ood = mbm.predict(x_ood)\n",
    "    mbm_res =  pd.DataFrame({'i': i, 'n': n_train,\n",
    "                             'model': 'MBM (Ours)',\n",
    "                             'mse_ood': np.sqrt(np.mean((yh_mbm_ood - y_ood)**2)),\n",
    "                             'mse_id': np.sqrt(np.mean((yh_mbm_id - y_id)**2)),\n",
    "                             'm_dim': m_dim}, index = [0])\n",
    "\n",
    "\n",
    "\n",
    "    res = pd.concat([tipmi_xf_res, mbm_res], ignore_index=True)\n",
    "\n",
    "    return res"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6585e30",
   "metadata": {},
   "source": [
    "## Run anticausal with shift"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bb14fb0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "n_iters = 100\n",
    "n_train= 500\n",
    "m_dims = [1, 2, 5, 10, 20, 30, 40, 50]\n",
    "\n",
    "# --- Run in parallel \n",
    "results = Parallel(n_jobs=-1, backend=\"loky\")(\n",
    "    delayed(main_anti_causal_shifted)(n_train, i, m_dim=m_dim)\n",
    "    for i in range(n_iters)\n",
    "    for m_dim in tqdm(m_dims)\n",
    ")\n",
    "\n",
    "df = pd.concat(results, ignore_index = True)\n",
    "df.to_csv(f'sim_n_train500_anticausal_shift.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8c24a502",
   "metadata": {},
   "outputs": [],
   "source": [
    "agg = df.copy()\n",
    "agg.drop(['i', 'n'], axis = 1, inplace = True)\n",
    "\n",
    "# Compute quantiles for both mse_id and mse_ood\n",
    "quantiles = (\n",
    "    agg.groupby(['model', 'm_dim'])[['mse_id', 'mse_ood']]\n",
    "       .quantile([0.25, 0.5, 0.75])\n",
    "       .unstack(level=2)\n",
    "       .reset_index()\n",
    ")\n",
    "\n",
    "# Flatten the MultiIndex columns\n",
    "quantiles.columns = ['model', 'm_dim',\n",
    "                     'mse_id_q25', 'mse_id_median', 'mse_id_q75',\n",
    "                     'mse_ood_q25', 'mse_ood_median', 'mse_ood_q75']\n",
    "\n",
    "quantiles = quantiles[(quantiles.m_dim.isin([1, 2, 5, 10, 20, 30, 40, 50]))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f3692e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rc('text', usetex=True)\n",
    "plt.rc('text.latex', preamble=r'\\usepackage{amsmath}')\n",
    "plt.rc(\"font\", weight='bold')\n",
    "plt.rc(\"axes\", labelweight='bold')\n",
    "plt.rcParams['text.latex.preamble'] = r'\\usepackage{sfmath} \\boldmath'\n",
    "\n",
    "\n",
    "def add_break_marks(ax, where=\"bottom\", size=0.015, **linekw):\n",
    "    fig = ax.figure\n",
    "    fig.canvas.draw_idle()\n",
    "\n",
    "    bbox = ax.get_position()\n",
    "    dx = size\n",
    "    dy = size * (bbox.width / bbox.height)\n",
    "\n",
    "    if where == \"bottom\":\n",
    "        y = (-dy, +dy)\n",
    "    elif where == \"top\":\n",
    "        y = (1 - dy, 1 + dy)\n",
    "\n",
    "    kw = dict(transform=ax.transAxes, clip_on=False)\n",
    "    kw.update(linekw)\n",
    "\n",
    "    ax.plot((-dx, +dx), y, **kw)\n",
    "    ax.plot((1 - dx, 1 + dx), y, **kw)\n",
    "\n",
    "\n",
    "models = ['TIPMI (Ours)', 'MBM (Ours)']\n",
    "model_colors = {'TIPMI (Ours)': '#882255', 'MBM (Ours)': '#88CCEE'}\n",
    "\n",
    "m_dims = list(quantiles.m_dim.unique())\n",
    "x = list(range(len(m_dims)))\n",
    "\n",
    "fig, (ax_upper, ax_lower) = plt.subplots(\n",
    "    2, 1, sharex=True,\n",
    "    gridspec_kw={'height_ratios': [1, 1]},\n",
    "    figsize=(6, 4)\n",
    ")\n",
    "\n",
    "for model in models:\n",
    "    color = model_colors[model]\n",
    "    model_data = quantiles[quantiles.model == model]\n",
    "    median = model_data['mse_ood_median'].values\n",
    "    q25 = model_data['mse_ood_q25'].values\n",
    "    q75 = model_data['mse_ood_q75'].values\n",
    "\n",
    "    ax_upper.plot(x, median, label=model, color=color, marker='o')\n",
    "    ax_upper.fill_between(x, q25, q75, color=color, alpha=0.2)\n",
    "\n",
    "    ax_lower.plot(x, median, label=model, color=color, marker='o')\n",
    "    ax_lower.fill_between(x, q25, q75, color=color, alpha=0.2)\n",
    "\n",
    "\n",
    "ax_lower.set_ylim(0.8, 2.5)\n",
    "ax_upper.set_ylim(3.0, 30.0)\n",
    "\n",
    "ax_upper.spines['bottom'].set_visible(False)\n",
    "ax_lower.spines['top'].set_visible(False)\n",
    "ax_upper.tick_params(labeltop=False)\n",
    "ax_upper.grid(True, which=\"both\", linestyle=\"--\", linewidth=0.5)\n",
    "ax_lower.grid(True, which=\"both\", linestyle=\"--\", linewidth=0.5)\n",
    "\n",
    "plt.tight_layout()\n",
    "fig.canvas.draw_idle()\n",
    "add_break_marks(ax_upper, where=\"bottom\", size=0.015, color='k', linewidth=1.2)\n",
    "add_break_marks(ax_lower, where=\"top\",    size=0.015, color='k', linewidth=1.2)\n",
    "\n",
    "ax_lower.set_xlabel(r'\\textbf{Mediator dimension size}', fontsize=14)\n",
    "ax_upper.set_ylabel(r'\\textbf{Root Mean Squared Error (RMSE)}', fontsize=14)\n",
    "ax_upper.yaxis.set_label_coords(-0.1, 0.0)\n",
    "ax_upper.set_title(r'\\textbf{RMSE vs. Mediator Dimension Size (Median with IQR Bands)}', fontsize=14)\n",
    "\n",
    "ax_lower.set_xticks(x)\n",
    "ax_lower.set_xticklabels(m_dims)\n",
    "\n",
    "ax_lower.tick_params(axis='both', labelsize=14) \n",
    "ax_upper.tick_params(axis='both', labelsize=14) \n",
    "\n",
    "ax_upper.legend(fontsize=14, ncol=3)\n",
    "\n",
    "plt.savefig('tipmi_sim.pdf', dpi=300, bbox_inches='tight', pad_inches=0)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
