{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ff575c6-7d56-4c49-9bf7-a8f34b0f8074",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from torch import nn\n",
    "import torch\n",
    "import csv\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.metrics import balanced_accuracy_score\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from fairness.data.objects.list import DATASETS, get_dataset_names\n",
    "from fairness.data.objects.ProcessedData import ProcessedData\n",
    "from sklearn.preprocessing import StandardScaler,MinMaxScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.datasets import fetch_openml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eed85186-6e53-49b4-8e22-54df0538eaf6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from helpers import *\n",
    "from metrics import *\n",
    "from ratioMSE_autoenc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c669825e-0602-4b2e-b2c1-b0d4a8c55ae9",
   "metadata": {},
   "source": [
    "Importing data and train unfair model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d26d617b-f97c-4fc2-85e6-2af7d23cfda4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "X_train0, X_test0, X_train, X_test, y_train, y_test, S_train, S_test, column_names = get_datasets(\"law_school\", n_splits=2, seed=11)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2254294c-03cc-4a28-b850-7a93cce82e49",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(X_train.shape, X_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30c44b9e-7f96-4deb-87ac-0e5db572742f",
   "metadata": {},
   "source": [
    "Train Logistic Regression on TRAIN and evaluate on VALIDATION."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91d83bb3-9efb-4b88-8930-9299c2094277",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "log_reg = LogisticRegression(max_iter=1000, random_state=42)\n",
    "log_reg.fit(X_train, y_train)\n",
    "\n",
    "predictions_test_logreg = log_reg.predict(X_test)\n",
    "\n",
    "accuracy = accuracy_score(y_test, predictions_test_logreg)\n",
    "print(\"Accuracy:\", accuracy*100,\"%\")\n",
    "#print(\"Balanced Accuracy:\", balanced_accuracy)\n",
    "print(\"p%-ratio (Demographic parity measure) is:\", p_rule(predictions_test_logreg, S_test),\"%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "423da52e-6ce8-4943-b1fe-3b88968fe781",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "proba_train_logreg = log_reg.predict_proba(X_train)[:, 1] \n",
    "proba_test_logreg = log_reg.predict_proba(X_test)[:, 1] "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14e6e40e-a5b3-4c25-9771-8cc5f0337812",
   "metadata": {},
   "source": [
    "Train the debiasing model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c9cdd8b-037b-48ba-8221-9132024693f1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "k = 2 #number of \"concepts\"\n",
    "\n",
    "class NN_r(nn.Module):\n",
    "#Network to predict ratio from concepts\n",
    "    def __init__(self):\n",
    "        super(NN_r, self).__init__()\n",
    "        self.fc1 = nn.Linear(k, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        x = self.fc1(x)\n",
    "        return x\n",
    "    \n",
    "class NN_c(nn.Module):\n",
    "#Network to predict concepts from features\n",
    "    def __init__(self):\n",
    "        super(NN_c, self).__init__()\n",
    "        self.fc1 = nn.Linear(X_train.shape[1], k)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        x = self.fc1(x)\n",
    "        return x\n",
    "\n",
    "class NN_s(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NN_s, self).__init__()\n",
    "        self.fc1_s = nn.Linear(1, 64)\n",
    "        self.fc2_s = nn.Linear(64, 32)\n",
    "        self.fc3_s = nn.Linear(32, 16)\n",
    "        self.fc4_s = nn.Linear(16, 1)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        x = self.fc1_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc2_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc3_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc4_s(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "504d5922-f073-47d5-98b2-636fb0fc38f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "ratio_model = ratio_debiasing(learning_rate=0.001, batch_size=512, lamb_fair=4, lamb_sparse=1e-2, lamb_diversity=1, \n",
    "                         lamb_ratio=0.1, num_epochs=1000, num_concepts=k, NN_r=NN_r, NN_s=NN_s, NN_c=NN_c, GPU='cuda:0')\n",
    "\n",
    "ratio_model.train(X_train, y_train, S_train, y_hat=proba_train_logreg, X_test=X_test, y_test=y_test, S_test=S_test, \n",
    "          y_hat_test=proba_test_logreg, plot_losses=False, dem_parity=True, write=False, ablation=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad47e1df-837e-4107-ad81-aac435f09e08",
   "metadata": {},
   "source": [
    "### COMMOD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7d8e064-45cc-4d43-b690-1c6a789b2246",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "COMMOD = ratio_debiasing(learning_rate=0.001, batch_size=10000, lamb_fair=3, lamb_sparse=1e-2, lamb_diversity=1, \n",
    "                         lamb_ratio=0.3, num_epochs=1000, num_concepts=k, NN_r=NN_r, NN_s=NN_s, NN_c=NN_c, GPU='cuda:0')\n",
    "\n",
    "COMMOD.train(X_train, y_train, S_train, y_hat=proba_train_logreg, X_test=X_test, y_test=y_test, S_test=S_test, \n",
    "          y_hat_test=proba_test_logreg, plot_losses=False, dem_parity=True, write=False, ablation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83186f63-6b7d-4aee-9476-6d662261e0ca",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_COMMOD_test = COMMOD.predict(X_test, proba_test_logreg)\n",
    "y_COMMOD_train = COMMOD.predict(X_train, proba_train_logreg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4becc70f-38c3-4c97-a032-0e0f3e92a724",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "test_target = (y_COMMOD_test != (proba_test_logreg>=0.5).astype(int)).astype(int)\n",
    "train_target = (y_COMMOD_train != (proba_train_logreg>=0.5).astype(int)).astype(int)\n",
    "\n",
    "results = []\n",
    "\n",
    "\n",
    "for max_depth in range(2, 11):  \n",
    "    \n",
    "    f1_scores = []\n",
    "    \n",
    "    for time in range(5):\n",
    "        classifier = DecisionTreeClassifier(max_depth=max_depth)\n",
    "        classifier.fit(X_test, test_target)\n",
    "\n",
    "        predictions = classifier.predict(X_test)\n",
    "\n",
    "        f1 = f1_score(test_target, predictions)\n",
    "\n",
    "        f1_scores.append(f1)\n",
    "    \n",
    "    results.append({\n",
    "        'max_depth': max_depth,\n",
    "        'f1_mean': np.mean(f1_scores),\n",
    "        'f1_std': np.std(f1_scores)\n",
    "    })\n",
    "\n",
    "# Print the results\n",
    "for result in results:\n",
    "    print(f\"Max Depth: {result['max_depth']}, F1 mean: {result['f1_mean']}, F1 std: {result['f1_std']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74dfb464-ad68-47f3-a4f6-9605cfac4edc",
   "metadata": {},
   "source": [
    "### COMMOD without minimal changes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92b6f6d0-3904-49d0-bb6c-ebf034702358",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/Users/z912yg/Library/CloudStorage/OneDrive-AXA/Desktop/fairness-AXA')\n",
    "sys.path.append('/Users/z912yg/Library/CloudStorage/OneDrive-AXA/Desktop/fairness-AXA/part2')\n",
    "sys.path.append('/Users/z912yg/Library/CloudStorage/OneDrive-AXA/Desktop/fairness-AXA/part2/DP')\n",
    "\n",
    "from helpers import *\n",
    "from metrics import *\n",
    "from ratioNOMSE_autoenc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2ae0b41-ea66-4555-a8df-58c1a5829738",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "COMMOD_NOMSE = ratio_debiasing(learning_rate=0.001, batch_size=10000, lamb_fair=3, lamb_sparse=1e-2, lamb_diversity=1, \n",
    "                         num_epochs=1000, num_concepts=k, NN_r=NN_r, NN_s=NN_s, NN_c=NN_c, GPU='cuda:0')\n",
    "\n",
    "COMMOD_NOMSE.train(X_train, y_train, S_train, y_hat=proba_train_logreg, X_test=X_test, y_test=y_test, S_test=S_test, \n",
    "          y_hat_test=proba_test_logreg, plot_losses=False, dem_parity=True, write=False, ablation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6707382c-2070-4c32-8546-5712c0b1c47b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_COMMOD_NOMSE_test = COMMOD_NOMSE.predict(X_test, proba_test_logreg)\n",
    "y_COMMOD_NOMSE_train = COMMOD_NOMSE.predict(X_train, proba_train_logreg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e85d36e2-2629-44f2-aec6-125eb44b5540",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "test_target = (y_COMMOD_NOMSE_test != (proba_test_logreg>=0.5).astype(int)).astype(int)\n",
    "train_target = (y_COMMOD_NOMSE_train != (proba_train_logreg>=0.5).astype(int)).astype(int)\n",
    "\n",
    "results_2 = []\n",
    "\n",
    "\n",
    "for max_depth in range(2, 11):  \n",
    "    \n",
    "    f1_scores = []\n",
    "    \n",
    "    for time in range(5):\n",
    "        classifier = DecisionTreeClassifier(max_depth=max_depth)\n",
    "        classifier.fit(X_test, test_target)\n",
    "\n",
    "        predictions = classifier.predict(X_test)\n",
    "\n",
    "        f1 = f1_score(test_target, predictions)\n",
    "\n",
    "        f1_scores.append(f1)\n",
    "    \n",
    "    results_2.append({\n",
    "        'max_depth': max_depth,\n",
    "        'f1_mean': np.mean(f1_scores),\n",
    "        'f1_std': np.std(f1_scores)\n",
    "    })\n",
    "\n",
    "# Print the results\n",
    "for result in results_2:\n",
    "    print(f\"Max Depth: {result['max_depth']}, F1 mean: {result['f1_mean']}, F1 std: {result['f1_std']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "932a6251-9d37-4bcd-b689-c4240db50d29",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "max_depths = [result['max_depth'] for result in results]\n",
    "f1_means = [result['f1_mean'] for result in results]\n",
    "f1_means_2 = [result['f1_mean'] for result in results_2]\n",
    "\n",
    "# Create the plot\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(max_depths, f1_means, marker='o', linestyle='-', color='orange', label='COMMOD', linewidth=3)\n",
    "plt.plot(max_depths, f1_means_2, marker='o', linestyle='-', color='purple', label='COMMOD without min changes', linewidth=3)\n",
    "\n",
    "# Add labels and title\n",
    "plt.xlabel('Max Depth', size=24)\n",
    "plt.ylabel('F1 Score', size=24)\n",
    "plt.title('F1 Score vs. Depth for Decision Tree Classifier', size=26)\n",
    "plt.grid(False)\n",
    "plt.xticks(size=20)\n",
    "plt.yticks(size=20)\n",
    "plt.legend(fontsize=13)\n",
    "plt.xticks(max_depths)\n",
    "\n",
    "plt.savefig('F1_SCORE_lawschool_NOGRID.pdf')\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41ece613-8c84-42bb-9b97-83a14356be3e",
   "metadata": {
    "tags": []
   },
   "source": [
    "### RBMD (sparsity and diversity = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91faf9e2-1cb9-4126-8867-99bb5214c52a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/Users/z912yg/Library/CloudStorage/OneDrive-AXA/Desktop/fairness-AXA')\n",
    "sys.path.append('/Users/z912yg/Library/CloudStorage/OneDrive-AXA/Desktop/fairness-AXA/part1')\n",
    "sys.path.append('/Users/z912yg/Library/CloudStorage/OneDrive-AXA/Desktop/fairness-AXA/part1/DP')\n",
    "\n",
    "from helpers import *\n",
    "from metrics import *\n",
    "from ratioMSE import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26bb04ef-e1cb-45ca-8a43-86094fc7ce44",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class NN_r(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NN_r, self).__init__()\n",
    "        self.fc1 = nn.Linear(X_train.shape[1], 64)\n",
    "        self.fc2 = nn.Linear(64, 32)\n",
    "        self.fc3 = nn.Linear(32, 1)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        #import pdb;pdb.set_trace() \n",
    "        x = self.fc1(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "class NN_s(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NN_s, self).__init__()\n",
    "        self.fc1_s = nn.Linear(1, 64)\n",
    "        self.fc2_s = nn.Linear(64, 32)\n",
    "        self.fc3_s = nn.Linear(32, 16)\n",
    "        self.fc4_s = nn.Linear(16, 1)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        x = self.fc1_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc2_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc3_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc4_s(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6e9372c-9503-45aa-90d2-bcc0fd13e36d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "RBMD = ratio_debiasing(learning_rate=0.001, batch_size=10000, lamb_ratio=0.1 ,lamb_fair=3, num_epochs=1000, NN_r=NN_r, NN_s=NN_s, GPU='cuda:0')\n",
    "\n",
    "RBMD.train(X_train, y_train, S_train, y_hat=proba_train_logreg, X_test=X_test, y_test=y_test, S_test=S_test, \n",
    "          y_hat_test=proba_test_logreg, plot_losses=False, dem_parity=True, write=False, ablation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "943a7497-e5e1-42ac-a3cb-5ee17eca32af",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_RBMD_test = RBMD.predict(X_test, proba_test_logreg)\n",
    "y_RBMD_train = RBMD.predict(X_train, proba_train_logreg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0fca64e-0291-420c-93d5-aeb68bc90ae6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "test_target = (y_RBMD_test != (proba_test_logreg>=0.5).astype(int)).astype(int)\n",
    "train_target = (y_RBMD_train != (proba_train_logreg>=0.5).astype(int)).astype(int)\n",
    "\n",
    "results_3 = []\n",
    "\n",
    "\n",
    "for max_depth in range(2, 11):  \n",
    "    \n",
    "    f1_scores = []\n",
    "    \n",
    "    for time in range(5):\n",
    "        classifier = DecisionTreeClassifier(max_depth=max_depth)\n",
    "        classifier.fit(X_test, test_target)\n",
    "\n",
    "        predictions = classifier.predict(X_test)\n",
    "\n",
    "        f1 = f1_score(test_target, predictions)\n",
    "\n",
    "        f1_scores.append(f1)\n",
    "    \n",
    "    results_3.append({\n",
    "        'max_depth': max_depth,\n",
    "        'f1_mean': np.mean(f1_scores),\n",
    "        'f1_std': np.std(f1_scores)\n",
    "    })\n",
    "\n",
    "# Print the results\n",
    "for result in results_3:\n",
    "    print(f\"Max Depth: {result['max_depth']}, F1 mean: {result['f1_mean']}, F1 std: {result['f1_std']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1492e958-d780-4a43-86ae-c94cecee1b5d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "max_depths = [result['max_depth'] for result in results]\n",
    "f1_means = [result['f1_mean'] for result in results]\n",
    "f1_means_2 = [result['f1_mean'] for result in results_2]\n",
    "f1_means_3 = [result['f1_mean'] for result in results_3]\n",
    "\n",
    "# Create the plot\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(max_depths, f1_means, marker='o', linestyle='-', color='orange', label='COMMOD', linewidth=3)\n",
    "plt.plot(max_depths, f1_means_2, marker='o', linestyle='-', color='purple', label='COMMOD without min changes', linewidth=3)\n",
    "plt.plot(max_depths, f1_means_3, marker='o', linestyle='-', color='#66b266', label='COMMOD without interpretability constraint', linewidth=3)\n",
    "\n",
    "# Add labels and title\n",
    "plt.xlabel('Max Depth', size=24)\n",
    "plt.ylabel('F1 Score', size=24)\n",
    "plt.title('F1 Score vs. Depth for Decision Tree Classifier', size=26)\n",
    "plt.grid(False)\n",
    "plt.xticks(size=20)\n",
    "plt.yticks(size=20)\n",
    "plt.legend(fontsize=13)\n",
    "plt.xticks(max_depths)\n",
    "\n",
    "plt.savefig('F1_SCORE_lawschool_RBMD_NOGRID.pdf')\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be1ba8d3-717a-4e39-926f-240872bce54c",
   "metadata": {},
   "source": [
    "### FRAPPE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8406b35f-172a-4219-8cad-aedd6d1f6043",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/Users/z912yg/Library/CloudStorage/OneDrive-AXA/Desktop/fairness-AXA')\n",
    "sys.path.append('/Users/z912yg/Library/CloudStorage/OneDrive-AXA/Desktop/fairness-AXA/part1')\n",
    "sys.path.append('/Users/z912yg/Library/CloudStorage/OneDrive-AXA/Desktop/fairness-AXA/part1/DP')\n",
    "\n",
    "from helpers import *\n",
    "from metrics import *\n",
    "from FRAPPE_alg import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc38531d-0830-445f-bc1b-c0b383270a47",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class NN_r(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NN_r, self).__init__()\n",
    "        self.fc1 = nn.Linear(X_train.shape[1], 64)\n",
    "        self.fc2 = nn.Linear(64, 32)\n",
    "        self.fc3 = nn.Linear(32, 1)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        #import pdb;pdb.set_trace() \n",
    "        x = self.fc1(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "class NN_s(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NN_s, self).__init__()\n",
    "        self.fc1_s = nn.Linear(1, 64)\n",
    "        self.fc2_s = nn.Linear(64, 32)\n",
    "        self.fc3_s = nn.Linear(32, 16)\n",
    "        self.fc4_s = nn.Linear(16, 1)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        x = self.fc1_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc2_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc3_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc4_s(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3658f666-8dc4-43cc-9660-a843a97ee6f4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "frappe_model = FRAPPE(learning_rate=0.001, batch_size=10000, lamb_fair=2, num_epochs=1000, NN_r=NN_r, NN_s=NN_s, GPU=None)\n",
    "\n",
    "frappe_model.train(X_train, y_train, S_train, y_hat=proba_train_logreg, X_test=X_test, y_test=y_test, S_test=S_test, \n",
    "          y_hat_test=proba_test_logreg, plot_losses=False, dem_parity=True, write=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2efe5f4-3e33-43e3-8f96-b905a0a9281f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_FRAPPE_test = frappe_model.predict(X_test, proba_test_logreg)\n",
    "y_FRAPPE_train = frappe_model.predict(X_train, proba_train_logreg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb791288-d908-47be-ab52-6603c2d6e340",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "test_target = (y_FRAPPE_test != (proba_test_logreg>=0.5).astype(int)).astype(int)\n",
    "train_target = (y_FRAPPE_train != (proba_train_logreg>=0.5).astype(int)).astype(int)\n",
    "\n",
    "results_3 = []\n",
    "\n",
    "\n",
    "for max_depth in range(2, 11):  \n",
    "    \n",
    "    f1_scores = []\n",
    "    \n",
    "    for time in range(5):\n",
    "        classifier = DecisionTreeClassifier(max_depth=max_depth)\n",
    "        classifier.fit(X_test, test_target)\n",
    "\n",
    "        predictions = classifier.predict(X_test)\n",
    "\n",
    "        f1 = f1_score(test_target, predictions)\n",
    "\n",
    "        f1_scores.append(f1)\n",
    "    \n",
    "    results_3.append({\n",
    "        'max_depth': max_depth,\n",
    "        'f1_mean': np.mean(f1_scores),\n",
    "        'f1_std': np.std(f1_scores)\n",
    "    })\n",
    "\n",
    "# Print the results\n",
    "for result in results_3:\n",
    "    print(f\"Max Depth: {result['max_depth']}, F1 mean: {result['f1_mean']}, F1 std: {result['f1_std']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92da50ec-9605-4c25-8b71-fcf001565dbf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "max_depths = [result['max_depth'] for result in results]\n",
    "f1_means = [result['f1_mean'] for result in results]\n",
    "f1_means_2 = [result['f1_mean'] for result in results_2]\n",
    "f1_means_3 = [result['f1_mean'] for result in results_3]\n",
    "\n",
    "# Create the plot\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(max_depths, f1_means, marker='o', linestyle='-', color='orange', label='COMMOD', linewidth=3)\n",
    "plt.plot(max_depths, f1_means_2, marker='o', linestyle='-', color='purple', label='COMMOD without min changes', linewidth=3)\n",
    "plt.plot(max_depths, f1_means_3, marker='o', linestyle='-', color='black', label='FRAPPE', linewidth=3)\n",
    "\n",
    "# Add labels and title\n",
    "plt.xlabel('Max Depth', size=24)\n",
    "plt.ylabel('F1 Score', size=24)\n",
    "plt.title('F1 Score vs. Depth for Decision Tree Classifier', size=26)\n",
    "plt.grid(False)\n",
    "plt.xticks(size=20)\n",
    "plt.yticks(size=20)\n",
    "plt.legend(fontsize=13)\n",
    "plt.xticks(max_depths)\n",
    "\n",
    "plt.savefig('F1_SCORE_lawschool_NOGRID.pdf')\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2313b39e-e603-4527-bebd-26cb54095e1c",
   "metadata": {},
   "source": [
    "### Zhang"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83708da5-9ab8-4f2c-885e-4d2f18f05b2d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from zhang import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d84ec4e-51fc-4429-b904-3f58c354ab6b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class NN_y(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NN_y, self).__init__()\n",
    "        self.fc1 = nn.Linear(X_train.shape[1], 64)\n",
    "        self.fc2 = nn.Linear(64, 32)\n",
    "        self.fc3 = nn.Linear(32, 1)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        #import pdb;pdb.set_trace() \n",
    "        x = self.fc1(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "class NN_s(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NN_s, self).__init__()\n",
    "        self.fc1_s = nn.Linear(1, 64)\n",
    "        self.fc2_s = nn.Linear(64, 32)\n",
    "        self.fc3_s = nn.Linear(32, 16)\n",
    "        self.fc4_s = nn.Linear(16, 1)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        x = self.fc1_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc2_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc3_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc4_s(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fdedd74-5a56-4adc-b64a-593b90482794",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "adv = ADV_debiasing(learning_rate=0.001,batch_size=512, lamb=5, num_epochs=1000, NN_y= NN_y, NN_s= NN_s, GPU=None)\n",
    "\n",
    "adv.train(X_train, y_train, S_train, X_test, y_test, S_test, plot_losses=False, dem_parity=True, y_biased_predictions=predictions_test_logreg, write=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdae76be-c413-404c-b954-71ab9d7704ab",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_zhang_test = (adv.predict(X_test)).flatten()\n",
    "y_zhang_train = (adv.predict(X_train)).flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d8ee932-fcdc-4b19-839b-d937e4e1ec83",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "test_target = (y_zhang_test != (proba_test_logreg>=0.5).astype(int)).astype(int)\n",
    "train_target = (y_zhang_train != (proba_train_logreg>=0.5).astype(int)).astype(int)\n",
    "\n",
    "results_4 = []\n",
    "\n",
    "\n",
    "for max_depth in range(2, 11):  \n",
    "    \n",
    "    f1_scores = []\n",
    "    \n",
    "    for time in range(5):\n",
    "        classifier = DecisionTreeClassifier(max_depth=max_depth)\n",
    "        classifier.fit(X_test, test_target)\n",
    "\n",
    "        predictions = classifier.predict(X_test)\n",
    "\n",
    "        f1 = f1_score(test_target, predictions)\n",
    "\n",
    "        f1_scores.append(f1)\n",
    "    \n",
    "    results_4.append({\n",
    "        'max_depth': max_depth,\n",
    "        'f1_mean': np.mean(f1_scores),\n",
    "        'f1_std': np.std(f1_scores)\n",
    "    })\n",
    "\n",
    "# Print the results\n",
    "for result in results_4:\n",
    "    print(f\"Max Depth: {result['max_depth']}, F1 mean: {result['f1_mean']}, F1 std: {result['f1_std']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a86a3be8-a9a3-4ab6-885a-6a8518bcb379",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "max_depths = [result['max_depth'] for result in results]\n",
    "f1_means = [result['f1_mean'] for result in results]\n",
    "f1_means_2 = [result['f1_mean'] for result in results_2]\n",
    "f1_means_3 = [result['f1_mean'] for result in results_3]\n",
    "f1_means_4 = [result['f1_mean'] for result in results_4]\n",
    "\n",
    "# Create the plot\n",
    "plt.figure(figsize=(10, 7))\n",
    "plt.plot(max_depths, f1_means, marker='o', linestyle='-', color='#ff7f0e', label='COMMOD (Ours)', linewidth=3)\n",
    "#plt.plot(max_depths, f1_means_2, marker='o', linestyle='-', color='purple', label='COMMOD without min changes', linewidth=3)\n",
    "plt.plot(max_depths, f1_means_3, marker='o', linestyle='-', color='black', label='FRAPPÈ (Tifrea et al., 2024)', linewidth=3)\n",
    "plt.plot(max_depths, f1_means_4, marker='o', linestyle='-', color='#1f77b4', label='AdvDebias (Zhang et al., 2018)', linewidth=3)\n",
    "\n",
    "# Add labels and title\n",
    "plt.xlabel('Max Depth', size=24)\n",
    "plt.ylabel('F1 Score', size=24)\n",
    "plt.title('LAW SCHOOL', size=26)\n",
    "plt.grid(False)\n",
    "plt.xticks(size=20)\n",
    "plt.yticks(size=20)\n",
    "plt.legend(fontsize=13)\n",
    "plt.xticks(max_depths)\n",
    "\n",
    "plt.savefig('F1_SCORE_lawschool_NOGRID.pdf')\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68f0e6f6-b44a-4403-b136-6e9506af9178",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "# Initialize an empty list to store DataFrames from each run\n",
    "df_list = []\n",
    "\n",
    "for i in range(10):\n",
    "    ratio_model = ratio_debiasing(learning_rate=0.001, batch_size=512, lamb_fair=2, lamb_sparse=1e-2, lamb_diversity=1, \n",
    "                             lamb_ratio=0.1, num_epochs=1000, num_concepts=k, NN_r=NN_r, NN_s=NN_s, NN_c=NN_c, GPU='cuda:0')\n",
    "\n",
    "    ratio_model.train(X_train, y_train, S_train, y_hat=proba_train_logreg, X_test=X_test, y_test=y_test, S_test=S_test, \n",
    "              y_hat_test=proba_test_logreg, plot_losses=False, dem_parity=True, write=False)\n",
    "    \n",
    "    # Assuming the discriminator is stored in an attribute called 'discriminator'\n",
    "    discriminator = ratio_model.m_NN_c\n",
    "\n",
    "    discriminator_weights = discriminator.state_dict()\n",
    "    weights_matrix = discriminator_weights['fc1.weight'].numpy()\n",
    "\n",
    "    weights_flat = weights_matrix.flatten()\n",
    "    indices = np.unravel_index(np.arange(weights_flat.size), weights_matrix.shape)\n",
    "\n",
    "    df = pd.DataFrame({\n",
    "        'NN_Weight': weights_flat,\n",
    "        'Concept': indices[0],\n",
    "        'InputFeatureIndex': indices[1]\n",
    "    })\n",
    "\n",
    "    df['InputFeature'] = df['InputFeatureIndex'].apply(lambda x: column_names[x])\n",
    "    df['AbsWeight'] = df['NN_Weight'].abs()\n",
    "    \n",
    "    # Append the DataFrame to the list\n",
    "    df_list.append(df)\n",
    "    df_sorted = df.sort_values(by='AbsWeight', ascending=False)\n",
    "    print(df_sorted.head(15))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a50f1c3-af35-4627-b2ac-2573d8b0c347",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Concatenate all DataFrames from each run\n",
    "df_all_runs = pd.concat(df_list, ignore_index=True)\n",
    "\n",
    "# Calculate the average and standard deviation of the absolute weight for each feature and concept\n",
    "df_stats = df_all_runs.groupby(['Concept', 'InputFeature']).agg(\n",
    "    AvgAbsWeight=('AbsWeight', 'mean'),\n",
    "    StdAbsWeight=('AbsWeight', 'std')\n",
    ").reset_index()\n",
    "\n",
    "# Optionally, sort by Concept and then by InputFeature\n",
    "df_stats = df_stats.sort_values(by=['Concept', 'InputFeature'])\n",
    "\n",
    "print(df_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24f024d9-8055-407f-aa92-4bb0a5edb136",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "means = []\n",
    "stds = []\n",
    "not_sparse = []\n",
    "\n",
    "for i, item in enumerate(df_all_runs['InputFeature'].unique()):\n",
    "    means.append(df_all_runs[df_all_runs['InputFeature']==item]['AbsWeight'].mean())\n",
    "    stds.append(df_all_runs[df_all_runs['InputFeature']==item]['AbsWeight'].std())\n",
    "    not_sparse.append(np.sum(df_all_runs[df_all_runs['InputFeature']==item]['AbsWeight']>=0.01))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cf14c70-dfd9-420f-949d-6422c943ef4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a DataFrame for plotting\n",
    "df_stats = pd.DataFrame({\n",
    "    'Feature': column_names,\n",
    "    'Mean': means,\n",
    "    'Std': stds,\n",
    "    'NotSparse': not_sparse\n",
    "})\n",
    "\n",
    "# Plot the histogram with error bars\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.bar(df_stats['Feature'], df_stats['Mean'], yerr=df_stats['Std'], capsize=5, color='skyblue', alpha=0.7)\n",
    "plt.xlabel('Feature')\n",
    "plt.ylabel('Mean Absolute Weight')\n",
    "plt.title('Mean Absolute Weight with Standard Deviation for Each Feature')\n",
    "plt.xticks(rotation=45)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ea8f407-b2b3-41b0-b8c2-fd49297e332c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Plot the histogram with error bars\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.bar(df_stats['Feature'], df_stats['NotSparse'], capsize=5, color='skyblue', alpha=0.7)\n",
    "plt.xlabel('Feature')\n",
    "plt.ylabel('Count')\n",
    "plt.title('Count over 20 concepts when each feature is selected')\n",
    "plt.xticks(rotation=45)\n",
    "\n",
    "# Customize y-ticks to display every 1\n",
    "max_count = df_stats['NotSparse'].max()\n",
    "plt.yticks(np.arange(0, max_count + 1, 2))\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13d0601d-6778-43c6-a0b1-0e26df6a0e63",
   "metadata": {},
   "source": [
    "Remove not when we had 0% of changes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bec3204-0f52-4139-a62f-e9719677c6f2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "means = []\n",
    "stds = []\n",
    "not_sparse = []\n",
    "\n",
    "for i, item in enumerate(df_all_runs_new['InputFeature'].unique()):\n",
    "    means.append(df_all_runs_new[df_all_runs_new['InputFeature']==item]['AbsWeight'].mean())\n",
    "    stds.append(df_all_runs_new[df_all_runs_new['InputFeature']==item]['AbsWeight'].std())\n",
    "    not_sparse.append(np.sum(df_all_runs_new[df_all_runs_new['InputFeature']==item]['AbsWeight']>=0.01))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "571a37c9-0730-481f-b2eb-ac75e25d8f28",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Create a DataFrame for plotting\n",
    "df_stats = pd.DataFrame({\n",
    "    'Feature': column_names,\n",
    "    'Mean': means,\n",
    "    'Std': stds,\n",
    "    'NotSparse': not_sparse\n",
    "})\n",
    "\n",
    "# Plot the histogram with error bars\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.bar(df_stats['Feature'], df_stats['Mean'], yerr=df_stats['Std'], capsize=5, color='skyblue', alpha=0.7)\n",
    "plt.xlabel('Feature')\n",
    "plt.ylabel('Mean Absolute Weight')\n",
    "plt.title('Mean Absolute Weight with Standard Deviation for Each Feature')\n",
    "plt.xticks(rotation=45)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "322caa12-b301-4e30-8e15-9b55eabe6d75",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Plot the histogram with error bars\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.bar(df_stats['Feature'], df_stats['NotSparse'], capsize=5, color='skyblue', alpha=0.7)\n",
    "plt.xlabel('Feature')\n",
    "plt.ylabel('Count')\n",
    "plt.title('Count over 12 concepts when each feature is selected')\n",
    "plt.xticks(rotation=45)\n",
    "\n",
    "# Customize y-ticks to display every 1\n",
    "max_count = df_stats['NotSparse'].max()\n",
    "plt.yticks(np.arange(0, max_count + 1, 2))\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da581086-3219-4c14-90b5-25d12b343fcc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "concepts = ratio_model.predict_concepts(X_test)\n",
    "ratios = ratio_model.predict_ratios(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d70d97b8-802f-470f-bc16-c9b7cba27212",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8,6))\n",
    "\n",
    "plt.scatter(concepts[:, 0], concepts[:, 1], c=ratios, cmap='viridis', marker='o')\n",
    "plt.colorbar(label='Ratio r(X)')\n",
    "plt.axhline(0, color='black', linestyle='--', linewidth=1)\n",
    "plt.axvline(0, color='black', linestyle='--', linewidth=1)\n",
    "plt.xlabel('Concept1')\n",
    "plt.ylabel('Concept2')\n",
    "plt.title('Concepts and Ratio value')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba68a3b4-fdcc-448d-b952-19c7394115e9",
   "metadata": {},
   "source": [
    "### TOP-K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58a593c8-561f-4e0b-8e43-7292de84b636",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Assuming df_list is your list of DataFrames\n",
    "filtered_df_list = []\n",
    "\n",
    "for df in df_list:\n",
    "    top_rows = pd.concat([\n",
    "        df[df['Concept'] == 0].nlargest(3, 'AbsWeight'),\n",
    "        df[df['Concept'] == 1].nlargest(3, 'AbsWeight')\n",
    "    ])\n",
    "    filtered_df_list.append(top_rows)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17578fcf-b872-48d9-8393-013f0afaa2c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Concatenate all DataFrames from each run\n",
    "df_all_runs = pd.concat(filtered_df_list, ignore_index=True)\n",
    "df_all_runs.sample(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ead2178-7a13-4dbf-9a40-443c025ba858",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Count occurrences of each InputFeature\n",
    "input_feature_counts = df_all_runs['InputFeature'].value_counts()\n",
    "\n",
    "# Plot the histogram\n",
    "plt.figure(figsize=(10, 6))\n",
    "input_feature_counts.plot(kind='bar')\n",
    "plt.xlabel('Feature')\n",
    "plt.ylabel('Count')\n",
    "plt.title('Histogram of TOP-3 features over 20 concepts')\n",
    "plt.xticks(rotation=90)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22d09d6a-d10f-4f02-aa8f-367c2c9ebd0d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Assuming filtered_df_list is already created and contains the filtered DataFrames\n",
    "df_all_runs = pd.concat(filtered_df_list, ignore_index=True)\n",
    "\n",
    "# Create new columns for positive and negative weights\n",
    "df_all_runs['WeightSign'] = df_all_runs['NN_Weight'].apply(lambda x: 'Positive' if x >= 0 else 'Negative')\n",
    "\n",
    "# Separate counts for positive and negative weights\n",
    "positive_counts = df_all_runs[df_all_runs['WeightSign'] == 'Positive']['InputFeature'].value_counts()\n",
    "negative_counts = df_all_runs[df_all_runs['WeightSign'] == 'Negative']['InputFeature'].value_counts()\n",
    "\n",
    "# Combine counts into a single DataFrame\n",
    "counts_df = pd.DataFrame({'Positive': positive_counts, 'Negative': negative_counts}).fillna(0)\n",
    "\n",
    "# Plot the combined counts as a grouped bar plot\n",
    "counts_df.plot(kind='bar', figsize=(12, 6))\n",
    "plt.xlabel('Feature')\n",
    "plt.ylabel('Count')\n",
    "plt.title('Histogram of TOP-3 features with Positive and Negative Weights over 20 concepts')\n",
    "plt.xticks(rotation=90)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03fc8f85-2863-4013-aa63-92a749c6c335",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "feature_counts = {}\n",
    "\n",
    "for df in filtered_df_list:\n",
    "    # Get unique features\n",
    "    features = df['InputFeature'].unique()\n",
    "    for feature in features:\n",
    "        # Filter the rows corresponding to the current feature\n",
    "        feature_df = df[df['InputFeature'] == feature]\n",
    "        concepts = feature_df['Concept'].unique()\n",
    "        # Update counts in the dictionary\n",
    "        if feature not in feature_counts:\n",
    "            feature_counts[feature] = {'one_concept': 0, 'both_concepts': 0}\n",
    "        if len(concepts) == 1:\n",
    "            feature_counts[feature]['one_concept'] += 1\n",
    "        else:\n",
    "            feature_counts[feature]['both_concepts'] += 1\n",
    "            counts_df = pd.DataFrame(feature_counts).T.reset_index().rename(columns={'index': 'Feature'})\n",
    "\n",
    "# Plotting with the same style as the previous plot but with separate columns for each feature\n",
    "plt.figure(figsize=(8, 6))\n",
    "counts_df.set_index('Feature').plot(kind='bar', color=['#1f77b4', '#ff7f0e'])\n",
    "plt.title('Count of Features in One Concept vs. Both Concepts')\n",
    "plt.xlabel('Input Feature')\n",
    "plt.ylabel('Count')\n",
    "plt.legend(['One Concept', 'Both Concepts'])\n",
    "plt.xticks(rotation=45)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff3a4a07-b032-4c45-85a5-47964eb93f79",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Initialize a dictionary to hold the counts for each feature with additional conditions\n",
    "feature_counts = {}\n",
    "\n",
    "# Iterate through each dataframe\n",
    "for df in filtered_df_list:\n",
    "    # Get unique features\n",
    "    features = df['InputFeature'].unique()\n",
    "    for feature in features:\n",
    "        # Filter the rows corresponding to the current feature\n",
    "        feature_df = df[df['InputFeature'] == feature]\n",
    "        concepts = feature_df['Concept'].unique()\n",
    "        # Update counts in the dictionary\n",
    "        if feature not in feature_counts:\n",
    "            feature_counts[feature] = {'one_concept': 0, 'both_concepts_same_sign': 0, 'both_concepts_diff_sign': 0}\n",
    "        if len(concepts) == 1:\n",
    "            feature_counts[feature]['one_concept'] += 1\n",
    "        else:\n",
    "            weights = feature_df['NN_Weight']\n",
    "            same_sign = all(weights > 0) or all(weights < 0)\n",
    "            if same_sign:\n",
    "                feature_counts[feature]['both_concepts_same_sign'] += 1\n",
    "            else:\n",
    "                feature_counts[feature]['both_concepts_diff_sign'] += 1\n",
    "\n",
    "# Convert the dictionary to a DataFrame for plotting\n",
    "counts_df = pd.DataFrame(feature_counts).T.reset_index().rename(columns={'index': 'Feature'})\n",
    "\n",
    "# Calculate the combined counts for both concepts\n",
    "counts_df['both_concepts'] = counts_df['both_concepts_same_sign'] + counts_df['both_concepts_diff_sign']\n",
    "\n",
    "# Plotting with two separate columns for each feature\n",
    "plt.figure(figsize=(10, 7))\n",
    "bar_width = 0.35\n",
    "index = np.arange(len(counts_df['Feature']))\n",
    "\n",
    "# Bar plot for one_concept\n",
    "plt.bar(index, counts_df['one_concept'], bar_width, label='One Concept', color='#1f77b4')\n",
    "\n",
    "# Stacked bar plot for both_concepts_same_sign and both_concepts_diff_sign\n",
    "plt.bar(index + bar_width, counts_df['both_concepts_same_sign'], bar_width, label='Both Concepts Same Sign', color='#ff7f0e')\n",
    "plt.bar(index + bar_width, counts_df['both_concepts_diff_sign'], bar_width, bottom=counts_df['both_concepts_same_sign'], label='Both Concepts Different Sign', color='#2ca02c')\n",
    "\n",
    "# Plot details\n",
    "plt.xlabel('Input Feature')\n",
    "plt.ylabel('Count')\n",
    "plt.title('Count of Features in One Concept vs. Both Concepts (Same/Different Sign)')\n",
    "plt.xticks(index + bar_width / 2, counts_df['Feature'], rotation=45)\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1624b755-7581-49d1-9804-a1f2236866a3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from itertools import combinations\n",
    "\n",
    "pairs = list(combinations(column_names, 2))\n",
    "\n",
    "pair_dictionary = {}\n",
    "\n",
    "for pair in pairs:\n",
    "    pair_dictionary[pair]=0\n",
    "    \n",
    "\n",
    "for df in filtered_df_list:\n",
    "    # Iterate over each concept (0 and 1)\n",
    "    for concept in [0, 1]:\n",
    "        # Select rows where Concept equals the current concept\n",
    "        subset = df[df['Concept'] == concept]\n",
    "        \n",
    "        # Extract feature names in this subset\n",
    "        feature_names = subset['InputFeature'].unique()\n",
    "        \n",
    "        # Generate all pairs of feature names\n",
    "        pairs = list(combinations(feature_names, 2))\n",
    "        \n",
    "        for pair in pairs:\n",
    "            if pair in list(pair_dictionary.keys()):\n",
    "                pair_dictionary[pair]+=1\n",
    "            else:\n",
    "                pair_dictionary[tuple(reversed(pair))]+=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0819e29e-cfb3-4dd7-879f-58737c610092",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Extract unique feature names\n",
    "features = sorted({feat for pair in pair_dictionary.keys() for feat in pair})\n",
    "\n",
    "# Create a matrix to store the counts\n",
    "num_features = len(features)\n",
    "count_matrix = np.zeros((num_features, num_features), dtype=int)\n",
    "\n",
    "# Fill the matrix with pair counts\n",
    "for i, feat1 in enumerate(features):\n",
    "    for j, feat2 in enumerate(features):\n",
    "        if (feat1, feat2) in pair_dictionary:\n",
    "            count_matrix[i, j] = pair_dictionary[(feat1, feat2)]\n",
    "        elif (feat2, feat1) in pair_dictionary:\n",
    "            count_matrix[i, j] = pair_dictionary[(feat2, feat1)]\n",
    "\n",
    "# Plotting the heatmap with annotations\n",
    "plt.figure(figsize=(10, 8))\n",
    "plt.imshow(count_matrix, cmap='viridis', interpolation='nearest')\n",
    "\n",
    "# Add annotations\n",
    "for i in range(num_features):\n",
    "    for j in range(num_features):\n",
    "        plt.text(j, i, str(count_matrix[i, j]), ha='center', va='center', color='white')\n",
    "\n",
    "plt.colorbar(label='Count of Pairs')\n",
    "plt.title('Pairwise Feature Co-occurrence in Same Concept')\n",
    "plt.xticks(np.arange(num_features), features, rotation=45)\n",
    "plt.yticks(np.arange(num_features), features)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58ee0524-1aec-4714-a287-936c1f88ae03",
   "metadata": {},
   "source": [
    "We want to have an activation value of each concept."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e545dec9-2158-4067-9935-76db7c8a7d90",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "k = 2 #number of \"concepts\"\n",
    "\n",
    "class NN_r(nn.Module):\n",
    "#Network to predict ratio from concepts\n",
    "    def __init__(self):\n",
    "        super(NN_r, self).__init__()\n",
    "        self.fc1 = nn.Linear(k, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        x = self.fc1(x)\n",
    "        return x\n",
    "    \n",
    "class NN_c(nn.Module):\n",
    "#Network to predict concepts from features\n",
    "    def __init__(self):\n",
    "        super(NN_c, self).__init__()\n",
    "        self.fc1 = nn.Linear(X_train.shape[1], k)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        x = self.fc1(x)\n",
    "        x = torch.sigmoid(x)\n",
    "        return x\n",
    "\n",
    "class NN_s(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NN_s, self).__init__()\n",
    "        self.fc1_s = nn.Linear(1, 64)\n",
    "        self.fc2_s = nn.Linear(64, 32)\n",
    "        self.fc3_s = nn.Linear(32, 16)\n",
    "        self.fc4_s = nn.Linear(16, 1)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        x = self.fc1_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc2_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc3_s(x)\n",
    "        x = torch.relu(x)\n",
    "        x = self.fc4_s(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41aca5b4-4374-4637-8a9e-71a05b99bcd8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "# Initialize an empty list to store DataFrames from each run\n",
    "df_list = []\n",
    "\n",
    "for i in range(10):\n",
    "    ratio_model = ratio_debiasing(learning_rate=0.001, batch_size=512, lamb_fair=2, lamb_sparse=1e-2, lamb_diversity=1, \n",
    "                             lamb_ratio=0.1, num_epochs=1000, num_concepts=k, NN_r=NN_r, NN_s=NN_s, NN_c=NN_c, GPU='cuda:0')\n",
    "\n",
    "    ratio_model.train(X_train, y_train, S_train, y_hat=proba_train_logreg, X_test=X_test, y_test=y_test, S_test=S_test, \n",
    "              y_hat_test=proba_test_logreg, plot_losses=False, dem_parity=True, write=False)\n",
    "    \n",
    "    # Assuming the discriminator is stored in an attribute called 'discriminator'\n",
    "    discriminator = ratio_model.m_NN_c\n",
    "\n",
    "    discriminator_weights = discriminator.state_dict()\n",
    "    weights_matrix = discriminator_weights['fc1.weight'].numpy()\n",
    "\n",
    "    weights_flat = weights_matrix.flatten()\n",
    "    indices = np.unravel_index(np.arange(weights_flat.size), weights_matrix.shape)\n",
    "\n",
    "    df = pd.DataFrame({\n",
    "        'NN_Weight': weights_flat,\n",
    "        'Concept': indices[0],\n",
    "        'InputFeatureIndex': indices[1]\n",
    "    })\n",
    "\n",
    "    df['InputFeature'] = df['InputFeatureIndex'].apply(lambda x: column_names[x])\n",
    "    df['AbsWeight'] = df['NN_Weight'].abs()\n",
    "    \n",
    "    # Append the DataFrame to the list\n",
    "    df_list.append(df)\n",
    "    df_sorted = df.sort_values(by='AbsWeight', ascending=False)\n",
    "    print(df_sorted.head(15))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2d5c552-a810-465a-8efe-51e3a627f428",
   "metadata": {},
   "source": [
    "### Trying different parameters..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29840ae4-61bb-410e-b53a-fb9a79266f16",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ratio_model = ratio_debiasing(learning_rate=0.001, batch_size=512, lamb_fair=3, lamb_sparse=1e-2, lamb_diversity=0.1, \n",
    "                         lamb_ratio=0.01, num_epochs=1000, num_concepts=k, NN_r=NN_r, NN_s=NN_s, NN_c=NN_c, GPU='cuda:0')\n",
    "\n",
    "ratio_model.train(X_train, y_train, S_train, y_hat=proba_train_logreg, X_test=X_test, y_test=y_test, S_test=S_test, \n",
    "          y_hat_test=proba_test_logreg, plot_losses=False, dem_parity=True, write=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "827593b7-f612-4518-b4cb-8e0534786416",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "concepts = ratio_model.predict_concepts(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d26ea441-ddb4-4005-8fe6-a07080362258",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "np.min(concepts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9ae6f08-679c-49a0-bd99-f178eff64c7d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
