{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ee32f18c-eda5-4a43-ad00-4d5854fa50ef",
   "metadata": {},
   "source": [
    "# Boston Housing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "321f40cc-1561-48a3-8550-94fdb3a37157",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import math, random\n",
    "import matplotlib.pyplot as plt\n",
    "import torch, torch.nn as nn\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from copy import deepcopy\n",
    "import numpy as np\n",
    "import csv\n",
    "import os\n",
    "\n",
    "\n",
    "def set_global_seed(seed: int):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "relu = torch.relu\n",
    "def pos(x): return relu(x)\n",
    "def neg(x): return relu(-x)\n",
    "\n",
    "\n",
    "class DeepReLU_Scalar(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_sizes, ws=0.05, bs=0.05, us=None):\n",
    "        super().__init__()\n",
    "        dims = [input_dim] + list(hidden_sizes)\n",
    "        self.L = len(hidden_sizes)\n",
    "        self.layers = nn.ModuleList([nn.Linear(dims[i], dims[i+1], bias=True) for i in range(self.L)])\n",
    "        for lin in self.layers:\n",
    "            nn.init.normal_(lin.weight, 0, ws)\n",
    "            nn.init.normal_(lin.bias,   0, bs)\n",
    "        last_h = dims[-1] if self.L > 0 else input_dim\n",
    "        self.u = nn.Parameter(torch.randn(last_h) * (ws if us is None else us))  \n",
    "        self.beta = nn.Parameter(torch.randn(1) * (ws if us is None else us))\n",
    "    def forward(self, x):\n",
    "        for lin in self.layers: x = relu(lin(x))\n",
    "        return x @ self.u + self.beta \n",
    "    def activations_until(self, x, k):  \n",
    "        a = x\n",
    "        for i in range(k): a = relu(self.layers[i](a))\n",
    "        return a  \n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def z_pair_from_aprev(model: DeepReLU_Scalar, a_prev, l):\n",
    "    lin = model.layers[l-1]\n",
    "    s_l = a_prev @ lin.weight.T + lin.bias    \n",
    "    Zp, Zm = relu(s_l), torch.zeros_like(s_l)\n",
    "    for k in range(l, model.L):\n",
    "        W, b = model.layers[k].weight, model.layers[k].bias\n",
    "        p = Zp @ pos(W).T + Zm @ neg(W).T + pos(b)\n",
    "        q = Zm @ pos(W).T + Zp @ neg(W).T + neg(b)\n",
    "        Zp, Zm = torch.maximum(p, q), q\n",
    "    return Zp, Zm, s_l\n",
    "\n",
    "@torch.no_grad()\n",
    "def AB_Wb(model: DeepReLU_Scalar, x, l, a_prev_cached=None):\n",
    "    a_prev = x if l == 1 else (a_prev_cached if a_prev_cached is not None else model.activations_until(x, l-1))\n",
    "    Zp, Zm, s_l = z_pair_from_aprev(model, a_prev, l)     \n",
    "    up, um = pos(model.u), neg(model.u)                   \n",
    "    betap, betam = pos(model.beta), neg(model.beta)\n",
    "    A = (Zp * up).sum(dim=1) + (Zm * um).sum(dim=1) + betap      \n",
    "    B = (Zm * up).sum(dim=1) + (Zp * um).sum(dim=1) + betam      \n",
    "    return A, B, (a_prev, s_l)\n",
    "\n",
    "@torch.no_grad()\n",
    "def AB_u(model: DeepReLU_Scalar, x):\n",
    "    aL = model.activations_until(x, model.L)            \n",
    "    up, um , betap, betam = pos(model.u), neg(model.u), pos(model.beta) , neg(model.beta)\n",
    "    A = (aL * up).sum(dim=1) + betap                            \n",
    "    B = (aL * um).sum(dim=1) + betam                             \n",
    "    return A, B, aL\n",
    "\n",
    "def AB(model, x, block, a_prev_cached=None):\n",
    "    if block[0] == 'u':\n",
    "        A, B, aL = AB_u(model, x);   return A, B, {'aL': aL}\n",
    "    l = block[1]\n",
    "    A, B, (a_prev, s_l) = AB_Wb(model, x, l, a_prev_cached)\n",
    "    return A, B, {'a_prev': a_prev, 's_l': s_l, 'l': l}\n",
    "\n",
    "\n",
    "def dAB_for(which, A, B, y, reduction='mean'):\n",
    "\n",
    "    c = pos(-y)\n",
    "    if which == 'g':\n",
    "        S = A + B + y + c\n",
    "        dA = 2 * S\n",
    "        dB = 2 * S + 2 * c\n",
    "    elif which == 'f':\n",
    "        dA = 4 * A + 2 * c\n",
    "        dB = 4 * (B + y + c)\n",
    "    else:\n",
    "        raise ValueError(\"which must be 'f' or 'g'\")\n",
    "    if reduction == 'mean':\n",
    "        m = A.numel()\n",
    "        dA = dA / m\n",
    "        dB = dB / m\n",
    "    return dA, dB\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def manual_grad_block(model: DeepReLU_Scalar, x, block, dA, dB, a_prev_cached=None):\n",
    "\n",
    "    if block[0] == 'u':\n",
    "        _, _, aL = AB_u(model, x)                 \n",
    "        mask_pos_u = (model.u > 0).to(aL.dtype)   \n",
    "        mask_neg_u = (model.u < 0).to(aL.dtype)   \n",
    "        grad_up = aL.T @ dA                       \n",
    "        grad_um = aL.T @ dB                       \n",
    "        grad_u  = grad_up * mask_pos_u + (-grad_um) * mask_neg_u\n",
    "\n",
    "        mask_pos_b = (model.beta > 0).to(dA.dtype)  \n",
    "        mask_neg_b = (model.beta < 0).to(dA.dtype)\n",
    "        grad_betap = dA.sum()\n",
    "        grad_betam = dB.sum()\n",
    "        grad_beta  = grad_betap * mask_pos_b + (-grad_betam) * mask_neg_b  # scalar\n",
    "\n",
    "        return [grad_u, grad_beta]\n",
    "\n",
    "    l = block[1]\n",
    "    a_prev = x if l == 1 else (a_prev_cached if a_prev_cached is not None else model.activations_until(x, l-1))\n",
    "    lin = model.layers[l-1]\n",
    "    s_l = a_prev @ lin.weight.T + lin.bias       \n",
    "    Zp = relu(s_l); Zm = torch.zeros_like(Zp)\n",
    "\n",
    "    masks, Wpos, Wneg = [], [], []\n",
    "    for k in range(l, model.L):\n",
    "        Wk1, bk1 = model.layers[k].weight, model.layers[k].bias\n",
    "        Wp, Wm = pos(Wk1), neg(Wk1)\n",
    "        p = Zp @ Wp.T + Zm @ Wm.T + pos(bk1)\n",
    "        q = Zm @ Wp.T + Zp @ Wm.T + neg(bk1)\n",
    "        M = (p >= q).to(Zp.dtype)\n",
    "        Zp, Zm = torch.maximum(p, q), q\n",
    "        masks.append(M); Wpos.append(Wp); Wneg.append(Wm)\n",
    "\n",
    "    up, um = pos(model.u), neg(model.u)           # (n_L,)\n",
    "    dZp = dA.unsqueeze(1) * up.unsqueeze(0) + dB.unsqueeze(1) * um.unsqueeze(0)  # (B, n_L)\n",
    "    dZm = dA.unsqueeze(1) * um.unsqueeze(0) + dB.unsqueeze(1) * up.unsqueeze(0)\n",
    "\n",
    "    for idx in reversed(range(len(masks))):\n",
    "        M, Wp, Wm = masks[idx], Wpos[idx], Wneg[idx]\n",
    "        dp = M * dZp\n",
    "        dq = dZm + (1 - M) * dZp\n",
    "        dZp, dZm = dp @ Wp + dq @ Wm, dp @ Wm + dq @ Wp   \n",
    "\n",
    "\n",
    "    ds_l = (s_l > 0).to(s_l.dtype) * dZp                 \n",
    "    gradW = ds_l.T @ a_prev                               \n",
    "    gradb = ds_l.sum(dim=0)                             \n",
    "    return [gradW, gradb]\n",
    "\n",
    "\n",
    "def blk_params(model, block):\n",
    "    if block[0] == 'u':\n",
    "        return [model.u, model.beta]  # include beta so it learns\n",
    "    lin = model.layers[block[1]-1]\n",
    "    return [lin.weight, lin.bias]\n",
    "\n",
    "@torch.no_grad()\n",
    "def prefix_aprev(model, x, l):\n",
    "    return x if l == 1 else model.activations_until(x, l-1)\n",
    "\n",
    "def dca_step_batch(model, xb, yb, block, inner=3, lr=5e-3, dev='cpu', grad_mode='manual'):\n",
    "\n",
    "    model.train()\n",
    "    params = blk_params(model, block)\n",
    "    xb = xb.to(dev, non_blocking=True).view(xb.size(0), -1)  \n",
    "    yb = yb.to(dev, non_blocking=True).view(-1)               \n",
    "\n",
    "    a_prev_cached = prefix_aprev(model, xb, block[1]) if block[0] == 'Wb' else None\n",
    "\n",
    "    with torch.no_grad():\n",
    "        A, B, _ = AB(model, xb, block, a_prev_cached)       \n",
    "    dA_g, dB_g = dAB_for('g', A, B, yb, reduction='mean')   \n",
    "    v_list = manual_grad_block(model, xb, block, dA_g, dB_g, a_prev_cached)\n",
    "\n",
    "    hist = []\n",
    "    for _ in range(inner):\n",
    "        with torch.no_grad():\n",
    "            A, B, _ = AB(model, xb, block, a_prev_cached)\n",
    "            dA_f, dB_f = dAB_for('f', A, B, yb, reduction='mean')\n",
    "            grad_list = manual_grad_block(model, xb, block, dA_f, dB_f, a_prev_cached)\n",
    "            c = pos(-yb)\n",
    "            f_val = (2*(A*A) + 2*(B + yb + c)**2 + 2*c*A).mean()\n",
    "            lin = sum((p * v).sum() for p, v in zip(params, v_list))\n",
    "            hist.append((f_val - lin).item())\n",
    "            # gradient step on f - <v,θ>\n",
    "            for p, g, v in zip(params, grad_list, v_list):\n",
    "                p -= lr * (g - v)\n",
    "    return hist\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def mse_on_loader(model, loader, dev='cpu'):\n",
    "    model.eval(); se=0.0; n=0\n",
    "    for xb, yb in loader:\n",
    "        xb=xb.to(dev, non_blocking=True).view(xb.size(0), -1)\n",
    "        yb=yb.to(dev, non_blocking=True).view(-1)\n",
    "        pred = model(xb)\n",
    "        se += ((pred - yb)**2).sum().item(); n += yb.numel()\n",
    "    return se / max(1, n)\n",
    "\n",
    "\n",
    "def train_rand_blocks(model, tr_loader, te_loader, epochs=200, inner=3, lr=5e-3, dev='cpu'):\n",
    "    model.to(dev)\n",
    "    blocks = [('Wb', l) for l in range(1, model.L+1)] + [('u', None)]\n",
    "    train_mse, test_mse, inner_curves = [], [], []\n",
    "    for ep in range(epochs):\n",
    "        sums = [0.0]*inner; nb=0\n",
    "        for xb, yb in tr_loader:\n",
    "            b = random.choice(blocks)\n",
    "            hist = dca_step_batch(model, xb, yb, b, inner=inner, lr=lr, dev=dev, grad_mode='manual')\n",
    "            for k in range(inner): sums[k] += hist[k]\n",
    "            nb += 1\n",
    "        avg_inner = [s/nb for s in sums]\n",
    "        tr_mse = mse_on_loader(model, tr_loader, dev)\n",
    "        te_mse = mse_on_loader(model, te_loader, dev)\n",
    "        train_mse.append(tr_mse); test_mse.append(te_mse); inner_curves.append(avg_inner)\n",
    "    return train_mse, test_mse, inner_curves\n",
    "\n",
    "\n",
    "def train_sgd(model, tr_loader, te_loader, epochs=200, lr=1e-2, weight_decay=0.0, dev='cpu'):\n",
    "    model.to(dev)\n",
    "    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)\n",
    "    loss_fn = nn.MSELoss(reduction='mean')\n",
    "    train_mse, test_mse = [], []\n",
    "    for _ in range(epochs):\n",
    "        model.train()\n",
    "        for xb, yb in tr_loader:\n",
    "            xb = xb.to(dev, non_blocking=True).view(xb.size(0), -1)\n",
    "            yb = yb.to(dev, non_blocking=True).view(-1)\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            pred = model(xb)\n",
    "            loss = loss_fn(pred, yb)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "        train_mse.append(mse_on_loader(model, tr_loader, dev))\n",
    "        test_mse.append(mse_on_loader(model, te_loader, dev))\n",
    "    return train_mse, test_mse\n",
    "\n",
    "\n",
    "def load_boston_tensors(device='cpu'):\n",
    "\n",
    "    X = y = None\n",
    "    try:\n",
    "        from sklearn.datasets import load_boston  # deprecated in newer sklearn\n",
    "        data = load_boston()\n",
    "        X, y = data['data'].astype('float32'), data['target'].astype('float32')\n",
    "    except Exception:\n",
    "        try:\n",
    "            from sklearn.datasets import fetch_openml\n",
    "            data = fetch_openml(name='boston', version=1, as_frame=False)\n",
    "            X = data['data'].astype('float32'); y = data['target'].astype('float32')\n",
    "        except Exception as e:\n",
    "            raise RuntimeError(\"Could not load Boston Housing dataset. \"\n",
    "                               \"Please install scikit-learn<=1.1 or enable OpenML.\") from e\n",
    "    # standardize features and target\n",
    "    X = (X - X.mean(0, keepdims=True)) / (X.std(0, keepdims=True) + 1e-8)\n",
    "    y = (y - y.mean()) / (y.std() + 1e-8)\n",
    "    X = torch.from_numpy(X).to(device)\n",
    "    y = torch.from_numpy(y).to(device)\n",
    "    return X, y\n",
    "\n",
    "def make_loaders(batch_size=64, device='cpu'):\n",
    "    X, y = load_boston_tensors(device=device)\n",
    "    N = X.shape[0]\n",
    "    idx = torch.randperm(N)\n",
    "    n_tr = int(0.8 * N)\n",
    "    tr_idx, te_idx = idx[:n_tr], idx[n_tr:]\n",
    "    Xtr, ytr = X[tr_idx], y[tr_idx]\n",
    "    Xte, yte = X[te_idx], y[te_idx]\n",
    "    tr_ds = TensorDataset(Xtr, ytr)\n",
    "    te_ds = TensorDataset(Xte, yte)\n",
    "    tr_loader = DataLoader(tr_ds, batch_size=batch_size, shuffle=True,  num_workers=0, pin_memory=True)\n",
    "    te_loader = DataLoader(te_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)\n",
    "    return tr_loader, te_loader, X.shape[1]\n",
    "\n",
    "\n",
    "def run_monte_carlo(\n",
    "    MC=10,\n",
    "    hidden_sizes=[64,32,16],\n",
    "    epochs=200,\n",
    "    inner=20,\n",
    "    dca_lr=5e-4,\n",
    "    sgd_lr=1e-2,\n",
    "    batch_size=20,\n",
    "    ws=0.1, bs=0.1, us=0.1,\n",
    "    dev=None,\n",
    "    out_dir='boston_mc_outputs'\n",
    "):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    device = dev or ('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "    dca_tr_all, dca_te_all = [], []\n",
    "    sgd_tr_all, sgd_te_all = [], []\n",
    "\n",
    "    dca_best_list, dca_best_ep = [], []\n",
    "    sgd_best_list, sgd_best_ep = [], []\n",
    "\n",
    "    for mc in range(MC):\n",
    "        set_global_seed(123 + mc)\n",
    "\n",
    "        tr_loader, te_loader, d = make_loaders(batch_size=batch_size, device=device)\n",
    "\n",
    "        init_model = DeepReLU_Scalar(input_dim=d, hidden_sizes=hidden_sizes, ws=ws, bs=bs, us=us)\n",
    "        init_state = deepcopy(init_model.state_dict())\n",
    "\n",
    "        # ---- DCA ----\n",
    "        dca_model = DeepReLU_Scalar(input_dim=d, hidden_sizes=hidden_sizes, ws=ws, bs=bs, us=us)\n",
    "        dca_model.load_state_dict(deepcopy(init_state))\n",
    "        dca_tr, dca_te, _ = train_rand_blocks(dca_model, tr_loader, te_loader, epochs=epochs, inner=inner, lr=dca_lr, dev=device)\n",
    "\n",
    "        # ---- SGD ----\n",
    "        sgd_model = DeepReLU_Scalar(input_dim=d, hidden_sizes=hidden_sizes, ws=ws, bs=bs, us=us)\n",
    "        sgd_model.load_state_dict(deepcopy(init_state))\n",
    "        sgd_tr, sgd_te = train_sgd(sgd_model, tr_loader, te_loader, epochs=epochs, lr=sgd_lr, weight_decay=0.0, dev=device)\n",
    "\n",
    "        # store curves\n",
    "        dca_tr_all.append(dca_tr); dca_te_all.append(dca_te)\n",
    "        sgd_tr_all.append(sgd_tr); sgd_te_all.append(sgd_te)\n",
    "\n",
    "        # best test per trial\n",
    "        dca_best = float(np.min(dca_te)); dca_ep = int(np.argmin(dca_te)) + 1\n",
    "        sgd_best = float(np.min(sgd_te)); sgd_ep = int(np.argmin(sgd_te)) + 1\n",
    "        dca_best_list.append(dca_best); dca_best_ep.append(dca_ep)\n",
    "        sgd_best_list.append(sgd_best); sgd_best_ep.append(sgd_ep)\n",
    "\n",
    "        print(f\"[MC {mc+1}/{MC}] DCA best test {dca_best:.4f} @ epoch {dca_ep} | SGD best test {sgd_best:.4f} @ epoch {sgd_ep}\")\n",
    "\n",
    "\n",
    "    dca_tr_all = np.array(dca_tr_all)  \n",
    "    dca_te_all = np.array(dca_te_all)\n",
    "    sgd_tr_all = np.array(sgd_tr_all)\n",
    "    sgd_te_all = np.array(sgd_te_all)\n",
    "\n",
    "    dca_tr_mean, dca_tr_std = dca_tr_all.mean(axis=0), dca_tr_all.std(axis=0)\n",
    "    dca_te_mean, dca_te_std = dca_te_all.mean(axis=0), dca_te_all.std(axis=0)\n",
    "    sgd_tr_mean, sgd_tr_std = sgd_tr_all.mean(axis=0), sgd_tr_all.std(axis=0)\n",
    "    sgd_te_mean, sgd_te_std = sgd_te_all.mean(axis=0), sgd_te_all.std(axis=0)\n",
    "\n",
    "    csv_path = os.path.join(out_dir, 'boston_mc_epoch_stats.csv')\n",
    "    with open(csv_path, 'w', newline='') as f:\n",
    "        w = csv.writer(f)\n",
    "        w.writerow(['epoch',\n",
    "                    'dca_train_mean','dca_train_std','dca_test_mean','dca_test_std',\n",
    "                    'sgd_train_mean','sgd_train_std','sgd_test_mean','sgd_test_std'])\n",
    "        for e in range(1, dca_tr_mean.shape[0]+1):\n",
    "            w.writerow([e,\n",
    "                        dca_tr_mean[e-1], dca_tr_std[e-1], dca_te_mean[e-1], dca_te_std[e-1],\n",
    "                        sgd_tr_mean[e-1], sgd_tr_std[e-1], sgd_te_mean[e-1], sgd_te_std[e-1]])\n",
    "    print(f\"Saved epoch stats: {csv_path}\")\n",
    "\n",
    "    best_csv = os.path.join(out_dir, 'boston_mc_best_per_trial.csv')\n",
    "    with open(best_csv, 'w', newline='') as f:\n",
    "        w = csv.writer(f)\n",
    "        w.writerow(['trial','dca_best_test','dca_best_epoch','sgd_best_test','sgd_best_epoch'])\n",
    "        for i in range(MC):\n",
    "            w.writerow([i+1, dca_best_list[i], dca_best_ep[i], sgd_best_list[i], sgd_best_ep[i]])\n",
    "    print(f\"Saved best-per-trial stats: {best_csv}\")\n",
    "\n",
    "    epochs_axis = np.arange(1, dca_tr_mean.shape[0]+1)\n",
    "    plt.figure()\n",
    "    plt.plot(epochs_axis, dca_tr_mean, label='DCA train (mean)')\n",
    "    plt.fill_between(epochs_axis, dca_tr_mean - dca_tr_std, dca_tr_mean + dca_tr_std, alpha=0.2)\n",
    "    plt.plot(epochs_axis, sgd_tr_mean, label='SGD train (mean)')\n",
    "    plt.fill_between(epochs_axis, sgd_tr_mean - sgd_tr_std, sgd_tr_mean + sgd_tr_std, alpha=0.2)\n",
    "    plt.xlabel('epoch'); plt.ylabel('MSE'); plt.title('Train MSE (20 Monte Carlo)'); plt.legend(); plt.tight_layout()\n",
    "    train_png = os.path.join(out_dir, 'train_mse_mc_mean_std.png')\n",
    "    plt.savefig(train_png)\n",
    "\n",
    "    plt.figure()\n",
    "    plt.plot(epochs_axis, dca_te_mean, label='DCA test (mean)')\n",
    "    plt.fill_between(epochs_axis, dca_te_mean - dca_te_std, dca_te_mean + dca_te_std, alpha=0.2)\n",
    "    plt.plot(epochs_axis, sgd_te_mean, label='SGD test (mean)')\n",
    "    plt.fill_between(epochs_axis, sgd_te_mean - sgd_te_std, sgd_te_mean + sgd_te_std, alpha=0.2)\n",
    "    plt.xlabel('epoch'); plt.ylabel('MSE'); plt.title('Test MSE (20 Monte Carlo)'); plt.legend(); plt.tight_layout()\n",
    "    test_png = os.path.join(out_dir, 'test_mse_mc_mean_std.png')\n",
    "    plt.savefig(test_png)\n",
    "\n",
    "    print(f\"DCA best test (per-trial): mean={np.mean(dca_best_list):.4f} ± {np.std(dca_best_list):.4f}\")\n",
    "    print(f\"SGD best test (per-trial): mean={np.mean(sgd_best_list):.4f} ± {np.std(sgd_best_list):.4f}\")\n",
    "    print(f\"Saved plots:\\n  {train_png}\\n  {test_png}\")\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    MC = 1\n",
    "    hidden_sizes = [64, 32, 16]\n",
    "    epochs = 60\n",
    "    inner = 50\n",
    "    dca_lr = 1e-3\n",
    "    sgd_lr = 1e-2\n",
    "    batch_size = 20\n",
    "    ws, bs, us = 0.1, 0.1, 0.1\n",
    "\n",
    "    run_monte_carlo(\n",
    "        MC=MC,\n",
    "        hidden_sizes=hidden_sizes,\n",
    "        epochs=epochs,\n",
    "        inner=inner,\n",
    "        dca_lr=dca_lr,\n",
    "        sgd_lr=sgd_lr,\n",
    "        batch_size=batch_size,\n",
    "        ws=ws, bs=bs, us=us,\n",
    "        dev='cpu',\n",
    "        out_dir='boston_mc_outputs'\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02aa72d9-a771-4a4a-bacf-58a7fd44eddd",
   "metadata": {},
   "source": [
    "# Cifar10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "203839a9-e5b3-41b4-a28b-d896a2e8f7d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, math, random, csv\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "\n",
    "def set_global_seed(seed: int):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "relu = torch.relu\n",
    "def pos(x): return relu(x)\n",
    "def neg(x): return relu(-x)\n",
    "\n",
    "\n",
    "class DeepReLU_Classifier(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_sizes, num_classes, ws=0.05, bs=0.05, us=None):\n",
    "        super().__init__()\n",
    "        dims = [input_dim] + list(hidden_sizes)\n",
    "        self.L = len(hidden_sizes)\n",
    "        self.layers = nn.ModuleList([nn.Linear(dims[i], dims[i+1], bias=True) for i in range(self.L)])\n",
    "        for lin in self.layers:\n",
    "            nn.init.normal_(lin.weight, 0, ws)\n",
    "            nn.init.normal_(lin.bias,   0, bs)\n",
    "        last_h = dims[-1] if self.L > 0 else input_dim\n",
    "        self.U = nn.Parameter(torch.randn(last_h, num_classes) * (ws if us is None else us))  # (h,C)\n",
    "        self.beta = nn.Parameter(torch.randn(num_classes) * (ws if us is None else us))       # (C,)\n",
    "    def forward(self, x):\n",
    "        for lin in self.layers:\n",
    "            x = relu(lin(x))\n",
    "        return x @ self.U + self.beta  \n",
    "    def activations_until(self, x, k):  \n",
    "        a = x\n",
    "        for i in range(k):\n",
    "            a = relu(self.layers[i](a))\n",
    "        return a  \n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def z_pair_from_aprev(model, a_prev, l):\n",
    "    lin = model.layers[l-1]\n",
    "    s_l = a_prev @ lin.weight.T + lin.bias    \n",
    "    Zp, Zm = relu(s_l), torch.zeros_like(s_l)\n",
    "    for k in range(l, model.L):\n",
    "        W, b = model.layers[k].weight, model.layers[k].bias\n",
    "        p = Zp @ pos(W).T + Zm @ neg(W).T + b\n",
    "        q = Zm @ pos(W).T + Zp @ neg(W).T\n",
    "        Zp, Zm = torch.maximum(p, q), q\n",
    "    return Zp, Zm, s_l\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def AB_Wb_class(model: DeepReLU_Classifier, x, l, a_prev_cached=None):\n",
    "\n",
    "    a_prev = x if l == 1 else (a_prev_cached if a_prev_cached is not None else model.activations_until(x, l-1))\n",
    "    Zp, Zm, s_l = z_pair_from_aprev(model, a_prev, l)  \n",
    "    U, beta = model.U, model.beta                      \n",
    "    Up, Um = pos(U), neg(U)                           \n",
    "    betap, betam = pos(beta), neg(beta)                \n",
    "    A = Zp @ Up + Zm @ Um + betap                      \n",
    "    B = Zm @ Up + Zp @ Um + betam                    \n",
    "    return A, B, (a_prev, s_l)\n",
    "\n",
    "@torch.no_grad()\n",
    "def AB_u_class(model: DeepReLU_Classifier, x):\n",
    "\n",
    "    aL = model.activations_until(x, model.L)           \n",
    "    U, beta = model.U, model.beta                      \n",
    "    Up, Um = pos(U), neg(U)                            \n",
    "    betap, betam = pos(beta), neg(beta)                \n",
    "    A = aL @ Up + betap                                \n",
    "    B = aL @ Um + betam                                \n",
    "    return A, B, aL\n",
    "\n",
    "def AB_class(model, x, block, a_prev_cached=None):\n",
    "\n",
    "    if block[0] == 'u':\n",
    "        A, B, aL = AB_u_class(model, x)\n",
    "        return A, B, {'aL': aL}\n",
    "    l = block[1]\n",
    "    A, B, (a_prev, s_l) = AB_Wb_class(model, x, l, a_prev_cached)\n",
    "    return A, B, {'a_prev': a_prev, 's_l': s_l, 'l': l}\n",
    "\n",
    "\n",
    "def ce_loss_bdcsplit(A, B, y):\n",
    "\n",
    "    logits = A - B                                   \n",
    "    lse = torch.logsumexp(logits, dim=1)             \n",
    "    Bsum = B.sum(dim=1)                              \n",
    "    idx = torch.arange(y.size(0), device=y.device)\n",
    "    Ay = A[idx, y]                                   \n",
    "    By = B[idx, y]                                   \n",
    "    G = lse + Bsum + By\n",
    "    H = Ay + Bsum\n",
    "    return G, H, G - H\n",
    "\n",
    "\n",
    "def dAB_for_ce(which, A, B, y, reduction='mean'):\n",
    "\n",
    "    logits = A - B\n",
    "    p = torch.softmax(logits, dim=1)                \n",
    "    Bsz, C = A.shape\n",
    "    if which == 'g':\n",
    "        dA = torch.zeros_like(A)\n",
    "        dA[torch.arange(Bsz, device=y.device), y] = 1.0\n",
    "        dB = torch.ones_like(B)\n",
    "    elif which == 'f':\n",
    "        dA = p\n",
    "        dB = -p + 1.0\n",
    "        dB[torch.arange(Bsz, device=y.device), y] += 1.0\n",
    "    else:\n",
    "        raise ValueError(\"which must be 'f' or 'g'\")\n",
    "    if reduction == 'mean':\n",
    "        m = Bsz\n",
    "        dA = dA / m\n",
    "        dB = dB / m\n",
    "    return dA, dB\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def manual_grad_block_class(model: DeepReLU_Classifier, x, block, dA, dB, a_prev_cached=None):\n",
    "\n",
    "    if block[0] == 'u':\n",
    "        _, _, aL = AB_u_class(model, x)             \n",
    "        U = model.U\n",
    "        mask_pos_U = (U > 0).to(aL.dtype)            \n",
    "        mask_neg_U = (U < 0).to(aL.dtype)\n",
    "\n",
    "        grad_Up = aL.T @ dA                          \n",
    "        grad_Um = aL.T @ dB                          \n",
    "        grad_U  = grad_Up * mask_pos_U + (-grad_Um) * mask_neg_U\n",
    "\n",
    "\n",
    "        beta = model.beta\n",
    "        mask_pos_b = (beta > 0).to(dA.dtype)         \n",
    "        mask_neg_b = (beta < 0).to(dA.dtype)\n",
    "        grad_betap = dA.sum(dim=0)                  \n",
    "        grad_betam = dB.sum(dim=0)                   \n",
    "        grad_beta  = grad_betap * mask_pos_b + (-grad_betam) * mask_neg_b\n",
    "        return [grad_U, grad_beta]\n",
    "\n",
    "\n",
    "    l = block[1]\n",
    "    a_prev = x if l == 1 else (a_prev_cached if a_prev_cached is not None else model.activations_until(x, l-1))\n",
    "    lin = model.layers[l-1]\n",
    "    s_l = a_prev @ lin.weight.T + lin.bias         \n",
    "    Zp = relu(s_l); Zm = torch.zeros_like(Zp)\n",
    "\n",
    "    masks, Wpos, Wneg = [], [], []\n",
    "    for k in range(l, model.L):\n",
    "        Wk1, bk1 = model.layers[k].weight, model.layers[k].bias\n",
    "        Wp, Wm = pos(Wk1), neg(Wk1)\n",
    "        p = Zp @ Wp.T + Zm @ Wm.T + bk1\n",
    "        q = Zm @ Wp.T + Zp @ Wm.T\n",
    "        M = (p >= q).to(Zp.dtype)\n",
    "        Zp, Zm = torch.maximum(p, q), q\n",
    "        masks.append(M); Wpos.append(Wp); Wneg.append(Wm)\n",
    "\n",
    "    U = model.U\n",
    "    Up, Um = pos(U), neg(U)          \n",
    "    dZp = dA @ Up.T + dB @ Um.T     \n",
    "    dZm = dA @ Um.T + dB @ Up.T     \n",
    "\n",
    "    for idx in reversed(range(len(masks))):\n",
    "        M, Wp, Wm = masks[idx], Wpos[idx], Wneg[idx]\n",
    "        dp = M * dZp\n",
    "        dq = dZm + (1 - M) * dZp\n",
    "        dZp, dZm = dp @ Wp + dq @ Wm, dp @ Wm + dq @ Wp  \n",
    "\n",
    "    ds_l = (s_l > 0).to(s_l.dtype) * dZp\n",
    "    gradW = ds_l.T @ a_prev\n",
    "    gradb = ds_l.sum(dim=0)\n",
    "    return [gradW, gradb]\n",
    "\n",
    "\n",
    "def blk_params_class(model, block):\n",
    "    if block[0] == 'u':\n",
    "        return [model.U, model.beta]\n",
    "    lin = model.layers[block[1]-1]\n",
    "    return [lin.weight, lin.bias]\n",
    "\n",
    "@torch.no_grad()\n",
    "def prefix_aprev(model, x, l):\n",
    "    return x if l == 1 else model.activations_until(x, l-1)\n",
    "\n",
    "def dca_step_batch_class(model, xb, yb, block, inner=3, lr=5e-3, dev='cpu'):\n",
    "\n",
    "    model.train()\n",
    "    params = blk_params_class(model, block)\n",
    "    xb = xb.to(dev, non_blocking=True).view(xb.size(0), -1)\n",
    "    yb = yb.to(dev, non_blocking=True).long()\n",
    "\n",
    "    a_prev_cached = prefix_aprev(model, xb, block[1]) if block[0] == 'Wb' else None\n",
    "\n",
    "    with torch.no_grad():\n",
    "        A, B, _ = AB_class(model, xb, block, a_prev_cached)  # (B,C)\n",
    "    dA_g, dB_g = dAB_for_ce('g', A, B, yb, reduction='mean')\n",
    "    v_list = manual_grad_block_class(model, xb, block, dA_g, dB_g, a_prev_cached)\n",
    "\n",
    "    hist = []\n",
    "    for _ in range(inner):\n",
    "        with torch.no_grad():\n",
    "            A, B, _ = AB_class(model, xb, block, a_prev_cached)\n",
    "            dA_f, dB_f = dAB_for_ce('f', A, B, yb, reduction='mean')\n",
    "            grad_list = manual_grad_block_class(model, xb, block, dA_f, dB_f, a_prev_cached)\n",
    "\n",
    "            logits = A - B\n",
    "            lse = torch.logsumexp(logits, dim=1)\n",
    "            Bsum = B.sum(dim=1)\n",
    "            By = B[torch.arange(yb.size(0), device=yb.device), yb]\n",
    "            G_val = (lse + Bsum + By).mean()\n",
    "            lin = sum((p * v).sum() for p, v in zip(params, v_list))\n",
    "            hist.append((G_val - lin).item())\n",
    "\n",
    "            # gradient step\n",
    "            for p, g, v in zip(params, grad_list, v_list):\n",
    "                p -= lr * (g - v)\n",
    "    return hist\n",
    "\n",
    "def train_rand_blocks_class(model, tr_loader, te_loader, epochs=20, inner=3, lr=5e-3, dev='cpu'):\n",
    "    model.to(dev)\n",
    "    blocks = [('Wb', l) for l in range(1, model.L+1)] + [('u', None)]\n",
    "    tr_ce, te_ce, tr_acc, te_acc, inner_curves = [], [], [], [], []\n",
    "    for ep in range(epochs):\n",
    "        sums = [0.0]*inner; nb=0\n",
    "        for xb, yb in tr_loader:\n",
    "            b = random.choice(blocks)\n",
    "            hist = dca_step_batch_class(model, xb, yb, b, inner=inner, lr=lr, dev=dev)\n",
    "            for k in range(inner): sums[k] += hist[k]\n",
    "            nb += 1\n",
    "        avg_inner = [s/nb for s in sums]\n",
    "        inner_curves.append(avg_inner)\n",
    "        tr_ce.append(ce_on_loader(model, tr_loader, dev))\n",
    "        te_ce.append(ce_on_loader(model, te_loader, dev))\n",
    "        tr_acc.append(acc_on_loader(model, tr_loader, dev))\n",
    "        te_acc.append(acc_on_loader(model, te_loader, dev))\n",
    "        print(f\"[ep {ep+1}] CE train {tr_ce[-1]:.4f} acc {tr_acc[-1]:.3f} | CE test {te_ce[-1]:.4f} acc {te_acc[-1]:.3f}\")\n",
    "    return tr_ce, te_ce, tr_acc, te_acc, inner_curves\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def ce_on_loader(model, loader, dev='cpu'):\n",
    "    model.eval(); total_loss=0.0; n=0\n",
    "    loss_fn = nn.CrossEntropyLoss(reduction='sum')\n",
    "    for xb, yb in loader:\n",
    "        xb = xb.to(dev, non_blocking=True).view(xb.size(0), -1)\n",
    "        yb = yb.to(dev, non_blocking=True).long()\n",
    "        logits = model(xb)\n",
    "        total_loss += loss_fn(logits, yb).item()\n",
    "        n += yb.numel()\n",
    "    return total_loss / max(1, n)\n",
    "\n",
    "@torch.no_grad()\n",
    "def acc_on_loader(model, loader, dev='cpu'):\n",
    "    model.eval(); correct=0; n=0\n",
    "    for xb, yb in loader:\n",
    "        xb = xb.to(dev, non_blocking=True).view(xb.size(0), -1)\n",
    "        yb = yb.to(dev, non_blocking=True).long()\n",
    "        logits = model(xb)\n",
    "        pred = logits.argmax(dim=1)\n",
    "        correct += (pred == yb).sum().item()\n",
    "        n += yb.numel()\n",
    "    return correct / max(1, n)\n",
    "\n",
    "def train_sgd_class(model, tr_loader, te_loader, epochs=20, lr=1e-2, dev='cpu'):\n",
    "    model.to(dev)\n",
    "    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    tr_ce, te_ce, tr_acc, te_acc = [], [], [], []\n",
    "    for ep in range(epochs):\n",
    "        model.train()\n",
    "        for xb, yb in tr_loader:\n",
    "            xb = xb.to(dev, non_blocking=True).view(xb.size(0), -1)\n",
    "            yb = yb.to(dev, non_blocking=True).long()\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            logits = model(xb)\n",
    "            loss = loss_fn(logits, yb)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "        tr_ce.append(ce_on_loader(model, tr_loader, dev))\n",
    "        te_ce.append(ce_on_loader(model, te_loader, dev))\n",
    "        tr_acc.append(acc_on_loader(model, tr_loader, dev))\n",
    "        te_acc.append(acc_on_loader(model, te_loader, dev))\n",
    "        print(f\"[ep {ep+1}] CE train {tr_ce[-1]:.4f} acc {tr_acc[-1]:.3f} | CE test {te_ce[-1]:.4f} acc {te_acc[-1]:.3f}\")\n",
    "    return tr_ce, te_ce, tr_acc, te_acc\n",
    "\n",
    "\n",
    "def make_loaders_class(dataset='fashion', batch_size=128, device='cpu'):\n",
    "    if dataset == 'fashion':\n",
    "        transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])\n",
    "        trainset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)\n",
    "        testset  = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)\n",
    "        input_dim, num_classes = 28*28, 10\n",
    "    elif dataset == 'cifar10':\n",
    "        transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])\n",
    "        trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "        testset  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "        input_dim, num_classes = 32*32*3, 10\n",
    "    else:\n",
    "        raise ValueError(\"dataset must be 'fashion' or 'cifar10'\")\n",
    "    tr_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True,  num_workers=2)\n",
    "    te_loader = DataLoader(testset,  batch_size=batch_size, shuffle=False, num_workers=2)\n",
    "    return tr_loader, te_loader, input_dim, num_classes\n",
    "\n",
    "\n",
    "def mean_ci_normal(x, ci=0.90, axis=0):\n",
    "\n",
    "    x = np.asarray(x)\n",
    "    n = x.shape[axis]\n",
    "    m = np.mean(x, axis=axis)\n",
    "    s = np.std(x, axis=axis, ddof=1) if n > 1 else np.zeros_like(m)\n",
    "    # z for two-sided CI:\n",
    "    # 90% -> z ≈ 1.6448536269514722\n",
    "    z = 1.6448536269514722 if abs(ci - 0.90) < 1e-12 else \\\n",
    "        float(torch.distributions.normal.Normal(0,1).icdf(torch.tensor((1+ci)/2)).item())\n",
    "    hw = z * (s / max(1, np.sqrt(n)))\n",
    "    return m, m - hw, m + hw\n",
    "\n",
    "def run_single_training(seed, dataset, trainer, hidden_sizes, epochs,\n",
    "                        batch_size, dev, lr_sgd, inner, lr_dca, ws, bs, us):\n",
    "    set_global_seed(seed)\n",
    "    tr_loader, te_loader, d, C = make_loaders_class(dataset=dataset, batch_size=batch_size, device=dev)\n",
    "    model = DeepReLU_Classifier(input_dim=d, hidden_sizes=hidden_sizes, num_classes=C, ws=ws, bs=bs, us=us).to(dev)\n",
    "\n",
    "    if trainer == 'sgd':\n",
    "        tr_ce, te_ce, tr_acc, te_acc = train_sgd_class(model, tr_loader, te_loader, epochs=epochs, lr=lr_sgd, dev=dev)\n",
    "    else:\n",
    "        tr_ce, te_ce, tr_acc, te_acc, _ = train_rand_blocks_class(model, tr_loader, te_loader, epochs=epochs, inner=inner, lr=lr_dca, dev=dev)\n",
    "\n",
    "    return np.array(tr_ce), np.array(te_ce), np.array(tr_acc), np.array(te_acc)\n",
    "\n",
    "def run_monte_carlo(mc_runs=10, base_seed=1234, out_csv='mc_results_ce_acc.csv',\n",
    "                    dataset='cifar10', trainer='dca',\n",
    "                    hidden_sizes=[512,256,128,64], epochs=50, batch_size=128,\n",
    "                    lr_sgd=1e-2, inner=100, lr_dca=1e-5,\n",
    "                    ws=0.2, bs=0.2, us=0.2, dev=None, ci=0.90):\n",
    "\n",
    "    if dev is None:\n",
    "        dev = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "    all_tr_ce = np.zeros((mc_runs, epochs), dtype=np.float64)\n",
    "    all_te_ce = np.zeros((mc_runs, epochs), dtype=np.float64)\n",
    "    all_tr_acc = np.zeros((mc_runs, epochs), dtype=np.float64)\n",
    "    all_te_acc = np.zeros((mc_runs, epochs), dtype=np.float64)\n",
    "\n",
    "    for r in range(mc_runs):\n",
    "        seed = base_seed + r\n",
    "        print(f\"\\n=== Monte Carlo run {r+1}/{mc_runs} (seed={seed}) ===\")\n",
    "        tr_ce, te_ce, tr_acc, te_acc = run_single_training(\n",
    "            seed, dataset, trainer, hidden_sizes, epochs,\n",
    "            batch_size, dev, lr_sgd, inner, lr_dca, ws, bs, us\n",
    "        )\n",
    "        all_tr_ce[r] = tr_ce\n",
    "        all_te_ce[r] = te_ce\n",
    "        all_tr_acc[r] = tr_acc\n",
    "        all_te_acc[r] = te_acc\n",
    "\n",
    "\n",
    "    tr_ce_mean, tr_ce_lo, tr_ce_hi = mean_ci_normal(all_tr_ce, ci=ci, axis=0)\n",
    "    te_ce_mean, te_ce_lo, te_ce_hi = mean_ci_normal(all_te_ce, ci=ci, axis=0)\n",
    "    tr_acc_mean, tr_acc_lo, tr_acc_hi = mean_ci_normal(all_tr_acc, ci=ci, axis=0)\n",
    "    te_acc_mean, te_acc_lo, te_acc_hi = mean_ci_normal(all_te_acc, ci=ci, axis=0)\n",
    "\n",
    "\n",
    "    os.makedirs(os.path.dirname(out_csv) or '.', exist_ok=True)\n",
    "    with open(out_csv, 'w', newline='') as f:\n",
    "        w = csv.writer(f)\n",
    "        w.writerow(['epoch','split','metric','mean','ci_lower','ci_upper','n_runs',\n",
    "                    'dataset','trainer'])\n",
    "        for ep in range(epochs):\n",
    "            # CE\n",
    "            w.writerow([ep+1,'train','ce', float(tr_ce_mean[ep]), float(tr_ce_lo[ep]), float(tr_ce_hi[ep]),\n",
    "                        mc_runs, dataset, trainer])\n",
    "            w.writerow([ep+1,'test','ce',  float(te_ce_mean[ep]), float(te_ce_lo[ep]), float(te_ce_hi[ep]),\n",
    "                        mc_runs, dataset, trainer])\n",
    "            # ACC\n",
    "            w.writerow([ep+1,'train','acc', float(tr_acc_mean[ep]), float(tr_acc_lo[ep]), float(tr_acc_hi[ep]),\n",
    "                        mc_runs, dataset, trainer])\n",
    "            w.writerow([ep+1,'test','acc',  float(te_acc_mean[ep]), float(te_acc_lo[ep]), float(te_acc_hi[ep]),\n",
    "                        mc_runs, dataset, trainer])\n",
    "\n",
    "    print(f\"\\nSaved Monte Carlo summary (mean & {int(ci*100)}% CI) to: {out_csv}\")\n",
    "    return {\n",
    "        'out_csv': out_csv,\n",
    "        'tr_ce': all_tr_ce, 'te_ce': all_te_ce, 'tr_acc': all_tr_acc, 'te_acc': all_te_acc\n",
    "    }\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    dataset = 'cifar10'     # 'fashion' or 'cifar10'\n",
    "    trainer = 'sgd'         # 'sgd' or 'dca'\n",
    "\n",
    "\n",
    "    dev = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    batch_size = 128\n",
    "    hidden_sizes = [ 256, 128, 64]\n",
    "    epochs = 100\n",
    "    lr_sgd = 1e-2\n",
    "    inner = 50        # for DCA\n",
    "    lr_dca = 1e-3      # for DCA\n",
    "    ws = bs = us = 0.05\n",
    "\n",
    "\n",
    "    mc_runs = 1\n",
    "    base_seed = 1234\n",
    "    ci_level = 0.90\n",
    "    out_csv = dataset + '_mc_results_ce_acc_' + trainer + '.csv'\n",
    "\n",
    "    run_monte_carlo(mc_runs=mc_runs, base_seed=base_seed, out_csv=out_csv,\n",
    "                    dataset=dataset, trainer=trainer,\n",
    "                    hidden_sizes=hidden_sizes, epochs=epochs, batch_size=batch_size,\n",
    "                    lr_sgd=lr_sgd, inner=inner, lr_dca=lr_dca,\n",
    "                    ws=ws, bs=bs, us=us, dev=dev, ci=ci_level)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06b851f2-5725-4ad5-952e-09b420ad1cdc",
   "metadata": {},
   "source": [
    "# FashionMNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bd23f7a-51fb-4d56-bce4-b71707126cf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os, math, random, csv\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "\n",
    "def set_global_seed(seed: int):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "relu = torch.relu\n",
    "def pos(x): return relu(x)\n",
    "def neg(x): return relu(-x)\n",
    "\n",
    "\n",
    "class DeepReLU_Classifier(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_sizes, num_classes, ws=0.05, bs=0.05, us=None):\n",
    "        super().__init__()\n",
    "        dims = [input_dim] + list(hidden_sizes)\n",
    "        self.L = len(hidden_sizes)\n",
    "        self.layers = nn.ModuleList([nn.Linear(dims[i], dims[i+1], bias=True) for i in range(self.L)])\n",
    "        for lin in self.layers:\n",
    "            nn.init.normal_(lin.weight, 0, ws)\n",
    "            nn.init.normal_(lin.bias,   0, bs)\n",
    "        last_h = dims[-1] if self.L > 0 else input_dim\n",
    "        self.U = nn.Parameter(torch.randn(last_h, num_classes) * (ws if us is None else us))  \n",
    "        self.beta = nn.Parameter(torch.randn(num_classes) * (ws if us is None else us))       \n",
    "    def forward(self, x):\n",
    "        for lin in self.layers:\n",
    "            x = relu(lin(x))\n",
    "        return x @ self.U + self.beta  \n",
    "    def activations_until(self, x, k):  \n",
    "        a = x\n",
    "        for i in range(k):\n",
    "            a = relu(self.layers[i](a))\n",
    "        return a  \n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def z_pair_from_aprev(model, a_prev, l):\n",
    "\n",
    "    lin = model.layers[l-1]\n",
    "    s_l = a_prev @ lin.weight.T + lin.bias     \n",
    "    Zp, Zm = relu(s_l), torch.zeros_like(s_l)\n",
    "    for k in range(l, model.L):\n",
    "        W, b = model.layers[k].weight, model.layers[k].bias\n",
    "        p = Zp @ pos(W).T + Zm @ neg(W).T + b\n",
    "        q = Zm @ pos(W).T + Zp @ neg(W).T\n",
    "        Zp, Zm = torch.maximum(p, q), q\n",
    "    return Zp, Zm, s_l\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def AB_Wb_class(model: DeepReLU_Classifier, x, l, a_prev_cached=None):\n",
    "\n",
    "    a_prev = x if l == 1 else (a_prev_cached if a_prev_cached is not None else model.activations_until(x, l-1))\n",
    "    Zp, Zm, s_l = z_pair_from_aprev(model, a_prev, l)  \n",
    "    U, beta = model.U, model.beta                      \n",
    "    Up, Um = pos(U), neg(U)                            \n",
    "    betap, betam = pos(beta), neg(beta)                \n",
    "    A = Zp @ Up + Zm @ Um + betap                      \n",
    "    B = Zm @ Up + Zp @ Um + betam                      \n",
    "    return A, B, (a_prev, s_l)\n",
    "\n",
    "@torch.no_grad()\n",
    "def AB_u_class(model: DeepReLU_Classifier, x):\n",
    "\n",
    "    aL = model.activations_until(x, model.L)           \n",
    "    U, beta = model.U, model.beta                     \n",
    "    Up, Um = pos(U), neg(U)                            \n",
    "    betap, betam = pos(beta), neg(beta)                \n",
    "    A = aL @ Up + betap                               \n",
    "    B = aL @ Um + betam                                \n",
    "    return A, B, aL\n",
    "\n",
    "def AB_class(model, x, block, a_prev_cached=None):\n",
    "    \n",
    "    if block[0] == 'u':\n",
    "        A, B, aL = AB_u_class(model, x)\n",
    "        return A, B, {'aL': aL}\n",
    "    l = block[1]\n",
    "    A, B, (a_prev, s_l) = AB_Wb_class(model, x, l, a_prev_cached)\n",
    "    return A, B, {'a_prev': a_prev, 's_l': s_l, 'l': l}\n",
    "\n",
    "\n",
    "def ce_loss_bdcsplit(A, B, y):\n",
    "\n",
    "    logits = A - B                                 \n",
    "    lse = torch.logsumexp(logits, dim=1)             \n",
    "    Bsum = B.sum(dim=1)                             \n",
    "    idx = torch.arange(y.size(0), device=y.device)\n",
    "    Ay = A[idx, y]                                   \n",
    "    By = B[idx, y]                                 \n",
    "    G = lse + Bsum + By\n",
    "    H = Ay + Bsum\n",
    "    return G, H, G - H\n",
    "\n",
    "\n",
    "def dAB_for_ce(which, A, B, y, reduction='mean'):\n",
    "\n",
    "    logits = A - B\n",
    "    p = torch.softmax(logits, dim=1)                 \n",
    "    Bsz, C = A.shape\n",
    "    if which == 'g':\n",
    "        dA = torch.zeros_like(A)\n",
    "        dA[torch.arange(Bsz, device=y.device), y] = 1.0\n",
    "        dB = torch.ones_like(B)\n",
    "    elif which == 'f':\n",
    "        dA = p\n",
    "        dB = -p + 1.0\n",
    "        dB[torch.arange(Bsz, device=y.device), y] += 1.0\n",
    "    else:\n",
    "        raise ValueError(\"which must be 'f' or 'g'\")\n",
    "    if reduction == 'mean':\n",
    "        m = Bsz\n",
    "        dA = dA / m\n",
    "        dB = dB / m\n",
    "    return dA, dB\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def manual_grad_block_class(model: DeepReLU_Classifier, x, block, dA, dB, a_prev_cached=None):\n",
    "\n",
    "    if block[0] == 'u':\n",
    "        _, _, aL = AB_u_class(model, x)              \n",
    "        U = model.U\n",
    "        mask_pos_U = (U > 0).to(aL.dtype)            \n",
    "        mask_neg_U = (U < 0).to(aL.dtype)\n",
    "\n",
    "\n",
    "        grad_Up = aL.T @ dA                          \n",
    "        grad_Um = aL.T @ dB                         \n",
    "        grad_U  = grad_Up * mask_pos_U + (-grad_Um) * mask_neg_U\n",
    "\n",
    "\n",
    "        beta = model.beta\n",
    "        mask_pos_b = (beta > 0).to(dA.dtype)         \n",
    "        mask_neg_b = (beta < 0).to(dA.dtype)\n",
    "        grad_betap = dA.sum(dim=0)                   \n",
    "        grad_betam = dB.sum(dim=0)                   \n",
    "        grad_beta  = grad_betap * mask_pos_b + (-grad_betam) * mask_neg_b\n",
    "        return [grad_U, grad_beta]\n",
    "\n",
    "\n",
    "    l = block[1]\n",
    "    a_prev = x if l == 1 else (a_prev_cached if a_prev_cached is not None else model.activations_until(x, l-1))\n",
    "    lin = model.layers[l-1]\n",
    "    s_l = a_prev @ lin.weight.T + lin.bias          \n",
    "    Zp = relu(s_l); Zm = torch.zeros_like(Zp)\n",
    "\n",
    "    masks, Wpos, Wneg = [], [], []\n",
    "    for k in range(l, model.L):\n",
    "        Wk1, bk1 = model.layers[k].weight, model.layers[k].bias\n",
    "        Wp, Wm = pos(Wk1), neg(Wk1)\n",
    "        p = Zp @ Wp.T + Zm @ Wm.T + bk1\n",
    "        q = Zm @ Wp.T + Zp @ Wm.T\n",
    "        M = (p >= q).to(Zp.dtype)\n",
    "        Zp, Zm = torch.maximum(p, q), q\n",
    "        masks.append(M); Wpos.append(Wp); Wneg.append(Wm)\n",
    "\n",
    "\n",
    "    U = model.U\n",
    "    Up, Um = pos(U), neg(U)          \n",
    "    dZp = dA @ Up.T + dB @ Um.T     \n",
    "    dZm = dA @ Um.T + dB @ Up.T     \n",
    "\n",
    "\n",
    "    for idx in reversed(range(len(masks))):\n",
    "        M, Wp, Wm = masks[idx], Wpos[idx], Wneg[idx]\n",
    "        dp = M * dZp\n",
    "        dq = dZm + (1 - M) * dZp\n",
    "        dZp, dZm = dp @ Wp + dq @ Wm, dp @ Wm + dq @ Wp   \n",
    "\n",
    "\n",
    "    ds_l = (s_l > 0).to(s_l.dtype) * dZp\n",
    "    gradW = ds_l.T @ a_prev\n",
    "    gradb = ds_l.sum(dim=0)\n",
    "    return [gradW, gradb]\n",
    "\n",
    "\n",
    "def blk_params_class(model, block):\n",
    "    if block[0] == 'u':\n",
    "        return [model.U, model.beta]\n",
    "    lin = model.layers[block[1]-1]\n",
    "    return [lin.weight, lin.bias]\n",
    "\n",
    "@torch.no_grad()\n",
    "def prefix_aprev(model, x, l):\n",
    "    return x if l == 1 else model.activations_until(x, l-1)\n",
    "\n",
    "def dca_step_batch_class(model, xb, yb, block, inner=3, lr=5e-3, dev='cpu'):\n",
    "\n",
    "    model.train()\n",
    "    params = blk_params_class(model, block)\n",
    "    xb = xb.to(dev, non_blocking=True).view(xb.size(0), -1)\n",
    "    yb = yb.to(dev, non_blocking=True).long()\n",
    "\n",
    "    a_prev_cached = prefix_aprev(model, xb, block[1]) if block[0] == 'Wb' else None\n",
    "\n",
    "\n",
    "    with torch.no_grad():\n",
    "        A, B, _ = AB_class(model, xb, block, a_prev_cached)  # (B,C)\n",
    "    dA_g, dB_g = dAB_for_ce('g', A, B, yb, reduction='mean')\n",
    "    v_list = manual_grad_block_class(model, xb, block, dA_g, dB_g, a_prev_cached)\n",
    "\n",
    "\n",
    "    hist = []\n",
    "    for _ in range(inner):\n",
    "        with torch.no_grad():\n",
    "            A, B, _ = AB_class(model, xb, block, a_prev_cached)\n",
    "            dA_f, dB_f = dAB_for_ce('f', A, B, yb, reduction='mean')\n",
    "            grad_list = manual_grad_block_class(model, xb, block, dA_f, dB_f, a_prev_cached)\n",
    "\n",
    "            # log objective value = mean(G) - <v,θ>\n",
    "            logits = A - B\n",
    "            lse = torch.logsumexp(logits, dim=1)\n",
    "            Bsum = B.sum(dim=1)\n",
    "            By = B[torch.arange(yb.size(0), device=yb.device), yb]\n",
    "            G_val = (lse + Bsum + By).mean()\n",
    "            lin = sum((p * v).sum() for p, v in zip(params, v_list))\n",
    "            hist.append((G_val - lin).item())\n",
    "\n",
    "            # gradient step\n",
    "            for p, g, v in zip(params, grad_list, v_list):\n",
    "                p -= lr * (g - v)\n",
    "    return hist\n",
    "\n",
    "def train_rand_blocks_class(model, tr_loader, te_loader, epochs=20, inner=3, lr=5e-3, dev='cpu'):\n",
    "    model.to(dev)\n",
    "    blocks = [('Wb', l) for l in range(1, model.L+1)] + [('u', None)]\n",
    "    tr_ce, te_ce, tr_acc, te_acc, inner_curves = [], [], [], [], []\n",
    "    for ep in range(epochs):\n",
    "        sums = [0.0]*inner; nb=0\n",
    "        for xb, yb in tr_loader:\n",
    "            b = random.choice(blocks)\n",
    "            hist = dca_step_batch_class(model, xb, yb, b, inner=inner, lr=lr, dev=dev)\n",
    "            for k in range(inner): sums[k] += hist[k]\n",
    "            nb += 1\n",
    "        avg_inner = [s/nb for s in sums]\n",
    "        inner_curves.append(avg_inner)\n",
    "        tr_ce.append(ce_on_loader(model, tr_loader, dev))\n",
    "        te_ce.append(ce_on_loader(model, te_loader, dev))\n",
    "        tr_acc.append(acc_on_loader(model, tr_loader, dev))\n",
    "        te_acc.append(acc_on_loader(model, te_loader, dev))\n",
    "        print(f\"[ep {ep+1}] CE train {tr_ce[-1]:.4f} acc {tr_acc[-1]:.3f} | CE test {te_ce[-1]:.4f} acc {te_acc[-1]:.3f}\")\n",
    "    return tr_ce, te_ce, tr_acc, te_acc, inner_curves\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def ce_on_loader(model, loader, dev='cpu'):\n",
    "    model.eval(); total_loss=0.0; n=0\n",
    "    loss_fn = nn.CrossEntropyLoss(reduction='sum')\n",
    "    for xb, yb in loader:\n",
    "        xb = xb.to(dev, non_blocking=True).view(xb.size(0), -1)\n",
    "        yb = yb.to(dev, non_blocking=True).long()\n",
    "        logits = model(xb)\n",
    "        total_loss += loss_fn(logits, yb).item()\n",
    "        n += yb.numel()\n",
    "    return total_loss / max(1, n)\n",
    "\n",
    "@torch.no_grad()\n",
    "def acc_on_loader(model, loader, dev='cpu'):\n",
    "    model.eval(); correct=0; n=0\n",
    "    for xb, yb in loader:\n",
    "        xb = xb.to(dev, non_blocking=True).view(xb.size(0), -1)\n",
    "        yb = yb.to(dev, non_blocking=True).long()\n",
    "        logits = model(xb)\n",
    "        pred = logits.argmax(dim=1)\n",
    "        correct += (pred == yb).sum().item()\n",
    "        n += yb.numel()\n",
    "    return correct / max(1, n)\n",
    "\n",
    "def train_sgd_class(model, tr_loader, te_loader, epochs=20, lr=1e-2, dev='cpu'):\n",
    "    model.to(dev)\n",
    "    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    tr_ce, te_ce, tr_acc, te_acc = [], [], [], []\n",
    "    for ep in range(epochs):\n",
    "        model.train()\n",
    "        for xb, yb in tr_loader:\n",
    "            xb = xb.to(dev, non_blocking=True).view(xb.size(0), -1)\n",
    "            yb = yb.to(dev, non_blocking=True).long()\n",
    "            opt.zero_grad(set_to_none=True)\n",
    "            logits = model(xb)\n",
    "            loss = loss_fn(logits, yb)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "        tr_ce.append(ce_on_loader(model, tr_loader, dev))\n",
    "        te_ce.append(ce_on_loader(model, te_loader, dev))\n",
    "        tr_acc.append(acc_on_loader(model, tr_loader, dev))\n",
    "        te_acc.append(acc_on_loader(model, te_loader, dev))\n",
    "        print(f\"[ep {ep+1}] CE train {tr_ce[-1]:.4f} acc {tr_acc[-1]:.3f} | CE test {te_ce[-1]:.4f} acc {te_acc[-1]:.3f}\")\n",
    "    return tr_ce, te_ce, tr_acc, te_acc\n",
    "\n",
    "\n",
    "def make_loaders_class(dataset='fashion', batch_size=128, device='cpu'):\n",
    "    if dataset == 'fashion':\n",
    "        transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])\n",
    "        trainset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)\n",
    "        testset  = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)\n",
    "        input_dim, num_classes = 28*28, 10\n",
    "    elif dataset == 'cifar10':\n",
    "        transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])\n",
    "        trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "        testset  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "        input_dim, num_classes = 32*32*3, 10\n",
    "    else:\n",
    "        raise ValueError(\"dataset must be 'fashion' or 'cifar10'\")\n",
    "    tr_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True,  num_workers=2)\n",
    "    te_loader = DataLoader(testset,  batch_size=batch_size, shuffle=False, num_workers=2)\n",
    "    return tr_loader, te_loader, input_dim, num_classes\n",
    "\n",
    "\n",
    "def mean_ci_normal(x, ci=0.90, axis=0):\n",
    "\n",
    "    x = np.asarray(x)\n",
    "    n = x.shape[axis]\n",
    "    m = np.mean(x, axis=axis)\n",
    "    s = np.std(x, axis=axis, ddof=1) if n > 1 else np.zeros_like(m)\n",
    "    # z for two-sided CI:\n",
    "    # 90% -> z ≈ 1.6448536269514722\n",
    "    z = 1.6448536269514722 if abs(ci - 0.90) < 1e-12 else \\\n",
    "        float(torch.distributions.normal.Normal(0,1).icdf(torch.tensor((1+ci)/2)).item())\n",
    "    hw = z * (s / max(1, np.sqrt(n)))\n",
    "    return m, m - hw, m + hw\n",
    "\n",
    "def run_single_training(seed, dataset, trainer, hidden_sizes, epochs,\n",
    "                        batch_size, dev, lr_sgd, inner, lr_dca, ws, bs, us):\n",
    "    set_global_seed(seed)\n",
    "    tr_loader, te_loader, d, C = make_loaders_class(dataset=dataset, batch_size=batch_size, device=dev)\n",
    "    model = DeepReLU_Classifier(input_dim=d, hidden_sizes=hidden_sizes, num_classes=C, ws=ws, bs=bs, us=us).to(dev)\n",
    "\n",
    "    if trainer == 'sgd':\n",
    "        tr_ce, te_ce, tr_acc, te_acc = train_sgd_class(model, tr_loader, te_loader, epochs=epochs, lr=lr_sgd, dev=dev)\n",
    "    else:\n",
    "        tr_ce, te_ce, tr_acc, te_acc, _ = train_rand_blocks_class(model, tr_loader, te_loader, epochs=epochs, inner=inner, lr=lr_dca, dev=dev)\n",
    "\n",
    "    return np.array(tr_ce), np.array(te_ce), np.array(tr_acc), np.array(te_acc)\n",
    "\n",
    "def run_monte_carlo(mc_runs=10, base_seed=1234, out_csv='mc_results_ce_acc.csv',\n",
    "                    dataset='cifar10', trainer='dca',\n",
    "                    hidden_sizes=[512,256,128,64], epochs=50, batch_size=128,\n",
    "                    lr_sgd=1e-2, inner=100, lr_dca=1e-5,\n",
    "                    ws=0.2, bs=0.2, us=0.2, dev=None, ci=0.90):\n",
    "\n",
    "    if dev is None:\n",
    "        dev = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "\n",
    "    all_tr_ce = np.zeros((mc_runs, epochs), dtype=np.float64)\n",
    "    all_te_ce = np.zeros((mc_runs, epochs), dtype=np.float64)\n",
    "    all_tr_acc = np.zeros((mc_runs, epochs), dtype=np.float64)\n",
    "    all_te_acc = np.zeros((mc_runs, epochs), dtype=np.float64)\n",
    "\n",
    "    for r in range(mc_runs):\n",
    "        seed = base_seed + r\n",
    "        print(f\"\\n=== Monte Carlo run {r+1}/{mc_runs} (seed={seed}) ===\")\n",
    "        tr_ce, te_ce, tr_acc, te_acc = run_single_training(\n",
    "            seed, dataset, trainer, hidden_sizes, epochs,\n",
    "            batch_size, dev, lr_sgd, inner, lr_dca, ws, bs, us\n",
    "        )\n",
    "        all_tr_ce[r] = tr_ce\n",
    "        all_te_ce[r] = te_ce\n",
    "        all_tr_acc[r] = tr_acc\n",
    "        all_te_acc[r] = te_acc\n",
    "\n",
    "\n",
    "    tr_ce_mean, tr_ce_lo, tr_ce_hi = mean_ci_normal(all_tr_ce, ci=ci, axis=0)\n",
    "    te_ce_mean, te_ce_lo, te_ce_hi = mean_ci_normal(all_te_ce, ci=ci, axis=0)\n",
    "    tr_acc_mean, tr_acc_lo, tr_acc_hi = mean_ci_normal(all_tr_acc, ci=ci, axis=0)\n",
    "    te_acc_mean, te_acc_lo, te_acc_hi = mean_ci_normal(all_te_acc, ci=ci, axis=0)\n",
    "\n",
    "\n",
    "    os.makedirs(os.path.dirname(out_csv) or '.', exist_ok=True)\n",
    "    with open(out_csv, 'w', newline='') as f:\n",
    "        w = csv.writer(f)\n",
    "        w.writerow(['epoch','split','metric','mean','ci_lower','ci_upper','n_runs',\n",
    "                    'dataset','trainer'])\n",
    "        for ep in range(epochs):\n",
    "            # CE\n",
    "            w.writerow([ep+1,'train','ce', float(tr_ce_mean[ep]), float(tr_ce_lo[ep]), float(tr_ce_hi[ep]),\n",
    "                        mc_runs, dataset, trainer])\n",
    "            w.writerow([ep+1,'test','ce',  float(te_ce_mean[ep]), float(te_ce_lo[ep]), float(te_ce_hi[ep]),\n",
    "                        mc_runs, dataset, trainer])\n",
    "            # ACC\n",
    "            w.writerow([ep+1,'train','acc', float(tr_acc_mean[ep]), float(tr_acc_lo[ep]), float(tr_acc_hi[ep]),\n",
    "                        mc_runs, dataset, trainer])\n",
    "            w.writerow([ep+1,'test','acc',  float(te_acc_mean[ep]), float(te_acc_lo[ep]), float(te_acc_hi[ep]),\n",
    "                        mc_runs, dataset, trainer])\n",
    "\n",
    "    print(f\"\\nSaved Monte Carlo summary (mean & {int(ci*100)}% CI) to: {out_csv}\")\n",
    "    return {\n",
    "        'out_csv': out_csv,\n",
    "        'tr_ce': all_tr_ce, 'te_ce': all_te_ce, 'tr_acc': all_tr_acc, 'te_acc': all_te_acc\n",
    "    }\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "\n",
    "    dataset = 'fashion'     # 'fashion' or 'cifar10'\n",
    "    trainer = 'sgd'         # 'sgd' or 'dca'\n",
    "\n",
    "    # Hyperparameters\n",
    "    dev = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    batch_size = 256\n",
    "    hidden_sizes = [ 512, 64]\n",
    "    epochs = 100\n",
    "    lr_sgd = 1e-2\n",
    "    inner = 20        # for DCA\n",
    "    lr_dca = 1e-3      # for DCA\n",
    "    ws = bs = us = 0.05\n",
    "\n",
    "    mc_runs = 1\n",
    "    base_seed = 1234\n",
    "    ci_level = 0.90\n",
    "    out_csv = dataset + '_mc_results_ce_acc_' + trainer + '.csv'\n",
    "\n",
    "\n",
    "    run_monte_carlo(mc_runs=mc_runs, base_seed=base_seed, out_csv=out_csv,\n",
    "                    dataset=dataset, trainer=trainer,\n",
    "                    hidden_sizes=hidden_sizes, epochs=epochs, batch_size=batch_size,\n",
    "                    lr_sgd=lr_sgd, inner=inner, lr_dca=lr_dca,\n",
    "                    ws=ws, bs=bs, us=us, dev=dev, ci=ci_level)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9ba187e-84d3-4183-8564-997da35f8ec2",
   "metadata": {},
   "source": [
    "# Small Generalized Smoothness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4799fbb2-ca50-414a-a0a0-d320fa4fb3c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math, random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "import matplotlib.patheffects as pe\n",
    "from matplotlib.lines import Line2D\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "relu = torch.relu\n",
    "def pos(x): return relu(x)\n",
    "def neg(x): return relu(-x)\n",
    "\n",
    "def vecnorm(ts):\n",
    "    return torch.sqrt(sum((t**2).sum() for t in ts) + 1e-20)\n",
    "\n",
    "class DeepReLU2Scalar(nn.Module):\n",
    "    def __init__(self, d, h=(128,64), ws=0.05, bs=0.05, us=0.05):\n",
    "        super().__init__()\n",
    "        n1, n2 = h\n",
    "        self.layers = nn.ModuleList([\n",
    "            nn.Linear(d,  n1, bias=True),\n",
    "            nn.Linear(n1, n2, bias=True),\n",
    "        ])\n",
    "        for lin in self.layers:\n",
    "            nn.init.normal_(lin.weight, 0, ws)\n",
    "            nn.init.normal_(lin.bias,   0, bs)\n",
    "        self.u = nn.Parameter(torch.randn(n2) * us)  # matches last hidden width\n",
    "        self.L = 2\n",
    "    def forward(self, x):\n",
    "        for lin in self.layers: x = relu(lin(x))\n",
    "        return x @ self.u\n",
    "    def activations_until(self, x, k):\n",
    "        a = x\n",
    "        for i in range(k): a = relu(self.layers[i](a))\n",
    "        return a\n",
    "\n",
    "def z_pair_from_aprev(model, a_prev, l):\n",
    "    s = a_prev @ model.layers[l-1].weight.T + model.layers[l-1].bias\n",
    "    Zp, Zm = relu(s), torch.zeros_like(s)\n",
    "    for k in range(l, model.L):\n",
    "        W, b = model.layers[k].weight, model.layers[k].bias\n",
    "        p = Zp @ pos(W).T + Zm @ neg(W).T + pos(b)\n",
    "        q = Zm @ pos(W).T + Zp @ neg(W).T + neg(b)\n",
    "        Zp, Zm = torch.maximum(p, q), q\n",
    "    return Zp, Zm\n",
    "\n",
    "def AB_block(model, x, block):\n",
    "    if block[0] == 'u':\n",
    "        aL = model.activations_until(x, model.L)\n",
    "        up, um = pos(model.u), neg(model.u)\n",
    "        return (aL * up).sum(1), (aL * um).sum(1)\n",
    "    l = block[1]\n",
    "    a_prev = x if l == 1 else model.activations_until(x, l-1)\n",
    "    Zp, Zm = z_pair_from_aprev(model, a_prev, l)\n",
    "    up, um = pos(model.u), neg(model.u)\n",
    "    A = (Zp * up).sum(1) + (Zm * um).sum(1)\n",
    "    B = (Zm * up).sum(1) + (Zp * um).sum(1)\n",
    "    return A, B\n",
    "\n",
    "def f_terms(A, B, y):\n",
    "    if y.dim()==0: y = y.expand_as(A)\n",
    "    c = pos(-y)\n",
    "    return (2*A*A + 2*(B + y + c)**2 + 2*c*A).mean()\n",
    "\n",
    "def g_terms(A, B, y):\n",
    "    if y.dim()==0: y = y.expand_as(A)\n",
    "    c = pos(-y)\n",
    "    return ((A + B + y + c)**2 + 2*c*(B + y) + c*c).mean()\n",
    "\n",
    "def block_params(model, block):\n",
    "    if block[0]=='u': return [model.u]\n",
    "    lin = model.layers[block[1]-1]; return [lin.weight, lin.bias]\n",
    "\n",
    "def set_requires_grad_only_block(model, block):\n",
    "    for p in model.parameters(): p.requires_grad_(False)\n",
    "    for p in block_params(model, block): p.requires_grad_(True)\n",
    "\n",
    "def clone_params(params):\n",
    "    return [p.detach().clone() for p in params]\n",
    "\n",
    "def set_params_(params, values):\n",
    "    with torch.no_grad():\n",
    "        for p, v in zip(params, values): p.copy_(v)\n",
    "\n",
    "def add_inplace_(params, offsets, scale=1.0):\n",
    "    with torch.no_grad():\n",
    "        for p, o in zip(params, offsets): p.add_(scale * o)\n",
    "\n",
    "def diff_params(new_params, old_params):\n",
    "    return [ (n.detach() - o.detach()) for n,o in zip(new_params, old_params) ]\n",
    "\n",
    "def grad_f_block(model, xb, yb, block, create_graph=False):\n",
    "    set_requires_grad_only_block(model, block)\n",
    "    A, B = AB_block(model, xb, block)\n",
    "    f = f_terms(A, B, yb)\n",
    "    grads = torch.autograd.grad(f, block_params(model, block),\n",
    "                                retain_graph=False, create_graph=create_graph)\n",
    "    return grads, f\n",
    "\n",
    "def estimate_Lhat_f_along_direction(model, xb, yb, block, base_vals, direction, delta=0.25):\n",
    "    params = block_params(model, block)\n",
    "    step_norm = float(vecnorm(direction))\n",
    "    if step_norm < 1e-20:\n",
    "        return 0.0\n",
    "\n",
    "    set_params_(params, base_vals)\n",
    "    g0, _ = grad_f_block(model, xb, yb, block, create_graph=False)\n",
    "    Lhat = 0.0\n",
    "    steps = max(1, int(1.0/delta + 1e-9))\n",
    "    for i in range(1, steps+1):\n",
    "        gamma = i * delta\n",
    "        set_params_(params, base_vals)\n",
    "        add_inplace_(params, direction, scale=gamma)\n",
    "        gγ, _ = grad_f_block(model, xb, yb, block, create_graph=False)\n",
    "        diff = float(vecnorm([g - gg for g, gg in zip(gγ, g0)]))\n",
    "        Lhat = max(Lhat, diff / (gamma * step_norm + 1e-20))\n",
    "    return Lhat\n",
    "\n",
    "def dca_step_block(model, xb, yb, block, inner_steps=1, lr=1e-4):\n",
    "    set_requires_grad_only_block(model, block)\n",
    "    A, B = AB_block(model, xb, block)\n",
    "    g = g_terms(A, B, yb)\n",
    "    v_list = torch.autograd.grad(g, block_params(model, block),\n",
    "                                 retain_graph=False, create_graph=False)\n",
    "    v_list = [v.detach() for v in v_list]\n",
    "    for _ in range(inner_steps):\n",
    "        set_requires_grad_only_block(model, block)\n",
    "        A, B = AB_block(model, xb, block)\n",
    "        f = f_terms(A, B, yb)\n",
    "        grads_f = torch.autograd.grad(f, block_params(model, block),\n",
    "                                      retain_graph=False, create_graph=False)\n",
    "        with torch.no_grad():\n",
    "            for p, gf, v in zip(block_params(model, block), grads_f, v_list):\n",
    "                p.add_( -lr * (gf - v) )\n",
    "\n",
    "def make_california(batch_size=512, seed=0):\n",
    "    import numpy as np\n",
    "    from sklearn.model_selection import train_test_split\n",
    "    torch.manual_seed(seed); np.random.seed(seed)\n",
    "    try:\n",
    "        from sklearn.datasets import fetch_california_housing\n",
    "        data = fetch_california_housing()\n",
    "        X = data.data.astype('float32')   # (20640, 8)\n",
    "        y = data.target.astype('float32')\n",
    "    except Exception:\n",
    "        from sklearn.datasets import make_regression\n",
    "        X, y = make_regression(n_samples=20000, n_features=8, n_informative=8,\n",
    "                               noise=15.0, bias=0.0, random_state=seed)\n",
    "        X = X.astype('float32'); y = y.astype('float32')\n",
    "\n",
    "    from sklearn.model_selection import train_test_split\n",
    "    Xtr, Xva, Ytr, Yva = train_test_split(X, y, test_size=0.2, random_state=seed)\n",
    "\n",
    "    X_mean, X_std = Xtr.mean(axis=0), Xtr.std(axis=0) + 1e-8\n",
    "    Y_mean, Y_std = Ytr.mean(), Ytr.std() + 1e-8\n",
    "    Xtr = (Xtr - X_mean) / X_std\n",
    "    Xva = (Xva - X_mean) / X_std\n",
    "    Ytr = (Ytr - Y_mean) / Y_std\n",
    "    Yva = (Yva - Y_mean) / Y_std\n",
    "\n",
    "    Xtr = torch.from_numpy(Xtr); Xva = torch.from_numpy(Xva)\n",
    "    Ytr = torch.from_numpy(Ytr); Yva = torch.from_numpy(Yva)\n",
    "\n",
    "    tr_ds = TensorDataset(Xtr, Ytr)\n",
    "    va_ds = TensorDataset(Xva, Yva)\n",
    "    tr_loader = DataLoader(tr_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)\n",
    "    va_loader = DataLoader(va_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)\n",
    "    d = Xtr.shape[1]\n",
    "    return tr_loader, va_loader, d\n",
    "\n",
    "def cycle(loader):\n",
    "    while True:\n",
    "        for b in loader: yield b\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    tr_loader, va_loader, d = make_california(batch_size=1024, seed=0)\n",
    "    model = DeepReLU2Scalar(d, (64,32)).to(device)\n",
    "\n",
    "    blocks = [('Wb',1), ('Wb',2), ('u',None)]\n",
    "    markers = {('Wb',1):'o', ('Wb',2):'^', ('u',None):'s'}  \n",
    "    logs = {b:{'logG':[], 'logLhat':[], 't':[]} for b in blocks}\n",
    "\n",
    "    outer_epochs   = 30\n",
    "    inner_steps    = 10\n",
    "    lr_block       = 5e-3\n",
    "    delta_probe    = 0.25\n",
    "\n",
    "    tr_iter = cycle(tr_loader)\n",
    "    total_steps = outer_epochs * len(tr_loader)\n",
    "    global_log_idx = 0  \n",
    "\n",
    "    model.train()\n",
    "    for step in range(1, total_steps+1):\n",
    "        xb, yb = next(tr_iter)\n",
    "        xb, yb = xb.to(device), yb.to(device)\n",
    "\n",
    "        for blk in blocks:\n",
    "            params_blk = block_params(model, blk)\n",
    "            theta_k = clone_params(params_blk)\n",
    "\n",
    "            gf_k, _ = grad_f_block(model, xb, yb, blk, create_graph=False)\n",
    "            gnorm = float(vecnorm(gf_k))\n",
    "\n",
    "            dca_step_block(model, xb, yb, blk, inner_steps=inner_steps, lr=lr_block)\n",
    "\n",
    "            theta_k1 = [p.detach().clone() for p in params_blk]\n",
    "            d_dir = diff_params(theta_k1, theta_k)\n",
    "\n",
    "            Lhat = 0.0\n",
    "            if gnorm > 1e-20:\n",
    "                Lhat = estimate_Lhat_f_along_direction(model, xb, yb, blk,\n",
    "                                                       base_vals=theta_k,\n",
    "                                                       direction=d_dir,\n",
    "                                                       delta=delta_probe)\n",
    "\n",
    "            set_params_(params_blk, theta_k1)\n",
    "\n",
    "            if gnorm > 1e-20 and Lhat > 0:\n",
    "                logs[blk]['logG'].append(math.log(gnorm))\n",
    "                logs[blk]['logLhat'].append(math.log(Lhat))\n",
    "                logs[blk]['t'].append(global_log_idx)\n",
    "                global_log_idx += 1\n",
    "\n",
    "        if step % 100 == 0:\n",
    "            with torch.no_grad():\n",
    "                model.eval()\n",
    "                se=0.0; n=0\n",
    "                for Xv, Yv in va_loader:\n",
    "                    Xv, Yv = Xv.to(device), Yv.to(device)\n",
    "                    pred = model(Xv)\n",
    "                    se += ((pred - Yv)**2).sum().item(); n += Yv.numel()\n",
    "                print(f\"[step {step}/{total_steps}] valid MSE={se/max(1,n):.4f}\")\n",
    "                model.train()\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ec5adab-8c2a-40bf-b264-488d29ff0085",
   "metadata": {},
   "source": [
    "## visualization "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb89b99f-d207-4bfc-8fa1-9b50a3c154c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "cmap = plt.cm.plasma\n",
    "tmax = max((max(v['t']) for v in logs.values() if len(v['t']) > 0), default=1)\n",
    "norm = mpl.colors.Normalize(vmin=0, vmax=tmax)\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "\n",
    "styles = {\n",
    "    ('Wb',1): dict(marker='o', s_base=16, s_latest=28, edge='k', stroke='white', lw=1.2),\n",
    "    ('Wb',2): dict(marker='^', s_base=22, s_latest=34, edge='k', stroke='white', lw=1.2),\n",
    "    ('u',None): dict(marker='X', s_base=28, s_latest=42, edge='k', stroke='white', lw=1.2),\n",
    "}\n",
    "K_latest = 10   # overlay last K points per block without downsampling\n",
    "stride   = 10    # downsample the base cloud for readability\n",
    "\n",
    "mappable_for_cbar = None\n",
    "legend_proxies = []\n",
    "\n",
    "for blk in blocks:\n",
    "    xs, ys, ts = logs[blk]['logG'], logs[blk]['logLhat'], logs[blk]['t']\n",
    "    if not xs:\n",
    "        continue\n",
    "\n",
    "    st = styles[blk]\n",
    "\n",
    "    sc = ax.scatter(xs[::stride], ys[::stride],\n",
    "                    s=st['s_base'], c=ts[::stride], cmap=cmap, norm=norm,\n",
    "                    marker=st['marker'], alpha=0.8, linewidths=0.6, edgecolors=st['edge'], zorder=2,\n",
    "                    label=str(blk))\n",
    "    sc.set_path_effects([pe.withStroke(linewidth=st['lw'], foreground=st['stroke'])])\n",
    "    if mappable_for_cbar is None:\n",
    "        mappable_for_cbar = sc\n",
    "\n",
    "    xs_tail, ys_tail, ts_tail = xs[-K_latest:], ys[-K_latest:], ts[-K_latest:]\n",
    "    if xs_tail:\n",
    "        sc_latest = ax.scatter(xs_tail, ys_tail,\n",
    "                               s=st['s_latest'], c=ts_tail, cmap=cmap, norm=norm,\n",
    "                               marker=st['marker'], alpha=0.95, linewidths=0.9, edgecolors=st['edge'],\n",
    "                               zorder=3)\n",
    "        sc_latest.set_path_effects([pe.withStroke(linewidth=st['lw']+0.3, foreground=st['stroke'])])\n",
    "\n",
    "    legend_proxies.append(\n",
    "        Line2D([0],[0], marker=st['marker'], linestyle='None',\n",
    "               markerfacecolor='#888', markeredgecolor=st['edge'],\n",
    "               markeredgewidth=1.0, markersize=8, label=str(blk))\n",
    "    )\n",
    "\n",
    "if mappable_for_cbar is not None:\n",
    "    cbar = fig.colorbar(mappable_for_cbar, ax=ax)\n",
    "    cbar.set_label('log index (early → late)')\n",
    "\n",
    "ax.set_xlabel(r\"$\\log(||\\nabla_{\\boldsymbol{\\theta}_{i_k}}f(\\boldsymbol{\\Theta}^k)||)$\")\n",
    "ax.set_ylabel(r\"$\\log(\\text{estimated smoothness per block})$\")\n",
    "ax.legend(handles=legend_proxies, title='Blocks', loc='best')\n",
    "ax.grid(True) \n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig('bdc_dca_logLhat_vs_logGrad_f_california_2layers_timecolored.png', dpi=150)\n",
    "print(\"Saved: bdc_dca_logLhat_vs_logGrad_f_california_2layers_timecolored.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8259386-97c9-438f-9593-42ec65e07b95",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
