{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "d:\\Microsoft VS Code\\PyCodes\\RA_Fairness\\fair_dummies/data/\n"
     ]
    }
   ],
   "source": [
    "import sys, os\n",
    "sys.path.append(os.path.abspath(os.path.join('../..')))\n",
    "cur_path = os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))\n",
    "sys.path.append(cur_path)\n",
    "base_path = cur_path + '/data/'\n",
    "print(base_path)\n",
    "import numpy as np\n",
    "import get_dataset\n",
    "import importlib\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch import distributions\n",
    "from torch.nn.parameter import Parameter\n",
    "import torch.utils.data as data_utils\n",
    "from collections import namedtuple\n",
    "import functools\n",
    "import random\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "os.environ['R_HOME'] = r\"D:\\R\\R-4.3.0\" # \"/vast/palmer/apps/avx2/software/R/4.3.0-foss-2020b/lib64/R\"\n",
    "os.environ['R_USER'] = r\"D:\\anaconda3\\Lib\\site-packages\\rpy2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# importlib.reload(utility_functions)\n",
    "# importlib.reload(fair_dummies_learning)\n",
    "\n",
    "\n",
    "from rpy2.robjects.packages import importr\n",
    "KPC = importr('KPC')\n",
    "kernlab = importr('kernlab')\n",
    "import rpy2.robjects\n",
    "from rpy2.robjects import FloatVector\n",
    "\n",
    "# specified_density_tr = utility_functions.MAF_density_estimation(Y, A, Y, A)\n",
    "# specified_density_te = utility_functions.MAF_density_estimation(np.concatenate((Y_cal, Y_test)), np.concatenate((A_cal, A_test)), Y_test, A_test) \n",
    "# print(\"density estimation done\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'d:\\\\Microsoft VS Code\\\\PyCodes\\\\RA_Fairness\\\\fair_dummies\\\\continuous-fairness-master'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.path.abspath(os.path.join('../..'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from facl.independence.density_estimation.pytorch_kde import kde\n",
    "from facl.independence.hgr import chi_2_cond, hgr_cond\n",
    "\n",
    "\n",
    "def chi_squared_kde(X, Y, Z):\n",
    "    return chi_2_cond(X, Y, Z, kde) #We are going to optimze for EO so we use the conditional version\n",
    "\n",
    "#You can also use the actual HGR computation function but when one of the variables is binary, the result will be the same\n",
    "#def hgr_cond(X, Y, Z):\n",
    "#    return hgr(X, Y, Z, kde)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We download and preprocess the dataset Adult from UCI as in https://github.com/jmikko/fair_ERM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-0.2134431 ,  0.28120807, -2.70196443, ..., -0.21878026,\n",
       "        -0.49453079, -5.16482102],\n",
       "       [-0.2134431 ,  1.05515285,  1.22785281, ..., -0.21878026,\n",
       "        -1.74376299,  0.26299899],\n",
       "       [-1.25716323, -1.30130918,  0.17990154, ..., -0.21878026,\n",
       "         2.58690862,  0.26299899],\n",
       "       ...,\n",
       "       [ 1.87399717,  0.12995572, -0.34407409, ..., -0.21878026,\n",
       "        -0.07812006,  0.26299899],\n",
       "       [-0.2134431 ,  0.33704024,  0.17990154, ..., -0.21878026,\n",
       "        -0.49453079,  0.26299899],\n",
       "       [-0.2134431 ,  0.78559083, -1.39202535, ..., -0.21878026,\n",
       "        -0.07812006,  0.26299899]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from examples.data_loading import read_dataset\n",
    "\n",
    "# encoded_data, to_protect, encoded_data_test, to_protect_test = read_dataset(name='adult', fold=1)\n",
    "# encoded_data.head()\n",
    "\n",
    "dataset = \"adult\"\n",
    "seed = 123\n",
    "X, A, Y, X_cal, A_cal, Y_cal, X_test, A_test, Y_test = get_dataset.get_train_test_data(base_path, dataset, seed, dim = 3)\n",
    "X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We  define a very simple neural net "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NetRegression(nn.Module):\n",
    "    def __init__(self, input_size, num_classes):\n",
    "        super(NetRegression, self).__init__()\n",
    "        size = 64\n",
    "        self.first = nn.Linear(input_size, size)\n",
    "        self.last = nn.Linear(size, num_classes)       \n",
    "    \n",
    "    def forward(self, x):\n",
    "        out = F.selu( self.first(x) )\n",
    "        out = self.last(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A few helper functions to compute performance metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def EntropyToProba(entropy): #Only for X Tensor of dimension 2\n",
    "    return entropy[:,1].exp() / entropy.exp().sum(dim=1)\n",
    "\n",
    "def calc_accuracy(outputs,Y): #Care outputs are going to be in dimension 2\n",
    "    max_vals, max_indices = torch.max(outputs,1)\n",
    "    acc = (max_indices == Y).sum().numpy()/max_indices.size()[0]\n",
    "    return acc\n",
    "\n",
    "def results_on_test(model, criterion, encoded_data_test, to_protect_test):\n",
    "    target = torch.tensor(encoded_data_test['Target'].values.astype(np.longlong)).long()\n",
    "    to_protect_test = torch.Tensor(to_protect_test)\n",
    "    data = torch.tensor(encoded_data_test.drop('Target', axis = 1).values.astype(np.float32))\n",
    "    outputs = model(data).detach()\n",
    "    loss = criterion(outputs, target)\n",
    "    p = EntropyToProba(outputs)\n",
    "    pt = torch.Tensor(to_protect_test)\n",
    "\n",
    "    ans = {}\n",
    "\n",
    "    balanced_acc = (calc_accuracy(outputs[to_protect_test==0],target[to_protect_test==0]) +\n",
    "                    calc_accuracy(outputs[to_protect_test==1],target[to_protect_test==1]))/2\n",
    "\n",
    "    ans['loss'] = loss.item()\n",
    "    ans['accuracy'] = calc_accuracy(outputs,target)\n",
    "    ans['balanced_acc'] = balanced_acc\n",
    "\n",
    "    f = 0.5\n",
    "    p1 = (((pt == 1.)*(p>f)).sum().float() / (pt == 1).sum().float())\n",
    "    p0 = (((pt == 0.)*(p>f)).sum().float() / (pt == 0).sum().float())\n",
    "    o1 = (((pt == 1.)*(p>f)*(target==1)).sum().float()  / ((pt == 1)*(target==1)).sum().float())\n",
    "    o2 = (((pt == 0.)*(p>f)*(target==1)).sum().float()  / ((pt == 0)*(target==1)).sum().float())\n",
    "    di = p1 / p0\n",
    "    deo = (o1 - o2).abs()\n",
    "    ans['di'] = di.item()\n",
    "    ans['deo'] = deo.item()\n",
    "\n",
    "    return ans\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Lenovo\\AppData\\Local\\Temp\\ipykernel_38036\\462570054.py:46: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\aten\\src\\ATen/native/IndexingUtils.h:28.)\n",
      "  br = train_data[foo, : ]\n",
      "C:\\Users\\Lenovo\\AppData\\Local\\Temp\\ipykernel_38036\\462570054.py:47: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\aten\\src\\ATen/native/IndexingUtils.h:28.)\n",
      "  pr = train_protect[foo, :]\n",
      "C:\\Users\\Lenovo\\AppData\\Local\\Temp\\ipykernel_38036\\462570054.py:48: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\aten\\src\\ATen/native/IndexingUtils.h:28.)\n",
      "  yr = train_target[foo].float()\n",
      "d:\\anaconda3\\lib\\site-packages\\torch\\functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\aten\\src\\ATen\\native\\TensorShape.cpp:3484.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results on test set\n",
      "misclassification error = 0.20298507462686566\n"
     ]
    }
   ],
   "source": [
    "from fair_dummies import fair_dummies_learning\n",
    "from fair_dummies import utility_functions\n",
    "importlib.reload(fair_dummies_learning)\n",
    "importlib.reload(utility_functions)\n",
    "importlib.reload(get_dataset)\n",
    "\n",
    "# Hyper Parameters \n",
    "input_size = X.shape[1]\n",
    "num_classes = 2\n",
    "num_epochs = 200\n",
    "batch_size = 128\n",
    "batchRenyi = 256\n",
    "learning_rate = 1e-2\n",
    "lambda_renyi = 1\n",
    "\n",
    "\n",
    "cfg_factory=namedtuple('Config', 'model  batch_size num_epochs lambda_renyi batchRenyi learning_rate input_size num_classes' )\n",
    "\n",
    "\n",
    "config = cfg_factory(NetRegression, batch_size, num_epochs, lambda_renyi, batchRenyi, learning_rate, input_size, num_classes)\n",
    "\n",
    "verbose = False\n",
    "\n",
    "model = config.model(config.input_size, config.num_classes)\n",
    "        \n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=0)\n",
    "\n",
    "# train_target = torch.tensor(encoded_data['Target'].values.astype(np.longlong)).long()\n",
    "# train_data = torch.tensor(encoded_data.drop('Target', axis = 1).values.astype(np.float32))\n",
    "# train_protect = torch.tensor(to_protect).float()\n",
    "\n",
    "train_target = torch.tensor(Y.astype(np.longlong)).long()\n",
    "train_data = torch.tensor(X.astype(np.float32))\n",
    "train_protect = torch.tensor(A).float()\n",
    "train_tensor = data_utils.TensorDataset(train_data, train_target)\n",
    "\n",
    "for epoch in range(config.num_epochs):\n",
    "    train_loader = data_utils.DataLoader(dataset = train_tensor, batch_size = config.batch_size, shuffle = True)\n",
    "    for i, (x, y) in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(x)\n",
    "        #Select a renyi regularization mini batch and compute the value of the model on it\n",
    "        frac=config.batchRenyi/train_data.shape[0]\n",
    "        foo = torch.bernoulli(frac*torch.ones(train_data.shape[0])).byte()\n",
    "        br = train_data[foo, : ]\n",
    "        pr = train_protect[foo, :]\n",
    "        yr = train_target[foo].float()\n",
    "        ren_outs = model(br)\n",
    "    \n",
    "        #Compute the usual loss of the prediction\n",
    "        loss =  criterion(outputs, y)\n",
    "        \n",
    "        #Compte the fairness penalty for positive labels only since we optimize for DEO\n",
    "        delta =  EntropyToProba(ren_outs)\n",
    "        #r2 = chi_squared_kde( delta, pr[yr==1.])\n",
    "        r2 = chi_2_cond(delta, pr, yr, kde, mode = \"sum\") \n",
    "        \n",
    "        loss += config.lambda_renyi * r2\n",
    "        \n",
    "        #In Adam we trust\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        break\n",
    "    if verbose:\n",
    "        print ('Epoch: [%d/%d], Batch: [%d/%d], Loss: %.4f, Accuracy: %.4f, Fairness penalty: %.4f'  % (epoch+1, config.num_epochs, i, len(encoded_data)//batch_size,\n",
    "                loss.item(),calc_accuracy(outputs,y),\n",
    "                r2.item()\n",
    "                    ))\n",
    "        #print( results_on_test(model, criterion, encoded_data_test, to_protect_test) )\n",
    "\n",
    "print(\"Results on test set\")\n",
    "# results_on_test(model, criterion, encoded_data_test, to_protect_test)\n",
    "\n",
    "Yhat_out_cal = EntropyToProba(model(torch.tensor(X_cal.astype(np.float32)))).detach().flatten().numpy()\n",
    "Yhat_out_test = EntropyToProba(model(torch.tensor(X_test.astype(np.float32)))).detach().flatten().numpy()\n",
    "misclassification = sum((Yhat_out_test > 0.5).astype(float) != Y_test) / Y_test.shape[0]\n",
    "print(\"misclassification error = \" + str(misclassification))\n",
    "\n",
    "Yhat_out_cal = np.concatenate([(1 - Yhat_out_cal)[:,None], Yhat_out_cal[:,None]], axis = 1)\n",
    "Yhat_out_test = np.concatenate([(1 - Yhat_out_test)[:,None], Yhat_out_test[:,None]], axis = 1)\n",
    "\n",
    "outputs = EntropyToProba(model(torch.tensor(X_test.astype(np.float32)))).detach().flatten()\n",
    "# r2 = np.max(hgr_cond(outputs, torch.tensor(A_test.astype(np.float32)), torch.tensor(Y_test.astype(np.float32)), kde))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.038778169853435136\n"
     ]
    }
   ],
   "source": [
    "\n",
    "rYhat = rpy2.robjects.r.matrix(FloatVector(Yhat_out_test.T.flatten()), nrow=Yhat_out_test.shape[0], ncol=Yhat_out_test.shape[1])\n",
    "rZ = rpy2.robjects.r.matrix(FloatVector(A_test.T.flatten()), nrow=A_test.shape[0], ncol=A_test.shape[1]) # rpy2.robjects.r.matrix\n",
    "rY = rpy2.robjects.r.matrix(FloatVector(Y_test), nrow=A_test.shape[0], ncol=1)\n",
    "\n",
    "# res = FOCI.codec(Y = rYhat, Z = rZ, X = rY)\n",
    "# stat = FOCI.codec # KPC.KPCgraph KPC.KPCRKHS FOCI.codec\n",
    "# res_ = stat(Y = rYhat, X = rY, Z = rZ)[0]\n",
    "# print(res_)\n",
    "# res_list = np.zeros(100)\n",
    "# for i in range(100):\n",
    "#     At_test = specified_At[i]\n",
    "#     rZt = rpy2.robjects.r.matrix(FloatVector(At_test.T.flatten()), nrow=A_test.shape[0], ncol=A_test.shape[1])\n",
    "#     res_list[i] = stat(Y = rYhat, X = rY, Z = rZt)[0]\n",
    "# p_val = 1.0/(100+1) * (1 + sum(res_list >= res_))\n",
    "# cur_codec.append(0)\n",
    "# cur_pval[\"codec\"].append(0)\n",
    "\n",
    "stat = KPC.KPCgraph # KPC.KPCgraph KPC.KPCRKHS FOCI.codec\n",
    "res_ = stat(Y = rYhat, X = rY, Z = rZ, Knn = 1)[0]\n",
    "print(res_)\n",
    "# p_val = utility_functions.fair_dummies_test_classification(Yhat_out_cal,\n",
    "#                                                         A_cal,\n",
    "#                                                         Y_cal,\n",
    "#                                                         Yhat_out_test,\n",
    "#                                                         A_test,\n",
    "#                                                         Y_test,\n",
    "#                                                         num_reps = 1,\n",
    "#                                                         num_p_val_rep=1000,\n",
    "#                                                         reg_func_name=\"RF\",\n",
    "#                                                         test_type = \"MAF_CPT_onA\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "49cb93f377a7abe7414b7b0f21fb3017538004a126cf690fb524202736b7fb92"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
