{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8df3ba7",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os, sys\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import datetime\n",
    "import torch\n",
    "import matplotlib.colors as mcolors\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Add parent directory to system path\n",
    "current_dir = os.getcwd()\n",
    "parent_dir = os.path.dirname(current_dir)\n",
    "sys.path.append(os.path.join(parent_dir, 'core'))\n",
    "\n",
    "# Import from core and utils directory \n",
    "from algorithms import OnlineMA, OnlineMARegret, OnlineMC\n",
    "from utils import rolling_mean, rolling_vec_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "910e24d6",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "# Output directory for plots\n",
    "plots_dir = os.path.join(os.getcwd(), \"plots\")\n",
    "os.makedirs(plots_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2cf9025",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "# Load the COMPAS dataset\n",
    "df = pd.read_csv('./raw_data/compas-scores-two-years.csv')\n",
    "df = df[['compas_screening_date', 'sex', 'race', 'v_decile_score', 'is_recid']]\n",
    "df = df[df.race.isin([\"African-American\", \"Caucasian\", \"Hispanic\"])]\n",
    "df.compas_screening_date = pd.to_datetime(df.compas_screening_date)\n",
    "\n",
    "# drop rows after 2014-04-01 (keep rows on or before this date)\n",
    "cutoff = pd.Timestamp('2014-04-01')\n",
    "df = df[df.compas_screening_date <= cutoff]\n",
    "\n",
    "df_req = df.sort_values(by='compas_screening_date', ignore_index=True)\n",
    "df_req['p_tilde'] = df_req['v_decile_score'] / 10.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61b98654",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "window = 50\n",
    "window_plot = 50\n",
    "eta = 0.05\n",
    "loss = \"squared\"\n",
    "\n",
    "# Build arrays aligned to df_req index (0..T-1)\n",
    "num_obs = len(df_req)\n",
    "X = np.zeros((num_obs, 1), dtype=float) \n",
    "y = df_req['is_recid'].to_numpy(dtype=float)\n",
    "p_tilde_seq = df_req['p_tilde'].to_numpy(dtype=float)\n",
    "G_df = pd.get_dummies(df_req['race'].astype('category'), drop_first=False)\n",
    "G = G_df.to_numpy(dtype=float)\n",
    "num_groups = G.shape[1]\n",
    "print(\"number of groups: \", num_groups)\n",
    "\n",
    "# Group by date and use the (now) positional indices to slice arrays\n",
    "by_date = df_req.groupby('compas_screening_date', sort=True)\n",
    "T = len(by_date)\n",
    "\n",
    "# Online learner (uses external p̃; β-step skipped internally)\n",
    "maonly = OnlineMA(d=X.shape[1], m=num_groups, eta=eta, window_size=window, gamma_pred=0.0, loss=loss)\n",
    "maonly_nonadaptive = OnlineMA(d=X.shape[1], m=num_groups, eta=eta, window_size=window, gamma_pred=0.0, loss=loss, num_time_steps=T, adaptive=False)\n",
    "mareg = OnlineMARegret(d=X.shape[1], m=num_groups, eta=eta, window_size=window, gamma_pred=0.0, loss=loss)\n",
    "mareg_nonadaptive = OnlineMARegret(d=X.shape[1], m=num_groups, eta=eta, window_size=window, gamma_pred=0.0, loss=loss, num_time_steps=T, adaptive=False)\n",
    "\n",
    "# Track metrics\n",
    "ptilde_baseline_ma_losses = []\n",
    "maonly_baseline_ma_losses, maonly_baseline_lreg = [], []\n",
    "maonly_nonadaptive_ma_losses, maonly_nonadaptive_lreg = [], []\n",
    "mareg_ma_losses, mareg_l_reg = [], []\n",
    "mareg_nonadaptive_ma_losses, mareg_nonadaptive_l_reg = [], []\n",
    "\n",
    "for date, df_d in tqdm(by_date, total=len(by_date)):\n",
    "    # Build batch for this date\n",
    "    idxs = df_d.index.to_numpy()\n",
    "    x_batch = X[idxs]                                  # (b, d)\n",
    "    y_batch = y[idxs].astype(float)                    # (b,)\n",
    "    g_batch = G[idxs]                                  # (b, m)\n",
    "    pt_batch = p_tilde_seq[idxs].astype(float)         # (b,)\n",
    "\n",
    "    # Baseline from p̃: use expected group residual vector over the batch\n",
    "    vec_ptilde = (g_batch * (y_batch - pt_batch)[:, None]).mean(axis=0)\n",
    "    vec_ptilde = np.asarray(vec_ptilde, dtype=float).ravel()\n",
    "    ptilde_baseline_ma_losses.append(np.concatenate([vec_ptilde, -vec_ptilde]))\n",
    "\n",
    "    # Baseline: MA-only baseline (no regret term)\n",
    "    metrics = maonly.update(x_batch, y_batch, g_batch, p_tilde=pt_batch)\n",
    "    maonly_baseline_ma_losses.append(metrics[\"ma_losses\"])\n",
    "    maonly_baseline_lreg.append(metrics[\"l_reg\"])\n",
    "\n",
    "    # Baseline: Non-adaptive MA-only\n",
    "    metrics = maonly_nonadaptive.update(x_batch, y_batch, g_batch, p_tilde=pt_batch)\n",
    "    maonly_nonadaptive_ma_losses.append(metrics[\"ma_losses\"])\n",
    "    maonly_nonadaptive_lreg.append(metrics[\"l_reg\"])\n",
    "\n",
    "    # Baseline: Non-adaptive MA+reg\n",
    "    metrics = mareg_nonadaptive.update(x_batch, y_batch, g_batch, p_tilde=pt_batch)\n",
    "    mareg_nonadaptive_ma_losses.append(metrics[\"ma_losses\"])\n",
    "    mareg_nonadaptive_l_reg.append(metrics[\"l_reg\"])\n",
    "   \n",
    "    # Model update with external p̃ (β-step skipped internally), batched\n",
    "    metrics = mareg.update(x_batch, y_batch, g_batch, p_tilde=pt_batch)\n",
    "    mareg_ma_losses.append(metrics[\"ma_losses\"])\n",
    "    mareg_l_reg.append(metrics[\"l_reg\"])\n",
    "\n",
    "ptilde_baseline_ma_l2 = rolling_vec_norm(ptilde_baseline_ma_losses, window_plot, norm=\"l2\")\n",
    "ptilde_baseline_ma_linf = rolling_vec_norm(ptilde_baseline_ma_losses, window_plot, norm=\"l_infty\")\n",
    "maonly_baseline_ma_l2 = rolling_vec_norm(maonly_baseline_ma_losses, window_plot, norm=\"l2\")\n",
    "maonly_baseline_ma_linf = rolling_vec_norm(maonly_baseline_ma_losses, window_plot, norm=\"l_infty\")\n",
    "maonly_baseline_lreg = rolling_mean(maonly_baseline_lreg, window_plot)\n",
    "maonly_nonadaptive_ma_l2 = rolling_vec_norm(maonly_nonadaptive_ma_losses, window_plot, norm=\"l2\")\n",
    "maonly_nonadaptive_ma_linf = rolling_vec_norm(maonly_nonadaptive_ma_losses, window_plot, norm=\"l_infty\")\n",
    "maonly_nonadaptive_lreg = rolling_mean(maonly_nonadaptive_lreg, window_plot)\n",
    "mareg_nonadaptive_ma_l2 = rolling_vec_norm(mareg_nonadaptive_ma_losses, window_plot, norm=\"l2\")\n",
    "mareg_nonadaptive_ma_linf = rolling_vec_norm(mareg_nonadaptive_ma_losses, window_plot, norm=\"l_infty\")\n",
    "mareg_nonadaptive_l_reg = rolling_mean(mareg_nonadaptive_l_reg, window_plot)\n",
    "mareg_ma_l2 = rolling_vec_norm(mareg_ma_losses, window_plot, norm=\"l2\")\n",
    "mareg_ma_linf = rolling_vec_norm(mareg_ma_losses, window_plot, norm=\"l_infty\")\n",
    "mareg_lreg = rolling_mean(mareg_l_reg, window_plot)\n",
    "\n",
    "# Plot L_infty\n",
    "plt.figure(figsize=(9, 3))\n",
    "plt.plot(ptilde_baseline_ma_linf, label=f\"p~ baseline\")\n",
    "plt.plot(maonly_baseline_ma_linf, label=f\"MA-only\")\n",
    "plt.plot(maonly_nonadaptive_ma_linf, label=f\"MA-only (non-adaptive)\")\n",
    "plt.plot(mareg_nonadaptive_ma_linf, label=f\"MA+reg (non-adaptive)\")\n",
    "plt.plot(mareg_ma_linf, label=f\"MA+reg\")\n",
    "plt.xlabel(\"t\")\n",
    "plt.ylabel(r\"$L_\\infty$ MA error\")\n",
    "plt.title(f\"COMPAS: Smoothed L_\\infty MA error (w_alg={window}, w_plot={window_plot})\")\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "out_inf = os.path.join(plots_dir, f\"compas_ma_error_linf_walg{window}_wplot{window_plot}.png\")\n",
    "plt.savefig(out_inf, dpi=150)\n",
    "plt.close()\n",
    "\n",
    "# Plot L2\n",
    "plt.figure(figsize=(9, 3))\n",
    "plt.plot(ptilde_baseline_ma_l2, label=f\"p~ baseline\")\n",
    "plt.plot(maonly_baseline_ma_l2, label=f\"MA-only\")\n",
    "plt.plot(maonly_nonadaptive_ma_l2, label=f\"MA-only (non-adaptive)\")\n",
    "plt.plot(mareg_nonadaptive_ma_l2, label=f\"MA+reg (non-adaptive)\")\n",
    "plt.plot(mareg_ma_l2, label=f\"MA+reg\")\n",
    "plt.xlabel(\"t\")\n",
    "plt.ylabel(r\"$L_2$ MA error\")\n",
    "plt.title(f\"COMPAS: Smoothed L2 MA error (w_alg={window}, w_plot={window_plot})\")\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "out_l2 = os.path.join(plots_dir, f\"compas_ma_error_l2_walg{window}_wplot{window_plot}.png\")\n",
    "plt.savefig(out_l2, dpi=150)\n",
    "plt.close()\n",
    "\n",
    "# Plot L_reg \n",
    "plt.figure(figsize=(9, 3))\n",
    "plt.plot(maonly_baseline_lreg, label=f\"MA-only\")\n",
    "plt.plot(maonly_nonadaptive_lreg, label=f\"MA-only (non-adaptive)\")\n",
    "plt.plot(mareg_nonadaptive_l_reg, label=f\"MA+reg (non-adaptive)\")\n",
    "plt.plot(mareg_lreg, label=f\"MA+reg\")\n",
    "plt.xlabel(\"t\")\n",
    "plt.ylabel(\"regret\")\n",
    "plt.title(f\"COMPAS: Smoothed regret (w_alg={window}, w_plot={window_plot})\")\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "out_reg = os.path.join(plots_dir, f\"compas_lreg_walg{window}_wplot{window_plot}.png\")\n",
    "plt.savefig(out_reg, dpi=150)\n",
    "plt.close()\n",
    "\n",
    "print(\"Saved plots:\\n -\", out_inf, \"\\n -\", out_l2, \"\\n -\", out_reg)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
