{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab2164dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "from itertools import cycle\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb9fac32",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dfTotensor(df):\n",
    "    r\"\"\"\n",
    "    Functs: - given a DataFrame, convert it into torch Tensor\n",
    "    \"\"\"\n",
    "    return (torch.from_numpy(df.values)).float()\n",
    "\n",
    "def print_nnmodule(f):\n",
    "    r\"\"\"\n",
    "    Functs: - print weight and bias of a nn.module\n",
    "    \"\"\"\n",
    "    if isinstance(f.function, nn.Linear):\n",
    "        print(f.function.weight)\n",
    "        print(f.function.bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fe4e906",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearF():\n",
    "    r\"\"\"\n",
    "    Functs: - nn.Linear(in_dim,out_dim)\n",
    "            - this is prepared for Y_hat = f_S_prime(S_prime, do M) and S_1=f_regen1(M,Y)\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, in_dim, baseinit, out_dim=1):\n",
    "        super(LinearF, self).__init__()\n",
    "\n",
    "        self.in_dim = in_dim\n",
    "        self.function = nn.Linear(in_dim, out_dim)\n",
    "        self.trained = False\n",
    "\n",
    "        nn.init.constant_(self.function.weight, 1)\n",
    "        nn.init.constant_(self.function.bias, 0)\n",
    "\n",
    "    def fit(self, covariates, target, num_iters, lr):\n",
    "        r\"\"\"\n",
    "        Functs: - train the model in covariates~target\n",
    "        \"\"\"\n",
    "        covariates = dfTotensor(covariates)\n",
    "        target = dfTotensor(target)\n",
    "\n",
    "        optimizer = torch.optim.SGD(self.function.parameters(), lr=lr)\n",
    "        loss_func = nn.MSELoss()\n",
    "\n",
    "        for itera in range(num_iters + 1):\n",
    "            prediction = self.function(covariates)\n",
    "            \n",
    "            loss = loss_func(prediction, target)\n",
    "            if itera % (num_iters // 2) == 0:\n",
    "                print('iteration: {:d}, loss: {:.7f}'.format(int(itera), float(loss)))\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "        self.trained = True\n",
    "        for param in self.function.parameters():\n",
    "            param.requires_grad = False\n",
    "    \n",
    "    def predict(self, covariates):\n",
    "        r\"\"\"\n",
    "        Functs: - a simple version of predict\n",
    "                - accept torch.Tensor as input and return torch.Tensor\n",
    "        \"\"\"\n",
    "        assert self.trained, 'LinearF must be trained befored prediction'\n",
    "\n",
    "        return self.function(covariates)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3bd1940",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleLinearF(nn.Module):\n",
    "    r\"\"\"\n",
    "    Functs: - a simple wrapper for nn.Linear()\n",
    "            - it does NOT have .fit() or .predict()\n",
    "            - just define FCs and a forward() function\n",
    "            - this is specically prepared for the estimation of pred_M = f_J_theta(PA_M)\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, baseinit, in_dim=1, out_dim=1):\n",
    "        super(SimpleLinearF, self).__init__()\n",
    "\n",
    "        self.function = nn.Linear(in_dim, out_dim)\n",
    "        self.trained = False\n",
    "\n",
    "        if baseinit:\n",
    "            nn.init.constant_(self.function.weight, 1)\n",
    "            nn.init.constant_(self.function.bias, 0)\n",
    "\n",
    "    def forward(self, covariates):\n",
    "        return self.function(covariates)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e851352",
   "metadata": {},
   "outputs": [],
   "source": [
    "class OptNodeSets():\n",
    "    r\"\"\"\n",
    "    Under the simulation graph Fig-2\n",
    "    \n",
    "    since we remove Y -> S_1, serveral things need to modify:\n",
    "    - S_1 = f_regen1(M)\n",
    "    - in the estimation of f_S_prime, X_do_star becomes {M,S_1}, \n",
    "      so we only need to shuffle X_do_star, no regeneration is needed\n",
    "    - in the estimation of hstar, when regenerate S_1, note Y is not needed anymore\n",
    "    \n",
    "    Functs: - For a given training split, e.g. 12589, and a given setting, e.g. setting1\n",
    "                1. read from BASE/Example4/12589.csv\n",
    "                2. for a given S' in S_all\n",
    "                    - Estimation of f_S'\n",
    "                        1. estimate S_1 = f_regen1(M)\n",
    "                        2. shffle X_do* = {S_1,M} and no regeneration\n",
    "                        3. train Y = f_S'(S',do(M)) in the shuffled samples\n",
    "                    - Estimation of h*(S')\n",
    "                        1. generate samples from P(J_theta)\n",
    "                             - replace M by J_theta(PA_M = Y)\n",
    "                             - regenerate S_1 by f_regen1(M)\n",
    "                        2. calculate Y_hat = f_S'(S',do(M)), where S',M~P(J_theta)\n",
    "                        3. compute negMSELoss -||Y-Y_hat|| and optimize over \\theta\n",
    "                3. for the given S' and the trained f_S'\n",
    "                    - read 5 test-sets from BASE/Example4/12589/X.csv\n",
    "                    - Y_hat = f_S*(X_S*,X_M)\n",
    "                    - compute maxMSEError ||Y-Y_hat|| among the 15 test-sets\n",
    "                4. Return [key=S']: {negMSELoss, maxMSEError}\n",
    "    \n",
    "            - we hope to, in this simulation, verify:\n",
    "                1.(least request) S*=S_all\n",
    "                2.(necessory) the estimated generalized error ~= the real worst case error on test-sets \n",
    "                              (that's what we fail to show in ADNI and IMPCgene, since in real world, \n",
    "                              number of test sets are limited and we may not have the worst case test-sets)\n",
    "                \n",
    "    NOTE: - input order to f_regen, f_S_prime, and f_J_theta must keep the same all the time\n",
    "          - we always use the order of S_i, M, Y \n",
    "    \"\"\"\n",
    "    def __init__(self, trainsplit, seed=1234, need_norm=False):\n",
    "        \n",
    "        self.trainsplit = trainsplit \n",
    "        BASE = '/home/anonymous/data/CausallyInvariant_output/Simulation/FindOptSets/Example4'\n",
    "        \n",
    "        trainfilename = os.path.join(BASE, '{}.csv'.format(trainsplit))\n",
    "        self.trainDF = pd.read_csv(trainfilename)\n",
    "        self.seed = seed\n",
    "        \n",
    "        self.need_norm = need_norm\n",
    "        if self.need_norm:\n",
    "            for var in ['S_1','S_2','S_3','M','Y']:\n",
    "                mean = self.trainDF[[var]].mean().values[0]\n",
    "                std = self.trainDF[[var]].std().values[0]\n",
    "                self.trainDF[[var]] = (self.trainDF[[var]] - mean) / std\n",
    "        \n",
    "        self.testfolder = os.path.join(BASE, '{}'.format(trainsplit))\n",
    "        \n",
    "        self.f_regen1 = LinearF(in_dim=1, baseinit=True)\n",
    "        self.estimate_f_regen1()\n",
    "        \n",
    "        \n",
    "    def estimate_f_regen1(self, ):\n",
    "        r\"\"\"\n",
    "        Functs: - learn S_1 = self.f_regen1(M)\n",
    "                - we have checked the learn params are 1*M, estimation of f_regen1 is okay\n",
    "        \"\"\"\n",
    "        print('Estimating f_regen1 ...')\n",
    "\n",
    "        X = self.trainDF[['M']]\n",
    "        Y = self.trainDF[['S_1']]\n",
    "\n",
    "        self.f_regen1.fit(X, Y, num_iters=1000, lr=0.01)\n",
    "        \n",
    "        \n",
    "    def estimate_f_S_prime(self, ):\n",
    "        r\"\"\"\n",
    "        Functs: - sample from p* by shuffle X_do*={S_1,M}\n",
    "                - no regeneration is needed\n",
    "                - train Y=f_S_prime(S_prime,do X_M) in p*\n",
    "        \"\"\"\n",
    "        print('Estimating f_S_prime ...')\n",
    "        \n",
    "        # shuffle @X_do*={S_1,M}\n",
    "        shufTrainDF = self.trainDF.copy()\n",
    "        shufTrainDF.loc[:, 'S_1'] = shufTrainDF.loc[:, 'S_1'].sample(frac=1, random_state=self.seed).values\n",
    "        shufTrainDF.loc[:, 'M'] = shufTrainDF.loc[:, 'M'].sample(frac=1, random_state=self.seed).values\n",
    "        \n",
    "        # no regeneration is needed, so we commen the below lines\n",
    "        #M = dfTotensor(shufTrainDF[['M']])\n",
    "        #Y = dfTotensor(shufTrainDF[['Y']])\n",
    "        #pred = self.f_regen1.predict(torch.cat([M,Y],dim=1))\n",
    "        #shufTrainDF.loc[:, 'S_1'] = pred.detach().numpy()\n",
    "        \n",
    "        # train Y=f_S_prime(S',do M)\n",
    "        XX = shufTrainDF[self.S_prime + ['M']]\n",
    "        YY = shufTrainDF[['Y']]\n",
    "        \n",
    "        numiters = 2000\n",
    "        \n",
    "        self.f_S_prime.fit(XX, YY, num_iters=numiters, lr=0.001)    \n",
    "        \n",
    "    def estimate_hstar_S_prime(self, num_iters, lr):\n",
    "        r\"\"\"\n",
    "        Functs: - estimate h*(S')\n",
    "                1. generate samples from P(J_theta)\n",
    "                    - replace M by f_J_theta(Y)\n",
    "                    - regenerate their descent by S_1=f_regen1(M)\n",
    "                2. calculate Y_hat = f_S_prime(S', do M)\n",
    "                3. compute negMSELoss = -||Y-Y_hat|| and optimize over \\theta\n",
    "        \"\"\"\n",
    "        print('Estimating hstar_S_prime ...')\n",
    "        \n",
    "        # since S2,S3,Y keep unchange in P_J_theta, we extract them firstly\n",
    "        S_2 = dfTotensor(self.trainDF[['S_2']])\n",
    "        S_3 = dfTotensor(self.trainDF[['S_3']])\n",
    "        Y = dfTotensor(self.trainDF[['Y']])\n",
    "        \n",
    "        optimizer = torch.optim.SGD(self.f_J_theta.parameters(), lr=lr)\n",
    "        loss_func = nn.MSELoss()\n",
    "        loss_log = list()\n",
    "        \n",
    "        for itera in range(num_iters + 1):\n",
    "            # replace M by f_J_theta(PA_XM)\n",
    "            pred_M = self.f_J_theta(Y)\n",
    "            # regenerate S_1 by f_regen1\n",
    "            pred_S_1 = self.f_regen1.predict(pred_M)\n",
    "            \n",
    "            # predict Y by f_S_prime\n",
    "            S_all = {'S_1':pred_S_1, 'S_2':S_2, 'S_3':S_3}\n",
    "            \n",
    "            S_prime_M = list()\n",
    "            for S in self.S_prime:\n",
    "                S_prime_M.append(S_all[S])\n",
    "                \n",
    "            S_prime_M.append(pred_M)\n",
    "            \n",
    "            # prediction\n",
    "            pred_Y = self.f_S_prime.predict(torch.cat(S_prime_M, dim=1))\n",
    "            \n",
    "            # we want to maximize the loss, so add a negative\n",
    "            loss = - loss_func(Y, pred_Y)\n",
    "            \n",
    "            if itera % (num_iters // 10) == 0:\n",
    "                loss_log.append(- float(loss.detach()))\n",
    "            \n",
    "            if itera % (num_iters // 2) == 0:\n",
    "                print('iteration: {:d}, loss: {:.7f}'.format(int(itera), - float(loss)))\n",
    "                \n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "        self.f_J_theta.trained = True\n",
    "        for param in self.f_J_theta.function.parameters():\n",
    "            param.requires_grad = False\n",
    "            \n",
    "        return loss_log\n",
    "        \n",
    "    def estimate(self, S_prime):\n",
    "        r\"\"\"\n",
    "        Functs: - in this function, for a given S_prime (a list of str), we estimate f_S_prime and f_J_theta\n",
    "        \"\"\"\n",
    "        assert 'M' not in S_prime, 'Mistake: S_prime must not contain any M'\n",
    "        \n",
    "        print('Train_split: {}, S_prime: {}'.format(self.trainsplit, ','.join(S_prime)))\n",
    "        self.S_prime = S_prime\n",
    "        \n",
    "        # Y = f_S_prime(S_prime,do M)\n",
    "        self.f_S_prime = LinearF(in_dim=len(S_prime) + 1, baseinit=True)\n",
    "        self.estimate_f_S_prime()\n",
    "        \n",
    "        # optimize over one f_J_theta, SimpleLinearF default in_dim=1\n",
    "        self.f_J_theta = SimpleLinearF(baseinit=True)\n",
    "        negMSELosses = self.estimate_hstar_S_prime(num_iters=2000,lr=0.05)\n",
    "        \n",
    "        return negMSELosses\n",
    "    \n",
    "    def test(self, S_prime):\n",
    "        r\"\"\"\n",
    "        Functs: - after estimation on a given S_prime\n",
    "                - predict on test-sets with the trained f_S_prime\n",
    "                - record maxMSEErrors and return\n",
    "        \"\"\"\n",
    "        assert S_prime == self.S_prime\n",
    "        assert self.f_S_prime.trained\n",
    "        assert self.f_J_theta.trained\n",
    "        \n",
    "        # test\n",
    "        error_log = list()\n",
    "\n",
    "        for filename in os.listdir(self.testfolder):\n",
    "            testDF = pd.read_csv(os.path.join(self.testfolder, filename))\n",
    "            if self.need_norm:\n",
    "                # normalization\n",
    "                for var in ['S_1','S_2','S_3','M','Y']:\n",
    "                    mean = testDF[[var]].mean().values[0]\n",
    "                    std = testDF[[var]].std().values[0]\n",
    "                    testDF[[var]] = (testDF[[var]] - mean) / std\n",
    "                    \n",
    "            X_test = dfTotensor(testDF[self.S_prime + ['M']])\n",
    "            Y_test = dfTotensor(testDF[['Y']])\n",
    "\n",
    "            \n",
    "            Y_pred = self.f_S_prime.predict(X_test)\n",
    "            mse = torch.mean((Y_test - Y_pred) ** 2).item()\n",
    "            error_log.append(mse)\n",
    "        \n",
    "        return np.std(error_log)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9342f5a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "BASE = '/home/anonymous/data/CausallyInvariant_output/Simulation/FindOptSets/Example4'\n",
    "trainsplits = list()\n",
    "for filename in os.listdir(BASE):\n",
    "    if 'csv' not in filename and 'ipynb_checkpoints' not in filename:\n",
    "        trainsplits.append(filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a9ca431",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainsplits = random.sample(trainsplits,10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9393124",
   "metadata": {},
   "outputs": [],
   "source": [
    "S_prime_all = [[],['S_1'],['S_2'],['S_3'],\n",
    "              ['S_1','S_2'],['S_1','S_3'],['S_2','S_3'],\n",
    "              ['S_1','S_2','S_3']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da241fcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "recorder = dict()\n",
    "for S_prime in S_prime_all:\n",
    "    save_name = ','.join(S_prime) if len(S_prime)>0 else 'empty'\n",
    "    recorder[save_name] = dict()\n",
    "    recorder[save_name]['h_stars'] = list()\n",
    "    recorder[save_name]['test_errors'] = list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ac7d755",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for trainsplit in trainsplits:\n",
    "    optnodeset = OptNodeSets(trainsplit=trainsplit)\n",
    "    for S_prime in S_prime_all:\n",
    "        save_name = ','.join(S_prime) if len(S_prime)>0 else 'empty'\n",
    "        h_star = optnodeset.estimate(S_prime)\n",
    "        test_error = optnodeset.test(S_prime)\n",
    "\n",
    "        recorder[save_name]['h_stars'].append(np.array(h_star)[:, np.newaxis])\n",
    "        recorder[save_name]['test_errors'].append(test_error)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47b4933d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab33fe3b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (Pytorch)",
   "language": "python",
   "name": "pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
