{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "0",
      "metadata": {},
      "source": [
        "# Parity Analysis\n",
        "\n",
        "An ANOVA-style global test answers:\n",
        "\n",
        "**Null hypothesis (H₀):**  \n",
        " All groups (say, the 8 models) have the same expected accuracy/error.\n",
        "\n",
        "**Alternative (H₁):**  \n",
        " At least one group’s mean differs from the others.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "9df8bc92",
      "metadata": {},
      "source": [
        "## **Configuration:** \n",
        "Paths and column names are set in the next code cell. For your own data or prediction file, edit `PREDICTIONS_PATH`, `PRED_COL`, `GOLD_COL`, and `CATEGORIES` there. See [README_extended.md](../../README_extended.md#parity-analysis-notebook) for details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1",
      "metadata": {},
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "from typing import List, Tuple\n",
        "\n",
        "# global defs\n",
        "PREDICTIONS_PATH = \"../../outputs/predictions/gemini_all_rows_fr+fs_modelslant.csv\"\n",
        "PRED_COL = \"gemini_fr+fs_avg\"\n",
        "GOLD_COL = \"representation_rating\"\n",
        "\n",
        "CATEGORIES = [\"Sex\",\"Ethnicity simplified\",\"U.s. political affiliation\",\"selection_position\",\"model\"]\n",
        "METRICS = [(\"MAE\",\"abs_err\"), (\"MSE\",\"squared_err\")]\n",
        "\n",
        "# Load data\n",
        "df = pd.read_csv(PREDICTIONS_PATH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "2",
      "metadata": {},
      "outputs": [],
      "source": [
        "\n",
        "# Keep only rows with ground truth and non-null predictions\n",
        "df = df.copy()\n",
        "df = df[pd.notnull(df[GOLD_COL]) & pd.notnull(df[PRED_COL])]\n",
        "\n",
        "# Ensure numeric\n",
        "df[GOLD_COL] = pd.to_numeric(df[GOLD_COL], errors=\"coerce\")\n",
        "df[PRED_COL] = pd.to_numeric(df[PRED_COL], errors=\"coerce\")\n",
        "df = df[pd.notnull(df[GOLD_COL]) & pd.notnull(df[PRED_COL])]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "3",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Per-row metrics\n",
        "df[\"correct\"] = (df[GOLD_COL] == df[PRED_COL]).astype(float)\n",
        "df[\"abs_err\"] = (df[PRED_COL] - df[GOLD_COL]).abs()\n",
        "df[\"squared_err\"] = (df[PRED_COL] - df[GOLD_COL])**2\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "4",
      "metadata": {},
      "outputs": [],
      "source": [
        "\n",
        "def anova_f_stat(groups: List[np.ndarray]) -> Tuple[float, float, float]:\n",
        "    \"\"\"Return (F, eta2, grand_mean).\"\"\"\n",
        "    # Filter out any empty groups\n",
        "    groups = [g.astype(float) for g in groups if len(g) > 0]\n",
        "    k = len(groups)\n",
        "    ns = np.array([len(g) for g in groups], dtype=float)\n",
        "    N = ns.sum()\n",
        "    grand_mean = np.sum([g.sum() for g in groups]) / N if N > 0 else np.nan\n",
        "    # Between-group sum of squares\n",
        "    ssb = np.sum(ns * (np.array([g.mean() for g in groups]) - grand_mean) ** 2)\n",
        "    # Within-group sum of squares\n",
        "    ssw = np.sum([np.sum((g - g.mean())**2) for g in groups])\n",
        "    # Degrees of freedom\n",
        "    dfb = k - 1\n",
        "    dfw = N - k\n",
        "    if dfb <= 0 or dfw <= 0 or ssw <= 0:\n",
        "        return np.nan, np.nan, grand_mean\n",
        "    msb = ssb / dfb\n",
        "    msw = ssw / dfw\n",
        "    F = msb / msw if msw > 0 else np.inf\n",
        "    # Effect size (eta squared)\n",
        "    sst = ssb + ssw\n",
        "    eta2 = ssb / sst if sst > 0 else np.nan\n",
        "    return F, eta2, grand_mean\n",
        "\n",
        "def permutation_anova(values: np.ndarray, labels: np.ndarray, n_perm: int = 5000, rng=None):\n",
        "    \"\"\"Permutation ANOVA retaining group sizes. Returns observed F, p-value, eta2, group sizes.\"\"\"\n",
        "    if rng is None:\n",
        "        rng = np.random.default_rng(123)\n",
        "    # Build groups\n",
        "    uniq, inv = np.unique(labels, return_inverse=True)\n",
        "    groups = [values[inv == i] for i in range(len(uniq))]\n",
        "    # Observed F\n",
        "    F_obs, eta2_obs, _ = anova_f_stat(groups)\n",
        "    # Permutations\n",
        "    n_perm_eff = int(n_perm)\n",
        "    F_perm = np.empty(n_perm_eff, dtype=float)\n",
        "    # Precompute group sizes\n",
        "    sizes = [len(g) for g in groups]\n",
        "    for b in range(n_perm_eff):\n",
        "        shuffled = rng.permutation(values)\n",
        "        # allocate in order of sizes\n",
        "        idx = 0\n",
        "        perm_groups = []\n",
        "        for s in sizes:\n",
        "            perm_groups.append(shuffled[idx:idx+s])\n",
        "            idx += s\n",
        "        F_perm[b], _, _ = anova_f_stat(perm_groups)\n",
        "    # p-value (>= observed), add 1 for unbiased estimate\n",
        "    p_val = (np.sum(F_perm >= F_obs) + 1.0) / (n_perm_eff + 1.0)\n",
        "    return F_obs, p_val, eta2_obs, sizes\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "5",
      "metadata": {},
      "outputs": [],
      "source": [
        "\n",
        "# Run permutation ANOVA for each category x metric\n",
        "rows = []\n",
        "rng = np.random.default_rng(123)\n",
        "for cat in CATEGORIES:\n",
        "    for metric_name, col in METRICS:\n",
        "        sub = df[[cat, col]].dropna().copy()\n",
        "        sub = sub[sub[cat] != \"DATA_EXPIRED\"]  # <-- skip expired groups\n",
        "        # Skip if less than 2 groups with data\n",
        "        if sub[cat].nunique() < 2:\n",
        "            continue\n",
        "        values = sub[col].values.astype(float)\n",
        "        labels = sub[cat].values.astype(str)\n",
        "        F_obs, p_val, eta2, sizes = permutation_anova(values, labels, n_perm=5000, rng=rng)\n",
        "        rows.append({\n",
        "            \"category\": cat,\n",
        "            \"metric\": metric_name,\n",
        "            \"F_obs\": F_obs,\n",
        "            \"p_perm\": p_val,\n",
        "            \"eta_squared\": eta2,\n",
        "            \"n_groups\": int(sub[cat].nunique()),\n",
        "            \"group_sizes\": str(sorted([int(s) for s in sizes]))\n",
        "        })\n",
        "\n",
        "perm_anova_df = pd.DataFrame(rows).sort_values([\"category\",\"metric\"]).reset_index(drop=True)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "id": "6",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>category</th>\n",
              "      <th>metric</th>\n",
              "      <th>F_obs</th>\n",
              "      <th>p_perm</th>\n",
              "      <th>eta_squared</th>\n",
              "      <th>n_groups</th>\n",
              "      <th>group_sizes</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>Ethnicity simplified</td>\n",
              "      <td>MAE</td>\n",
              "      <td>1.781883</td>\n",
              "      <td>0.126775</td>\n",
              "      <td>9.896427e-04</td>\n",
              "      <td>5</td>\n",
              "      <td>[696, 720, 888, 960, 3936]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>Ethnicity simplified</td>\n",
              "      <td>MSE</td>\n",
              "      <td>1.717011</td>\n",
              "      <td>0.141372</td>\n",
              "      <td>9.536476e-04</td>\n",
              "      <td>5</td>\n",
              "      <td>[696, 720, 888, 960, 3936]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>Sex</td>\n",
              "      <td>MAE</td>\n",
              "      <td>0.001555</td>\n",
              "      <td>0.976205</td>\n",
              "      <td>2.160021e-07</td>\n",
              "      <td>2</td>\n",
              "      <td>[3528, 3672]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>Sex</td>\n",
              "      <td>MSE</td>\n",
              "      <td>0.599623</td>\n",
              "      <td>0.441912</td>\n",
              "      <td>8.329713e-05</td>\n",
              "      <td>2</td>\n",
              "      <td>[3528, 3672]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>U.s. political affiliation</td>\n",
              "      <td>MAE</td>\n",
              "      <td>5.287438</td>\n",
              "      <td>0.003799</td>\n",
              "      <td>1.467189e-03</td>\n",
              "      <td>3</td>\n",
              "      <td>[2160, 2184, 2856]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>U.s. political affiliation</td>\n",
              "      <td>MSE</td>\n",
              "      <td>2.494403</td>\n",
              "      <td>0.092382</td>\n",
              "      <td>6.926982e-04</td>\n",
              "      <td>3</td>\n",
              "      <td>[2160, 2184, 2856]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>model</td>\n",
              "      <td>MAE</td>\n",
              "      <td>2.266550</td>\n",
              "      <td>0.026795</td>\n",
              "      <td>2.201185e-03</td>\n",
              "      <td>8</td>\n",
              "      <td>[900, 900, 900, 900, 900, 900, 900, 900]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>model</td>\n",
              "      <td>MSE</td>\n",
              "      <td>3.133390</td>\n",
              "      <td>0.002999</td>\n",
              "      <td>3.040468e-03</td>\n",
              "      <td>8</td>\n",
              "      <td>[900, 900, 900, 900, 900, 900, 900, 900]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>selection_position</td>\n",
              "      <td>MAE</td>\n",
              "      <td>4.226002</td>\n",
              "      <td>0.017397</td>\n",
              "      <td>1.173001e-03</td>\n",
              "      <td>3</td>\n",
              "      <td>[1072, 2584, 3544]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9</th>\n",
              "      <td>selection_position</td>\n",
              "      <td>MSE</td>\n",
              "      <td>6.979119</td>\n",
              "      <td>0.001400</td>\n",
              "      <td>1.935698e-03</td>\n",
              "      <td>3</td>\n",
              "      <td>[1072, 2584, 3544]</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "                     category metric     F_obs    p_perm   eta_squared  \\\n",
              "0        Ethnicity simplified    MAE  1.781883  0.126775  9.896427e-04   \n",
              "1        Ethnicity simplified    MSE  1.717011  0.141372  9.536476e-04   \n",
              "2                         Sex    MAE  0.001555  0.976205  2.160021e-07   \n",
              "3                         Sex    MSE  0.599623  0.441912  8.329713e-05   \n",
              "4  U.s. political affiliation    MAE  5.287438  0.003799  1.467189e-03   \n",
              "5  U.s. political affiliation    MSE  2.494403  0.092382  6.926982e-04   \n",
              "6                       model    MAE  2.266550  0.026795  2.201185e-03   \n",
              "7                       model    MSE  3.133390  0.002999  3.040468e-03   \n",
              "8          selection_position    MAE  4.226002  0.017397  1.173001e-03   \n",
              "9          selection_position    MSE  6.979119  0.001400  1.935698e-03   \n",
              "\n",
              "   n_groups                               group_sizes  \n",
              "0         5                [696, 720, 888, 960, 3936]  \n",
              "1         5                [696, 720, 888, 960, 3936]  \n",
              "2         2                              [3528, 3672]  \n",
              "3         2                              [3528, 3672]  \n",
              "4         3                        [2160, 2184, 2856]  \n",
              "5         3                        [2160, 2184, 2856]  \n",
              "6         8  [900, 900, 900, 900, 900, 900, 900, 900]  \n",
              "7         8  [900, 900, 900, 900, 900, 900, 900, 900]  \n",
              "8         3                        [1072, 2584, 3544]  \n",
              "9         3                        [1072, 2584, 3544]  "
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "perm_anova_df"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "7",
      "metadata": {},
      "source": [
        "In the paper, we observe the following results:\n",
        "\n",
        "| Category              | Metric | F       | p_perm   | η²      | # Groups |\n",
        "|----------------------|--------|---------|----------|---------|----------|\n",
        "| Ethnicity (simplified) | MAE   | 1.78    | 0.127    | 0.0010  | 5        |\n",
        "|                      | MSE    | 1.72    | 0.141    | 0.0010  | 5        |\n",
        "| Sex                  | MAE    | 0.00    | 0.976    | 0.0000  | 2        |\n",
        "|                      | MSE    | 0.60    | 0.442    | 0.0001  | 2        |\n",
        "| Political party      | MAE    | **5.29** | **0.004** | **0.0015** | 3     |\n",
        "|                      | MSE    | 2.49    | 0.092    | 0.0007  | 3        |\n",
        "| Model                | MAE    | **2.27** | **0.027** | **0.0022** | 8     |\n",
        "|                      | MSE    | **3.13** | **0.003** | **0.0030** | 8     |\n",
        "| Stance (selection)   | MAE    | **4.23** | **0.017** | **0.0012** | 3     |\n",
        "|                      | MSE    | **6.98** | **0.001** | **0.0019** | 3     |\n",
        "\n",
        "\n",
        "### Interpretation:\n",
        "\n",
        "* The `p_perm` column is your permutation-based *p*\\-value.\n",
        "\n",
        "* `eta_squared` is the effect size \\= % of variance explained by the grouping (tiny numbers, but still useful).\n",
        "\n",
        "* **Threshold:** usually p \\< 0.05 \\= significant evidence of differences.\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "oeenv",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
