{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7ad2aea-b906-4dd8-b23f-b08282e7ba1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from scipy.stats import cumfreq\n",
    "import os\n",
    "from numpy.linalg import svd, det\n",
    "from sklearn.utils import resample\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.linear_model import SGDClassifier, SGDRegressor, Perceptron, LogisticRegression\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39c69867-d342-4e20-914d-fe98d6f8133d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def TPR(Y_p, Y_t, tag):\n",
    "    return np.mean(Y_p[Y_t == tag] == tag)\n",
    "\n",
    "def get_TPR_gap(Y_pred, Y_true, Y_gender):\n",
    "    unique_classes = np.unique(Y_true)\n",
    "    tpr_gaps = {}\n",
    "    \n",
    "    for tag in unique_classes:\n",
    "        tpr_m = TPR(Y_pred[Y_gender == 0], Y_true[Y_gender == 0], tag)  # Male TPR\n",
    "        tpr_f = TPR(Y_pred[Y_gender == 1], Y_true[Y_gender == 1], tag)  # Female TPR\n",
    "        tpr_gaps[tag] = tpr_f - tpr_m  # TPR gap for the current tag\n",
    "    \n",
    "    rms_gap = np.sqrt(np.mean([gap**2 for gap in tpr_gaps.values()]))  # RMS of TPR gaps\n",
    "    \n",
    "    return tpr_gaps, rms_gap\n",
    "\n",
    "def DP_metric(y_pred, sensitive_attr):\n",
    "\n",
    "    y_pred = np.asarray(y_pred)\n",
    "    sensitive_attr = np.asarray(sensitive_attr)\n",
    "    classes = np.unique(y_pred)\n",
    "\n",
    "    dp_gaps = {}\n",
    "\n",
    "    for c in classes:\n",
    "        group_0_mask = (sensitive_attr == 0)\n",
    "        group_1_mask = (sensitive_attr == 1)\n",
    "\n",
    "        rate_0 = np.mean(y_pred[group_0_mask] == c)\n",
    "        rate_1 = np.mean(y_pred[group_1_mask] == c)\n",
    "\n",
    "        dp_gaps[c] = np.abs(rate_0 - rate_1)\n",
    "\n",
    "    return np.mean(list(dp_gaps.values()))\n",
    "\n",
    "def MCDP_metric(Y_pred, Y_true, Y_gender, num_bins=10000):\n",
    "\n",
    "    K = Y_pred.shape[1] \n",
    "    mcdp_gaps = []\n",
    "\n",
    "    for k in range(K-1):\n",
    "        Y_pred_k = Y_pred[:, k]\n",
    "\n",
    "        Y_pred_k_gender_0 = Y_pred_k[Y_gender == 0]\n",
    "        Y_pred_k_gender_1 = Y_pred_k[Y_gender == 1]\n",
    "\n",
    "        cdf_0 = cumfreq(Y_pred_k_gender_0, numbins=num_bins, defaultreallimits=(0, 1))\n",
    "        cdf_1 = cumfreq(Y_pred_k_gender_1, numbins=num_bins, defaultreallimits=(0, 1))\n",
    "\n",
    "        cdf_0 = cdf_0.cumcount / len(Y_pred_k_gender_0) if len(Y_pred_k_gender_0) > 0 else np.zeros(num_bins)\n",
    "        cdf_1 = cdf_1.cumcount / len(Y_pred_k_gender_1) if len(Y_pred_k_gender_1) > 0 else np.zeros(num_bins)\n",
    "\n",
    "        delta_F = np.abs(cdf_0 - cdf_1)\n",
    "        mcdp_gaps.append(np.max(delta_F))\n",
    "\n",
    "    rms_mcdp = np.sqrt(np.mean([gap**2 for gap in mcdp_gaps]))\n",
    "    return rms_mcdp\n",
    "\n",
    "def pairwise_distances(x, y=None):\n",
    "    x = np.atleast_2d(x)\n",
    "    if y is None:\n",
    "        y = x\n",
    "    x_norm = np.sum(x ** 2, axis=1).reshape(-1, 1)\n",
    "    y_norm = np.sum(y ** 2, axis=1).reshape(1, -1)\n",
    "    dist = x_norm + y_norm - 2 * np.dot(x, y.T)\n",
    "    return dist\n",
    "\n",
    "def dcor(X, Y):\n",
    "    n = X.shape[0]\n",
    "    DX = pairwise_distances(X)\n",
    "    DY = pairwise_distances(Y)\n",
    "    H = np.eye(n) - np.ones((n, n)) / n\n",
    "    RX = H @ DX @ H\n",
    "    RY = H @ DY @ H\n",
    "    covXY = np.sum(RX * RY) / (n * n)\n",
    "    return covXY"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "929b51f1-0ac3-4be5-a7e8-bdd45d619b11",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_bank_marketing_data(path=\"./data/bank\", sensitive_attribute=\"age\"):\n",
    "    df = pd.read_csv(os.path.join(path, \"bank-additional-full.csv\"), sep=\";\")\n",
    "    categorical_features = [\"job\", \"marital\", \"education\", \"default\", \"housing\", \"loan\", \"contact\", \"month\", \"day_of_week\", \"poutcome\"]\n",
    "\n",
    "    \n",
    "    df[\"y\"] = df[\"y\"].replace({\"yes\": 1, \"no\": 0})\n",
    "    y = df[\"y\"].to_frame()\n",
    "    s = df[sensitive_attribute]\n",
    "    s = (s >= 25).astype(int).to_frame()\n",
    "\n",
    "    X = df.drop(columns=[\"y\", \"age\"])\n",
    "\n",
    "    X[categorical_features] = X[categorical_features].astype(\"string\")\n",
    "\n",
    "    # Convert all non-uint8 columns to float32\n",
    "    string_cols = X.select_dtypes(exclude=\"string\").columns\n",
    "    X[string_cols] = X[string_cols].astype(\"float32\")\n",
    "\n",
    "    return X, y, s\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "758e1dd9-b79b-48e1-8b70-15756f5018e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_all, Y_all, Z_all = load_bank_marketing_data(path=\"./data/bank\", sensitive_attribute=\"age\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d756860-9edb-4091-b4eb-3c78e24dc2c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "categorical_cols = X_all.select_dtypes(\"string\").columns\n",
    "if len(categorical_cols) > 0:\n",
    "    X_all = pd.get_dummies(X_all, columns=categorical_cols)\n",
    "\n",
    "\n",
    "n_features = X_all.shape[1]\n",
    "n_classes = len(np.unique(Y_all))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4a096fb-f9d4-47cc-9bc5-0b2bdc00fd53",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_all = (X_all - np.mean(X_all, axis=0)) / np.std(X_all, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5903881a-eff5-4ea7-b4c3-6f8610f15e7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, Y_train, Y_test, Z_train, Z_test = train_test_split(\n",
    "        X_all.values.astype('float32'), Y_all.values.reshape(-1), Z_all.values.reshape(-1), test_size=0.2, random_state=42\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cb6a7bf-1cc7-4cb3-b760-1072f945464b",
   "metadata": {},
   "outputs": [],
   "source": [
    "clf_Y = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', max_iter = 100)\n",
    "\n",
    "clf_Y.fit(X_train, Y_train)\n",
    "Y_pred_orig = clf_Y.predict(X_test)\n",
    "print(f\"Original Accuracy {clf_Y.score(X_test, Y_test)}.\\n\")\n",
    "print(f\"Original DP Gaps {DP_metric(Y_pred_orig, Z_test)}.\\n\") \n",
    "print(f\"Original MCDP Gaps {MCDP_metric(clf_Y.predict_proba(X_test), Y_test, Z_test, num_bins=100)}.\\n\") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f22e1235-eeec-4393-be1f-cfcf98c16bbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_train_new = clf_Y.predict_proba(X_train)\n",
    "Y_train_new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d1e1742-9330-4f60-82f3-11df0f345298",
   "metadata": {},
   "outputs": [],
   "source": [
    "hist(Y_train_new[:,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54ea0c4d-7126-4367-b28b-349d691868d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(Y_train_new[:,1], bins=10, edgecolor='black')\n",
    "plt.xlabel('Value')\n",
    "plt.ylabel('Frequency')\n",
    "plt.title('Histogram of Data')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce251bc4-4af4-4013-b0ea-534a86dce995",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "023835f0-583b-4109-b0ed-5c77f0104c9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "clf_Z = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', max_iter = 100)\n",
    "\n",
    "clf_Z.fit(X_train, Z_train)\n",
    "print(clf_Z.score(X_test, Z_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "790b0776-c9e1-4dca-9742-1b3cb40d800d",
   "metadata": {},
   "outputs": [],
   "source": [
    "Z_train_new = clf_Z.predict_proba(X_train)\n",
    "Z_train_new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65ac13e7-edd2-498b-8718-45f2fcca95a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(Z_train_new[:,0], bins=10, edgecolor='black')\n",
    "plt.xlabel('Value')\n",
    "plt.ylabel('Frequency')\n",
    "plt.title('Histogram of Data')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a2a3a37-b770-458e-bc52-a38816553b20",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.shape(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c0266d5-1207-4fb8-bb41-1187a2329f32",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d4bed85-357a-4fe3-9f3f-e1fb0f882701",
   "metadata": {},
   "outputs": [],
   "source": [
    "def inv_sqrt_matrix(A, eps=1e-10):\n",
    "    eigvals, eigvecs = np.linalg.eigh(A)\n",
    "\n",
    "    eigvals = np.where(eigvals < eps, eps, eigvals)\n",
    "\n",
    "    D_inv_sqrt = np.diag(1.0 / np.sqrt(eigvals))\n",
    "    A_inv_sqrt = eigvecs @ D_inv_sqrt @ eigvecs.T\n",
    "    return A_inv_sqrt\n",
    "\n",
    "def inv_matrix(A, eps=1e-10):\n",
    "    eigvals, eigvecs = np.linalg.eigh(A)\n",
    "\n",
    "    eigvals = np.where(eigvals < eps, eps, eigvals)\n",
    "\n",
    "    D_inv_sqrt = np.diag(1.0 / eigvals)\n",
    "    A_inv = eigvecs @ D_inv_sqrt @ eigvecs.T\n",
    "    return A_inv\n",
    "    \n",
    "def covariance(X):\n",
    "    X = X - np.mean(X, axis=0)\n",
    "    n_samples = X.shape[0]\n",
    "    cov_matrix = (X.T @ X) / (n_samples - 1)\n",
    "    return cov_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38789491-9434-4101-b846-02b8935a6009",
   "metadata": {},
   "outputs": [],
   "source": [
    "def low_rank_approximation(A, r):\n",
    "    U, S, Vt = np.linalg.svd(A, full_matrices=False)\n",
    "    S_r = np.diag(S[:r])\n",
    "    U_r = U[:, :r]\n",
    "    Vt_r = Vt[:r, :]\n",
    "    \n",
    "    A_r = U_r @ S_r @ Vt_r\n",
    "    return A_r\n",
    "\n",
    "def eig_decompose_sorted_filtered(M, tol=1e-20):\n",
    "    eigvals, eigvecs = np.linalg.eig(M)\n",
    "    eigvals = np.abs(np.real(eigvals))\n",
    "    eigvecs = np.real(eigvecs)\n",
    "    \n",
    "    valid_idx = eigvals > tol\n",
    "    eigvals = eigvals[valid_idx]\n",
    "    eigvecs = eigvecs[:, valid_idx]\n",
    "    \n",
    "    sorted_idx = np.argsort(-eigvals)\n",
    "    return eigvals[sorted_idx], eigvecs[:, sorted_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4074828d-7453-4463-b30b-31f1533d8c04",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.covariance import LedoitWolf\n",
    "\n",
    "def MSAVE(X, Y):\n",
    "    n, p = X.shape\n",
    "    d = Y.shape[1]\n",
    "    M_total = np.zeros((p, p))\n",
    "    n_slices = p+1\n",
    "    \n",
    "    CovX = covariance(X)\n",
    "    X_scaled = X @ inv_sqrt_matrix(CovX)\n",
    "\n",
    "    kmeans = KMeans(n_clusters=n_slices).fit(Y)\n",
    "    labels = kmeans.labels_\n",
    "\n",
    "    for j in range(n_slices):\n",
    "        idx = (labels == j)\n",
    "        if np.sum(idx) < 2:\n",
    "            continue\n",
    "        Xj = X_scaled[idx]\n",
    "        cov_j = np.cov(Xj, rowvar=False)\n",
    "        delta = np.eye(p) - cov_j\n",
    "        M_total += (np.sum(idx) / n) * delta @ delta\n",
    "    return M_total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe74a10f-7035-486a-b89f-cfa3f9c6e2a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from numpy.linalg import svd, det\n",
    "from sklearn.utils import resample\n",
    "\n",
    "def ladle_estimator(M_Y, X, Y):\n",
    "    n, p = X.shape\n",
    "    r = p-1   \n",
    "    D_Y, U_Y = eig_decompose_sorted_filtered(M_Y)\n",
    "    Bhat_Y = U_Y[:, :(r+1)]\n",
    "    lam_Y = D_Y[:r]\n",
    "\n",
    "    fn_Y = np.zeros(r)\n",
    "    n_boot = 30\n",
    "\n",
    "    for _ in range(n_boot):\n",
    "        Xs, Ys = resample(X, Y, replace=True, n_samples = np.min([n,5000]))\n",
    "        M_Ys = MSAVE(Xs, Ys)\n",
    "        _, U_star = eig_decompose_sorted_filtered(M_Ys)\n",
    "        Bstar_Y = U_star[:, :(r+1)]\n",
    "\n",
    "        for i in range(0, r):\n",
    "            MM = Bstar_Y[:, :(i+1)].T @ Bhat_Y[:, :(i+1)]\n",
    "            _, d_MM, _ = svd(MM)\n",
    "            fn_Y[i] += np.min(d_MM)\n",
    "           \n",
    "    fn_Y /= n_boot\n",
    "    phi0_Y = lam_Y / (1 + np.sum(lam_Y))\n",
    "    f0_Y = fn_Y / (1 + np.sum(fn_Y))\n",
    "    d_Y = np.argmin(f0_Y + phi0_Y)\n",
    "    return d_Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "deed9415-2526-40ec-8aae-e80f898a7340",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5590969d-7079-46e8-931f-ca15ae585809",
   "metadata": {},
   "outputs": [],
   "source": [
    "M_Y = MSAVE(X_train, Y_train_new)\n",
    "U_Y, D_Y, _ = svd(M_Y)\n",
    "D_Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7465102b-b5df-4ddc-91a3-c918e172eb6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_Y = ladle_estimator(M_Y, X_train, Y_train_new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d289f624-0248-4760-910c-4f6c61ec7643",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2381d30b-07f2-4ac5-b407-366254042ee6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fa22a9d-f8a6-47c4-ab0d-c9f243f50048",
   "metadata": {},
   "outputs": [],
   "source": [
    "M_Z = MSAVE(X_train, Z_train_new)\n",
    "U_Z, D_Z, _ = svd(M_Z)\n",
    "D_Z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cea49f83-fbca-43d3-8c67-086517e4002d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_Z = ladle_estimator(M_Z, X_train, Z_train_new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d933d64b-ddae-405b-b9ef-11a6778cdd9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_Z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e940d43-65ec-4871-a3a3-780375631640",
   "metadata": {},
   "outputs": [],
   "source": [
    "CovX = covariance(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1dd3a4c-9db2-423c-accf-7ab24d3fb08b",
   "metadata": {},
   "outputs": [],
   "source": [
    "p = X_train.shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5c4ba29-3e51-4a55-89ec-8402980f18b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecbeeaca-6e1a-4b06-ae66-763113ace27b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def projection(V):\n",
    "    return V @ inv_matrix(V.T @ V) @ V.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c4caaec-9810-4326-a735-0a23e4ffb153",
   "metadata": {},
   "outputs": [],
   "source": [
    "M_Z_r = low_rank_approximation(M_Z, dim_Z)\n",
    "S_Z, V_Z = eig_decompose_sorted_filtered(M_Z)\n",
    "V_Z = V_Z[:,:dim_Z]\n",
    "#P_Z = projection(inv_sqrt_matrix(CovX) @ V_Z)\n",
    "P_Z = projection(V_Z)\n",
    "Q_Z = np.eye(p) - P_Z\n",
    "Q_Z_r = low_rank_approximation(Q_Z, p-dim_Z)\n",
    "S_Z_c, V_Z_c = eig_decompose_sorted_filtered(Q_Z_r)\n",
    "V_Z_c = V_Z_c[:,:p-dim_Z]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "424a90e7-8cae-4a9f-bcab-e4c03522ff43",
   "metadata": {},
   "outputs": [],
   "source": [
    "M_Y_only = MSAVE(X_train @ V_Z_c, Y_train_new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "790809f0-e8fb-4098-983e-208b362515c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_Y_only = ladle_estimator(M_Y_only, X_train @ V_Z_c, Z_train_new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ab5e577-0f8b-42b2-a39b-195494d211b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "M_Y_only_r = low_rank_approximation(M_Y_only, dim_Y_only)\n",
    "S_Y_only, V_Y_only = eig_decompose_sorted_filtered(M_Y_only_r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e365a9a8-0be2-4711-902c-97853868bf83",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a22f0b9-a3ac-4218-a58c-4d012c3be913",
   "metadata": {},
   "outputs": [],
   "source": [
    "direction_Y_only = V_Z_c @ V_Y_only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a799f69-2431-415c-a7ad-63fb2fc297f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_Y_only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7fdc868-32f7-4747-b3cd-8d43609b8f60",
   "metadata": {},
   "outputs": [],
   "source": [
    "S_Y_only_after, V_Y_only_after = eig_decompose_sorted_filtered(direction_Y_only @ direction_Y_only.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "486badad-83f4-45b3-9956-1222429fb88f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "296f1998-05ce-48a4-bbce-0632e077760c",
   "metadata": {},
   "outputs": [],
   "source": [
    "M_YZ = MSAVE(X_train @ V_Z, Y_train_new)\n",
    "dim_YZ = ladle_estimator(M_YZ, X_train @ V_Z, Z_train_new)\n",
    "M_YZ_r = low_rank_approximation(M_YZ, dim_YZ)\n",
    "S_YZ, V_YZ = eig_decompose_sorted_filtered(M_YZ_r)\n",
    "V_YZ = V_YZ[:,:dim_YZ]\n",
    "direction_YZ = V_Z @ V_YZ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52f1cfb4-cb80-4240-ac92-03f44d7d5ede",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_YZ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05ca619b-aacb-4a01-8e0a-d8511ce12792",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "results = []\n",
    "\n",
    "for tt in range(1, dim_Y_only + dim_YZ + 1):\n",
    "    row = {}\n",
    "    if tt < dim_Y_only + 1:\n",
    "        print(f\"Include: {tt} only Y dimension and 0 shared dimension\")\n",
    "        Directions = projection(V_Y_only_after[:, (dim_Y_only - tt):(dim_Y_only + 1)])\n",
    "        row['Y_only_dim'] = tt\n",
    "        row['YZ_shared_dim'] = 0\n",
    "    else:\n",
    "        shared_dim = tt - dim_Y_only\n",
    "        print(f\"Include: {shared_dim} shared dimension\")\n",
    "        Directions = projection(V_Y_only_after[:, :dim_Y_only]) + projection(direction_YZ[:, (dim_YZ-shared_dim):(dim_YZ+1)])\n",
    "        row['Y_only_dim'] = dim_Y_only\n",
    "        row['YZ_shared_dim'] = shared_dim\n",
    "\n",
    "    # Train Y classifier\n",
    "    clf_Y_after = LogisticRegression(solver=\"saga\", multi_class='multinomial', max_iter=20)\n",
    "    clf_Y_after.fit(X_train @ Directions, Y_train)\n",
    "    Y_pred_after = clf_Y_after.predict(X_test @ Directions)\n",
    "\n",
    "    # Accuracy and fairness metrics\n",
    "    acc = clf_Y_after.score(X_test @ Directions, Y_test)\n",
    "    dp_gap = DP_metric(Y_pred_after, Z_test)\n",
    "    mcdp_gap = MCDP_metric(clf_Y_after.predict_proba(X_test @ Directions), Y_test, Z_test, num_bins=100)\n",
    "\n",
    "    row['Y_acc'] = acc\n",
    "    row['DP_gap'] = dp_gap\n",
    "    row['MCDP_gap'] = mcdp_gap\n",
    "\n",
    "    print(f\"After Accuracy: {acc}\")\n",
    "    print(f\"After DP Gap: {dp_gap}\")\n",
    "    print(f\"After MCDP Gap: {mcdp_gap}\")\n",
    "\n",
    "    # Train Z classifier\n",
    "    clf_Z_after = LogisticRegression(warm_start=True, penalty='l2', solver=\"saga\", multi_class='multinomial', max_iter=20)\n",
    "    clf_Z_after.fit(X_train @ Directions, Z_train)\n",
    "    z_acc = clf_Z_after.score(X_test @ Directions, Z_test)\n",
    "    row['Z_acc'] = z_acc\n",
    "    print(f\"Z accuracy: {z_acc}\")\n",
    "\n",
    "    results.append(row)\n",
    "\n",
    "# Save to CSV\n",
    "results_df = pd.DataFrame(results)\n",
    "results_df.to_csv(\"Bank_results.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ceb38c7-13eb-42d9-9d9f-edc80aca9c45",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a568720-95f2-4344-b2d6-a9cc06c050e0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd27b779-d3ca-4c5d-b8fb-e2bd3963fb73",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "540eeb26-132b-469c-951c-e5a2928fc341",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97fc28fc-050e-4f84-9e03-c69c93f099b0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9141650-775e-4770-ab1d-a1fca4cf7f9b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e993908-d429-48a3-9601-b6a68285b6e6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
