{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "from search.FCMBased.lingam import CAMUV\n",
    "from sklearn import linear_model\n",
    "from sklearn.model_selection import train_test_split\n",
    "from utils_PartDAG import sample_intervention_right, MLP, LR, IPS, Model_Multi\n",
    "from causallearn.search.ConstraintBased.PC import pc\n",
    "from demo_SCORE import *\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "def factorial_sum(n):\n",
    "    return 1 if n < 2 else n + factorial_sum(n - 1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def BFS_node(L, adj):\n",
    "    visit = []\n",
    "    while len(L) > 0:\n",
    "        i = L[0]\n",
    "        del L[0]    \n",
    "        if i not in visit:\n",
    "            visit.append(i)\n",
    "            for j in np.nonzero(adj[i, :])[0]:\n",
    "                if j not in visit and j not in L:\n",
    "                    L.append(j)\n",
    "    return visit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_graph(Graph):\n",
    "    Graph_new_fair_relax = np.zeros_like(Graph)\n",
    "    Graph_new_fair = np.zeros_like(Graph)\n",
    "    Graph_new_ours = np.zeros_like(Graph)\n",
    "    for i in range(Graph.shape[0]):\n",
    "        for j in range(i + 1, Graph.shape[1]):\n",
    "            if Graph[i, j] * Graph[j, i] > 0 and Graph[i, j] < 0: # undirected edge\n",
    "                Graph_new_fair[i, j] = 1\n",
    "                Graph_new_fair[j, i] = 1\n",
    "                Graph_new_ours[i, j] = 2\n",
    "                Graph_new_ours[j, i] = 2\n",
    "            elif Graph[i, j] * Graph[j, i] < 0 and Graph[i, j] < 0: # i -> j\n",
    "                Graph_new_fair_relax[i, j] = 1\n",
    "                Graph_new_fair[i, j] = 1\n",
    "                Graph_new_ours[i, j] = 1\n",
    "            elif Graph[i, j] * Graph[j, i] > 0 and Graph[i, j] > 0: # double side direction\n",
    "                Graph_new_fair_relax[i, j] = 1\n",
    "                Graph_new_fair[i, j] = 1             \n",
    "                Graph_new_ours[i, j] = 1\n",
    "                Graph_new_fair_relax[j, i] = 1\n",
    "                Graph_new_fair[j, i] = 1             \n",
    "                Graph_new_ours[j, i] = 1                      \n",
    "            elif Graph[i, j] * Graph[j, i] < 0 and Graph[i, j] > 0: # j -> i\n",
    "                Graph_new_fair_relax[j, i] = 1\n",
    "                Graph_new_fair[j, i] = 1 \n",
    "                Graph_new_ours[j, i] = 1\n",
    "    return Graph_new_fair, Graph_new_fair_relax, Graph_new_ours"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1.5\n",
    "\n",
    "Oracle_RMSE = [[], [], [], []]\n",
    "\n",
    "Unaware_RMSE = [[], [], [], []]\n",
    "Unaware_fair = [[], [], [], []]   \n",
    "Full_RMSE = [[], [], [], []]\n",
    "Full_fair = [[], [], [], []]\n",
    "\n",
    "Fair_RMSE = [[], [], [], []]\n",
    "Fair_fair = [[], [], [], []]\n",
    "Fair_relax_RMSE = [[], [], [], []]\n",
    "Fair_relax_fair = [[], [], [], []]\n",
    "\n",
    "Ours_RMSE = [[], [], [], []]\n",
    "Ours_fair = [[], [], [], []]\n",
    "for d in [10, 20, 30, 40]:\n",
    "    if d == 20:\n",
    "        print(d)\n",
    "        print('Ours', np.mean(Ours_RMSE[0]), np.std(Ours_RMSE[0]), \n",
    "              np.mean(Ours_fair[0]), np.std(Ours_fair[0]))\n",
    "        print('Oracle', np.mean(Oracle_RMSE[0]), np.std(Oracle_RMSE[0]), 0.0, 0.0)\n",
    "        print('Unaware', np.mean(Unaware_RMSE[0]), np.std(Unaware_RMSE[0]), np.mean(Unaware_fair[0]), np.std(Unaware_fair[0]))\n",
    "        print('Full', np.mean(Full_RMSE[0]), np.std(Full_RMSE[0]), np.mean(Full_fair[0]), np.std(Full_fair[0]))\n",
    "        print('Fair', np.mean(Fair_RMSE[0]), np.std(Fair_RMSE[0]), np.mean(Fair_fair[0]), np.std(Fair_fair[0]))\n",
    "        print('Fair_relax', np.mean(Fair_relax_RMSE[0]), np.std(Fair_relax_RMSE[0]), np.mean(Fair_relax_fair[0]), np.std(Fair_relax_fair[0]))\n",
    "        \n",
    "    if d == 30:\n",
    "        print(d)\n",
    "        print('Ours', np.mean(Ours_RMSE[1]), np.std(Ours_RMSE[1]), \n",
    "              np.mean(Ours_fair[1]), np.std(Ours_fair[1]))        \n",
    "        print('Oracle', np.mean(Oracle_RMSE[1]), np.std(Oracle_RMSE[1]), 0.0, 0.0)\n",
    "        print('Unaware', np.mean(Unaware_RMSE[1]), np.std(Unaware_RMSE[1]), np.mean(Unaware_fair[1]), np.std(Unaware_fair[1]))\n",
    "        print('Full', np.mean(Full_RMSE[1]), np.std(Full_RMSE[1]), np.mean(Full_fair[1]), np.std(Full_fair[1]))\n",
    "        print('Fair', np.mean(Fair_RMSE[1]), np.std(Fair_RMSE[1]), np.mean(Fair_fair[1]), np.std(Fair_fair[1]))\n",
    "        print('Fair_relax', np.mean(Fair_relax_RMSE[1]), np.std(Fair_relax_RMSE[1]), np.mean(Fair_relax_fair[1]), np.std(Fair_relax_fair[1]))\n",
    "        \n",
    "    if d == 40:\n",
    "        print(d)\n",
    "        print('Ours', np.mean(Ours_RMSE[2]), np.std(Ours_RMSE[2]), \n",
    "              np.mean(Ours_fair[2]), np.std(Ours_fair[2]))\n",
    "        print('Oracle', np.mean(Oracle_RMSE[2]), np.std(Oracle_RMSE[2]), 0.0, 0.0)\n",
    "        print('Unaware', np.mean(Unaware_RMSE[2]), np.std(Unaware_RMSE[2]), np.mean(Unaware_fair[2]), np.std(Unaware_fair[2]))\n",
    "        print('Full', np.mean(Full_RMSE[2]), np.std(Full_RMSE[2]), np.mean(Full_fair[2]), np.std(Full_fair[2]))\n",
    "        print('Fair', np.mean(Fair_RMSE[2]), np.std(Fair_RMSE[2]), np.mean(Fair_fair[2]), np.std(Fair_fair[2]))\n",
    "        print('Fair_relax', np.mean(Fair_relax_RMSE[2]), np.std(Fair_relax_RMSE[2]), np.mean(Fair_relax_fair[2]), np.std(Fair_relax_fair[2]))         \n",
    "    node_names = list(np.arange(d))        \n",
    "\n",
    "    \n",
    "    for num in range(50):\n",
    "        X, adj, noise, beta, Y, node = generate_right(d, 2*d, 1000, noise_std = 1.5, GP=True)\n",
    "        X = np.array(X)\n",
    "        cg = pc(X, alpha = 0.05, node_names = node_names)\n",
    "        \n",
    "        L = list(np.nonzero(adj[node, :])[0])\n",
    "        visit = BFS_node(L, adj)\n",
    "        visit.append(node)\n",
    "        \n",
    "        inter_value = (1 - X[:, node])\n",
    "        Inter_X = sample_intervention_right(d = d, n =1000, noise = noise, beta = beta, inter_value = inter_value, node = node, Y = Y, adj = adj)\n",
    "\n",
    "        oracle_node = np.array(list(set(np.arange(d)) - set(visit) - set({Y}) - set({node})))\n",
    "        full_node = np.array(list(set(np.arange(d)) - set({Y})))\n",
    "        unaware_node = np.array(list(set(np.arange(d)) - set({node}) - set({Y})))\n",
    "\n",
    "        Inter_X = np.array(Inter_X)\n",
    "        total_X = np.c_[X, Inter_X]\n",
    "        node_Y = X[:, Y]\n",
    "        X_train, X_test, Y_train, Y_test = train_test_split(total_X, node_Y, test_size = 0.2, random_state = 0)\n",
    "        X_train = X_train[:, :X.shape[1]]\n",
    "        X_test_inter = X_test[:, X.shape[1]:].copy()\n",
    "        X_test = X_test[:, :X.shape[1]]\n",
    "\n",
    "        X_train_del = np.delete(X_train, [node, Y], 1)\n",
    "        X_test_del = np.delete(X_test, [node, Y], 1)\n",
    "        X_test_inter_del = np.delete(X_test_inter, [node, Y], 1)\n",
    "\n",
    "        if len(oracle_node) == 0:\n",
    "            test_pred_oracle = 0\n",
    "        else:\n",
    "            oracle_X_train = X_train[:, oracle_node]\n",
    "            oracle_X_test = X_test[:, oracle_node]\n",
    "            lm_oracle = linear_model.LinearRegression()\n",
    "            lm_oracle.fit(oracle_X_train, Y_train)\n",
    "            test_pred_oracle  = lm_oracle.predict(oracle_X_test)\n",
    "\n",
    "        Oracle_RMSE[int(d/10 - 1)].append(np.sqrt(np.mean((test_pred_oracle-Y_test)**2)))        \n",
    "        \n",
    "        full_X_train = X_train[:, full_node]\n",
    "        full_X_test = X_test[:, full_node]\n",
    "        lm_full = linear_model.LinearRegression()\n",
    "        lm_full.fit(full_X_train, Y_train)\n",
    "        test_pred_full  = lm_full.predict(full_X_test)\n",
    "\n",
    "        Full_RMSE[int(d/10 - 1)].append(np.sqrt(np.mean((test_pred_full-Y_test)**2)))\n",
    "\n",
    "        full_X_train_inter = X_test_inter[:, full_node]\n",
    "        test_pred_full_inter = lm_full.predict(full_X_train_inter)\n",
    "        Full_fair[int(d/10 - 1)].append(np.sqrt(np.mean((test_pred_full-test_pred_full_inter)**2)))        \n",
    "            \n",
    "        unaware_X_train = X_train[:, unaware_node]\n",
    "        unaware_X_test = X_test[:, unaware_node]\n",
    "        lm_unaware = linear_model.LinearRegression()\n",
    "        lm_unaware.fit(unaware_X_train, Y_train)\n",
    "        test_pred_unaware  = lm_unaware.predict(unaware_X_test)\n",
    "\n",
    "        Unaware_RMSE[int(d/10 - 1)].append(np.sqrt(np.mean((test_pred_unaware-Y_test)**2)))\n",
    "\n",
    "        unaware_X_train_inter = X_test_inter[:, unaware_node]\n",
    "        test_pred_unaware_inter = lm_unaware.predict(unaware_X_train_inter) \n",
    "        Unaware_fair[int(d/10 - 1)].append(np.sqrt(np.mean((test_pred_unaware-test_pred_unaware_inter)**2)))             \n",
    "            \n",
    "        visit_fair_relax = []\n",
    "        visit_fair = []\n",
    "        Graph = cg.G.graph\n",
    "        Graph_new_fair, Graph_new_fair_relax, Graph_new_ours = generate_graph(Graph)\n",
    "\n",
    "        List_2 = list(np.nonzero(Graph_new_ours[:, node] == 2)[0]) # possible parent node\n",
    "        List_1 = list(np.nonzero(Graph_new_ours[:, node] == 1)[0]) # exist parent node\n",
    "        print('exist parent node', List_1, 'possible parent node', List_2)\n",
    "        possible_backdoor = []\n",
    "        if List_1 == []:\n",
    "            possible_backdoor = List_2.copy()\n",
    "        elif List_2 == []:\n",
    "            possible_backdoor = List_1\n",
    "        else:\n",
    "            for i in List_2:\n",
    "                for j in range(len(List_1)):\n",
    "                    if i in np.nonzero(Graph_new_ours[List_1[j], :])[0] or i in np.nonzero(Graph_new_ours[:, List_1[j]])[0]:\n",
    "                        if j == len(List_1) - 1:\n",
    "                            possible_backdoor.append(i)\n",
    "                    else:\n",
    "                        break\n",
    "\n",
    "        if len(possible_backdoor) > 1 and possible_backdoor != List_1:\n",
    "            N = len(possible_backdoor)\n",
    "            ALL_possible = []\n",
    "            for i in range(2 ** N):\n",
    "                combo = []\n",
    "                for j in range(N):\n",
    "                    if(i>>j)%2:\n",
    "                        combo.append(possible_backdoor[j])\n",
    "                ALL_possible.append(combo)\n",
    "\n",
    "            for i in ALL_possible:\n",
    "                if len(i) > 1:\n",
    "                    temp = Graph_new_ours[:, i]\n",
    "                    temp = temp[i, :]\n",
    "                    if (len(np.nonzero(np.triu(temp))[0]) == factorial_sum(len(i)) - len(i)) or (len(np.nonzero(np.tril(temp))[0]) == factorial_sum(len(i)) - len(i)):\n",
    "                        possible_backdoor.append(i)                 \n",
    "\n",
    "        if possible_backdoor == List_1 and List_1 != []:\n",
    "            possible_backdoor = [possible_backdoor]\n",
    "        else:\n",
    "            for i in range(len(possible_backdoor)):\n",
    "                if type(possible_backdoor[i]) is not list:\n",
    "                    possible_backdoor[i] = [possible_backdoor[i]]\n",
    "                    \n",
    "        if List_1 != [] and possible_backdoor != [List_1]:\n",
    "            for i in possible_backdoor:\n",
    "                i.extend(List_1)\n",
    "            possible_backdoor.insert(0, List_1)\n",
    "\n",
    "        print('possible set', possible_backdoor) # possible set\n",
    "        print('real set', np.nonzero(adj[:, node])) # real set\n",
    "\n",
    "        fair_relax = list(np.nonzero(Graph_new_fair_relax[node, :])[0])    \n",
    "        fair = list(np.nonzero(Graph_new_fair[node, :])[0])\n",
    "\n",
    "        visit_fair_relax = BFS_node(fair_relax, Graph_new_fair_relax)\n",
    "        visit_fair_relax.append(node)\n",
    "        visit_fair = BFS_node(fair, Graph_new_fair)\n",
    "        visit_fair.append(node)\n",
    "\n",
    "        fair_node = np.array(list(set(np.arange(d)) - set(visit_fair) - set({Y}) - set({node})))\n",
    "        fair_relax_node = np.array(list(set(np.arange(d)) - set(visit_fair_relax) - set({Y}) - set({node})))\n",
    "\n",
    "        propensity = np.zeros(X_train.shape[0])\n",
    "\n",
    "        if possible_backdoor == []: \n",
    "            propensity[X_train[:, node] == 0] = np.mean(X_train[:, node] == 0)\n",
    "            propensity[X_train[:, node] == 1] = np.mean(X_train[:, node] == 1)   \n",
    "        else:    \n",
    "            j = 0\n",
    "            for i in possible_backdoor:\n",
    "                prop_estimate_model = MLP(X_train[:, i].reshape(X_train.shape[0], -1), X_train[:, node])\n",
    "                prop_estimate_model.cuda()\n",
    "                prop_estimate_model.fit(lamb = 1e-4)\n",
    "                pred = prop_estimate_model.predict(X_train[:, i].reshape(X_train.shape[0], -1))\n",
    "                if j == 0:\n",
    "                    j += 1\n",
    "                    propensity = pred\n",
    "                else:\n",
    "                    j += 1\n",
    "                    propensity = np.c_[propensity, pred]\n",
    "\n",
    "        propensity = np.clip(propensity.reshape(X_train.shape[0], -1), 0.2, 0.8)\n",
    "\n",
    "        if len(fair_node) == 0:\n",
    "            test_pred_fair = 0\n",
    "        else:\n",
    "            fair_X_train = X_train[:, fair_node]\n",
    "            fair_X_test = X_test[:, fair_node]\n",
    "            lm_fair = linear_model.LinearRegression()\n",
    "            lm_fair.fit(fair_X_train, Y_train)\n",
    "            test_pred_fair  = lm_fair.predict(fair_X_test)\n",
    "\n",
    "        Fair_RMSE[int(d/10 - 1)].append(np.sqrt(np.mean((test_pred_fair-Y_test)**2)))\n",
    "\n",
    "        if len(fair_relax_node) == 0:\n",
    "            test_pred_fair_relax = 0\n",
    "        else:\n",
    "            fair_relax_X_train = X_train[:, fair_relax_node]\n",
    "            fair_relax_X_test = X_test[:, fair_relax_node]\n",
    "            lm_fair_relax = linear_model.LinearRegression()\n",
    "            lm_fair_relax.fit(fair_relax_X_train, Y_train)\n",
    "            test_pred_fair_relax  = lm_fair_relax.predict(fair_relax_X_test)\n",
    "\n",
    "        Fair_relax_RMSE[int(d/10 - 1)].append(np.sqrt(np.mean((test_pred_fair_relax-Y_test)**2)))            \n",
    "\n",
    "        if len(fair_node) == 0:\n",
    "            test_pred_fair_inter = 0\n",
    "            Fair_fair[int(d/10 - 1)].append(0)\n",
    "        else:\n",
    "            fair_X_train_inter = X_test_inter[:, fair_node]\n",
    "            test_pred_fair_inter = lm_fair.predict(fair_X_train_inter)\n",
    "\n",
    "            Fair_fair[int(d/10 - 1)].append(np.sqrt(np.mean((test_pred_fair-test_pred_fair_inter)**2)))\n",
    "\n",
    "        if len(fair_relax_node) == 0:\n",
    "            test_pred_fair_relax_inter = 0\n",
    "            Fair_relax_fair[int(d/10 - 1)].append(0)\n",
    "        else:\n",
    "            fair_relax_X_train_inter = X_test_inter[:, fair_relax_node]\n",
    "            test_pred_fair_relax_inter = lm_fair_relax.predict(fair_relax_X_train_inter)\n",
    "     \n",
    "            Fair_relax_fair[int(d/10 - 1)].append(np.sqrt(np.mean((test_pred_fair_relax-test_pred_fair_relax_inter)**2)))            \n",
    "\n",
    "            model = Model_Multi(X_train_del, Y_train, X_train[:, node], propensity, output_dim = propensity.shape[1])\n",
    "            model.cuda()\n",
    "            model.fit(constrain = 0.2, alpha = 15, C = 0.03, lr = 0.01)\n",
    "\n",
    "            test_pred_ours  = model.LR.predict(X_test_del)\n",
    "\n",
    "            Ours_RMSE[int(d/10 - 1)].append(np.sqrt(np.mean((test_pred_ours-Y_test)**2)))\n",
    "\n",
    "            test_pred_ours_inter = model.LR.predict(X_test_inter_del)\n",
    "\n",
    "            Ours_fair[int(d/10 - 1)].append(np.sqrt(np.mean((test_pred_ours-test_pred_ours_inter)**2)))    \n",
    "print(d)\n",
    "print('Ours', np.mean(Ours_RMSE[3]), np.std(Ours_RMSE[3]), \n",
    "      np.mean(Ours_fair[3]), np.std(Ours_fair[3]))\n",
    "print('Oracle', np.mean(Oracle_RMSE[3]), np.std(Oracle_RMSE[3]), 0.0, 0.0)\n",
    "print('Unaware', np.mean(Unaware_RMSE[3]), np.std(Unaware_RMSE[3]), np.mean(Unaware_fair[3]), np.std(Unaware_fair[3]))\n",
    "print('Full', np.mean(Full_RMSE[3]), np.std(Full_RMSE[3]), np.mean(Full_fair[3]), np.std(Full_fair[3]))\n",
    "print('Fair', np.mean(Fair_RMSE[3]), np.std(Fair_RMSE[3]), np.mean(Fair_fair[3]), np.std(Fair_fair[3]))\n",
    "print('Fair_relax', np.mean(Fair_relax_RMSE[3]), np.std(Fair_relax_RMSE[3]), np.mean(Fair_relax_fair[3]), np.std(Fair_relax_fair[3]))                        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:graph]",
   "language": "python",
   "name": "conda-env-graph-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
