{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.utils.data as data_utils\n",
    "from pathlib import Path\n",
    "from sklearn.metrics import r2_score\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "import os\n",
    "import numpy as np\n",
    "import logging.handlers\n",
    "import random\n",
    "import sys\n",
    "import sklearn\n",
    "\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "import statsmodels.api as sm\n",
    "import torch\n",
    "import torch.utils.data as utils\n",
    "\n",
    "from pathlib import Path\n",
    "\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.linear_model import Ridge, RidgeCV\n",
    "from sklearn.neural_network import MLPRegressor\n",
    "from sklearn.tree import DecisionTreeRegressor\n",
    "from sklearn.neighbors import KNeighborsRegressor\n",
    "from sklearn.metrics import accuracy_score, mean_squared_error\n",
    "\n",
    "import numpy as np\n",
    "import statsmodels.api as sm\n",
    "from statsmodels.stats.outliers_influence import variance_inflation_factor\n",
    "\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from sklearn.base import BaseEstimator, RegressorMixin\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "import yaml\n",
    "import argparse\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "sys.path.append('../')\n",
    "from VAE.model import CVAE, DCEVAE\n",
    "from dataset import SimLaw\n",
    "\n",
    "sns.set_theme()\n",
    "from copy import deepcopy\n",
    "from utils import erm_classifier, pcf_classifier, cfe_classifier, cfr_classifier, pcfaug_classifier\n",
    "from utils import infer_u, gen_x, pcf_mix, cf_eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args:\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    u_dim = 3\n",
    "    r_dim = 8 \n",
    "    d_dim = 2\n",
    "    act_fn = 'Tanh'\n",
    "    use_label = False\n",
    "    dataset = 'law'\n",
    "args = Args()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "FONTSIZE = 20\n",
    "METRIC_DICT = {'cf_effect':r'$TE$',\n",
    "                'cf_effect0':r'$TE_0$',\n",
    "                'cf_effect1':r'$TE_1$'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cvae_prepare_data(datasets, model):\n",
    "\n",
    "    data_dict = {}\n",
    "    model.eval()\n",
    "\n",
    "    for split in ['train', 'test']:\n",
    "        cur_data_dict = {}\n",
    "        dataset = datasets[split]\n",
    "\n",
    "        r,d,a,y,u = dataset.r, dataset.d, dataset.a, dataset.y, dataset.u\n",
    "        x = torch.cat([r,d], dim=1)\n",
    "        x_cf, a_cf, y_cf = dataset.x_cf, dataset.a_cf, dataset.y_cf \n",
    "        \n",
    "\n",
    "        x = x.to(DEVICE)\n",
    "        r = r.to(DEVICE)\n",
    "        d = d.to(DEVICE)\n",
    "        u = u.to(DEVICE)\n",
    "        y = y.to(DEVICE)\n",
    "        a = a.to(DEVICE)\n",
    "        x_cf = x_cf.to(DEVICE)\n",
    "        y_cf = y_cf.to(DEVICE)\n",
    "        a_cf = a_cf.to(DEVICE)\n",
    "\n",
    "        cur_data_dict['x'] = x\n",
    "        cur_data_dict['y'] = y \n",
    "        cur_data_dict['a'] = a\n",
    "        cur_data_dict['u'] = u\n",
    "        cur_data_dict['x_cf'] = x_cf    \n",
    "        cur_data_dict['y_cf'] = y_cf\n",
    "        cur_data_dict['a_cf'] = a_cf\n",
    "\n",
    "        u_hat = infer_u(model, r, d, a)\n",
    "        u_cf_hat = infer_u(model, x_cf[:,:3], x_cf[:,3:], a_cf)\n",
    "        x_cf_uhat = gen_x(model, u_hat, a_cf)\n",
    "        x_cf_cf_uhat = gen_x(model, u_cf_hat, a)\n",
    "\n",
    "\n",
    "        cur_data_dict['u_hat'] = u_hat\n",
    "        cur_data_dict['u_cf_hat'] = u_cf_hat\n",
    "        cur_data_dict['x_cf_uhat'] = x_cf_uhat\n",
    "        cur_data_dict['x_cf_cf_uhat'] = x_cf_cf_uhat\n",
    "        # if split == 'train':\n",
    "        #     cur_data_dict['y_cf_uhat']= y\n",
    "            \n",
    "        for key, data in cur_data_dict.items():\n",
    "            cur_data_dict[key] = data.detach().cpu().numpy()\n",
    "            \n",
    "        data_dict[split] = cur_data_dict\n",
    "        \n",
    "\n",
    "    return data_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ermsoft_classifier(data_dict, clf, alpha=1):\n",
    "    train_dat = data_dict[\"train\"]\n",
    "    test_dat = data_dict[\"test\"]\n",
    "    \n",
    "    inputs = np.concatenate([train_dat[\"x\"],\n",
    "                             train_dat['a']],axis=1)\n",
    "    y = train_dat[\"y\"].ravel()\n",
    "\n",
    "    clf.fit(inputs, y)\n",
    "    train_acc = mean_squared_error(y, clf.predict(inputs),squared=False)\n",
    "\n",
    "\n",
    "\n",
    "    y_factual = clf.predict(np.concatenate([test_dat[\"x\"],\n",
    "                                            test_dat['a']],axis=1))\n",
    "    \n",
    "    acc = mean_squared_error(test_dat[\"y\"].ravel(), y_factual.ravel(),squared=False)\n",
    "\n",
    "    y_counter = clf.predict(np.concatenate([test_dat[\"x_cf\"],\n",
    "                                            test_dat['a_cf']],axis=1))\n",
    "    a = test_dat[\"a\"]\n",
    "    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a)\n",
    "\n",
    "    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf\n",
    "\n",
    "def cfesoft_classifier(data_dict, clf, alpha=1):\n",
    "\n",
    "    train_dat = data_dict[\"train\"]\n",
    "    test_dat = data_dict[\"test\"]\n",
    "\n",
    "    clf_erm = deepcopy(clf)\n",
    "    clf = deepcopy(clf)\n",
    "    clf_erm.fit(np.concatenate([train_dat[\"x\"],\n",
    "                             train_dat['a']],axis=1), \n",
    "                train_dat[\"y\"].ravel())\n",
    "    \n",
    "    inputs =  train_dat[\"u_hat\"]\n",
    "    y = train_dat[\"y\"].ravel()\n",
    "    clf.fit(inputs, y)\n",
    "    train_acc = mean_squared_error(y, clf.predict(inputs),squared=False)\n",
    "\n",
    "\n",
    "    y_factual_cfe = clf.predict(test_dat[\"u_hat\"])\n",
    "    y_factual_erm = clf_erm.predict(np.concatenate([test_dat[\"x\"],\n",
    "                                            test_dat['a']],axis=1))\n",
    "    y_factual = alpha * y_factual_cfe + (1-alpha) * y_factual_erm\n",
    "    acc = mean_squared_error(test_dat[\"y\"].ravel(), y_factual.ravel(),squared=False)\n",
    "\n",
    "\n",
    "    y_counter_cfe = clf.predict(test_dat[\"u_cf_hat\"])\n",
    "    y_counter_erm = clf_erm.predict(np.concatenate([test_dat[\"x_cf\"],\n",
    "                                            test_dat['a_cf']],axis=1))\n",
    "    y_counter = alpha * y_counter_cfe + (1-alpha) * y_counter_erm\n",
    "    a = test_dat[\"a\"]\n",
    "    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a)\n",
    "\n",
    "    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf\n",
    "\n",
    "def cfrsoft_classifier(data_dict, clf, alpha=1):\n",
    "\n",
    "    train_dat = data_dict[\"train\"]\n",
    "    test_dat = data_dict[\"test\"]\n",
    "\n",
    "    clf_erm = deepcopy(clf)\n",
    "    clf = deepcopy(clf)\n",
    "    clf_erm.fit(np.concatenate([train_dat[\"x\"],\n",
    "                             train_dat['a']],axis=1), \n",
    "                train_dat[\"y\"].ravel())\n",
    "    \n",
    "    inputs = np.concatenate([train_dat[\"u_hat\"], \n",
    "                        (train_dat[\"x\"] + train_dat[\"x_cf_uhat\"]) / 2], axis=1)\n",
    "    y = train_dat[\"y\"].ravel()\n",
    "    clf.fit(inputs, y)\n",
    "    train_acc = mean_squared_error(y, clf.predict(inputs),squared=False)\n",
    "\n",
    "    y_factual_cfr = clf.predict(np.concatenate([\n",
    "        test_dat[\"u_hat\"],\n",
    "        (test_dat[\"x\"] + test_dat[\"x_cf_uhat\"]) / 2\n",
    "    ], axis=1))\n",
    "    y_factual_erm = clf_erm.predict(np.concatenate([test_dat[\"x\"],\n",
    "                                            test_dat['a']],axis=1))\n",
    "    y_factual = alpha * y_factual_cfr + (1-alpha) * y_factual_erm\n",
    "    acc = mean_squared_error(test_dat[\"y\"].ravel(), y_factual.ravel(),squared=False)\n",
    "\n",
    "    y_counter_cfr = clf.predict(np.concatenate([\n",
    "        test_dat[\"u_cf_hat\"],\n",
    "        (test_dat[\"x_cf\"] + test_dat[\"x_cf_cf_uhat\"]) / 2\n",
    "    ], axis=1))\n",
    "    a = test_dat[\"a\"]\n",
    "    y_counter_erm = clf_erm.predict(np.concatenate([test_dat[\"x_cf\"],\n",
    "                                            test_dat['a_cf']],axis=1))\n",
    "    y_counter = alpha * y_counter_cfr + (1-alpha) * y_counter_erm\n",
    "    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a)\n",
    "\n",
    "    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf\n",
    "\n",
    "\n",
    "def pcfsoft_classifier(data_dict, clf, alpha=1):\n",
    "    train_dat = data_dict[\"train\"]\n",
    "    test_dat = data_dict[\"test\"]\n",
    "\n",
    "    clf_erm = deepcopy(clf)\n",
    "    clf = deepcopy(clf)\n",
    "    clf_erm.fit(np.concatenate([train_dat[\"x\"],\n",
    "                             train_dat['a']],axis=1), \n",
    "                train_dat[\"y\"].ravel())\n",
    "    \n",
    "    inputs = np.concatenate([train_dat[\"x\"],\n",
    "                             train_dat['a']],axis=1)\n",
    "    y = train_dat[\"y\"].ravel()\n",
    "    clf.fit(inputs, y)\n",
    "    train_acc = mean_squared_error(y, clf.predict(inputs),squared=False)\n",
    "\n",
    "    # ======= factual pred ======= #\n",
    "    y_factual_score = clf.predict(np.concatenate([test_dat[\"x\"],\n",
    "                                                test_dat['a']],axis=1))\n",
    "    y_factual_cf_score = clf.predict(np.concatenate([test_dat[\"x_cf_uhat\"],\n",
    "                                                     test_dat['a_cf']],axis=1))\n",
    "    \n",
    "    y_factual_pcf = pcf_mix(y_factual_score, y_factual_cf_score, test_dat['a'].ravel())\n",
    "\n",
    "    y_factual_erm = clf_erm.predict(np.concatenate([test_dat[\"x\"],\n",
    "                                            test_dat['a']],axis=1))\n",
    "    y_factual = alpha * y_factual_pcf + (1-alpha) * y_factual_erm\n",
    "    acc = mean_squared_error(test_dat[\"y\"].ravel(), y_factual.ravel(), squared=False)\n",
    "\n",
    "    # ======= counter pred ======= #\n",
    "    y_counter_score = clf.predict(np.concatenate([test_dat[\"x_cf\"],\n",
    "                                                        test_dat['a_cf']],axis=1))\n",
    "    y_counter_cf_score = clf.predict(np.concatenate([test_dat[\"x_cf_cf_uhat\"],\n",
    "                                                        test_dat['a']],axis=1))\n",
    "    y_counter_pcf = pcf_mix(y_counter_score, y_counter_cf_score, test_dat['a_cf'].ravel(),is_cf=True)\n",
    "    y_counter_erm = clf_erm.predict(np.concatenate([test_dat[\"x_cf\"],\n",
    "                                            test_dat['a_cf']],axis=1))\n",
    "    y_counter = alpha * y_counter_pcf + (1-alpha) * y_counter_erm\n",
    "\n",
    "    a = test_dat[\"a\"]\n",
    "    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a)\n",
    "\n",
    "    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf\n",
    "\n",
    "def pcfaugsoft_classifier(data_dict, clf, alpha=1):\n",
    "    train_dat = data_dict[\"train\"]\n",
    "    test_dat = data_dict[\"test\"]\n",
    "    \n",
    "    clf_erm = deepcopy(clf)\n",
    "    clf = deepcopy(clf)\n",
    "    clf_erm.fit(np.concatenate([train_dat[\"x\"],\n",
    "                             train_dat['a']],axis=1), \n",
    "                train_dat[\"y\"].ravel())\n",
    "\n",
    "    inputs = np.concatenate([\n",
    "    np.concatenate([train_dat[\"x\"],\n",
    "                             train_dat['a']],axis=1),\n",
    "    np.concatenate([train_dat[\"x_cf_uhat\"],\n",
    "                             train_dat['a_cf']],axis=1)],axis=0)\n",
    "    y = np.concatenate([train_dat[\"y\"],train_dat[\"y\"]],axis=0).ravel()\n",
    "    clf.fit(inputs, y)\n",
    "\n",
    "    train_acc = mean_squared_error(y, clf.predict(inputs),squared=False)\n",
    "\n",
    "    # ======= factual pred ======= #\n",
    "    y_factual_score = clf.predict(np.concatenate([test_dat[\"x\"],\n",
    "                                                test_dat['a']],axis=1))\n",
    "    y_factual_cf_score = clf.predict(np.concatenate([test_dat[\"x_cf_uhat\"],\n",
    "                                                     test_dat['a_cf']],axis=1))\n",
    "    \n",
    "    y_factual_pcf = pcf_mix(y_factual_score, y_factual_cf_score, test_dat['a'].ravel())\n",
    "    \n",
    "    y_factual_erm = clf_erm.predict(np.concatenate([test_dat[\"x\"],\n",
    "                                            test_dat['a']],axis=1))\n",
    "    y_factual = alpha * y_factual_pcf + (1-alpha) * y_factual_erm\n",
    "    acc = mean_squared_error(test_dat[\"y\"].ravel(), y_factual.ravel(), squared=False)\n",
    "\n",
    "    # ======= counter pred ======= #\n",
    "    y_counter_score = clf.predict(np.concatenate([test_dat[\"x_cf\"],\n",
    "                                                        test_dat['a_cf']],axis=1))\n",
    "    y_counter_cf_score = clf.predict(np.concatenate([test_dat[\"x_cf_cf_uhat\"],\n",
    "                                                        test_dat['a']],axis=1))\n",
    "    y_counter_pcf = pcf_mix(y_counter_score, y_counter_cf_score, test_dat['a_cf'].ravel(),is_cf=True)\n",
    "    y_counter_erm = clf_erm.predict(np.concatenate([test_dat[\"x_cf\"],\n",
    "                                            test_dat['a_cf']],axis=1))\n",
    "    y_counter = alpha * y_counter_pcf + (1-alpha) * y_counter_erm\n",
    "\n",
    "    a = test_dat[\"a\"]\n",
    "    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a)\n",
    "\n",
    "    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_ild(global_dir,\n",
    "             all_res,\n",
    "             regressor_type='ridge'\n",
    "             ):\n",
    "\n",
    "    seed_list = [1,2,3,4,5]\n",
    "    for seed in seed_list:\n",
    "\n",
    "        args.seed = seed\n",
    "        train_set = SimLaw(seed=1, split='train', root=global_dir)\n",
    "        test_set = SimLaw(seed=1, split='test', root=global_dir)\n",
    "        datasets = {'train':train_set, 'test':test_set}\n",
    "\n",
    "        model_dir = global_dir / f'law/est/cvae/a_r_1.0_a_d_1.0_a_y_1.0_a_f_0.0_u_3_run_{seed}_use_label_True/model.pth' \n",
    "        model = CVAE(r_dim=args.r_dim,\n",
    "                        d_dim=args.d_dim, \n",
    "                        sens_dim=1, \n",
    "                        label_dim=1, \n",
    "                        args=args).to(args.device)\n",
    "        model.load_state_dict(torch.load(model_dir))\n",
    "\n",
    "        data_dict = cvae_prepare_data(datasets, model)\n",
    "        for lamb in [0, 0.2, 0.4, 0.6, 0.8, 1]:\n",
    "            for method, classifier in zip(['cfr', 'cfe','pcf', 'pcfaug','erm'],\n",
    "                                        [cfrsoft_classifier, cfesoft_classifier, pcfsoft_classifier, pcfaugsoft_classifier, ermsoft_classifier]):\n",
    "                if regressor_type == 'linear':\n",
    "                    predictor = LinearRegression()\n",
    "                elif regressor_type == 'ridge':\n",
    "                    predictor = RidgeCV(alphas=[0.1,1,10,100,1000,10000])\n",
    "                elif regressor_type == 'mlp':\n",
    "                    predictor = MLPRegressor(hidden_layer_sizes=(5,5),max_iter=2000,activation='tanh')\n",
    "                elif regressor_type == 'tree':\n",
    "                    predictor = DecisionTreeRegressor()\n",
    "                elif regressor_type == 'knn':\n",
    "                    predictor = KNeighborsRegressor(n_neighbors=10)\n",
    "                else:\n",
    "                    raise ValueError(\"Invalid regressor type\")\n",
    "                train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf = classifier(data_dict, predictor, alpha=lamb)\n",
    "                res = dict()\n",
    "                res['seed'] = seed\n",
    "                res['train_error'] = train_acc\n",
    "                res['test_error'] = acc\n",
    "                res['method'] = method\n",
    "                res['clf'] = clf\n",
    "                res['cf_effect'] = cf_effect\n",
    "                res['cf_effect0'] = cf_effect0\n",
    "                res['cf_effect1'] = cf_effect1\n",
    "                res['lamb'] = lamb\n",
    "                all_res = all_res.append(res, ignore_index=True)\n",
    "\n",
    "    return all_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def vis_ild_alg(all_res, save_dir=None,mute=[]):\n",
    "# Split the 'method' column to separate 'method' and 'group'\n",
    "    \n",
    "    all_res = all_res.copy()\n",
    "\n",
    "    #display(all_res)\n",
    "    if len(mute)>0:\n",
    "        for mm in mute:\n",
    "            all_res = all_res[all_res['method']!=mm]\n",
    "    #display(all_res)\n",
    "    replace_dict = {}\n",
    "    replace_dict['cda'] = 'CDA'\n",
    "    replace_dict['cfe'] = 'CFU'\n",
    "    replace_dict['cfr'] = 'CFR'\n",
    "    replace_dict['erm'] = 'ERM'\n",
    "    replace_dict['ermaug'] = 'ERM-A'\n",
    "    replace_dict['npcf'] = 'nPCF'\n",
    "    replace_dict['pcf'] = 'PCF'\n",
    "    replace_dict['pcfy'] = 'PCFy'\n",
    "    replace_dict['pcfycf'] = 'PCFycf'\n",
    "    replace_dict['pcfaug'] = 'PCF-CRM'\n",
    "    for alpha in [0,0.2,0.4,0.6,0.8,1]:\n",
    "        replace_dict[alpha] = f'PCFaug-{alpha}'\n",
    "\n",
    "    all_res['method'] = all_res['method'].replace(replace_dict)\n",
    "    #all_res['test_error'] = all_res['acc']\n",
    "    all_res = all_res.groupby(by=['method','lamb']).mean().reset_index()\n",
    "    #all_res = all_res.sort_values(by=['method', 'lamb'])\n",
    "    lamb_order = sorted(all_res['lamb'].unique())\n",
    "    print(lamb_order)\n",
    "\n",
    "        \n",
    "    all_res['style'] = all_res['method']\n",
    "\n",
    "    # Define the plot size\n",
    "    if save_dir:\n",
    "        save_dir = Path(save_dir)\n",
    "\n",
    "    for col in ['cf_effect', 'cf_effect0', 'cf_effect1']:\n",
    "        fig, ax = plt.subplots(figsize=(10,6))\n",
    "        # Create the scatter plot with unique styles\n",
    "        sns.scatterplot(data=all_res, x=col, y='test_error', style='method', hue='lamb', s=200, ax=ax)\n",
    "\n",
    "        for method in all_res['method'].unique():\n",
    "            method_data = all_res[all_res['method'] == method]\n",
    "            #sns.lineplot(data=method_data, x=col, y='test_error', ax=ax, label=method, legend=False)\n",
    "            ax.plot(method_data[col], method_data['test_error'], label='_nolegend_', linestyle=':',alpha=0.5)\n",
    "        \n",
    "        #ax.set_title(col)\n",
    "        ax.set_xlabel(METRIC_DICT[col],fontsize=FONTSIZE)\n",
    "        ax.set_ylabel('RMSE',fontsize=FONTSIZE)\n",
    "        ax.legend(fontsize=FONTSIZE,markerscale=2, bbox_to_anchor=(1, 1), loc='upper left')\n",
    "        plt.xticks(fontsize=FONTSIZE, rotation=30)\n",
    "        plt.yticks(fontsize=FONTSIZE)\n",
    "\n",
    "        #plt.tight_layout()\n",
    "        # plt.show()\n",
    "        if save_dir:\n",
    "            plt.savefig(f'{save_dir}_{col}.png', bbox_inches='tight',dpi=200)\n",
    "            plt.show()\n",
    "        else:\n",
    "            plt.show()\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "all_res = pd.DataFrame()\n",
    "\n",
    "# ILD\n",
    "clf_name = 'mlp'\n",
    "global_dir = Path('../VAE/saved/final')\n",
    "all_res = eval_ild(global_dir,\n",
    "                 all_res,\n",
    "                 regressor_type=clf_name)\n",
    "vis_ild_alg(all_res, save_dir=f'../figures/law/estaug_{clf_name}',mute=['pcf'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "vis_ild_alg(all_res, save_dir=f'../figures/law/estaug_pcfcrm_{clf_name}',\n",
    "            mute=['cfr','cfe','erm'])\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "causal",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
