{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "<>:121: SyntaxWarning: invalid escape sequence '\\h'\n",
            "<>:121: SyntaxWarning: invalid escape sequence '\\h'\n",
            "/tmp/ipykernel_3442850/2708816350.py:121: SyntaxWarning: invalid escape sequence '\\h'\n",
            "  Return per-degree mass: mass[n] = sum_{|S|=n} \\hat f(S)^2 with orthonormal coeffs.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[ep     1] loss=1.183224e+01 | anchor-avg max|pred-true|≈6.969 | mean≈3.000\n",
            "[ep    50] loss=1.082400e+01 | anchor-avg max|pred-true|≈6.046 | mean≈3.000\n",
            "[ep   100] loss=1.073393e+01 | anchor-avg max|pred-true|≈6.115 | mean≈3.000\n",
            "[ep   150] loss=1.066834e+01 | anchor-avg max|pred-true|≈6.125 | mean≈3.000\n",
            "[ep   200] loss=1.062777e+01 | anchor-avg max|pred-true|≈5.683 | mean≈3.000\n",
            "[ep   250] loss=1.071896e+01 | anchor-avg max|pred-true|≈5.395 | mean≈3.000\n",
            "[ep   300] loss=1.067020e+01 | anchor-avg max|pred-true|≈5.260 | mean≈3.000\n",
            "[ep   350] loss=1.057290e+01 | anchor-avg max|pred-true|≈5.397 | mean≈3.000\n",
            "[ep   400] loss=1.049733e+01 | anchor-avg max|pred-true|≈5.247 | mean≈2.999\n",
            "[ep   450] loss=9.905169e+00 | anchor-avg max|pred-true|≈4.980 | mean≈2.889\n",
            "[ep   500] loss=7.252274e+00 | anchor-avg max|pred-true|≈4.348 | mean≈2.422\n",
            "[ep   550] loss=4.243666e+00 | anchor-avg max|pred-true|≈5.911 | mean≈1.634\n",
            "[ep   600] loss=2.985983e+00 | anchor-avg max|pred-true|≈5.429 | mean≈1.255\n",
            "[ep   650] loss=2.531389e+00 | anchor-avg max|pred-true|≈6.675 | mean≈1.126\n",
            "[ep   700] loss=1.923994e+00 | anchor-avg max|pred-true|≈5.421 | mean≈0.892\n",
            "[ep   750] loss=1.422819e+00 | anchor-avg max|pred-true|≈4.207 | mean≈0.765\n",
            "[ep   800] loss=1.225252e+00 | anchor-avg max|pred-true|≈4.700 | mean≈0.644\n",
            "[ep   850] loss=9.762559e-01 | anchor-avg max|pred-true|≈5.612 | mean≈0.590\n",
            "[ep   900] loss=1.172634e+00 | anchor-avg max|pred-true|≈5.831 | mean≈0.599\n",
            "[ep   950] loss=1.023571e+00 | anchor-avg max|pred-true|≈4.871 | mean≈0.595\n",
            "[ep  1000] loss=7.633966e-01 | anchor-avg max|pred-true|≈4.678 | mean≈0.567\n",
            "[ep  1050] loss=8.424720e-01 | anchor-avg max|pred-true|≈4.693 | mean≈0.522\n",
            "[ep  1100] loss=6.582520e-01 | anchor-avg max|pred-true|≈5.791 | mean≈0.453\n",
            "[ep  1150] loss=6.941900e-01 | anchor-avg max|pred-true|≈5.429 | mean≈0.410\n",
            "[ep  1200] loss=7.309475e-01 | anchor-avg max|pred-true|≈5.642 | mean≈0.428\n",
            "[ep  1250] loss=6.376904e-01 | anchor-avg max|pred-true|≈5.061 | mean≈0.327\n",
            "[ep  1300] loss=5.846848e-01 | anchor-avg max|pred-true|≈5.105 | mean≈0.409\n",
            "[ep  1350] loss=5.908748e-01 | anchor-avg max|pred-true|≈4.766 | mean≈0.386\n",
            "[ep  1400] loss=5.085770e-01 | anchor-avg max|pred-true|≈5.348 | mean≈0.455\n",
            "[ep  1450] loss=5.860483e-01 | anchor-avg max|pred-true|≈6.420 | mean≈0.495\n",
            "[ep  1500] loss=3.953664e-01 | anchor-avg max|pred-true|≈4.508 | mean≈0.307\n",
            "[ep  1550] loss=5.841720e-01 | anchor-avg max|pred-true|≈5.041 | mean≈0.427\n",
            "[ep  1600] loss=4.159739e-01 | anchor-avg max|pred-true|≈4.197 | mean≈0.248\n",
            "[ep  1650] loss=4.934322e-01 | anchor-avg max|pred-true|≈4.779 | mean≈0.416\n",
            "[ep  1700] loss=2.658701e-01 | anchor-avg max|pred-true|≈4.855 | mean≈0.419\n",
            "[ep  1750] loss=6.367639e-01 | anchor-avg max|pred-true|≈4.247 | mean≈0.394\n",
            "[ep  1800] loss=2.822230e-01 | anchor-avg max|pred-true|≈2.898 | mean≈0.226\n",
            "[ep  1850] loss=2.972431e-01 | anchor-avg max|pred-true|≈2.791 | mean≈0.591\n",
            "[ep  1900] loss=1.727956e-01 | anchor-avg max|pred-true|≈3.445 | mean≈0.427\n",
            "[ep  1950] loss=1.813970e-01 | anchor-avg max|pred-true|≈2.565 | mean≈0.157\n",
            "[ep  2000] loss=1.515616e-01 | anchor-avg max|pred-true|≈2.591 | mean≈0.180\n",
            "[ep  2050] loss=9.787881e-02 | anchor-avg max|pred-true|≈1.559 | mean≈0.356\n",
            "[ep  2100] loss=1.206464e-01 | anchor-avg max|pred-true|≈1.874 | mean≈0.202\n",
            "[ep  2150] loss=5.036328e-01 | anchor-avg max|pred-true|≈2.501 | mean≈0.646\n",
            "[ep  2200] loss=7.509083e-02 | anchor-avg max|pred-true|≈1.541 | mean≈0.183\n",
            "[ep  2250] loss=1.025786e-01 | anchor-avg max|pred-true|≈1.425 | mean≈0.323\n",
            "[ep  2300] loss=4.213593e-02 | anchor-avg max|pred-true|≈0.985 | mean≈0.162\n",
            "[ep  2350] loss=1.643740e-01 | anchor-avg max|pred-true|≈1.845 | mean≈0.313\n",
            "[ep  2400] loss=1.301717e-01 | anchor-avg max|pred-true|≈1.279 | mean≈0.338\n",
            "[ep  2450] loss=2.619160e-01 | anchor-avg max|pred-true|≈0.714 | mean≈0.174\n",
            "[ep  2500] loss=9.829255e-02 | anchor-avg max|pred-true|≈0.659 | mean≈0.126\n",
            "[ep  2550] loss=9.226254e-02 | anchor-avg max|pred-true|≈1.293 | mean≈0.334\n",
            "[ep  2600] loss=2.874161e-01 | anchor-avg max|pred-true|≈1.089 | mean≈0.222\n",
            "[ep  2650] loss=1.025562e-01 | anchor-avg max|pred-true|≈1.560 | mean≈0.275\n",
            "[ep  2700] loss=8.754379e-02 | anchor-avg max|pred-true|≈0.726 | mean≈0.198\n",
            "[ep  2750] loss=3.633285e-02 | anchor-avg max|pred-true|≈1.251 | mean≈0.310\n",
            "[ep  2800] loss=3.558464e-02 | anchor-avg max|pred-true|≈1.016 | mean≈0.258\n",
            "[ep  2850] loss=2.396625e-02 | anchor-avg max|pred-true|≈0.970 | mean≈0.303\n",
            "[ep  2900] loss=3.389960e-02 | anchor-avg max|pred-true|≈0.492 | mean≈0.137\n",
            "[ep  2950] loss=7.714929e-02 | anchor-avg max|pred-true|≈1.192 | mean≈0.181\n",
            "[ep  3000] loss=5.110230e-02 | anchor-avg max|pred-true|≈0.878 | mean≈0.144\n",
            "[ep  3050] loss=3.133843e-02 | anchor-avg max|pred-true|≈0.649 | mean≈0.091\n",
            "[ep  3100] loss=1.117510e-01 | anchor-avg max|pred-true|≈0.958 | mean≈0.164\n",
            "[ep  3150] loss=1.810782e-01 | anchor-avg max|pred-true|≈0.894 | mean≈0.208\n",
            "[ep  3200] loss=5.865879e-02 | anchor-avg max|pred-true|≈0.820 | mean≈0.156\n",
            "[ep  3250] loss=1.861518e-01 | anchor-avg max|pred-true|≈1.191 | mean≈0.318\n",
            "[ep  3300] loss=4.163872e-01 | anchor-avg max|pred-true|≈1.708 | mean≈0.557\n",
            "[ep  3350] loss=6.042300e-02 | anchor-avg max|pred-true|≈0.526 | mean≈0.156\n",
            "[ep  3400] loss=8.324265e-02 | anchor-avg max|pred-true|≈0.539 | mean≈0.192\n",
            "[ep  3450] loss=5.567418e-02 | anchor-avg max|pred-true|≈1.030 | mean≈0.193\n",
            "[ep  3500] loss=5.051050e-02 | anchor-avg max|pred-true|≈0.879 | mean≈0.224\n",
            "[ep  3550] loss=6.934343e-02 | anchor-avg max|pred-true|≈0.621 | mean≈0.188\n",
            "[ep  3600] loss=1.434383e-01 | anchor-avg max|pred-true|≈1.233 | mean≈0.280\n",
            "[ep  3650] loss=1.547844e-01 | anchor-avg max|pred-true|≈1.089 | mean≈0.319\n",
            "[ep  3700] loss=3.197739e-01 | anchor-avg max|pred-true|≈1.779 | mean≈0.313\n",
            "[ep  3750] loss=5.293944e-02 | anchor-avg max|pred-true|≈0.681 | mean≈0.239\n",
            "[ep  3800] loss=6.011964e-02 | anchor-avg max|pred-true|≈1.533 | mean≈0.282\n",
            "[ep  3850] loss=8.677322e-02 | anchor-avg max|pred-true|≈0.820 | mean≈0.181\n",
            "[ep  3900] loss=4.837423e-02 | anchor-avg max|pred-true|≈0.687 | mean≈0.159\n",
            "[ep  3950] loss=1.242396e-01 | anchor-avg max|pred-true|≈1.166 | mean≈0.191\n",
            "[ep  4000] loss=3.073769e-02 | anchor-avg max|pred-true|≈0.838 | mean≈0.193\n",
            "[ep  4050] loss=1.475731e-01 | anchor-avg max|pred-true|≈0.930 | mean≈0.234\n",
            "[ep  4100] loss=4.940917e-02 | anchor-avg max|pred-true|≈0.567 | mean≈0.093\n",
            "[ep  4150] loss=1.155812e-01 | anchor-avg max|pred-true|≈0.679 | mean≈0.160\n",
            "[ep  4200] loss=8.057915e-02 | anchor-avg max|pred-true|≈0.519 | mean≈0.124\n",
            "[ep  4250] loss=1.638443e-01 | anchor-avg max|pred-true|≈0.736 | mean≈0.194\n",
            "[ep  4300] loss=9.624937e-02 | anchor-avg max|pred-true|≈0.525 | mean≈0.159\n",
            "[ep  4350] loss=1.126690e-01 | anchor-avg max|pred-true|≈0.555 | mean≈0.148\n",
            "[ep  4400] loss=8.779263e-02 | anchor-avg max|pred-true|≈0.773 | mean≈0.153\n",
            "[ep  4450] loss=7.547469e-02 | anchor-avg max|pred-true|≈0.561 | mean≈0.154\n",
            "[ep  4500] loss=1.451792e-01 | anchor-avg max|pred-true|≈1.279 | mean≈0.397\n",
            "[ep  4550] loss=1.633658e-01 | anchor-avg max|pred-true|≈0.791 | mean≈0.196\n",
            "[ep  4600] loss=1.679205e-01 | anchor-avg max|pred-true|≈0.647 | mean≈0.183\n",
            "[ep  4650] loss=1.246643e-01 | anchor-avg max|pred-true|≈0.558 | mean≈0.189\n",
            "[ep  4700] loss=1.282504e-01 | anchor-avg max|pred-true|≈0.675 | mean≈0.211\n",
            "[ep  4750] loss=1.510728e-01 | anchor-avg max|pred-true|≈1.313 | mean≈0.245\n",
            "[ep  4800] loss=3.443837e-01 | anchor-avg max|pred-true|≈1.043 | mean≈0.334\n",
            "[ep  4850] loss=7.405949e-02 | anchor-avg max|pred-true|≈0.971 | mean≈0.198\n",
            "[ep  4900] loss=2.160002e-01 | anchor-avg max|pred-true|≈1.444 | mean≈0.345\n",
            "[ep  4950] loss=6.511987e-01 | anchor-avg max|pred-true|≈1.388 | mean≈0.251\n",
            "[ep  5000] loss=1.051266e-01 | anchor-avg max|pred-true|≈0.765 | mean≈0.296\n",
            "[ep  5050] loss=2.384086e-01 | anchor-avg max|pred-true|≈1.349 | mean≈0.301\n",
            "[ep  5100] loss=5.855742e-02 | anchor-avg max|pred-true|≈1.036 | mean≈0.302\n",
            "[ep  5150] loss=1.307203e-01 | anchor-avg max|pred-true|≈1.712 | mean≈0.226\n",
            "[ep  5200] loss=1.617823e-01 | anchor-avg max|pred-true|≈0.920 | mean≈0.165\n",
            "[ep  5250] loss=1.338779e-01 | anchor-avg max|pred-true|≈0.954 | mean≈0.208\n",
            "[ep  5300] loss=1.507800e-01 | anchor-avg max|pred-true|≈0.853 | mean≈0.225\n",
            "[ep  5350] loss=7.725601e-02 | anchor-avg max|pred-true|≈1.423 | mean≈0.487\n",
            "[ep  5400] loss=2.335519e-01 | anchor-avg max|pred-true|≈0.801 | mean≈0.246\n",
            "[ep  5450] loss=2.018427e-01 | anchor-avg max|pred-true|≈1.018 | mean≈0.230\n",
            "[ep  5500] loss=9.827755e-02 | anchor-avg max|pred-true|≈0.766 | mean≈0.262\n",
            "[ep  5550] loss=2.300494e-01 | anchor-avg max|pred-true|≈0.980 | mean≈0.236\n",
            "[ep  5600] loss=3.433090e-02 | anchor-avg max|pred-true|≈1.028 | mean≈0.377\n",
            "[ep  5650] loss=1.844742e-01 | anchor-avg max|pred-true|≈1.183 | mean≈0.348\n",
            "[ep  5700] loss=6.361525e-02 | anchor-avg max|pred-true|≈0.546 | mean≈0.159\n",
            "[ep  5750] loss=1.397226e-01 | anchor-avg max|pred-true|≈1.116 | mean≈0.313\n",
            "[ep  5800] loss=1.630608e-01 | anchor-avg max|pred-true|≈1.062 | mean≈0.214\n",
            "[ep  5850] loss=5.537772e-02 | anchor-avg max|pred-true|≈0.825 | mean≈0.329\n",
            "[ep  5900] loss=2.114967e-01 | anchor-avg max|pred-true|≈1.521 | mean≈0.392\n",
            "[ep  5950] loss=2.286547e-01 | anchor-avg max|pred-true|≈0.883 | mean≈0.295\n",
            "[ep  6000] loss=1.396985e-01 | anchor-avg max|pred-true|≈1.688 | mean≈0.287\n"
          ]
        }
      ],
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import os, math\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "from math import comb\n",
        "\n",
        "# ---------------------------\n",
        "# Config (tune these)\n",
        "# ---------------------------\n",
        "seed = 0\n",
        "torch.manual_seed(seed); np.random.seed(seed)\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "checkpoint_dir = \"\"   # <-- set a directory path to save every epoch (e.g., \"ckpts_poly8\"); \"\" disables saving\n",
        "\n",
        "d_indicator = 8              # 8 variables of interest (Fourier is done over these)\n",
        "d_noise     = 500            # additional variables with NO effect (set 0 to test)\n",
        "d_total     = d_indicator + d_noise\n",
        "N_indicator = 1 << d_indicator   # 256\n",
        "\n",
        "# Train\n",
        "epochs       = 6000\n",
        "print_every  = 50\n",
        "lr           = 2e-3\n",
        "weight_decay = 0.0  # <- pure MSE (no L2)\n",
        "hidden       = 128\n",
        "\n",
        "# Monte Carlo coverage for noise\n",
        "train_R = 64     # noise samples per anchor per step for the loss\n",
        "eval_R  = 256    # noise samples per anchor for evaluation/FWHT snapshots\n",
        "\n",
        "# Polynomial coefficients over the first 8 bits in {−1,+1}\n",
        "coeff_deg1 = 0.1\n",
        "coeff_deg2 = 0.2\n",
        "coeff_deg4 = 0.4\n",
        "coeff_deg6 = 0.8\n",
        "coeff_deg8 = 3.0\n",
        "\n",
        "# ---------------------------\n",
        "# Utilities\n",
        "# ---------------------------\n",
        "def full_cube_pm1(dim, device):\n",
        "    n = 1 << dim\n",
        "    idx = torch.arange(n, device=device, dtype=torch.long).unsqueeze(1)\n",
        "    shifts = torch.arange(dim, device=device, dtype=torch.long).unsqueeze(0)\n",
        "    bits = (idx >> shifts) & 1\n",
        "    return bits.float().mul_(2).sub_(1)  # in {-1, +1}\n",
        "\n",
        "def rand_pm1(shape, device):\n",
        "    if shape[-1] == 0:\n",
        "        return torch.empty(shape, device=device, dtype=torch.float32)\n",
        "    return (torch.randint(0, 2, shape, device=device, dtype=torch.int8).to(torch.float32) * 2 - 1)\n",
        "\n",
        "def paired_noise(R, d_noise, device):\n",
        "    if d_noise == 0:\n",
        "        return torch.empty((R, 0), device=device, dtype=torch.float32)\n",
        "    if R % 2 != 0:\n",
        "        R += 1\n",
        "    half = R // 2\n",
        "    nh = rand_pm1((half, d_noise), device)\n",
        "    return torch.cat([nh, -nh], dim=0)  # [R, d_noise]\n",
        "\n",
        "# Polynomial on the first 8 variables:\n",
        "# f(x) = c1 * sum xi\n",
        "#      + c2 * (x1x2 + x3x4 + x5x6 + x7x8)\n",
        "#      + c4 * (x1x2x3x4 + x5x6x7x8)\n",
        "#      + c6 * (x1x2x3x4x5x6 + x3x4x5x6x7x8)\n",
        "#      + c8 * (x1x2x3x4x5x6x7x8)\n",
        "def poly_val_first8(x, c1, c2, c4, c6, c8):\n",
        "    # x: [..., 8] in {-1, +1}\n",
        "    s1 = x.sum(dim=-1)\n",
        "    s2 = (x[...,0]*x[...,1] + x[...,2]*x[...,3] + x[...,4]*x[...,5] + x[...,6]*x[...,7])\n",
        "    s4 = (x[...,0:4].prod(dim=-1) + x[...,4:8].prod(dim=-1))\n",
        "    s6 = (x[...,0:6].prod(dim=-1) + x[...,2:8].prod(dim=-1))   # first 6, last 6\n",
        "    s8 = x.prod(dim=-1)\n",
        "    return c1*s1 + c2*s2 + c4*s4 + c6*s6 + c8*s8\n",
        "\n",
        "def weighted_mse(pred, target, w):\n",
        "    err2 = (pred - target).pow(2)\n",
        "    return (w * err2).sum() / (w.sum() + 1e-12)\n",
        "\n",
        "# ---------------------------\n",
        "# FWHT helpers for 8-D cube\n",
        "# ---------------------------\n",
        "N = N_indicator\n",
        "d = d_indicator\n",
        "\n",
        "# degree(popcount) of subset indices 0..N-1 (CPU)\n",
        "deg = torch.empty(N, dtype=torch.int16)\n",
        "arr = torch.arange(N, dtype=torch.int64)\n",
        "pc = torch.zeros(N, dtype=torch.int16)\n",
        "for j in range(d):\n",
        "    pc += (((arr >> j) & 1).to(torch.int16))\n",
        "deg[:] = pc\n",
        "del arr, pc\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def fwht_inplace(a: torch.Tensor):\n",
        "    \"\"\"\n",
        "    In-place Fast Walsh–Hadamard Transform over a 1D tensor of length N (power of two).\n",
        "    Normalization: unnormalized; divide by N afterwards for orthonormal coefficients.\n",
        "    \"\"\"\n",
        "    n = a.shape[0]\n",
        "    h = 1\n",
        "    while h < n:\n",
        "        a2 = a.view(-1, h * 2)\n",
        "        x = a2[:, :h]\n",
        "        y = a2[:, h:2*h]\n",
        "        tmp = x.clone()\n",
        "        x += y\n",
        "        y.copy_(tmp - y)\n",
        "        a = a2.view(-1)\n",
        "        h <<= 1\n",
        "\n",
        "@torch.no_grad()\n",
        "def masses_by_degree_from_f_fwht(f_vec: torch.Tensor) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    Return per-degree mass: mass[n] = sum_{|S|=n} \\hat f(S)^2 with orthonormal coeffs.\n",
        "    Works on CPU or CUDA; ensures bincount inputs share device.\n",
        "    \"\"\"\n",
        "    N = f_vec.shape[0]\n",
        "    g = f_vec.clone()\n",
        "    fwht_inplace(g)\n",
        "    coeffs = g / N\n",
        "    vals = (coeffs.to(torch.float64) ** 2)                 # [N] on f_vec.device\n",
        "    mass = torch.zeros(d + 1, dtype=torch.float64, device=vals.device)\n",
        "    idx = deg.to(device=vals.device, dtype=torch.int64)    # <<< move deg to match vals\n",
        "    mass += torch.bincount(idx, weights=vals, minlength=d + 1)\n",
        "    return mass.cpu()\n",
        "\n",
        "\n",
        "def theoretical_fwht_coeffs_for_poly_first8(c1, c2, c4, c6, c8):\n",
        "    \"\"\"\n",
        "    Build a length-256 vector of theoretical Fourier-Walsh coefficients (orthonormal),\n",
        "    indexed by subset mask (bit j = variable j).\n",
        "    \"\"\"\n",
        "    coeff = torch.zeros(N, dtype=torch.float64)\n",
        "    # degree-1: each singleton gets c1\n",
        "    for j in range(8):\n",
        "        coeff[1 << j] += c1\n",
        "    # degree-2 pairs\n",
        "    pairs = [(0,1),(2,3),(4,5),(6,7)]\n",
        "    for (i,j) in pairs:\n",
        "        coeff[(1<<i)|(1<<j)] += c2\n",
        "    # degree-4 blocks\n",
        "    coeff[(1<<0)|(1<<1)|(1<<2)|(1<<3)] += c4\n",
        "    coeff[(1<<4)|(1<<5)|(1<<6)|(1<<7)] += c4\n",
        "    # degree-6 windows\n",
        "    coeff[(1<<0)|(1<<1)|(1<<2)|(1<<3)|(1<<4)|(1<<5)] += c6\n",
        "    coeff[(1<<2)|(1<<3)|(1<<4)|(1<<5)|(1<<6)|(1<<7)] += c6\n",
        "    # degree-8 all\n",
        "    coeff[(1<<0)|(1<<1)|(1<<2)|(1<<3)|(1<<4)|(1<<5)|(1<<6)|(1<<7)] += c8\n",
        "    return coeff  # zeros elsewhere\n",
        "\n",
        "theory_coeff = theoretical_fwht_coeffs_for_poly_first8(\n",
        "    coeff_deg1, coeff_deg2, coeff_deg4, coeff_deg6, coeff_deg8\n",
        ")\n",
        "theory_sign = torch.sign(theory_coeff).to(torch.int8)  # {-1,0,+1}\n",
        "# degrees that actually have nonzero theory coefficients\n",
        "degrees_present = sorted(set(int(n) for n in deg[theory_sign != 0].tolist()))\n",
        "\n",
        "# Track the first epoch when all nonzero-theory coefficients at degree n have correct sign\n",
        "first_align_epoch = {n: None for n in degrees_present}\n",
        "\n",
        "def theoretical_mass_by_degree():\n",
        "    m = torch.zeros(d+1, dtype=torch.float64)\n",
        "    for idx in range(N):\n",
        "        n = int(deg[idx].item())\n",
        "        m[n] += (theory_coeff[idx] ** 2)\n",
        "    return m.tolist()\n",
        "\n",
        "# ---------------------------\n",
        "# Data on anchors (first 8 dims) + GT values\n",
        "# ---------------------------\n",
        "anchors = full_cube_pm1(d_indicator, device)        # [256, 8]\n",
        "true_vals = poly_val_first8(anchors, coeff_deg1, coeff_deg2, coeff_deg4, coeff_deg6, coeff_deg8)  # [256]\n",
        "anchor_weights = torch.ones(N_indicator, device=device, dtype=torch.float32)\n",
        "\n",
        "# ---------------------------\n",
        "# Model (single FCNN)\n",
        "# ---------------------------\n",
        "class FCNN(nn.Module):\n",
        "    def __init__(self, in_dim=8+500, hidden=128):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(in_dim, hidden), nn.ReLU(),\n",
        "            nn.Linear(hidden, hidden), nn.ReLU(),\n",
        "            nn.Linear(hidden, 1),\n",
        "        )\n",
        "        for m in self.modules():\n",
        "            if isinstance(m, nn.Linear):\n",
        "                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)\n",
        "    def forward(self, x): return self.net(x).squeeze(-1)\n",
        "\n",
        "# count linear layers (for plotting title/info)\n",
        "_linear_count = sum(1 for m in FCNN(d_total, hidden).net if isinstance(m, nn.Linear))\n",
        "\n",
        "model = FCNN(d_total, hidden).to(device)\n",
        "opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "\n",
        "# ---------------------------\n",
        "# Training: MSE on anchor-averaged predictions (noise marginalized)\n",
        "# ---------------------------\n",
        "history = []\n",
        "os.makedirs(checkpoint_dir, exist_ok=True) if checkpoint_dir else None\n",
        "\n",
        "for ep in range(1, epochs + 1):\n",
        "    model.train()\n",
        "\n",
        "    # Fresh paired noise tails for this step\n",
        "    noise = paired_noise(train_R, d_noise, device)             # [R, d_noise]\n",
        "    X1 = anchors.repeat_interleave(train_R, dim=0)             # [256*R, 8]\n",
        "    X2 = noise.repeat(N_indicator, 1)                          # [256*R, d_noise]\n",
        "    X  = torch.cat([X1, X2], dim=1)                            # [256*R, 8+d_noise]\n",
        "\n",
        "    pred_all = model(X).view(N_indicator, -1)                  # [256, R]\n",
        "    pred_avg = pred_all.mean(dim=1)                            # [256]\n",
        "\n",
        "    loss = weighted_mse(pred_avg, true_vals, anchor_weights)\n",
        "\n",
        "    opt.zero_grad(set_to_none=True)\n",
        "    loss.backward()\n",
        "    nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
        "    opt.step()\n",
        "\n",
        "    # ---- Save every epoch\n",
        "    if checkpoint_dir:\n",
        "        ckpt_path = os.path.join(checkpoint_dir, f\"ckpt_epoch_{ep}.pt\")\n",
        "        torch.save({\n",
        "            \"epoch\": ep,\n",
        "            \"in_dim\": d_total,\n",
        "            \"hidden\": hidden,\n",
        "            \"model_state_dict\": model.state_dict(),\n",
        "            \"coeffs\": dict(c1=coeff_deg1, c2=coeff_deg2, c4=coeff_deg4, c6=coeff_deg6, c8=coeff_deg8),\n",
        "            \"d_noise\": d_noise\n",
        "        }, ckpt_path)\n",
        "\n",
        "    # ---- Periodic diagnostics: Fourier mass + degree alignment\n",
        "    if ep % print_every == 0 or ep == 1 or ep == epochs:\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            # Anchor-averaged predictions using larger eval_R for stabler FWHT\n",
        "            noise_eval = paired_noise(eval_R, d_noise, device)\n",
        "            X1e = anchors.repeat_interleave(eval_R, dim=0)\n",
        "            X2e = noise_eval.repeat(N_indicator, 1)\n",
        "            Xe  = torch.cat([X1e, X2e], dim=1)\n",
        "            preds_avg = model(Xe).view(N_indicator, eval_R).mean(dim=1)   # [256]\n",
        "\n",
        "            # Errors vs true per anchor\n",
        "            diffs = (preds_avg - true_vals).abs()\n",
        "            max_abs_err = diffs.max().item()\n",
        "            mean_abs_err = diffs.mean().item()\n",
        "\n",
        "            # Fourier masses by degree (FWHT over 8-dim cube)\n",
        "            mass = masses_by_degree_from_f_fwht(preds_avg)\n",
        "\n",
        "            # Coefficient signs and first-alignment epochs\n",
        "            g = preds_avg.clone()\n",
        "            fwht_inplace(g)                 # unnormalized FWHT\n",
        "            coeffs_est = (g / N).to(torch.float64)    # orthonormal coefficients\n",
        "            est_sign = torch.sign(coeffs_est).to(torch.int8).cpu()\n",
        "\n",
        "            for n in degrees_present:\n",
        "                if first_align_epoch[n] is not None:\n",
        "                    continue\n",
        "                mask_n = (deg == n)                         # CPU\n",
        "                # only consider nonzero-theory coeffs at degree n\n",
        "                consider = mask_n & (theory_sign != 0)\n",
        "                if consider.any():\n",
        "                    wrong = (est_sign[consider] != theory_sign[consider])\n",
        "                    # treat zeros as not aligned; require exact sign match\n",
        "                    # (if you want to allow zero as not-wrong, swap to '==' or check wrong on sign==-theory_sign)\n",
        "                    # Here we enforce exact sign match:\n",
        "                    if not wrong.any().item():\n",
        "                        first_align_epoch[n] = ep\n",
        "\n",
        "            # Log\n",
        "            row = {\"epoch\": ep, \"loss\": float(loss.item()),\n",
        "                   \"anchor_avg_max_abs_err\": float(max_abs_err),\n",
        "                   \"anchor_avg_mean_abs_err\": float(mean_abs_err)}\n",
        "            for n in range(d + 1):\n",
        "                row[f\"mass_deg{n}\"] = float(mass[n].item())\n",
        "            history.append(row)\n",
        "\n",
        "        print(f\"[ep {ep:5d}] loss={loss.item():.6e} | \"\n",
        "              f\"anchor-avg max|pred-true|≈{max_abs_err:.3f} | mean≈{mean_abs_err:.3f}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "SSd3FGCuOX4m"
      },
      "outputs": [],
      "source": [
        "D = d_total"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "f0ing7qe4VH1"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import numpy as np\n",
        "import os, glob\n",
        "import tqdm\n",
        "import math\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "class GWGSampler:\n",
        "    def __init__(self, model, beta=1.0):\n",
        "        self.model = model\n",
        "        self.beta = float(beta)\n",
        "\n",
        "    def _energy(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        # NEGATIVE sign: lower energy = higher model output\n",
        "        y = self.model(x.view(1, -1)).view(())\n",
        "        return -y\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def _deltas_exact(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        device = x.device\n",
        "        D = x.numel()\n",
        "        y = self._energy(x)  # scalar E(x)\n",
        "\n",
        "        # vectorized single-bit flips\n",
        "        X = x.unsqueeze(0).repeat(D, 1)\n",
        "        idx = torch.arange(D, device=device)\n",
        "        X[idx, idx] = -X[idx, idx]\n",
        "        y_flips = torch.vmap(self._energy)(X)  # or: torch.stack([self._energy(X[i]) for i in range(D)])\n",
        "        return y_flips - y  # Δ_i = E(x^i) - E(x)\n",
        "\n",
        "\n",
        "    def _deltas_grad(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        # GWG approx: Δ_i ≈ -2 x_i ∂_i E(x) = 2 x_i ∂_i model(x)\n",
        "        x = x.detach().clone().requires_grad_(True)\n",
        "        y = self.model(x.view(1, -1)).view(())\n",
        "        (g,) = torch.autograd.grad(y, x, create_graph=False, retain_graph=False)\n",
        "        return (2.0 * x * g).detach()\n",
        "\n",
        "    #@torch.no_grad()\n",
        "    def single_step(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        x = x.detach().clone()\n",
        "        deltas = self._deltas_exact(x)                # Δ_i\n",
        "        #deltas = self._deltas_grad(x)\n",
        "\n",
        "        # coordinate proposal p(i) ∝ exp(-β Δ_i / 2)\n",
        "        logits = -self.beta * deltas / 2.0\n",
        "        probs  = torch.softmax(logits, dim=0)\n",
        "        i = torch.multinomial(probs, 1).item()\n",
        "\n",
        "        # candidate flip\n",
        "        x_new = x.clone(); x_new[i] = -x_new[i]\n",
        "\n",
        "        # MH correction (exact reverse proposal)\n",
        "        deltas_p = self._deltas_exact(x_new)\n",
        "        #deltas_p = self._deltas_grad(x_new)\n",
        "        q_fwd = probs[i]\n",
        "        q_rev = torch.softmax(-self.beta * deltas_p / 2.0, dim=0)[i]\n",
        "        delta_i = deltas[i]\n",
        "\n",
        "        accept = torch.exp(-self.beta * delta_i) * (q_rev / q_fwd)\n",
        "        if torch.rand((), device=x.device) < torch.clamp(accept, max=1.0):\n",
        "            return x_new.detach()\n",
        "        return x.detach()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "5-qyZqMa65dP"
      },
      "outputs": [],
      "source": [
        "def sampling_via_checkpoints(\n",
        "    checkpoint_dir: str,\n",
        "    epochs: list[int],\n",
        "    FCNNClass,\n",
        "    GWGSamplerClass,\n",
        "    num_particles: int = 200,\n",
        "    mcmc_steps: int = 15,\n",
        "    resample_thresh: float = 0.5,\n",
        "    device: str = \"cuda\",\n",
        "    beta: float = 1.0\n",
        "):\n",
        "    epochs = sorted(epochs)\n",
        "    ckpts = [os.path.join(checkpoint_dir, f\"ckpt_epoch_{e}.pt\") for e in epochs]\n",
        "\n",
        "    D = d_total\n",
        "    #print(\"D\", D)\n",
        "    particles = (torch.randint(0, 2, (num_particles, D), device=device) * 2 - 1).float()\n",
        "\n",
        "\n",
        "    for t, ckpt in enumerate(ckpts):\n",
        "        # load model\n",
        "        model = FCNNClass(D, hidden).to(device).eval()\n",
        "        sd = torch.load(ckpt, map_location=device)\n",
        "        model.load_state_dict(sd['model_state_dict'])\n",
        "\n",
        "\n",
        "        # GWG rejuvenation targeting current energy\n",
        "        sampler = GWGSamplerClass(model, beta=beta)\n",
        "        for i in range(num_particles):\n",
        "            x = particles[i]\n",
        "            for _ in range(mcmc_steps):\n",
        "                x = sampler.single_step(x)\n",
        "            particles[i] = x\n",
        "\n",
        "        # progress\n",
        "        with torch.no_grad():\n",
        "            # Euclidean distance: ||x - x*|| = 2 * sqrt(Hamming)\n",
        "            deltas_L2 = (particles[:,:d_indicator] - torch.ones(1,8).cuda()).norm(dim=1)\n",
        "            #print(\"distance to target1 (L2):\", deltas_L2.min(), deltas_L2.median(), deltas_L2.max())\n",
        "\n",
        "    return particles.cpu().numpy(), (particles[:,:d_indicator] - torch.ones(1,8).cuda()).min() == 0.0\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [],
      "source": [
        "hit_count = 0\n",
        "for i in range(200):\n",
        "  particles, hit_or_not = sampling_via_checkpoints(checkpoint_dir,[25,30000], FCNN, GWGSampler,num_particles = 1, mcmc_steps=20, beta=100.0)\n",
        "  hit_count += hit_or_not.item()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "104"
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "hit_count"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Hit count: 104/200\n",
            "Hit fraction: 0.5200\n",
            "2 SD CI: [0.4493, 0.5907]  (SE ≈ 0.0353)\n"
          ]
        }
      ],
      "source": [
        "import math\n",
        "\n",
        "n_trials = 200  # or len of your loop\n",
        "p = hit_count / float(n_trials)  # hit fraction\n",
        "se = math.sqrt(p * (1.0 - p) / n_trials) if n_trials > 0 else float('nan')\n",
        "\n",
        "lo = max(0.0, p - 2 * se)\n",
        "hi = min(1.0, p + 2 * se)\n",
        "\n",
        "print(f\"Hit count: {hit_count}/{n_trials}\")\n",
        "print(f\"Hit fraction: {p:.4f}\")\n",
        "print(f\"2 SD CI: [{lo:.4f}, {hi:.4f}]  (SE ≈ {se:.4f})\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "F8C_eM1iU-KV"
      },
      "outputs": [],
      "source": [
        "@torch.no_grad()\n",
        "def first_hit_steps(sampler, target10, d_indicator, d_total, device, max_steps=10000):\n",
        "    \"\"\"\n",
        "    Start from a random ±1 particle of length d_total.\n",
        "    Run single-step GWG until the first d_indicator bits equal target10.\n",
        "    Return the number of steps to first hit; None if not hit within max_steps.\n",
        "    \"\"\"\n",
        "    x = rand_pm1((d_total,), device).to(torch.float32)\n",
        "\n",
        "    # check if we already start on target\n",
        "    if (x[:d_indicator] == target10).all():\n",
        "        return 0\n",
        "\n",
        "    for t in range(1, max_steps + 1):\n",
        "        x = sampler.single_step(x)\n",
        "        # Ensure x stays in ±1 if sampler returns logits or probabilities:\n",
        "        # (Uncomment the next line if needed for your GWG implementation)\n",
        "        # x = torch.sign(x).clamp(min=-1, max=1)\n",
        "\n",
        "        if (x[:d_indicator] == target10).all():\n",
        "            return t\n",
        "    return None"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {},
      "outputs": [],
      "source": [
        "import torch\n",
        "\n",
        "@torch.no_grad()\n",
        "def first_hit_steps_annealed(\n",
        "    sampler,\n",
        "    target10: torch.Tensor,\n",
        "    d_indicator: int,\n",
        "    d_total: int,\n",
        "    device,\n",
        "    betas: torch.Tensor | None = None,\n",
        "    steps_per_beta: int = 1,\n",
        "    max_total_steps: int = 10000,\n",
        "    enforce_pm1: bool = False,\n",
        "):\n",
        "    \"\"\"\n",
        "    Simulated annealing / tempered transitions version.\n",
        "    No AIS weights; just marches beta and does steps_per_beta GWG steps at each beta.\n",
        "\n",
        "    Returns:\n",
        "        steps_to_hit: int or None\n",
        "    \"\"\"\n",
        "    # beta schedule\n",
        "    if betas is None:\n",
        "        # Choose length consistent with max_total_steps and steps_per_beta\n",
        "        n_levels = max(2, min(max_total_steps // max(1, steps_per_beta), 1000))\n",
        "        betas = torch.linspace(0.0, 1.0, steps=n_levels, device=device)\n",
        "    else:\n",
        "        betas = betas.to(device)\n",
        "\n",
        "    # init state ±1\n",
        "    x = rand_pm1((d_total,), device).to(torch.float32)\n",
        "\n",
        "    if (x[:d_indicator] == target10).all():\n",
        "        return 0\n",
        "\n",
        "    steps = 0\n",
        "\n",
        "    for b in betas:\n",
        "        sampler.beta = float(b.item())\n",
        "        for _ in range(steps_per_beta):\n",
        "            x = sampler.single_step(x)\n",
        "            if enforce_pm1:\n",
        "                x = torch.sign(x).clamp(min=-1, max=1)\n",
        "            steps += 1\n",
        "            if (x[:d_indicator] == target10).all():\n",
        "                return steps\n",
        "            if steps >= max_total_steps:\n",
        "                return None\n",
        "    return None\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "PZpQUR6DU6XX"
      },
      "outputs": [],
      "source": [
        "@torch.no_grad()\n",
        "def run_gwg_trials(model, target10, d_indicator, d_total, device,\n",
        "                   n_trials=200, max_steps=10000, beta=1.0, verbose_every=50,\n",
        "                   bootstrap_B=2000, rng_seed=0):\n",
        "    \"\"\"\n",
        "    Run GWG first-hit experiments and report statistics with 2 SD confidence intervals.\n",
        "\n",
        "    - Unsuccessful trials are counted as max_steps for 'ALL trials' aggregates.\n",
        "    - Medians use bootstrap to estimate the standard error, then ± 2*SE for the CI.\n",
        "    \"\"\"\n",
        "    import numpy as np\n",
        "    import math\n",
        "    model.eval()\n",
        "    sampler = GWGSampler(model, beta=beta)\n",
        "\n",
        "    # ---------------------------\n",
        "    # Helpers\n",
        "    # ---------------------------\n",
        "    def ci_mean_2sd(x):\n",
        "        \"\"\"Mean ± 2 SD/√n CI.\"\"\"\n",
        "        x = np.asarray(x, dtype=np.float64)\n",
        "        n = len(x)\n",
        "        mu = float(x.mean())\n",
        "        sd = float(x.std(ddof=1)) if n > 1 else 0.0\n",
        "        se = sd / math.sqrt(n) if n > 0 else float(\"nan\")\n",
        "        return mu, sd, (mu - 2*se, mu + 2*se)\n",
        "\n",
        "    def ci_prop_2sd(k, n):\n",
        "        \"\"\"Proportion ± 2 SD (binomial SE).\"\"\"\n",
        "        p = (k / n) if n > 0 else float(\"nan\")\n",
        "        se = math.sqrt(p * (1 - p) / n) if n > 0 else float(\"nan\")\n",
        "        lo, hi = max(0.0, p - 2*se), min(1.0, p + 2*se)\n",
        "        return p, se, (lo, hi)\n",
        "\n",
        "    def ci_median_bootstrap_2sd(x, B=2000, seed=0):\n",
        "        \"\"\"Median and ± 2 SD bootstrap CI.\"\"\"\n",
        "        x = np.asarray(x, dtype=np.float64)\n",
        "        n = len(x)\n",
        "        if n == 0:\n",
        "            return float(\"nan\"), float(\"nan\"), (float(\"nan\"), float(\"nan\"))\n",
        "        if n == 1:\n",
        "            med = float(x[0])\n",
        "            return med, 0.0, (med, med)\n",
        "        rng = np.random.default_rng(seed)\n",
        "        med = float(np.median(x))\n",
        "        meds = np.empty(B, dtype=np.float64)\n",
        "        idx = np.arange(n)\n",
        "        for b in range(B):\n",
        "            resample = x[rng.choice(idx, size=n, replace=True)]\n",
        "            meds[b] = np.median(resample)\n",
        "        sd = float(meds.std(ddof=1))\n",
        "        return med, sd, (med - 2*sd, med + 2*sd)\n",
        "\n",
        "    def robust_mad(x):\n",
        "        \"\"\"Median Absolute Deviation (MAD).\"\"\"\n",
        "        x = np.asarray(x, dtype=np.float64)\n",
        "        if len(x) == 0:\n",
        "            return float(\"nan\")\n",
        "        med = np.median(x)\n",
        "        return float(np.median(np.abs(x - med)))\n",
        "\n",
        "    # ---------------------------\n",
        "    # Trials\n",
        "    # ---------------------------\n",
        "    hits_only = []    # steps for successful trials\n",
        "    all_steps = []    # steps for all trials (misses counted as max_steps)\n",
        "    misses = 0\n",
        "\n",
        "    for i in range(1, n_trials + 1):\n",
        "        #steps = first_hit_steps(sampler, target10, d_indicator, d_total, device, max_steps)\n",
        "        steps = first_hit_steps_annealed(sampler, target10, d_indicator, d_total, device, max_total_steps=max_steps)\n",
        "        \n",
        "        if steps is None:\n",
        "            misses += 1\n",
        "            all_steps.append(max_steps)\n",
        "            last_str = \"miss\"\n",
        "        else:\n",
        "            s = int(steps)\n",
        "            hits_only.append(s)\n",
        "            all_steps.append(s)\n",
        "            last_str = str(s)\n",
        "\n",
        "        if verbose_every and (i % verbose_every == 0 or i == n_trials):\n",
        "            hit_rate = (i - misses) / i\n",
        "            print(f\"[trial {i:4d}] last={last_str} | hits={i - misses} | misses={misses} | hit_rate={hit_rate:.3f}\")\n",
        "\n",
        "    # Convert to numpy\n",
        "    arr_all  = np.array(all_steps, dtype=np.int64)\n",
        "    arr_hits = np.array(hits_only, dtype=np.int64)\n",
        "\n",
        "    # ---------------------------\n",
        "    # Core statistics with 2 SD CIs\n",
        "    # ---------------------------\n",
        "    print(\"\\n=== GWG First-Hit Statistics with 2 SD Confidence Intervals ===\")\n",
        "    print(f\"trials={n_trials} | hits={n_trials - misses} | misses={misses} | miss_penalty=max_steps({max_steps})\")\n",
        "\n",
        "    # Hit rate CI (binomial)\n",
        "    p, p_se, (p_lo, p_hi) = ci_prop_2sd(n_trials - misses, n_trials)\n",
        "    print(f\"Hit rate              : {p:.4f}  (±2SD CI: [{p_lo:.4f}, {p_hi:.4f}])  | SE≈{p_se:.4f}\")\n",
        "\n",
        "    # Mean (ALL trials)\n",
        "    mean_all, sd_all, (lo_all, hi_all) = ci_mean_2sd(arr_all)\n",
        "    print(f\"Mean steps (ALL)      : {mean_all:.2f}  (±2SD CI: [{lo_all:.2f}, {hi_all:.2f}])  | SD={sd_all:.2f}\")\n",
        "\n",
        "    # Median (ALL trials, misses=max_steps)\n",
        "    med_all, med_all_se_boot, (med_all_lo, med_all_hi) = ci_median_bootstrap_2sd(\n",
        "        arr_all, B=bootstrap_B, seed=rng_seed\n",
        "    )\n",
        "    print(f\"Median steps (ALL)    : {med_all:.2f}  (±2SD boot CI: [{med_all_lo:.2f}, {med_all_hi:.2f}])  | boot SD≈{med_all_se_boot:.2f}\")\n",
        "\n",
        "    # Median (HITS only)\n",
        "    if len(arr_hits) > 0:\n",
        "        med_hits, med_hits_se_boot, (med_hits_lo, med_hits_hi) = ci_median_bootstrap_2sd(\n",
        "            arr_hits, B=bootstrap_B, seed=rng_seed + 1\n",
        "        )\n",
        "        print(f\"Median steps (HITS)   : {med_hits:.2f}  (±2SD boot CI: [{med_hits_lo:.2f}, {med_hits_hi:.2f}])  | boot SD≈{med_hits_se_boot:.2f}\")\n",
        "    else:\n",
        "        print(\"Median steps (HITS)   : n/a (no successful trials)\")\n",
        "\n",
        "    # ---------------------------\n",
        "    # Additional useful stats\n",
        "    # ---------------------------\n",
        "    def q(arr, p): return float(np.percentile(arr, p)) if len(arr) else float(\"nan\")\n",
        "\n",
        "    if len(arr_all):\n",
        "        print(\"\\n-- Distribution (ALL trials) --\")\n",
        "        print(f\"min / p25 / p50 / p75 / max : {arr_all.min():.0f} / {q(arr_all,25):.0f} / {q(arr_all,50):.0f} / {q(arr_all,75):.0f} / {arr_all.max():.0f}\")\n",
        "        print(f\"IQR (p75-p25)         : {q(arr_all,75) - q(arr_all,25):.2f}\")\n",
        "        print(f\"MAD (about median)    : {robust_mad(arr_all):.2f}\")\n",
        "        for pct in (90, 95, 99):\n",
        "            print(f\"p{pct:02d}                 : {q(arr_all, pct):.2f}\")\n",
        "\n",
        "        # Probability of hitting within selected step budgets\n",
        "        budgets = sorted(set([100, 500, 1000, 5000, max_steps]))\n",
        "        probs_within = []\n",
        "        for T in budgets:\n",
        "            probs_within.append((T, float((arr_all <= T).mean())))\n",
        "        print(\"\\nHit probability within budgets (ALL trials):\")\n",
        "        for T, pr in probs_within:\n",
        "            print(f\"  ≤ {T:6d} steps : {pr:.4f}\")\n",
        "\n",
        "    if len(arr_hits):\n",
        "        print(\"\\n-- Distribution (successful trials ONLY) --\")\n",
        "        print(f\"min / p25 / p50 / p75 / max : {arr_hits.min():.0f} / {q(arr_hits,25):.0f} / {q(arr_hits,50):.0f} / {q(arr_hits,75):.0f} / {arr_hits.max():.0f}\")\n",
        "        print(f\"IQR (p75-p25)         : {q(arr_hits,75) - q(arr_hits,25):.2f}\")\n",
        "        print(f\"MAD (about median)    : {robust_mad(arr_hits):.2f}\")\n",
        "        for pct in (90, 95, 99):\n",
        "            print(f\"p{pct:02d}                 : {q(arr_hits, pct):.2f}\")\n",
        "\n",
        "    # Optional compact histogram (ALL trials). Shows censoring spike at max_steps if many misses.\n",
        "    try:\n",
        "        import collections\n",
        "        hist = collections.Counter(arr_all.tolist())\n",
        "        most_common = sorted(hist.items(), key=lambda kv: (-kv[1], kv[0]))[:20]\n",
        "        print(\"\\nTop (step,count) bins (ALL trials, 20 most common):\")\n",
        "        for step, cnt in most_common:\n",
        "            print(f\"  {step:7d} : {cnt}\")\n",
        "    except Exception:\n",
        "        pass\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-Rd2cBBmRjeF",
        "outputId": "7574a178-e8e4-4dcb-c4f4-102d71eec07e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trial    5] last=miss | hits=0 | misses=5 | hit_rate=0.000\n",
            "[trial   10] last=miss | hits=0 | misses=10 | hit_rate=0.000\n",
            "[trial   15] last=212 | hits=1 | misses=14 | hit_rate=0.067\n",
            "[trial   20] last=miss | hits=1 | misses=19 | hit_rate=0.050\n",
            "[trial   25] last=miss | hits=1 | misses=24 | hit_rate=0.040\n",
            "[trial   30] last=miss | hits=1 | misses=29 | hit_rate=0.033\n",
            "[trial   35] last=miss | hits=1 | misses=34 | hit_rate=0.029\n",
            "[trial   40] last=miss | hits=2 | misses=38 | hit_rate=0.050\n",
            "[trial   45] last=miss | hits=2 | misses=43 | hit_rate=0.044\n",
            "[trial   50] last=miss | hits=2 | misses=48 | hit_rate=0.040\n",
            "[trial   55] last=miss | hits=2 | misses=53 | hit_rate=0.036\n",
            "[trial   60] last=miss | hits=2 | misses=58 | hit_rate=0.033\n",
            "[trial   65] last=miss | hits=2 | misses=63 | hit_rate=0.031\n",
            "[trial   70] last=miss | hits=2 | misses=68 | hit_rate=0.029\n",
            "[trial   75] last=miss | hits=4 | misses=71 | hit_rate=0.053\n",
            "[trial   80] last=miss | hits=4 | misses=76 | hit_rate=0.050\n",
            "[trial   85] last=miss | hits=4 | misses=81 | hit_rate=0.047\n",
            "[trial   90] last=269 | hits=6 | misses=84 | hit_rate=0.067\n",
            "[trial   95] last=miss | hits=6 | misses=89 | hit_rate=0.063\n",
            "[trial  100] last=miss | hits=6 | misses=94 | hit_rate=0.060\n",
            "[trial  105] last=miss | hits=6 | misses=99 | hit_rate=0.057\n",
            "[trial  110] last=miss | hits=6 | misses=104 | hit_rate=0.055\n",
            "[trial  115] last=miss | hits=6 | misses=109 | hit_rate=0.052\n",
            "[trial  120] last=miss | hits=6 | misses=114 | hit_rate=0.050\n",
            "[trial  125] last=63 | hits=7 | misses=118 | hit_rate=0.056\n",
            "[trial  130] last=miss | hits=7 | misses=123 | hit_rate=0.054\n",
            "[trial  135] last=miss | hits=7 | misses=128 | hit_rate=0.052\n",
            "[trial  140] last=miss | hits=7 | misses=133 | hit_rate=0.050\n",
            "[trial  145] last=miss | hits=7 | misses=138 | hit_rate=0.048\n",
            "[trial  150] last=miss | hits=7 | misses=143 | hit_rate=0.047\n",
            "[trial  155] last=miss | hits=7 | misses=148 | hit_rate=0.045\n",
            "[trial  160] last=miss | hits=7 | misses=153 | hit_rate=0.044\n",
            "[trial  165] last=miss | hits=7 | misses=158 | hit_rate=0.042\n",
            "[trial  170] last=miss | hits=7 | misses=163 | hit_rate=0.041\n",
            "[trial  175] last=miss | hits=7 | misses=168 | hit_rate=0.040\n",
            "[trial  180] last=miss | hits=7 | misses=173 | hit_rate=0.039\n",
            "[trial  185] last=miss | hits=7 | misses=178 | hit_rate=0.038\n",
            "[trial  190] last=miss | hits=7 | misses=183 | hit_rate=0.037\n",
            "[trial  195] last=miss | hits=7 | misses=188 | hit_rate=0.036\n",
            "[trial  200] last=miss | hits=7 | misses=193 | hit_rate=0.035\n",
            "\n",
            "=== GWG First-Hit Statistics with 2 SD Confidence Intervals ===\n",
            "trials=200 | hits=7 | misses=193 | miss_penalty=max_steps(2000)\n",
            "Hit rate              : 0.0350  (±2SD CI: [0.0090, 0.0610])  | SE≈0.0130\n",
            "Mean steps (ALL)      : 1936.22  (±2SD CI: [1888.57, 1983.87])  | SD=336.91\n",
            "Median steps (ALL)    : 2000.00  (±2SD boot CI: [2000.00, 2000.00])  | boot SD≈0.00\n",
            "Median steps (HITS)   : 183.00  (±2SD boot CI: [19.45, 346.55])  | boot SD≈81.78\n",
            "\n",
            "-- Distribution (ALL trials) --\n",
            "min / p25 / p50 / p75 / max : 0 / 2000 / 2000 / 2000 / 2000\n",
            "IQR (p75-p25)         : 0.00\n",
            "MAD (about median)    : 0.00\n",
            "p90                 : 2000.00\n",
            "p95                 : 2000.00\n",
            "p99                 : 2000.00\n",
            "\n",
            "Hit probability within budgets (ALL trials):\n",
            "  ≤    100 steps : 0.0150\n",
            "  ≤    500 steps : 0.0350\n",
            "  ≤   1000 steps : 0.0350\n",
            "  ≤   2000 steps : 1.0000\n",
            "  ≤   5000 steps : 1.0000\n",
            "\n",
            "-- Distribution (successful trials ONLY) --\n",
            "min / p25 / p50 / p75 / max : 0 / 55 / 183 / 240 / 470\n",
            "IQR (p75-p25)         : 185.50\n",
            "MAD (about median)    : 120.00\n",
            "p90                 : 349.40\n",
            "p95                 : 409.70\n",
            "p99                 : 457.94\n",
            "\n",
            "Top (step,count) bins (ALL trials, 20 most common):\n",
            "     2000 : 193\n",
            "        0 : 1\n",
            "       47 : 1\n",
            "       63 : 1\n",
            "      183 : 1\n",
            "      212 : 1\n",
            "      269 : 1\n",
            "      470 : 1\n"
          ]
        }
      ],
      "source": [
        "run_gwg_trials(\n",
        "    model=model,\n",
        "    target10=torch.ones((8,)).cuda(),\n",
        "    d_indicator=d_indicator,\n",
        "    d_total=d_total,      # = 10 + 500 in the current script\n",
        "    device=torch.device(\"cuda\"),\n",
        "    n_trials=200,\n",
        "    max_steps=2000,\n",
        "    beta=100.0,\n",
        "    verbose_every=5\n",
        ")\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "torch-gpu-env",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.13.5"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
