{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.linear_model import Ridge, RidgeCV\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.metrics import accuracy_score, mean_squared_error\n",
    "from sklearn.base import clone\n",
    "from copy import deepcopy\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from pathlib import Path\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "sns.set_theme()\n",
    "FONTSIZE = 20\n",
    "METRIC_DICT = {'cf_effect':r'$TE$',\n",
    "                'cf_effect0':r'$TE_0$',\n",
    "                'cf_effect1':r'$TE_1$'}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Helper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoid(x):\n",
    "    return 1 / (1 + np.exp(-x))\n",
    "\n",
    "def cf_eval(y, y_cf, a):\n",
    "    a = a.squeeze()\n",
    "    mask1 = (a == 0)\n",
    "    mask2 = (a == 1)\n",
    "    \n",
    "    cf_effect = np.abs(y_cf - y)\n",
    "    o1 = cf_effect[mask1]\n",
    "    o2 = cf_effect[mask2]\n",
    "    return np.sum(cf_effect) / cf_effect.shape[0], np.sum(o1) / o1.shape[0], np.sum(o2) / o2.shape[0]\n",
    "\n",
    "def pcf_mix(y_score, ycf_score, a, is_cf=False):\n",
    "    # attribute corresponding to y\n",
    "    a_0_indices = a == 0\n",
    "    a_1_indices = a == 1\n",
    "    a_0_ratio = np.sum(a_0_indices) / len(a)\n",
    "    a_1_ratio = 1-a_0_ratio\n",
    "    if is_cf is True:\n",
    "        # we need to use the ratio in the real data\n",
    "        a_0_ratio, a_1_ratio = a_1_ratio, a_0_ratio\n",
    "\n",
    "    y_output = np.zeros_like(y_score.ravel())\n",
    "    y_output[a_0_indices] = y_score[a_0_indices] * a_0_ratio + ycf_score[a_0_indices] * a_1_ratio\n",
    "    y_output[a_1_indices] = y_score[a_1_indices] * a_1_ratio + ycf_score[a_1_indices] * a_0_ratio\n",
    "\n",
    "    return y_output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cfe_classifier(data_dict, clf):\n",
    "\n",
    "    train_dat = data_dict[\"train\"]\n",
    "    test_dat = data_dict[\"test\"]\n",
    "    \n",
    "    # ========= Training ========= #\n",
    "    y = train_dat[\"y\"].ravel()\n",
    "    clf.fit(train_dat[\"u_hat\"], y)\n",
    "    train_acc = accuracy_score(y, clf.predict(train_dat[\"u_hat\"]))\n",
    "\n",
    "\n",
    "    # ========= Testing ========= #\n",
    "    y_factual = clf.predict(test_dat[\"u_hat\"])\n",
    "    acc = accuracy_score(test_dat[\"y\"].ravel(), y_factual.ravel())\n",
    "    y_counter = clf.predict(test_dat[\"u_cf_hat\"])\n",
    "\n",
    "    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, test_dat[\"a\"])\n",
    "\n",
    "    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf\n",
    "\n",
    "def cfr_classifier(data_dict, clf):\n",
    "\n",
    "    train_dat = data_dict[\"train\"]\n",
    "    test_dat = data_dict[\"test\"]\n",
    "    \n",
    "    # ========= Training ========= #\n",
    "    inputs = np.concatenate([train_dat[\"u_hat\"], \n",
    "                        (train_dat[\"x\"] + train_dat[\"x_cf_uhat\"]) / 2], axis=1)\n",
    "    y = train_dat[\"y\"].ravel()\n",
    "    clf.fit(inputs, y)\n",
    "    train_acc = accuracy_score(y, clf.predict(inputs))    \n",
    "\n",
    "    # ========= Testing ========= #\n",
    "    y_factual = clf.predict(np.concatenate([\n",
    "        test_dat[\"u_hat\"],\n",
    "        (test_dat[\"x\"] + test_dat[\"x_cf_uhat\"]) / 2\n",
    "    ], axis=1))\n",
    "    acc = accuracy_score(test_dat[\"y\"].ravel(), y_factual.ravel())\n",
    "\n",
    "    y_counter = clf.predict(np.concatenate([\n",
    "        test_dat[\"u_cf_hat\"],\n",
    "        (test_dat[\"x_cf\"] + test_dat[\"x_cf_cf_uhat\"]) / 2\n",
    "    ], axis=1))\n",
    "    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, test_dat[\"a\"])\n",
    "\n",
    "    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf\n",
    "\n",
    "def erm_classifier(data_dict, clf):\n",
    "    train_dat = data_dict[\"train\"]\n",
    "    test_dat = data_dict[\"test\"]\n",
    "    \n",
    "    # ========= Training ========= # \n",
    "    inputs = np.concatenate([train_dat[\"x\"],\n",
    "                             train_dat['a']],axis=1)\n",
    "    y = train_dat[\"y\"].ravel()\n",
    "    clf.fit(inputs, y)\n",
    "    train_acc = accuracy_score(y, clf.predict(inputs))\n",
    "\n",
    "    # ========= Testing ========= #\n",
    "    y_factual = clf.predict(np.concatenate([test_dat[\"x\"],\n",
    "                                            test_dat['a']],axis=1))\n",
    "    \n",
    "    acc = accuracy_score(test_dat[\"y\"].ravel(), y_factual.ravel())\n",
    "\n",
    "    y_counter = clf.predict(np.concatenate([test_dat[\"x_cf\"],\n",
    "                                            test_dat['a_cf']],axis=1))\n",
    "    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, test_dat[\"a\"])\n",
    "\n",
    "    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf\n",
    "\n",
    "def pcf_mix(y_score, ycf_score, a, is_cf=False):\n",
    "    # attribute corresponding to y\n",
    "    a_0_indices = a == 0\n",
    "    a_1_indices = a == 1\n",
    "    a_0_ratio = np.sum(a_0_indices) / len(a)\n",
    "    a_1_ratio = 1-a_0_ratio\n",
    "    if is_cf is True:\n",
    "        # we need to use the ratio in the real data\n",
    "        a_0_ratio, a_1_ratio = a_1_ratio, a_0_ratio\n",
    "\n",
    "    y_output = np.zeros_like(y_score)\n",
    "    y_output[a_0_indices] = y_score[a_0_indices] * a_0_ratio + ycf_score[a_0_indices] * a_1_ratio\n",
    "    y_output[a_1_indices] = y_score[a_1_indices] * a_1_ratio + ycf_score[a_1_indices] * a_0_ratio\n",
    "\n",
    "    return y_output\n",
    "\n",
    "def pcf_classifier(data_dict, clf):\n",
    "    train_dat = data_dict[\"train\"]\n",
    "    test_dat = data_dict[\"test\"]\n",
    "    \n",
    "    # ======= Training ======= #\n",
    "    inputs = np.concatenate([train_dat[\"x\"],\n",
    "                             train_dat['a']],axis=1)\n",
    "    y = train_dat[\"y\"].ravel()\n",
    "    clf.fit(inputs, y)\n",
    "    train_acc = accuracy_score(y, clf.predict(inputs))\n",
    "\n",
    "    # ======= Testing ======= #\n",
    "\n",
    "    # ======= factual pred ======= #\n",
    "    y_factual_score = clf.predict_proba(np.concatenate([test_dat[\"x\"],\n",
    "                                                test_dat['a']],axis=1))\n",
    "    y_factual_cf_score = clf.predict_proba(np.concatenate([test_dat[\"x_cf_uhat\"],\n",
    "                                                     test_dat['a_cf']],axis=1))\n",
    "    \n",
    "    y_factual = pcf_mix(y_factual_score, y_factual_cf_score, test_dat['a'].ravel())\n",
    "    y_factual = y_factual.argmax(axis=1)\n",
    "    acc = accuracy_score(test_dat[\"y\"].ravel(), y_factual.ravel())\n",
    "\n",
    "    # ======= counter pred ======= #\n",
    "    y_counter_score = clf.predict_proba(np.concatenate([test_dat[\"x_cf\"],\n",
    "                                                        test_dat['a_cf']],axis=1))\n",
    "    y_counter_cf_score = clf.predict_proba(np.concatenate([test_dat[\"x_cf_cf_uhat\"],\n",
    "                                                        test_dat['a']],axis=1))\n",
    "    y_counter = pcf_mix(y_counter_score, y_counter_cf_score, test_dat['a_cf'].ravel(),is_cf=True)\n",
    "    y_counter = y_counter.argmax(axis=1)\n",
    "\n",
    "    a = test_dat[\"a\"]\n",
    "    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a)\n",
    "\n",
    "    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf\n",
    "\n",
    "def pcfaug_classifier(data_dict, clf):\n",
    "    train_dat = data_dict[\"train\"]\n",
    "    test_dat = data_dict[\"test\"]\n",
    "    \n",
    "    # ======= Training ======= #\n",
    "    inputs = np.concatenate([\n",
    "    np.concatenate([train_dat[\"x\"],\n",
    "                             train_dat['a']],axis=1),\n",
    "    np.concatenate([train_dat[\"x_cf_uhat\"],\n",
    "                             train_dat['a_cf']],axis=1)],axis=0)\n",
    "\n",
    "    y = np.concatenate([train_dat[\"y\"],train_dat[\"y\"]],axis=0).ravel()\n",
    "    clf.fit(inputs, y)\n",
    "    train_acc = accuracy_score(y, clf.predict(inputs))\n",
    "\n",
    "    # ======= Testing ======= #\n",
    "\n",
    "    # ======= factual pred ======= #\n",
    "    y_factual_score = clf.predict_proba(np.concatenate([test_dat[\"x\"],\n",
    "                                                test_dat['a']],axis=1))\n",
    "    y_factual_cf_score = clf.predict_proba(np.concatenate([test_dat[\"x_cf_uhat\"],\n",
    "                                                     test_dat['a_cf']],axis=1))\n",
    "    \n",
    "    y_factual = pcf_mix(y_factual_score, y_factual_cf_score, test_dat['a'].ravel())\n",
    "    y_factual = y_factual.argmax(axis=1)\n",
    "    acc = accuracy_score(test_dat[\"y\"].ravel(), y_factual.ravel())\n",
    "\n",
    "    # ======= counter pred ======= #\n",
    "    y_counter_score = clf.predict_proba(np.concatenate([test_dat[\"x_cf\"],\n",
    "                                                        test_dat['a_cf']],axis=1))\n",
    "    y_counter_cf_score = clf.predict_proba(np.concatenate([test_dat[\"x_cf_cf_uhat\"],\n",
    "                                                        test_dat['a']],axis=1))\n",
    "    y_counter = pcf_mix(y_counter_score, y_counter_cf_score, test_dat['a_cf'].ravel(),is_cf=True)\n",
    "    y_counter = y_counter.argmax(axis=1)\n",
    "\n",
    "    a = test_dat[\"a\"]\n",
    "    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a)\n",
    "\n",
    "    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf\n",
    "\n",
    "def erm_ana_classifer(data_dict, \n",
    "                        dataset_type,\n",
    "                        w_a,\n",
    "                        ):\n",
    "\n",
    "    test_dat = data_dict[\"test\"]\n",
    "    # ======= Testing ======= #\n",
    "\n",
    "    # ======= factual pred ======= #\n",
    "    if dataset_type == 'linear':\n",
    "        y_factual = sigmoid(2 * test_dat['x'] - w_a * test_dat['a'])\n",
    "    elif dataset_type == 'cubic':\n",
    "        y_factual= sigmoid(test_dat['x'] ** 3 + test_dat['x'] - w_a * test_dat['a'])\n",
    "    else:\n",
    "        raise ValueError('dataset_type not recognized')\n",
    "    y_factual = (y_factual > 0.5).astype(float)\n",
    "    acc = mean_squared_error(test_dat[\"y\"].ravel(), y_factual.ravel(), squared=False)\n",
    "\n",
    "    # ======= counter pred ======= #\n",
    "    if dataset_type == 'linear':\n",
    "        y_counter = sigmoid(2 * test_dat['x_cf'] - w_a * test_dat['a_cf'])\n",
    "    elif dataset_type == 'cubic':\n",
    "        y_counter = sigmoid(test_dat['x_cf'] ** 3 + test_dat['x_cf'] - w_a * test_dat['a_cf'])\n",
    "    else:\n",
    "        raise ValueError('dataset_type not recognized')\n",
    "    y_counter = (y_counter > 0.5).astype(float)\n",
    "    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a = test_dat[\"a\"])\n",
    "    return None, acc, cf_effect, cf_effect0, cf_effect1, None\n",
    "\n",
    "def pcf_ana_classifer(data_dict, \n",
    "                        dataset_type,\n",
    "                        w_a,\n",
    "                        ):\n",
    "\n",
    "    test_dat = data_dict[\"test\"]\n",
    "    \n",
    "    # ======= Testing ======= #\n",
    "\n",
    "    # ======= factual pred ======= #\n",
    "    if dataset_type == 'linear':\n",
    "        y_factual_score = sigmoid(2 * test_dat['x'] - w_a * test_dat['a'])\n",
    "        y_factual_cf_score = sigmoid(2 * test_dat['x_cf_uhat'] - w_a * test_dat['a_cf'])\n",
    "    elif dataset_type == 'cubic':\n",
    "        y_factual_score = sigmoid(test_dat['x'] ** 3 + test_dat['x'] - w_a * test_dat['a'])\n",
    "        y_factual_cf_score = sigmoid(test_dat['x_cf_uhat'] ** 3 + test_dat['x_cf_uhat'] - w_a * test_dat['a_cf'])\n",
    "    else:\n",
    "        raise ValueError('dataset_type not recognized')\n",
    "    y_factual = pcf_mix(y_factual_score.ravel(), y_factual_cf_score.ravel(), test_dat['a'].ravel())\n",
    "    y_factual = (y_factual > 0.5).astype(float)\n",
    "    acc = accuracy_score(test_dat[\"y\"].ravel(), y_factual.ravel())\n",
    "\n",
    "    # ======= counter pred ======= #\n",
    "    if dataset_type == 'linear':\n",
    "        y_counter_score = sigmoid(2 * test_dat['x_cf'] - w_a * test_dat['a_cf'])\n",
    "        y_counter_cf_score = sigmoid(2 * test_dat['x_cf_cf_uhat'] - w_a * test_dat['a'])\n",
    "    elif dataset_type == 'cubic':\n",
    "        y_counter_score = sigmoid(test_dat['x_cf'] ** 3 + test_dat['x_cf'] - w_a * test_dat['a_cf'])\n",
    "        y_counter_cf_score = sigmoid(test_dat['x_cf_cf_uhat'] ** 3 + test_dat['x_cf_cf_uhat'] - w_a * test_dat['a'])\n",
    "    else:\n",
    "        raise ValueError('dataset_type not recognized')\n",
    "    y_counter = pcf_mix(y_counter_score.ravel(), y_counter_cf_score.ravel(), test_dat['a_cf'].ravel(),is_cf=True)\n",
    "    y_counter = (y_counter > 0.5).astype(float)\n",
    "    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a = test_dat[\"a\"])\n",
    "\n",
    "    return None, acc, cf_effect, cf_effect0, cf_effect1, None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_reg_dataset(dataset_type, \n",
    "                    w_epsx=0, \n",
    "                    w_epsy=1, \n",
    "                    w_a=2, \n",
    "                    a_freq=0.6,\n",
    "                    num_samples=2000,\n",
    "                    err_std=0,\n",
    "                    err_bias=0,\n",
    "                    seed=0):\n",
    "    \n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    u_distribution = torch.distributions.Normal(0,1)\n",
    "    a_distribution = torch.distributions.Bernoulli(torch.tensor([a_freq]))\n",
    "    epsx_distribution = torch.distributions.Normal(0,1)\n",
    "    epsy_distribution = torch.distributions.Normal(0,1)\n",
    "\n",
    "    dataset = {}\n",
    "    for split in ['train', 'test']:\n",
    "        dataset[split] = {}\n",
    "        u = u_distribution.sample((num_samples,1))\n",
    "        a = a_distribution.sample((num_samples,))\n",
    "        eps_x = epsx_distribution.sample((num_samples,1))\n",
    "        eps_y = epsy_distribution.sample((num_samples,1))\n",
    "        a_cf = 1-a\n",
    "\n",
    "        if dataset_type == 'linear':\n",
    "            x = w_a * a + u + w_epsx * eps_x\n",
    "            x_cf = w_a*a_cf + u + w_epsx * eps_x\n",
    "            y = torch.bernoulli(torch.sigmoid(x + u + w_epsy * eps_y))\n",
    "\n",
    "        elif dataset_type == 'cubic': \n",
    "            x = w_a * a + u + w_epsx * eps_x\n",
    "            x_cf = w_a*a_cf + u + w_epsx * eps_x\n",
    "            y =  torch.bernoulli(torch.sigmoid(x**3 + u + w_epsy * eps_y))\n",
    "        else:\n",
    "            raise ValueError('Invalid dataset type')\n",
    "        \n",
    "        dataset[split]['x'] = x\n",
    "        dataset[split]['y'] = y\n",
    "        dataset[split]['a'] = a\n",
    "        dataset[split]['u'] = u\n",
    "        dataset[split]['a_cf'] = a_cf\n",
    "        dataset[split]['x_cf'] = x_cf\n",
    "\n",
    "        # prepare data as algorithm input\n",
    "        dataset[split]['u_hat'] = u + torch.randn_like(u) * err_std + err_bias\n",
    "        dataset[split]['u_cf_hat'] = u + torch.randn_like(u) * err_std + err_bias\n",
    "        dataset[split]['x_cf_uhat'] = x_cf + torch.randn_like(x_cf) * err_std + err_bias\n",
    "        dataset[split]['x_cf_cf_uhat'] = x + torch.randn_like(x) * err_std + err_bias\n",
    "        \n",
    "    for split in ['train', 'test']:\n",
    "        for key in dataset[split].keys():\n",
    "            dataset[split][key] = dataset[split][key].numpy()\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exp 1 - GT Estimation Error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval(all_res,\n",
    "         dataset_type = 'linear',\n",
    "            clf_name = 'Ridge',\n",
    "            w_a = 2,\n",
    "            w_epsx = 0,\n",
    "            w_epsy = 1,\n",
    "            a_freq = 0.7,\n",
    "            err_std=0,\n",
    "            err_bias=0):\n",
    "    \n",
    "\n",
    "    for repeat in list(range(5)):\n",
    "        np.random.seed(repeat)\n",
    "        dataset = gen_reg_dataset(dataset_type,\n",
    "            w_epsx=w_epsx,\n",
    "            w_epsy=w_epsy,\n",
    "            w_a=w_a,\n",
    "            a_freq=a_freq,\n",
    "            err_std=err_std,\n",
    "            err_bias=err_bias,\n",
    "            seed=repeat)    \n",
    "        \n",
    "        for method, classifier in zip(['cfr', 'cfe','erm','pcf','pcfaug','ermana','pcfana'],\n",
    "                                      [cfr_classifier, cfe_classifier, erm_classifier, pcf_classifier, pcfaug_classifier,\n",
    "                                       erm_ana_classifer, pcf_ana_classifer]):\n",
    "                if clf_name == 'ridge':\n",
    "                    predictor = LogisticRegression(penalty=\"l2\", solver=\"liblinear\",random_state=repeat)\n",
    "                elif clf_name == 'mlp':\n",
    "                    predictor = MLPClassifier(hidden_layer_sizes=(20,20),max_iter=2000,activation='tanh',random_state=repeat)\n",
    "                elif clf_name == 'tree':\n",
    "                    predictor = DecisionTreeRegressor(random_state=repeat)\n",
    "                elif clf_name == 'knn':\n",
    "                    predictor = KNeighborsClassifier()\n",
    "                elif clf_name == 'svm':\n",
    "                    predictor = SVC(probability=True)\n",
    "                else:\n",
    "                    raise ValueError('Invalid clf_name')\n",
    "                \n",
    "                if method in ['ermana','pcfana']:\n",
    "                    train_err, test_err, cf_effect, cf_effect0, cf_effect1, clf = classifier(dataset, dataset_type=dataset_type, w_a=w_a)\n",
    "                else:\n",
    "                    train_err, test_err, cf_effect, cf_effect0, cf_effect1, clf = classifier(dataset, predictor)\n",
    "\n",
    "                res = dict()\n",
    "                res['repeat'] = repeat\n",
    "                res['dataset_type'] = dataset_type\n",
    "                res['clf'] = clf_name\n",
    "                res['w_a'] = w_a\n",
    "                res['w_epsx'] = w_epsx\n",
    "                res['w_epsy'] = w_epsy\n",
    "                res['a_freq'] = a_freq\n",
    "                res['method'] = method\n",
    "                res['train_err'] = train_err\n",
    "                res['test_err'] = test_err\n",
    "                res['cf_effect'] = cf_effect\n",
    "                res['cf_effect0'] = cf_effect0\n",
    "                res['cf_effect1'] = cf_effect1\n",
    "                res['std'] = err_std\n",
    "                res['bias'] = err_bias\n",
    "\n",
    "                all_res = all_res.append(res,ignore_index=True)\n",
    "    return all_res\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def vis_alg(all_res, save_dir=None, legend=False, mute=[]):\n",
    "# Split the 'method' column to separate 'method' and 'group'\n",
    "    \n",
    "    all_res = all_res.copy()\n",
    "    if len(mute)>0:\n",
    "        for mm in mute:\n",
    "            all_res = all_res[all_res['method']!=mm]\n",
    "    replace_dict = {}\n",
    "    replace_dict['cfe'] = 'CFU'\n",
    "    replace_dict['cfr'] = 'CFR'\n",
    "    replace_dict['erm'] = 'ERM'\n",
    "    replace_dict['pcf'] = 'PCF'\n",
    "    replace_dict['ermana'] = 'ERM-Ana'\n",
    "    replace_dict['pcfana'] = 'PCF-Ana'\n",
    "    all_res['method'] = all_res['method'].replace(replace_dict)\n",
    "\n",
    "    all_res['error'] = 1-all_res['test_err']\n",
    "    all_res = all_res.groupby(by='method').mean().reset_index()    \n",
    "    all_res['style'] = all_res['method']\n",
    "\n",
    "\n",
    "    for col in ['cf_effect']:\n",
    "        fig, ax = plt.subplots(figsize=(6,6))\n",
    "        # Create the scatter plot with unique styles\n",
    "        sns.scatterplot(data=all_res, x=col, y='error', style='method', hue='method', s=200, ax=ax,legend=legend)\n",
    "        #ax.set_title(col)\n",
    "        ax.set_xlabel(METRIC_DICT[col],fontsize=FONTSIZE)\n",
    "        ax.set_ylabel('Error',fontsize=FONTSIZE)\n",
    "        if legend:\n",
    "            ax.legend(fontsize=FONTSIZE,markerscale=2, bbox_to_anchor=(0.95, 1), loc='upper left')\n",
    "        plt.xticks(fontsize=FONTSIZE, rotation=30)\n",
    "        plt.yticks(fontsize=FONTSIZE)\n",
    "        if save_dir:\n",
    "            #plt.savefig(save_dir / f'pcf_{col}.png', bbox_inches='tight',dpi=200)\n",
    "            plt.savefig(f'{save_dir}_{col}.png', bbox_inches='tight',dpi=200)\n",
    "            plt.show()\n",
    "        else:\n",
    "            plt.show()\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = pd.DataFrame()\n",
    "clf_name = 'knn'\n",
    "dataset_type = 'linear'\n",
    "\n",
    "res = eval(res, \n",
    "        dataset_type=dataset_type, \n",
    "        clf_name=clf_name, \n",
    "        w_a=2, \n",
    "        w_epsx=0)\n",
    "\n",
    "vis_alg(res, save_dir=f'./figures/synthetic_cls/gt_{dataset_type}_{clf_name}',\n",
    "        mute=['ermana','pcfana','pcfaug'])\n",
    "\n",
    "res = pd.DataFrame()\n",
    "dataset_type = 'cubic'\n",
    "res = eval(res, \n",
    "        dataset_type=dataset_type, \n",
    "        clf_name=clf_name, \n",
    "        w_a=2, \n",
    "        w_epsx=0)\n",
    "\n",
    "vis_alg(res, save_dir=f'./figures/synthetic_cls/gt_{dataset_type}_{clf_name}',\n",
    "        mute=['ermana','pcfana','pcfaug'],\n",
    "        legend=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = pd.DataFrame()\n",
    "dataset_type = 'linear'\n",
    "clf_name = 'mlp'\n",
    "\n",
    "res = eval(res, \n",
    "        dataset_type=dataset_type, \n",
    "        clf_name=clf_name, \n",
    "        w_a=2, \n",
    "        w_epsx=0)\n",
    "\n",
    "vis_alg(res, save_dir=f'./figures/synthetic_cls/gt_{dataset_type}_{clf_name}',\n",
    "        mute=['ermana','pcfana','pcfaug'])\n",
    "\n",
    "res = pd.DataFrame()\n",
    "dataset_type = 'cubic'\n",
    "res = eval(res, \n",
    "        dataset_type=dataset_type, \n",
    "        clf_name=clf_name, \n",
    "        w_a=2, \n",
    "        w_epsx=0)\n",
    "\n",
    "vis_alg(res, save_dir=f'./figures/synthetic_cls/gt_{dataset_type}_{clf_name}',\n",
    "        mute=['ermana','pcfana','pcfaug'],\n",
    "        legend=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exp 1.2 - CF Estimation Error  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def vis_alg(all_res, save_dir=None, legend=False, mute=[]):\n",
    "# Split the 'method' column to separate 'method' and 'group'\n",
    "    \n",
    "    all_res = all_res.copy()\n",
    "    if len(mute)>0:\n",
    "        for mm in mute:\n",
    "            all_res = all_res[all_res['method']!=mm]\n",
    "    replace_dict = {}\n",
    "    replace_dict['cfe'] = 'CFU'\n",
    "    replace_dict['cfr'] = 'CFR'\n",
    "    replace_dict['erm'] = 'ERM'\n",
    "    replace_dict['pcf'] = 'PCF'\n",
    "    replace_dict['pcfaug'] = 'PCGAug'\n",
    "    replace_dict['ermana'] = 'ERM-Ana'\n",
    "    replace_dict['pcfana'] = 'PCF-Ana'\n",
    "    all_res['method'] = all_res['method'].replace(replace_dict)\n",
    "    all_res['error'] = 1-all_res['test_err']\n",
    "    all_res = all_res.groupby(by=['method','std']).mean().reset_index()    \n",
    "    all_res['style'] = all_res['method']\n",
    "    #all_res['style'] = all_res['method'] + '-' + all_res['group']\n",
    "\n",
    "    # Define the plot size\n",
    "    # if save_dir:\n",
    "    #     save_dir = Path(save_dir)\n",
    "\n",
    "    for col in ['cf_effect']:\n",
    "        fig, ax = plt.subplots(figsize=(6,6))\n",
    "        # Create the scatter plot with unique styles\n",
    "        sns.scatterplot(data=all_res, x=col, y='error', style='method', hue='std', s=200, ax=ax, palette='deep',legend=legend)\n",
    "        #ax.set_title(col)\n",
    "        ax.set_xlabel(METRIC_DICT[col],fontsize=FONTSIZE)\n",
    "        ax.set_ylabel('Error',fontsize=FONTSIZE)\n",
    "        if legend:\n",
    "            ax.legend(fontsize=FONTSIZE,markerscale=2, bbox_to_anchor=(0.95, 1), loc='upper left')\n",
    "        plt.xticks(fontsize=FONTSIZE, rotation=30)\n",
    "        plt.yticks(fontsize=FONTSIZE)\n",
    "        if save_dir:\n",
    "            #plt.savefig(save_dir / f'pcf_{col}.png', bbox_inches='tight',dpi=200)\n",
    "            plt.savefig(f'{save_dir}_{col}.png', bbox_inches='tight',dpi=200)\n",
    "            plt.show()\n",
    "        else:\n",
    "            plt.show()\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### EST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = pd.DataFrame()\n",
    "\n",
    "clf_name = 'knn'\n",
    "bias = 0 \n",
    "\n",
    "dataset_type = 'linear'\n",
    "for err_std in [0,0.001,0.01,0.1]:\n",
    "        res = eval(res, \n",
    "                dataset_type=dataset_type, \n",
    "                clf_name=clf_name, \n",
    "                w_a=2, \n",
    "                w_epsx=0,\n",
    "                err_std=err_std,\n",
    "                err_bias=bias)\n",
    "\n",
    "vis_alg(res, save_dir=f'./figures/synthetic_cls/eststd_b{bias}_{dataset_type}_{clf_name}',\n",
    "        mute=['ermana','pcfaug','pcfana'])\n",
    "\n",
    "dataset_type = 'cubic'\n",
    "res = pd.DataFrame()\n",
    "for err_std in [0,0.001,0.01,0.1]:\n",
    "        res = eval(res, \n",
    "                dataset_type=dataset_type, \n",
    "                clf_name=clf_name, \n",
    "                w_a=2, \n",
    "                w_epsx=0,\n",
    "                err_std=err_std,\n",
    "                err_bias=bias)\n",
    "\n",
    "vis_alg(res, save_dir=f'./figures/synthetic_cls/eststd_b{bias}_{dataset_type}_{clf_name}',\n",
    "        mute=['ermana','pcfaug','pcfana'],\n",
    "                legend=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = pd.DataFrame()\n",
    "\n",
    "clf_name = 'knn'\n",
    "bias = 0.001\n",
    "\n",
    "dataset_type = 'linear'\n",
    "for err_std in [0,0.001,0.01,0.1]:\n",
    "        res = eval(res, \n",
    "                dataset_type=dataset_type, \n",
    "                clf_name=clf_name, \n",
    "                w_a=2, \n",
    "                w_epsx=0,\n",
    "                err_std=err_std,\n",
    "                err_bias=bias)\n",
    "\n",
    "vis_alg(res, save_dir=f'./figures/synthetic_cls/eststd_b{bias}_{dataset_type}_{clf_name}',\n",
    "        mute=['ermana','pcfaug','pcfana'])\n",
    "\n",
    "dataset_type = 'cubic'\n",
    "res = pd.DataFrame()\n",
    "for err_std in [0,0.001,0.01,0.1]:\n",
    "        res = eval(res, \n",
    "                dataset_type=dataset_type, \n",
    "                clf_name=clf_name, \n",
    "                w_a=2, \n",
    "                w_epsx=0,\n",
    "                err_std=err_std,\n",
    "                err_bias=bias)\n",
    "\n",
    "vis_alg(res, save_dir=f'./figures/synthetic_cls/eststd_b{bias}_{dataset_type}_{clf_name}',\n",
    "        mute=['ermana','pcfaug','pcfana'],\n",
    "                legend=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Analytic Solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = pd.DataFrame()\n",
    "clf_name = 'knn'\n",
    "dataset_type = 'linear'\n",
    "\n",
    "for err_std in [0,0.001,0.01,0.1]:\n",
    "    res = eval(res, \n",
    "                dataset_type=dataset_type, \n",
    "                clf_name=clf_name, \n",
    "                w_a=2, \n",
    "                w_epsx=0,\n",
    "                err_std=err_std,\n",
    "                err_bias=0)\n",
    "\n",
    "vis_alg(res, save_dir=f'./figures/synthetic_cls/eststdana_{dataset_type}_{clf_name}',\n",
    "        mute=['ermana','erm','cfr','cfe','pcfaug'])\n",
    "\n",
    "res = pd.DataFrame()\n",
    "dataset_type = 'cubic'\n",
    "for err_std in [0,0.001,0.01,0.1]:\n",
    "    res = eval(res, \n",
    "                dataset_type=dataset_type, \n",
    "                clf_name=clf_name, \n",
    "                w_a=2, \n",
    "                w_epsx=0,\n",
    "                err_std=err_std,\n",
    "                err_bias=0)\n",
    "    \n",
    "vis_alg(res, save_dir=f'./figures/synthetic_cls/eststdana_{dataset_type}_{clf_name}',\n",
    "        mute=['ermana','erm','cfr','cfe','pcfaug'],\n",
    "            legend=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = pd.DataFrame()\n",
    "clf_name = 'knn'\n",
    "bias = 0.001\n",
    "\n",
    "dataset_type = 'linear'\n",
    "for err_std in [0,0.001,0.01,0.1]:\n",
    "    \n",
    "    res = eval(res, \n",
    "                dataset_type=dataset_type, \n",
    "                clf_name=clf_name, \n",
    "                w_a=2, \n",
    "                w_epsx=0,\n",
    "                err_std=err_std,\n",
    "                err_bias=bias)\n",
    "\n",
    "vis_alg(res, save_dir=f'./figures/synthetic_cls/eststdana_b{bias}_{dataset_type}_{clf_name}',\n",
    "        mute=['ermana','erm','cfr','cfe','pcfaug'])\n",
    "\n",
    "res = pd.DataFrame()\n",
    "dataset_type = 'cubic'\n",
    "for err_std in [0,0.001,0.01,0.1]:\n",
    "    res = eval(res, \n",
    "                dataset_type=dataset_type, \n",
    "                clf_name=clf_name, \n",
    "                w_a=2, \n",
    "                w_epsx=0,\n",
    "                err_std=err_std,\n",
    "                err_bias=bias)\n",
    "    \n",
    "vis_alg(res, save_dir=f'./figures/synthetic_cls/eststdana_b{bias}_{dataset_type}_{clf_name}',\n",
    "        mute=['ermana','erm','cfr','cfe','pcfaug'],\n",
    "            legend=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "causal",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
