{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import os\n",
                "import warnings\n",
                "warnings.filterwarnings('ignore')\n",
                "\n",
                "import numpy as np\n",
                "import pandas as pd\n",
                "import seaborn as sns\n",
                "\n",
                "sns.set_theme(palette=\"colorblind\", style='whitegrid', font_scale=1.25)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "METRIC_MAP = {\n",
                "    \"Demographic_Parity\": \"Dem_Parity\",\n",
                "    \"Equal_Opportunity\": \"Eq_Opp\", \n",
                "    \"Predictive_Equality\": \"Pred_Eq\",\n",
                "}"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "result_files = [os.path.join(dp, f) for dp, _, fn in os.walk(os.path.expanduser(\"fairness_trials\")) for f in fn]\n",
                "\n",
                "all_results = []\n",
                "for rfile in result_files:\n",
                "    arr = rfile.split(\"/\")[1:]\n",
                "    dataset_sens_attrs, client_type, _, metric, cp_method = arr\n",
                "    if \"Fitzpatrick\" not in dataset_sens_attrs and client_type != \"8_clients\":\n",
                "        continue\n",
                "    df = pd.read_csv(rfile, header=0)\n",
                "    df = df.dropna()\n",
                "\n",
                "    df['c'] = df[\"c\"].astype(str)\n",
                "    df = df[df['c'] != '0.05'] # remove c=0.05\n",
                "    df = df[df['sketch_method'] != 'ddsketch'] # remove ddsketch runs\n",
                "    df[\"client_type\"] = client_type\n",
                "    df[\"dataset_sens_attrs\"] = dataset_sens_attrs\n",
                "    df[\"cp_method\"] = cp_method.split(\".\")[0]\n",
                "    df[\"fairness_metric\"] = METRIC_MAP[metric] if metric in METRIC_MAP else metric\n",
                "\n",
                "    formulation_columns = df.filter(regex=r'^client_formulation_\\d+$', axis='columns')\n",
                "\n",
                "    # Count the number of 0s and 1s in each of these columns\n",
                "    count_zeros = (formulation_columns == 0).sum(axis=1)\n",
                "    count_ones = (formulation_columns == 1).sum(axis=1)\n",
                "\n",
                "    df[\"num_client_formulation_0\"] = count_zeros\n",
                "    df[\"num_client_formulation_1\"] = count_ones\n",
                "\n",
                "    all_results.append(df)\n",
                "\n",
                "res_df = pd.concat(all_results, ignore_index=True)\n",
                "\n",
                "res_df = res_df.drop(columns=[\"Unnamed: 0\"])\n",
                "\n",
                "res_df[\"efficiency\"] = res_df[\"efficiency\"].apply(lambda x: round(float(x[len(\"tensor(\"):-1]), 3))\n",
                "res_df[\"coverage\"] = res_df[\"coverage\"].apply(lambda x: round(float(x[len(\"tensor(\"):-1]), 3))\n",
                "res_df[\"violation\"] = res_df[\"violation\"].apply(lambda x: round(float(x[len(\"tensor(\"):-1]), 3))\n",
                "\n",
                "res_df[\"base_efficiency\"] = res_df[\"base_efficiency\"].apply(lambda x: round(float(x[len(\"tensor(\"):-1]), 3))\n",
                "res_df[\"base_coverage\"] = res_df[\"base_coverage\"].apply(lambda x: round(float(x[len(\"tensor(\"):-1]), 3))\n",
                "res_df[\"base_violation\"] = res_df[\"base_violation\"].apply(lambda x: round(float(x[len(\"tensor(\"):-1]), 3))\n",
                "\n",
                "our_res_df = res_df[['c', 'efficiency', 'coverage',\n",
                "       'violation', 'dataset_sens_attrs', 'client_type', 'use_mle', 'cp_method',\n",
                "       'fairness_metric','num_client_formulation_0', 'num_client_formulation_1']]\n",
                "\n",
                "our_res_df[\"type\"] = \"ours\"\n",
                "\n",
                "base_res_df = res_df[['c', 'base_efficiency', 'base_coverage',\n",
                "       'base_violation', 'dataset_sens_attrs', 'client_type', 'use_mle', 'cp_method',\n",
                "       'fairness_metric','num_client_formulation_0', 'num_client_formulation_1']].rename(columns={'base_efficiency': 'efficiency', 'base_coverage': 'coverage', 'base_violation': 'violation'})\n",
                "base_res_df[\"type\"] = \"base\"\n",
                "\n",
                "res_df = pd.concat([our_res_df,base_res_df],ignore_index=True)\n",
                "\n",
                "\n",
                "grouped_res_df = res_df.groupby((['c', 'dataset_sens_attrs', 'client_type', 'use_mle', 'cp_method', 'fairness_metric', 'type', 'num_client_formulation_0', 'num_client_formulation_1']))\n",
                "\n",
                "mean_res_df = grouped_res_df.mean(numeric_only=True).reset_index()\n",
                "\n",
                "res_df = mean_res_df[['c', 'dataset_sens_attrs', 'client_type', 'use_mle', 'cp_method', 'fairness_metric', 'type', 'num_client_formulation_0', 'num_client_formulation_1']]\n",
                "\n",
                "res_df[\"eff_mean\"] = mean_res_df[\"efficiency\"].round(4)\n",
                "\n",
                "res_df[\"violation_mean\"] = mean_res_df[\"violation\"].round(4)\n",
                "\n",
                "res_df[\"formulation\"] = np.where(\n",
                "    res_df[\"num_client_formulation_0\"] == 0, \"Enhanced_Privacy\", \n",
                "    np.where(res_df[\"num_client_formulation_1\"] == 0, \"Communication_Efficient\", \"Hybrid\")\n",
                ")\n",
                "res_df[\"form_split\"] = res_df[\"num_client_formulation_0\"].astype(str) + \",\" + res_df[\"num_client_formulation_1\"].astype(str)\n",
                "res_df = res_df.drop(columns=[\"num_client_formulation_0\", \"num_client_formulation_1\"])\n",
                "\n",
                "res_df"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "res_df[(res_df[\"dataset_sens_attrs\"] == \"Fitzpatrick\") & (res_df[\"fairness_metric\"] == \"Pred_Eq\") & (res_df[\"cp_method\"] == \"aps\") & (res_df[\"client_type\"] == \"8_clients\") & (res_df[\"use_mle\"] == False) & (res_df[\"c\"] == \"0.1\")]"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import os\n",
                "os.makedirs(\"./figures\", exist_ok=True)\n",
                "os.makedirs(\"./processed_trials\", exist_ok=True)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "melted_res_df = res_df.melt([\"c\", \"dataset_sens_attrs\", \"client_type\", \"use_mle\", \"cp_method\", \"fairness_metric\", \"type\", \"formulation\"], var_name=\"stat_type\", value_name=\"stats\")\n",
                "for dataset_sens_attrs in res_df[\"dataset_sens_attrs\"].unique():\n",
                "    for use_mle in [False, True]:\n",
                "        df = melted_res_df[(melted_res_df[\"dataset_sens_attrs\"] == dataset_sens_attrs) & (melted_res_df[\"use_mle\"] == use_mle)]\n",
                "        df[\"index\"] = f\"{dataset_sens_attrs}_{use_mle}\"\n",
                "        results_table = df.groupby([\"index\", 'c', \"fairness_metric\", \"client_type\", \"cp_method\", \"type\",\"formulation\", \"stat_type\"])[\"stats\"].apply(lambda x: x.values[0]).reset_index().pivot(index=[\"index\", \"c\", \"fairness_metric\"], columns=[\"formulation\",\"client_type\", \"cp_method\", \"type\", \"stat_type\"], values=[\"stats\"]).droplevel(0, axis=0)\n",
                "        results_table.to_excel(f\"./processed_trials/{dataset_sens_attrs}_{use_mle}.xlsx\")\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import matplotlib.pyplot as plt\n",
                "import seaborn as sns\n",
                "\n",
                "METRICS = list(map(lambda x: METRIC_MAP[x], [\"Demographic_Parity\", \"Equal_Opportunity\", \"Predictive_Equality\"]))\n",
                "\n",
                "C = [0.1, 0.15, 0.2]\n",
                "for dataset_sens_attrs in res_df[\"dataset_sens_attrs\"].unique():\n",
                "    for client_type in res_df[res_df[\"dataset_sens_attrs\"] == dataset_sens_attrs][\"client_type\"].unique():\n",
                "        for use_mle in [False, True]:\n",
                "            df = res_df[(res_df[\"dataset_sens_attrs\"] == dataset_sens_attrs) & (res_df[\"use_mle\"] == use_mle) & (res_df[\"client_type\"] == client_type)]\n",
                "            \n",
                "            eff_lim = 4\n",
                "            if \"ACSEducation\" in dataset_sens_attrs:\n",
                "                eff_lim = 6\n",
                "            elif \"Fitzpatrick\" in dataset_sens_attrs:\n",
                "                eff_lim = 9\n",
                "            else:\n",
                "                eff_lim = 4\n",
                "            \n",
                "            if \"Pokec\" in dataset_sens_attrs and not use_mle:\n",
                "                CP_METHODS = [\"aps\", \"daps\", \"raps_rand_no_mod\"]\n",
                "            else:\n",
                "                CP_METHODS = [\"aps\", \"raps_rand_no_mod\"]\n",
                "\n",
                "            filtered_df = df[(df[\"fairness_metric\"].isin(METRICS)) & (df[\"cp_method\"].isin(CP_METHODS))]\n",
                "            for formulation in filtered_df[\"formulation\"].dropna().unique():\n",
                "                fdf = filtered_df[filtered_df[\"formulation\"] == formulation].copy()\n",
                "                if fdf.empty:\n",
                "                    continue\n",
                "\n",
                "                # Plot efficiency\n",
                "                grid = sns.catplot(\n",
                "                    data=fdf[fdf[\"type\"] == \"ours\"],\n",
                "                    x=\"fairness_metric\",\n",
                "                    y=\"eff_mean\",\n",
                "                    hue=\"c\",\n",
                "                    col=\"cp_method\",\n",
                "                    kind=\"bar\",\n",
                "                    order=METRICS,\n",
                "                    ci=None,\n",
                "                    sharex=False,\n",
                "                )\n",
                "\n",
                "                grid.set(xlabel=\"\")\n",
                "                grid.set(ylabel=\"Efficiency\")\n",
                "                grid.set(ylim=(0, eff_lim))\n",
                "                for key, ax in grid.axes_dict.items():\n",
                "                  ax.set_title(key.split(\"_\")[0].upper())\n",
                "                sns.move_legend(grid, \"upper center\", bbox_to_anchor=(0.5, 1.1), ncol=len(C))\n",
                "                grid.legend.set_title(\"Closeness Criterion:\")\n",
                "\n",
                "                # Plot violation\n",
                "                grid2 = sns.catplot(\n",
                "                    data=fdf[fdf[\"type\"] == \"ours\"],\n",
                "                    x=\"fairness_metric\",\n",
                "                    y=\"violation_mean\",\n",
                "                    hue=\"c\",\n",
                "                    col=\"cp_method\",\n",
                "                    kind=\"bar\",\n",
                "                    order=METRICS,\n",
                "                    ci=None,\n",
                "                    sharex=False,\n",
                "                )\n",
                "                grid2.set(xlabel=\"\")\n",
                "                grid2.set(ylabel=\"Actual Fairness Disparity\")\n",
                "\n",
                "                grid2.set(ylim=(0, 0.5), yticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5])\n",
                "                for key, ax in grid2.axes_dict.items():\n",
                "                  ax.set_title(key.split(\"_\")[0].upper())\n",
                "                sns.move_legend(grid2, \"upper center\", bbox_to_anchor=(0.5, 1.1), ncol=len(C))\n",
                "                grid2.legend.set_title(\"Closeness Criterion:\")\n",
                "\n",
                "                # Add the base FCP method lines and points\n",
                "\n",
                "                axes_eff = grid.axes.flatten() if isinstance(grid.axes, np.ndarray) else [grid.axes]\n",
                "                axes_viol = grid2.axes.flatten() if isinstance(grid2.axes, np.ndarray) else [grid2.axes]\n",
                "                print(dataset_sens_attrs, client_type, use_mle, formulation)\n",
                "                axes_idx = 0\n",
                "                for i in range(len(CP_METHODS)):\n",
                "                    base_cp_df = fdf[(fdf[\"type\"] == \"base\") & (fdf[\"cp_method\"] == CP_METHODS[i])]\n",
                "                    if base_cp_df.empty:\n",
                "                        continue\n",
                "                    \n",
                "                    # Efficiency base line\n",
                "                    ax_eff = axes_eff[axes_idx]\n",
                "                    patches = ax_eff.patches\n",
                "\n",
                "                    for ix, a in enumerate(patches):\n",
                "                        x_start = a.get_x()\n",
                "                        width = a.get_width()\n",
                "                        if width == 0: \n",
                "                            continue\n",
                "\n",
                "                        eff_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)]) & (base_cp_df[\"client_type\"] == client_type)][\"eff_mean\"].mean()\n",
                "                        ax_eff.plot([x_start, x_start + width], [eff_mean] * 2, '-', c='k')\n",
                "\n",
                "                    # Violation base line and marker\n",
                "                    ax_viol = axes_viol[axes_idx]\n",
                "                    patches = ax_viol.patches\n",
                "\n",
                "                    for ix, a in enumerate(patches):\n",
                "                        x_start = a.get_x()\n",
                "                        width = a.get_width()\n",
                "                        if width == 0: \n",
                "                            continue\n",
                "\n",
                "                        violation_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)]) & (base_cp_df[\"client_type\"] == client_type)][\"violation_mean\"].mean()\n",
                "                        ax_viol.plot([x_start, x_start + width], [violation_mean] * 2, '--', c='k')\n",
                "                        ax_viol.plot(x_start + width / 2, float(C[ix // len(METRICS)]), 'o', c='k', markersize=3.5)\n",
                "                    \n",
                "                    axes_idx += 1\n",
                "                    \n",
                "                plt.tight_layout()\n",
                "                form_slug = str(formulation).replace(\" \", \"_\")\n",
                "                # Save the figure\n",
                "                grid.savefig(f\"./figures/{str(dataset_sens_attrs).strip('_')}_{client_type}_{use_mle}_{form_slug}_efficiency.pdf\", bbox_inches='tight')\n",
                "                grid2.savefig(f\"./figures/{str(dataset_sens_attrs).strip('_')}_{client_type}_{use_mle}_{form_slug}_violation.pdf\", bbox_inches='tight')\n",
                "\n",
                "                plt.show() \n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import matplotlib.pyplot as plt\n",
                "import seaborn as sns\n",
                "\n",
                "METRICS = list(map(lambda x: METRIC_MAP[x], [\"Demographic_Parity\", \"Predictive_Equality\"]))\n",
                "\n",
                "C = [0.1, 0.15, 0.2]\n",
                "CP_METHODS = [\"raps_rand_no_mod\"]\n",
                "CLIENT_TYPES = [\"1_clients\", \"2_clients\", \"4_clients\", \"8_clients\"]\n",
                "eff_lim = 9\n",
                "formulation = \"Communication_Efficient\"\n",
                "use_mle = False\n",
                "fitz_df = res_df[\n",
                "    (res_df[\"dataset_sens_attrs\"] == \"Fitzpatrick\") & (res_df[\"cp_method\"].isin(CP_METHODS)) & (res_df[\"client_type\"].isin(CLIENT_TYPES))\n",
                "    & (res_df[\"formulation\"] == formulation) & (res_df[\"fairness_metric\"].isin(METRICS))\n",
                "    ]\n",
                "\n",
                "fitz_df\n",
                "\n",
                "## Combine eff_mean and violation_mean into a single string \"eff / viol\"\n",
                "fitz_df[\"eff/viol\"] = fitz_df.apply(lambda x: f\"{x['eff_mean']:.3f} / {x['violation_mean']:.3f}\", axis=1)\n",
                "\n",
                "display(fitz_df)\n",
                "\n",
                "for use_mle in [False, True]:\n",
                "    df = fitz_df[(fitz_df[\"use_mle\"] == use_mle)]\n",
                "\n",
                "    filtered_df = df[(df[\"fairness_metric\"].isin(METRICS)) & (df[\"cp_method\"].isin(CP_METHODS))]\n",
                "    fdf = filtered_df[filtered_df[\"formulation\"] == formulation].copy()\n",
                "    if fdf.empty:\n",
                "        continue\n",
                "\n",
                "    # Plot efficiency\n",
                "    grid = sns.catplot(\n",
                "        data=fdf[fdf[\"type\"] == \"ours\"],\n",
                "        x=\"fairness_metric\",\n",
                "        y=\"eff_mean\",\n",
                "        hue=\"c\",\n",
                "        col=\"client_type\",\n",
                "        kind=\"bar\",\n",
                "        order=METRICS,\n",
                "        ci=None,\n",
                "        sharex=False,\n",
                "    )\n",
                "\n",
                "    grid.set(xlabel=\"\")\n",
                "    grid.set(ylabel=\"Efficiency\")\n",
                "    grid.set(ylim=(0, eff_lim))\n",
                "    for key, ax in grid.axes_dict.items():\n",
                "      ax.set_title(key.split(\"_\")[0].upper())\n",
                "    sns.move_legend(grid, \"upper center\", bbox_to_anchor=(0.5, 1.1), ncol=len(C))\n",
                "    grid.legend.set_title(\"Closeness Criterion:\")\n",
                "\n",
                "    # Plot violation\n",
                "    grid2 = sns.catplot(\n",
                "        data=fdf[fdf[\"type\"] == \"ours\"],\n",
                "        x=\"fairness_metric\",\n",
                "        y=\"violation_mean\",\n",
                "        hue=\"c\",\n",
                "        col=\"client_type\",\n",
                "        kind=\"bar\",\n",
                "        order=METRICS,\n",
                "        ci=None,\n",
                "        sharex=False,\n",
                "    )\n",
                "    grid2.set(xlabel=\"\")\n",
                "    grid2.set(ylabel=\"Actual Fairness Disparity\")\n",
                "\n",
                "    grid2.set(ylim=(0, 0.5), yticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5])\n",
                "    for key, ax in grid2.axes_dict.items():\n",
                "      ax.set_title(key.split(\"_\")[0].upper())\n",
                "    sns.move_legend(grid2, \"upper center\", bbox_to_anchor=(0.5, 1.1), ncol=len(C))\n",
                "    grid2.legend.set_title(\"Closeness Criterion:\")\n",
                "\n",
                "    # Add the base FCP method lines and points\n",
                "\n",
                "    axes_eff = grid.axes.flatten() if isinstance(grid.axes, np.ndarray) else [grid.axes]\n",
                "    axes_viol = grid2.axes.flatten() if isinstance(grid2.axes, np.ndarray) else [grid2.axes]\n",
                "    print(\"Fitzpatrick comparison\", use_mle, formulation)\n",
                "    axes_idx = 0\n",
                "    for i in range(len(CLIENT_TYPES)):\n",
                "        base_cp_df = fdf[(fdf[\"type\"] == \"base\") & (fdf[\"cp_method\"] == CP_METHODS[0])]\n",
                "        if base_cp_df.empty:\n",
                "            continue\n",
                "        \n",
                "        # Efficiency base line\n",
                "        ax_eff = axes_eff[axes_idx]\n",
                "        patches = ax_eff.patches\n",
                "\n",
                "        for ix, a in enumerate(patches):\n",
                "            x_start = a.get_x()\n",
                "            width = a.get_width()\n",
                "            if width == 0: \n",
                "                continue\n",
                "\n",
                "            eff_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)]) & (base_cp_df[\"client_type\"] == CLIENT_TYPES[i])][\"eff_mean\"].mean()\n",
                "            ax_eff.plot([x_start, x_start + width], [eff_mean] * 2, '-', c='k')\n",
                "\n",
                "        # Violation base line and marker\n",
                "        ax_viol = axes_viol[axes_idx]\n",
                "        patches = ax_viol.patches\n",
                "\n",
                "        for ix, a in enumerate(patches):\n",
                "            x_start = a.get_x()\n",
                "            width = a.get_width()\n",
                "            if width == 0: \n",
                "                continue\n",
                "\n",
                "            violation_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)]) & (base_cp_df[\"client_type\"] == client_type)][\"violation_mean\"].mean()\n",
                "            ax_viol.plot([x_start, x_start + width], [violation_mean] * 2, '--', c='k')\n",
                "            ax_viol.plot(x_start + width / 2, float(C[ix // len(METRICS)]), 'o', c='k', markersize=3.5)\n",
                "        \n",
                "        axes_idx += 1\n",
                "        \n",
                "    plt.tight_layout()\n",
                "    form_slug = str(formulation).replace(\" \", \"_\")\n",
                "    # Save the figure\n",
                "    grid.savefig(f\"./figures/Fitzpatrick_client_comparison_{use_mle}_{form_slug}_efficiency.pdf\", bbox_inches='tight')\n",
                "    grid2.savefig(f\"./figures/Fitzpatrick_client_comparison_{use_mle}_{form_slug}_violation.pdf\", bbox_inches='tight')\n",
                "\n",
                "    plt.show() \n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import matplotlib.pyplot as plt\n",
                "import seaborn as sns\n",
                "\n",
                "METRICS = list(map(lambda x: METRIC_MAP[x], [\"Demographic_Parity\", \"Equal_Opportunity\", \"Predictive_Equality\"]))\n",
                "\n",
                "C = [0.1, 0.15, 0.2]\n",
                "cp_method = \"raps_rand_no_mod\"\n",
                "client_type = \"continental_all\"\n",
                "eff_lim = 6\n",
                "formulation = \"Communication_Efficient\"\n",
                "USE_MLE = [False, True]\n",
                "    \n",
                "fdf = res_df[(res_df[\"dataset_sens_attrs\"] == \"ACSEducation\") & (res_df[\"fairness_metric\"].isin(METRICS)) & (res_df[\"cp_method\"] == cp_method) & (res_df[\"client_type\"] == client_type) & (res_df[\"formulation\"] == formulation)]\n",
                "\n",
                "# Plot efficiency\n",
                "grid = sns.catplot(\n",
                "    data=fdf[fdf[\"type\"] == \"ours\"],\n",
                "    x=\"fairness_metric\",\n",
                "    y=\"eff_mean\",\n",
                "    hue=\"c\",\n",
                "    col=\"use_mle\",\n",
                "    kind=\"bar\",\n",
                "    order=METRICS,\n",
                "    ci=None,\n",
                "    sharex=False,\n",
                ")\n",
                "\n",
                "grid.set(xlabel=\"\")\n",
                "grid.set(ylabel=\"Efficiency\")\n",
                "grid.set(ylim=(0, eff_lim))\n",
                "for key, ax in grid.axes_dict.items():\n",
                "  ax.set_title(f\"MLE: {key}\")\n",
                "sns.move_legend(grid, \"upper center\", bbox_to_anchor=(0.5, 1.1), ncol=len(C))\n",
                "grid.legend.set_title(\"Closeness Criterion:\")\n",
                "\n",
                "# Plot violation\n",
                "grid2 = sns.catplot(\n",
                "    data=fdf[fdf[\"type\"] == \"ours\"],\n",
                "    x=\"fairness_metric\",\n",
                "    y=\"violation_mean\",\n",
                "    hue=\"c\",\n",
                "    col=\"use_mle\",\n",
                "    kind=\"bar\",\n",
                "    order=METRICS,\n",
                "    ci=None,\n",
                "    sharex=False,\n",
                ")\n",
                "grid2.set(xlabel=\"\")\n",
                "grid2.set(ylabel=\"Actual Fairness Disparity\")\n",
                "\n",
                "grid2.set(ylim=(0, 0.5), yticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5])\n",
                "for key, ax in grid2.axes_dict.items():\n",
                "  ax.set_title(f\"MLE: {key}\")\n",
                "sns.move_legend(grid2, \"upper center\", bbox_to_anchor=(0.5, 1.1), ncol=len(C))\n",
                "grid2.legend.set_title(\"Closeness Criterion:\")\n",
                "\n",
                "# Add the base FCP method lines and points\n",
                "\n",
                "axes_eff = grid.axes.flatten() if isinstance(grid.axes, np.ndarray) else [grid.axes]\n",
                "axes_viol = grid2.axes.flatten() if isinstance(grid2.axes, np.ndarray) else [grid2.axes]\n",
                "print(\"ACSEducation MLE comparison\", use_mle, formulation)\n",
                "axes_idx = 0\n",
                "for i in range(len(USE_MLE)):\n",
                "    base_cp_df = fdf[(fdf[\"type\"] == \"base\")]\n",
                "    if base_cp_df.empty:\n",
                "        continue\n",
                "    \n",
                "    # Efficiency base line\n",
                "    ax_eff = axes_eff[axes_idx]\n",
                "    patches = ax_eff.patches\n",
                "\n",
                "    for ix, a in enumerate(patches):\n",
                "        x_start = a.get_x()\n",
                "        width = a.get_width()\n",
                "        if width == 0: \n",
                "            continue\n",
                "\n",
                "        eff_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)]) & (base_cp_df[\"client_type\"] == client_type) &  (base_cp_df[\"use_mle\"] == USE_MLE[i])][\"eff_mean\"].mean()\n",
                "        ax_eff.plot([x_start, x_start + width], [eff_mean] * 2, '-', c='k')\n",
                "\n",
                "    # Violation base line and marker\n",
                "    ax_viol = axes_viol[axes_idx]\n",
                "    patches = ax_viol.patches\n",
                "\n",
                "    for ix, a in enumerate(patches):\n",
                "        x_start = a.get_x()\n",
                "        width = a.get_width()\n",
                "        if width == 0: \n",
                "            continue\n",
                "\n",
                "        violation_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)]) & (base_cp_df[\"client_type\"] == client_type) & (base_cp_df[\"use_mle\"] == USE_MLE[i])][\"violation_mean\"].mean()\n",
                "        ax_viol.plot([x_start, x_start + width], [violation_mean] * 2, '--', c='k')\n",
                "        ax_viol.plot(x_start + width / 2, float(C[ix // len(METRICS)]), 'o', c='k', markersize=3.5)\n",
                "    \n",
                "    axes_idx += 1\n",
                "    \n",
                "plt.tight_layout()\n",
                "form_slug = str(formulation).replace(\" \", \"_\")\n",
                "# Save the figure\n",
                "grid.savefig(f\"./figures/ACSEducation_mle_comparison_{form_slug}_efficiency.pdf\", bbox_inches='tight')\n",
                "grid2.savefig(f\"./figures/ACSEducation_mle_comparison_{form_slug}_violation.pdf\", bbox_inches='tight')\n",
                "\n",
                "plt.show() \n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import matplotlib.pyplot as plt\n",
                "import seaborn as sns\n",
                "\n",
                "METRICS = list(map(lambda x: METRIC_MAP[x], [\"Demographic_Parity\", \"Predictive_Equality\"]))\n",
                "\n",
                "C = [0.1, 0.15, 0.2]\n",
                "dataset_sens_attrs = \"Fitzpatrick\"\n",
                "cp_method = \"aps\"\n",
                "client_type = \"8_clients\"\n",
                "eff_lim = 9\n",
                "FORMULATIONS = [\"Communication_Efficient\", \"Enhanced_Privacy\", \"Hybrid\"]\n",
                "use_mle = False\n",
                "\n",
                "fdf = res_df[(res_df[\"dataset_sens_attrs\"] == dataset_sens_attrs) & (res_df[\"fairness_metric\"].isin(METRICS)) & (res_df[\"cp_method\"] == cp_method) & (res_df[\"client_type\"] == client_type) & (res_df[\"use_mle\"] == use_mle)]\n",
                "display(fdf)\n",
                "# Plot efficiency\n",
                "grid = sns.catplot(\n",
                "    data=fdf[fdf[\"type\"] == \"ours\"],\n",
                "    x=\"fairness_metric\",\n",
                "    y=\"eff_mean\",\n",
                "    hue=\"c\",\n",
                "    col=\"formulation\",\n",
                "    kind=\"bar\",\n",
                "    order=METRICS,\n",
                "    ci=None,\n",
                "    sharex=False,\n",
                ")\n",
                "\n",
                "grid.set(xlabel=\"\")\n",
                "grid.set(ylabel=\"Efficiency\")\n",
                "grid.set(ylim=(0, eff_lim))\n",
                "for key, ax in grid.axes_dict.items():\n",
                "  ax.set_title(key.upper())\n",
                "sns.move_legend(grid, \"upper center\", bbox_to_anchor=(0.5, 1.1), ncol=len(C))\n",
                "grid.legend.set_title(\"Closeness Criterion:\")\n",
                "\n",
                "# Plot violation\n",
                "grid2 = sns.catplot(\n",
                "    data=fdf[fdf[\"type\"] == \"ours\"],\n",
                "    x=\"fairness_metric\",\n",
                "    y=\"violation_mean\",\n",
                "    hue=\"c\",\n",
                "    col=\"formulation\",\n",
                "    kind=\"bar\",\n",
                "    order=METRICS,\n",
                "    ci=None,\n",
                "    sharex=False,\n",
                ")\n",
                "grid2.set(xlabel=\"\")\n",
                "grid2.set(ylabel=\"Actual Fairness Disparity\")\n",
                "\n",
                "grid2.set(ylim=(0, 0.5), yticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5])\n",
                "for key, ax in grid2.axes_dict.items():\n",
                "  ax.set_title(key.upper())\n",
                "sns.move_legend(grid2, \"upper center\", bbox_to_anchor=(0.5, 1.1), ncol=len(C))\n",
                "grid2.legend.set_title(\"Closeness Criterion:\")\n",
                "\n",
                "# Add the base FCP method lines and points\n",
                "\n",
                "axes_eff = grid.axes.flatten() if isinstance(grid.axes, np.ndarray) else [grid.axes]\n",
                "axes_viol = grid2.axes.flatten() if isinstance(grid2.axes, np.ndarray) else [grid2.axes]\n",
                "axes_idx = 0\n",
                "for i in range(len(FORMULATIONS)):\n",
                "    print(dataset_sens_attrs, \" formulation \", FORMULATIONS[i])\n",
                "\n",
                "    base_cp_df = fdf[(fdf[\"type\"] == \"base\")]\n",
                "    if base_cp_df.empty:\n",
                "        continue\n",
                "    \n",
                "    # Efficiency base line\n",
                "    ax_eff = axes_eff[axes_idx]\n",
                "    patches = ax_eff.patches\n",
                "\n",
                "    for ix, a in enumerate(patches):\n",
                "        x_start = a.get_x()\n",
                "        width = a.get_width()\n",
                "        if width == 0: \n",
                "            continue\n",
                "\n",
                "        eff_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)]) & (base_cp_df[\"client_type\"] == client_type) & (base_cp_df[\"formulation\"] == FORMULATIONS[i])][\"eff_mean\"].mean()\n",
                "        ax_eff.plot([x_start, x_start + width], [eff_mean] * 2, '-', c='k')\n",
                "\n",
                "    # Violation base line and marker\n",
                "    ax_viol = axes_viol[axes_idx]\n",
                "    patches = ax_viol.patches\n",
                "\n",
                "    for ix, a in enumerate(patches):\n",
                "        x_start = a.get_x()\n",
                "        width = a.get_width()\n",
                "        if width == 0: \n",
                "            continue\n",
                "            \n",
                "        violation_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)]) & (base_cp_df[\"client_type\"] == client_type) & (base_cp_df[\"formulation\"] == FORMULATIONS[i])][\"violation_mean\"].mean()\n",
                "        ax_viol.plot([x_start, x_start + width], [violation_mean] * 2, '--', c='k')\n",
                "        ax_viol.plot(x_start + width / 2, float(C[ix // len(METRICS)]), 'o', c='k', markersize=3.5)\n",
                "    \n",
                "    axes_idx += 1\n",
                "    \n",
                "plt.tight_layout()\n",
                "form_slug = str(FORMULATIONS[i]).replace(\" \", \"_\")\n",
                "# Save the figure\n",
                "grid.savefig(f\"./figures/{dataset_sens_attrs}_formulation_comparison_{client_type}_efficiency.pdf\", bbox_inches='tight')\n",
                "grid2.savefig(f\"./figures/{dataset_sens_attrs}_formulation_comparison_{client_type}_violation.pdf\", bbox_inches='tight')\n",
                "\n",
                "plt.show() \n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "df = res_df[\n",
                "    (res_df[\"fairness_metric\"] == \"Pred_Eq\") & \n",
                "    (res_df[\"cp_method\"] == \"aps\") & \n",
                "    (res_df[\"use_mle\"] == False) &\n",
                "    (res_df[\"client_type\"] == \"8_clients\")\n",
                "]\n",
                "df.sort_values(by=[\"c\", \"form_split\", \"client_type\", \"formulation\"])\n",
                "\n",
                "\n",
                "import matplotlib.pyplot as plt\n",
                "import seaborn as sns\n",
                "\n",
                "METRICS = list(map(lambda x: METRIC_MAP[x], [\"Demographic_Parity\", \"Predictive_Equality\"]))\n",
                "\n",
                "C = [0.1, 0.15, 0.2]\n",
                "dataset_sens_attrs = \"Fitzpatrick\"\n",
                "cp_method = \"aps\"\n",
                "client_type = \"8_clients\"\n",
                "eff_lim = 8\n",
                "FORMULATIONS = [f\"{i},{8 - i}\" for i in range(9)]\n",
                "use_mle = False\n",
                "\n",
                "fdf = res_df[(res_df[\"dataset_sens_attrs\"] == dataset_sens_attrs) & (res_df[\"fairness_metric\"].isin(METRICS)) & (res_df[\"cp_method\"] == cp_method) & (res_df[\"client_type\"] == client_type) & (res_df[\"use_mle\"] == use_mle)]\n",
                "display(fdf)\n",
                "# Plot efficiency\n",
                "grid = sns.catplot(\n",
                "    data=fdf[fdf[\"type\"] == \"ours\"],\n",
                "    x=\"fairness_metric\",\n",
                "    y=\"eff_mean\",\n",
                "    hue=\"c\",\n",
                "    col=\"form_split\",\n",
                "    kind=\"bar\",\n",
                "    order=METRICS,\n",
                "    ci=None,\n",
                "    sharex=False,\n",
                ")\n",
                "\n",
                "grid.set(xlabel=\"\")\n",
                "grid.set(ylabel=\"Efficiency\")\n",
                "grid.set(ylim=(0, eff_lim))\n",
                "for key, ax in grid.axes_dict.items():\n",
                "  if key[0] == \"0\":\n",
                "      ax.set_title(\"Enhanced_Privacy\")\n",
                "  elif key[0] == \"8\":\n",
                "      ax.set_title(\"Communication_Efficient\")\n",
                "  else:\n",
                "      ax.set_title(f\"Hybrid ({key})\")\n",
                "sns.move_legend(grid, \"upper center\", bbox_to_anchor=(0.5, 1.1), ncol=len(C))\n",
                "grid.legend.set_title(\"Closeness Criterion:\")\n",
                "\n",
                "# Plot violation\n",
                "grid2 = sns.catplot(\n",
                "    data=fdf[fdf[\"type\"] == \"ours\"],\n",
                "    x=\"fairness_metric\",\n",
                "    y=\"violation_mean\",\n",
                "    hue=\"c\",\n",
                "    col=\"form_split\",\n",
                "    kind=\"bar\",\n",
                "    order=METRICS,\n",
                "    ci=None,\n",
                "    sharex=False,\n",
                ")\n",
                "grid2.set(xlabel=\"\")\n",
                "grid2.set(ylabel=\"Actual Fairness Disparity\")\n",
                "\n",
                "grid2.set(ylim=(0, 0.5), yticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5])\n",
                "for key, ax in grid2.axes_dict.items():\n",
                "  if key[0] == \"0\":\n",
                "      ax.set_title(\"Enhanced_Privacy\")\n",
                "  elif key[0] == \"8\":\n",
                "      ax.set_title(\"Communication_Efficient\")\n",
                "  else:\n",
                "      ax.set_title(f\"Hybrid ({key})\")\n",
                "sns.move_legend(grid2, \"upper center\", bbox_to_anchor=(0.5, 1.1), ncol=len(C))\n",
                "grid2.legend.set_title(\"Closeness Criterion:\")\n",
                "\n",
                "# Add the base FCP method lines and points\n",
                "\n",
                "axes_eff = grid.axes.flatten() if isinstance(grid.axes, np.ndarray) else [grid.axes]\n",
                "axes_viol = grid2.axes.flatten() if isinstance(grid2.axes, np.ndarray) else [grid2.axes]\n",
                "axes_idx = 0\n",
                "for i in range(len(FORMULATIONS)):\n",
                "    print(dataset_sens_attrs, \" formulation \", FORMULATIONS[i])\n",
                "\n",
                "    base_cp_df = fdf[(fdf[\"type\"] == \"base\")]\n",
                "    if base_cp_df.empty:\n",
                "        continue\n",
                "    \n",
                "    # Efficiency base line\n",
                "    ax_eff = axes_eff[axes_idx]\n",
                "    patches = ax_eff.patches\n",
                "\n",
                "    for ix, a in enumerate(patches):\n",
                "        x_start = a.get_x()\n",
                "        width = a.get_width()\n",
                "        if width == 0: \n",
                "            continue\n",
                "\n",
                "        eff_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)]) & (base_cp_df[\"client_type\"] == client_type) & (base_cp_df[\"form_split\"] == FORMULATIONS[i])][\"eff_mean\"].mean()\n",
                "        ax_eff.plot([x_start, x_start + width], [eff_mean] * 2, '-', c='k')\n",
                "\n",
                "    # Violation base line and marker\n",
                "    ax_viol = axes_viol[axes_idx]\n",
                "    patches = ax_viol.patches\n",
                "\n",
                "    for ix, a in enumerate(patches):\n",
                "        x_start = a.get_x()\n",
                "        width = a.get_width()\n",
                "        if width == 0: \n",
                "            continue\n",
                "            \n",
                "        violation_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)]) & (base_cp_df[\"client_type\"] == client_type) & (base_cp_df[\"form_split\"] == FORMULATIONS[i])][\"violation_mean\"].mean()\n",
                "        ax_viol.plot([x_start, x_start + width], [violation_mean] * 2, '--', c='k')\n",
                "        ax_viol.plot(x_start + width / 2, float(C[ix // len(METRICS)]), 'o', c='k', markersize=3.5)\n",
                "    \n",
                "    axes_idx += 1\n",
                "    \n",
                "plt.tight_layout()\n",
                "form_slug = str(FORMULATIONS[i]).replace(\" \", \"_\")\n",
                "# Save the figure\n",
                "grid.savefig(f\"./figures/{dataset_sens_attrs}_form_split_comparison_{client_type}_efficiency.pdf\", bbox_inches='tight')\n",
                "grid2.savefig(f\"./figures/{dataset_sens_attrs}_form_split_comparison_{client_type}_violation.pdf\", bbox_inches='tight')\n",
                "\n",
                "plt.show() \n"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "fedcf",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.11.11"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
