{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79059b40",
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import gamma\n",
    "import numpy as np\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.datasets import load_breast_cancer\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# reproducibility\n",
    "def set_seed(s=0):\n",
    "    random.seed(s)\n",
    "    np.random.seed(s)\n",
    "\n",
    "# data_loading\n",
    "data = load_breast_cancer()\n",
    "X_all, y_all = data.data, data.target\n",
    "X_all = StandardScaler().fit_transform(X_all)\n",
    "y_all = 1 - y_all  \n",
    "\n",
    "\n",
    "# train/test split (stratified)\n",
    "X_tr, X_te, y_tr, y_te = train_test_split(\n",
    "    X_all, y_all, test_size=0.2, stratify=y_all, random_state=123\n",
    ")\n",
    "\n",
    "# IID_partition\n",
    "def make_clients(X, y, n_clients=10):\n",
    "    client_data = {i: {\"x0\": None, \"x1\": None} for i in range(n_clients)}\n",
    "    idx_0 = np.where(y == 0)[0]\n",
    "    idx_1 = np.where(y == 1)[0]\n",
    "    np.random.shuffle(idx_0)\n",
    "    np.random.shuffle(idx_1)\n",
    "    for i in range(n_clients):\n",
    "        client_data[i][\"x0\"] = X[idx_0[i::n_clients]]\n",
    "        client_data[i][\"x1\"] = X[idx_1[i::n_clients]]\n",
    "    return client_data\n",
    "\n",
    "# Loss/grads\n",
    "def softplus(z):\n",
    "    return np.logaddexp(0.0, z)\n",
    "\n",
    "def sigmoid(z):\n",
    "    return 1.0 / (1.0 + np.exp(-z))\n",
    "    # return 0.5 * (1.0 + np.tanh(z * 0.5))\n",
    "\n",
    "# class-0 (objective)\n",
    "def loss_class0(w, x0):\n",
    "    if x0.size == 0: return 0.0\n",
    "    z = x0 @ w\n",
    "    return np.mean(softplus(z))\n",
    "\n",
    "def grad_class0(w, x0):\n",
    "    if x0.size == 0: return np.zeros_like(w)\n",
    "    z = x0 @ w\n",
    "    s = sigmoid(z)\n",
    "    return (s[:, None] * x0).mean(axis=0)\n",
    "\n",
    "# class-1 (constraint)\n",
    "def loss_class1(w, x1):\n",
    "    if x1.size == 0: return 0.0\n",
    "    z = x1 @ w\n",
    "    return np.mean(softplus(-z))\n",
    "\n",
    "def grad_class1(w, x1):\n",
    "    if x1.size == 0: return np.zeros_like(w)\n",
    "    z = x1 @ w\n",
    "    s = sigmoid(z)\n",
    "    return ((s - 1.0)[:, None] * x1).mean(axis=0)\n",
    "\n",
    "# switching rule\n",
    "def sigma_beta(x, beta=1.0):\n",
    "    # x = g_avg - epsilon\n",
    "    return np.clip(1.0 + beta * x, 0.0, 1.0)\n",
    "\n",
    "#Compressors\n",
    "\n",
    "def compress_identity(v):\n",
    "    return v\n",
    "\n",
    "def compress_randk(v, k, rng):\n",
    "    d = v.size\n",
    "    if k >= d: return v\n",
    "    out = np.zeros_like(v)\n",
    "    idx = rng.choice(d, size=k, replace=False)\n",
    "    scale = d / float(k)\n",
    "    out[idx] = scale * v[idx]\n",
    "    return out\n",
    "\n",
    "def compress_topk(v, k):\n",
    "    \"\"\"Top-K compression\"\"\"\n",
    "    d = v.size\n",
    "    if k >= d:\n",
    "        return v.copy()\n",
    "    out = np.zeros_like(v)\n",
    "    idx = np.argpartition(np.abs(v), -k)[-k:]\n",
    "    out[idx] = v[idx] \n",
    "    return out\n",
    "\n",
    "# evaluation\n",
    "def test_accuracy(w, X, y):\n",
    "    z = X @ w\n",
    "    p = sigmoid(z)\n",
    "    yhat = (p >= 0.5).astype(int)\n",
    "    return (yhat == y).mean()\n",
    "\n",
    "def test_accuracy_all(w, X, y):\n",
    "    z = X @ w\n",
    "    p = sigmoid(z)\n",
    "    yhat = (p >= 0.5).astype(int)\n",
    "    overall = (yhat == y).mean()\n",
    "    acc0 = (yhat[y==0] == 0).mean() if np.any(y==0) else np.nan\n",
    "    acc1 = (yhat[y==1] == 1).mean() if np.any(y==1) else np.nan\n",
    "    return overall, acc0, acc1\n",
    "\n",
    "# Federated run\n",
    "def run_trial(\n",
    "    seed,\n",
    "    betas,                 \n",
    "    n_clients=10,\n",
    "    num_rounds=500,\n",
    "    local_epochs=5,\n",
    "    stepsize=0.1,          \n",
    "    epsilon=0.1,\n",
    "    participation_rate=1.0,\n",
    "    compressor=\"topk\",     \n",
    "    k_frac=0.1,\n",
    "    comp_on=True,          \n",
    "    downlink_ef=True       \n",
    "):\n",
    "    print(\"federated training\")\n",
    "    set_seed(seed)\n",
    "    clients = make_clients(X_tr, y_tr, n_clients=n_clients)\n",
    "\n",
    "    d = X_tr.shape[1]\n",
    "    print(f\"seed: {seed}, d: {d}\")\n",
    "\n",
    "    if np.isscalar(stepsize):\n",
    "        eta_vec = [float(stepsize)] * len(betas)\n",
    "    else:\n",
    "        eta_vec = stepsize\n",
    "\n",
    "    rng_k = np.random.default_rng(seed + 999)\n",
    "\n",
    "    if comp_on:\n",
    "        e_client = {j: np.zeros(d, dtype=np.float64) for j in range(n_clients)}  # e_j^0\n",
    "    else:\n",
    "        e_client = {}\n",
    "\n",
    "    results = {}\n",
    "    for cnt, beta_val in enumerate(betas):\n",
    "        set_seed(seed)\n",
    "        w = np.random.randn(d)  \n",
    "        x = w.copy()            # server x_t (used when comp_on)\n",
    "\n",
    "        f_hist, g_hist, feas_hist = [], [], []\n",
    "        g_hist_selected = []\n",
    "        f_hist_selected = []\n",
    "        f_t = np.mean([loss_class0(w, clients[i][\"x0\"]) for i in range(n_clients)])\n",
    "        g_t = np.mean([loss_class1(w, clients[i][\"x1\"]) for i in range(n_clients)])\n",
    "        feasible = (g_t <= epsilon)\n",
    "        f_hist.append(f_t); g_hist.append(g_t); g_hist_selected.append(g_t); f_hist_selected.append(f_t)\n",
    "        viol_count = 0\n",
    "        feas_hist.append(1.0 if feasible else 0.0)\n",
    "        viol_count += (not feasible)\n",
    "        eta = float(eta_vec[cnt])\n",
    "\n",
    "        comm_hist = []\n",
    "        cum_coords = 0\n",
    "\n",
    "        for t in range(num_rounds):\n",
    "            k_clients = max(1, int(participation_rate * n_clients))\n",
    "            selected = random.sample(range(n_clients), k=k_clients)\n",
    "\n",
    "            # constriant check\n",
    "            g_local = []\n",
    "            for j in selected:\n",
    "                x1 = clients[j][\"x1\"]\n",
    "                g_local.append(loss_class1(w, x1))\n",
    "            g_avg = float(np.mean(g_local)) if len(g_local) else 0.0\n",
    "\n",
    "            feasible = (g_avg <= epsilon)\n",
    "            viol_count += (not feasible)\n",
    "\n",
    "            # Switching rule\n",
    "            if beta_val is None:\n",
    "                sigma_t = 1.0 if (g_avg > epsilon) else 0.0\n",
    "            else:\n",
    "                sigma_t = float(sigma_beta(g_avg - epsilon, beta_val))\n",
    "\n",
    "            # Local client updates\n",
    "            msgs_to_server = []\n",
    "            for j in selected:\n",
    "                x0, x1 = clients[j][\"x0\"], clients[j][\"x1\"]\n",
    "                wl = w.copy()\n",
    "                delta_sum = np.zeros_like(w)\n",
    "\n",
    "                for _ in range(local_epochs):\n",
    "                    gf = grad_class0(wl, x0)\n",
    "                    gg = grad_class1(wl, x1)\n",
    "\n",
    "                    if beta_val is None:\n",
    "                        nu = gf if feasible else gg\n",
    "                    else:\n",
    "                        nu = (1.0 - sigma_t) * gf + sigma_t * gg\n",
    "\n",
    "                    delta_sum += nu\n",
    "                    wl = wl - eta * nu\n",
    "\n",
    "                Delta_j_t = delta_sum\n",
    "\n",
    "                if comp_on:\n",
    "                    tmp = e_client[j] + Delta_j_t\n",
    "                    k = max(1, int(k_frac * tmp.size))\n",
    "                    if compressor == \"topk\":\n",
    "                        v_j_t = compress_topk(tmp, k)\n",
    "                    elif compressor == \"randk\":\n",
    "                        v_j_t = compress_randk(tmp, k, rng_k)\n",
    "                    elif compressor == \"id\":\n",
    "                        v_j_t = tmp\n",
    "                    else:\n",
    "                        raise ValueError\n",
    "                    e_client[j] = tmp - v_j_t\n",
    "                    msgs_to_server.append(v_j_t)\n",
    "                else:\n",
    "                    msgs_to_server.append(Delta_j_t)\n",
    "\n",
    "            # === (28–39) Server aggregation & update\n",
    "            if not msgs_to_server:\n",
    "                print(\"empty-->no clients participated\")\n",
    "                continue\n",
    "\n",
    "            if comp_on:\n",
    "                v_t = np.mean(np.stack(msgs_to_server, axis=0), axis=0)\n",
    "                x_next = x - eta * v_t\n",
    "\n",
    "                if downlink_ef:\n",
    "                    diff = (x_next - w)\n",
    "                    k0 = max(1, int(k_frac * diff.size))\n",
    "                    if compressor == \"topk\":\n",
    "                        c0 = compress_topk(diff, k0)\n",
    "                    elif compressor == \"randk\":\n",
    "                        c0 = compress_randk(diff, k0, rng_k)\n",
    "                    elif compressor == \"id\":\n",
    "                        c0 = diff\n",
    "                    else:\n",
    "                        raise ValueError(\"compressor must be 'topk', 'randk', or 'id'\")\n",
    "                    w = w + c0\n",
    "                else:\n",
    "                    w = x_next.copy()\n",
    "\n",
    "                x = x_next.copy()\n",
    "            else:\n",
    "                mean_delta = np.mean(np.stack(msgs_to_server, axis=0), axis=0)\n",
    "                w = w - eta * mean_delta\n",
    "\n",
    "            # comm. accounting\n",
    "            # Uplink\n",
    "            if comp_on:\n",
    "              if compressor == \"topk\":\n",
    "                  k_val = max(1, int(k_frac*d))\n",
    "                  uplink = (k_val * 2) * len(selected)\n",
    "              elif compressor == \"randk\":\n",
    "                  k_val = max(1, int(k_frac*d))\n",
    "                  uplink = k_val * len(selected)\n",
    "              elif compressor == \"id\":\n",
    "                  uplink = d * len(selected)\n",
    "              else:\n",
    "                  raise ValueError(\"invalid compressor\")\n",
    "            else:\n",
    "                uplink = d * len(selected)\n",
    "            # Downlink\n",
    "            if comp_on and downlink_ef:\n",
    "              if compressor == \"topk\":\n",
    "                  k_val = max(1, int(k_frac*d))\n",
    "                  downlink = (k_val * 2) * n_clients\n",
    "              elif compressor == \"randk\":\n",
    "                  k_val = max(1, int(k_frac*d))\n",
    "                  downlink = k_val * n_clients\n",
    "              elif compressor == \"id\":\n",
    "                  downlink = d * n_clients\n",
    "              else:\n",
    "                  raise ValueError(\"invalid compressor\")\n",
    "            else:\n",
    "                downlink = d * n_clients \n",
    "            cum_coords += uplink + downlink\n",
    "            comm_hist.append(cum_coords)\n",
    "\n",
    "            # logging on clients\n",
    "            f_t = np.mean([loss_class0(w, clients[i][\"x0\"]) for i in range(n_clients)])\n",
    "            g_t = np.mean([loss_class1(w, clients[i][\"x1\"]) for i in range(n_clients)])\n",
    "\n",
    "            f_selected = np.mean([loss_class0(w, clients[i][\"x0\"]) for i in selected])\n",
    "            g_selected = np.mean([loss_class1(w, clients[i][\"x1\"]) for i in selected])\n",
    "\n",
    "            f_hist.append(f_t); g_hist.append(g_t); feas_hist.append(1.0 if feasible else 0.0)\n",
    "            g_hist_selected.append(g_selected)\n",
    "            f_hist_selected.append(f_selected)\n",
    "\n",
    "        # Final test\n",
    "        acc_te, acc_0, acc_1 = test_accuracy_all(w, X_te, y_te)\n",
    "        label = \"Hard\" if beta_val is None else f\"β={beta_val}\"\n",
    "        idx_0_test = np.where(y_te == 0)[0]\n",
    "        idx_1_test = np.where(y_te == 1)[0]\n",
    "\n",
    "        results[label] = {\n",
    "            \"f_hist\": np.array(f_hist),\n",
    "            \"g_hist\": np.array(g_hist),\n",
    "            \"f_hist_selected\": np.array(f_hist_selected),\n",
    "            \"g_hist_selected\": np.array(g_hist_selected),\n",
    "            \"feas_hist\": np.array(feas_hist),\n",
    "            \"violations\": int(viol_count),\n",
    "            \"test_acc\": float(acc_te),\n",
    "            \"comm_hist\": np.array(comm_hist), \n",
    "            \"test_acc_0\": float(acc_0),\n",
    "            \"test_acc_1\": float(acc_1),\n",
    "            \"objective_test\": float(loss_class0(w, X_te[idx_0_test])),\n",
    "            \"constraint_test\": float(loss_class1(w, X_te[idx_1_test])),\n",
    "        }\n",
    "    return results\n",
    "\n",
    "#  centralized_full\n",
    "def run_centralized(seed, betas, num_rounds=500, stepsize=0.01, epsilon=0.1):\n",
    "    print('centralized_training')\n",
    "    set_seed(seed)\n",
    "    d = X_tr.shape[1]\n",
    "    results = {}\n",
    "    if np.isscalar(stepsize):\n",
    "        eta_vec = [float(stepsize)] * len(betas)\n",
    "    else:\n",
    "        eta_vec = stepsize\n",
    "\n",
    "    for cnt, beta_val in enumerate(betas):\n",
    "        set_seed(seed)\n",
    "        w = np.random.randn(d)\n",
    "\n",
    "        f_hist, g_hist = [], []\n",
    "        f_hist.append(loss_class0(w, X_tr[y_tr == 0]))\n",
    "        g_hist.append(loss_class1(w, X_tr[y_tr == 1]))\n",
    "        comm_hist, cum_coords = [], 0\n",
    "        eta = float(eta_vec[cnt])\n",
    "        viol_count = 0\n",
    "        feas_hist = []\n",
    "        feasible = (loss_class1(w, X_tr[y_tr == 1]) <= epsilon)\n",
    "        feas_hist.append(1.0 if feasible else 0.0)\n",
    "        viol_count += (not feasible)\n",
    "\n",
    "        for t in range(num_rounds):\n",
    "            gf = grad_class0(w, X_tr[y_tr == 0])\n",
    "            gg = grad_class1(w, X_tr[y_tr == 1])\n",
    "            g_val = loss_class1(w, X_tr[y_tr == 1])\n",
    "            feasible = (g_val <= epsilon)\n",
    "            viol_count += (not feasible)\n",
    "            if beta_val is None:\n",
    "                nu = gf if feasible else gg\n",
    "            else:\n",
    "                sigma_t = float(sigma_beta(g_val - epsilon, beta_val))\n",
    "                nu = (1.0 - sigma_t) * gf + sigma_t * gg\n",
    "\n",
    "            w = w - eta * nu\n",
    "\n",
    "            f_t = loss_class0(w, X_tr[y_tr == 0])\n",
    "            g_t = loss_class1(w, X_tr[y_tr == 1])\n",
    "            f_hist.append(f_t); g_hist.append(g_t)\n",
    "            feas_hist.append(1.0 if feasible else 0.0)\n",
    "\n",
    "            cum_coords += 2 * d \n",
    "            comm_hist.append(cum_coords)\n",
    "\n",
    "        acc_te, acc_0, acc_1 = test_accuracy_all(w, X_te, y_te)\n",
    "        label = \"Hard\" if beta_val is None else f\"β={beta_val}\"\n",
    "        results[label] = {\n",
    "            \"f_hist\": np.array(f_hist),\n",
    "            \"g_hist\": np.array(g_hist),\n",
    "            \"feas_hist\": np.array(feas_hist),\n",
    "            \"violations\": int(viol_count),\n",
    "            \"test_acc\": float(acc_te),\n",
    "            \"comm_hist\": np.array(comm_hist),\n",
    "            \"test_acc_0\": float(acc_0),\n",
    "            \"test_acc_1\": float(acc_1),\n",
    "            \"objective_test\": float(loss_class0(w, X_te[y_te == 0])),\n",
    "            \"constraint_test\": float(loss_class1(w, X_te[y_te == 1])),\n",
    "        }\n",
    "    return results\n",
    "\n",
    "# multiple seed runs\n",
    "seeds = [123]                  \n",
    "betas = [None]            \n",
    "num_rounds = 500\n",
    "local_epochs = 5\n",
    "stepsize = [0.01]\n",
    "epsilon = 0.1\n",
    "participation_rate = 1.0\n",
    "compressor = \"topk\"                \n",
    "k_frac = 1.0\n",
    "comp_on = False\n",
    "downlink_ef = False\n",
    "n_clients=10\n",
    "\n",
    "#federated run\n",
    "all_trials = []\n",
    "for s in seeds:\n",
    "    set_seed(s)\n",
    "    trial_res = run_trial(\n",
    "        seed=s,\n",
    "        betas=betas,\n",
    "        n_clients=n_clients,\n",
    "        num_rounds=num_rounds,\n",
    "        local_epochs=local_epochs,\n",
    "        stepsize=stepsize,\n",
    "        epsilon=epsilon,\n",
    "        participation_rate=participation_rate,\n",
    "        compressor=compressor,\n",
    "        k_frac=k_frac,\n",
    "        comp_on=comp_on,\n",
    "        downlink_ef=downlink_ef,\n",
    "    )\n",
    "    all_trials.append(trial_res)\n",
    "\n",
    "# centralized run\n",
    "central_trials = []\n",
    "centralized_stepsize = 0.3\n",
    "for s in seeds:\n",
    "    cent_res = run_centralized(\n",
    "        seed=s,\n",
    "        betas=betas,\n",
    "        num_rounds=num_rounds*local_epochs,\n",
    "        stepsize=centralized_stepsize,\n",
    "        epsilon=epsilon\n",
    "    )\n",
    "    central_trials.append(cent_res)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
