{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "target10 (fixed to all +1s): [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
            "Train size: 67584 | base reps/anchor: 64 | target extra reps: 2048\n",
            "[ep     1] loss=4.416933e+01 | mean(pred on target rows)≈-0.330 | max|pred| on non-target rows≈3.089\n",
            "[ep    50] loss=7.900964e-01 | mean(pred on target rows)≈9.022 | max|pred| on non-target rows≈3.788\n",
            "[ep   100] loss=6.922771e-02 | mean(pred on target rows)≈8.791 | max|pred| on non-target rows≈1.476\n",
            "[ep   150] loss=3.400241e-02 | mean(pred on target rows)≈9.085 | max|pred| on non-target rows≈1.424\n",
            "[ep   200] loss=1.232355e-02 | mean(pred on target rows)≈8.985 | max|pred| on non-target rows≈1.327\n",
            "[ep   250] loss=1.372893e-02 | mean(pred on target rows)≈9.078 | max|pred| on non-target rows≈1.198\n",
            "[ep   300] loss=8.899807e-03 | mean(pred on target rows)≈8.937 | max|pred| on non-target rows≈1.139\n",
            "[ep   350] loss=6.513441e-03 | mean(pred on target rows)≈9.055 | max|pred| on non-target rows≈1.083\n",
            "[ep   400] loss=8.706742e-02 | mean(pred on target rows)≈8.962 | max|pred| on non-target rows≈1.101\n",
            "[ep   450] loss=3.303090e-03 | mean(pred on target rows)≈8.991 | max|pred| on non-target rows≈1.106\n",
            "[ep   500] loss=2.015171e-03 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈1.077\n",
            "[ep   550] loss=1.886434e-03 | mean(pred on target rows)≈8.971 | max|pred| on non-target rows≈1.067\n",
            "[ep   600] loss=2.176045e-03 | mean(pred on target rows)≈9.043 | max|pred| on non-target rows≈1.027\n",
            "[ep   650] loss=1.471363e-01 | mean(pred on target rows)≈8.942 | max|pred| on non-target rows≈1.037\n",
            "[ep   700] loss=1.543127e-03 | mean(pred on target rows)≈8.991 | max|pred| on non-target rows≈1.014\n",
            "[ep   750] loss=7.914665e-04 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.975\n",
            "[ep   800] loss=1.204887e-03 | mean(pred on target rows)≈9.031 | max|pred| on non-target rows≈0.962\n",
            "[ep   850] loss=1.934342e-03 | mean(pred on target rows)≈9.051 | max|pred| on non-target rows≈0.962\n",
            "[ep   900] loss=1.122819e-03 | mean(pred on target rows)≈9.035 | max|pred| on non-target rows≈0.966\n",
            "[ep   950] loss=5.506509e-03 | mean(pred on target rows)≈9.098 | max|pred| on non-target rows≈0.959\n",
            "[ep  1000] loss=3.480601e-03 | mean(pred on target rows)≈9.019 | max|pred| on non-target rows≈0.961\n",
            "[ep  1050] loss=3.756034e-04 | mean(pred on target rows)≈9.002 | max|pred| on non-target rows≈0.962\n",
            "[ep  1100] loss=1.003567e-03 | mean(pred on target rows)≈9.036 | max|pred| on non-target rows≈0.954\n",
            "[ep  1150] loss=1.703260e-03 | mean(pred on target rows)≈9.052 | max|pred| on non-target rows≈0.945\n",
            "[ep  1200] loss=1.452987e-01 | mean(pred on target rows)≈8.963 | max|pred| on non-target rows≈1.125\n",
            "[ep  1250] loss=1.617175e-02 | mean(pred on target rows)≈8.840 | max|pred| on non-target rows≈1.080\n",
            "[ep  1300] loss=1.507910e-02 | mean(pred on target rows)≈9.166 | max|pred| on non-target rows≈0.926\n",
            "[ep  1350] loss=4.801654e-04 | mean(pred on target rows)≈8.983 | max|pred| on non-target rows≈0.970\n",
            "[ep  1400] loss=2.177609e-04 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.953\n",
            "[ep  1450] loss=2.073559e-04 | mean(pred on target rows)≈8.993 | max|pred| on non-target rows≈0.950\n",
            "[ep  1500] loss=2.724839e-04 | mean(pred on target rows)≈8.985 | max|pred| on non-target rows≈0.943\n",
            "[ep  1550] loss=9.779178e-04 | mean(pred on target rows)≈9.038 | max|pred| on non-target rows≈0.927\n",
            "[ep  1600] loss=3.619821e-04 | mean(pred on target rows)≈9.001 | max|pred| on non-target rows≈0.954\n",
            "[ep  1650] loss=1.914377e-03 | mean(pred on target rows)≈8.944 | max|pred| on non-target rows≈0.974\n",
            "[ep  1700] loss=2.525964e-03 | mean(pred on target rows)≈8.939 | max|pred| on non-target rows≈0.972\n",
            "[ep  1750] loss=5.486442e-04 | mean(pred on target rows)≈8.983 | max|pred| on non-target rows≈0.942\n",
            "[ep  1800] loss=5.171351e-04 | mean(pred on target rows)≈9.009 | max|pred| on non-target rows≈0.940\n",
            "[ep  1850] loss=7.997073e-03 | mean(pred on target rows)≈9.065 | max|pred| on non-target rows≈0.960\n",
            "[ep  1900] loss=5.280577e-04 | mean(pred on target rows)≈8.975 | max|pred| on non-target rows≈0.953\n",
            "[ep  1950] loss=6.431603e-03 | mean(pred on target rows)≈8.985 | max|pred| on non-target rows≈1.012\n",
            "[ep  2000] loss=1.456296e-03 | mean(pred on target rows)≈9.048 | max|pred| on non-target rows≈0.916\n",
            "[ep  2050] loss=1.208424e-03 | mean(pred on target rows)≈9.024 | max|pred| on non-target rows≈0.949\n",
            "[ep  2100] loss=1.186362e-03 | mean(pred on target rows)≈9.040 | max|pred| on non-target rows≈0.919\n",
            "[ep  2150] loss=2.712248e-04 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.951\n",
            "[ep  2200] loss=2.608391e-03 | mean(pred on target rows)≈8.981 | max|pred| on non-target rows≈0.976\n",
            "[ep  2250] loss=1.984302e-03 | mean(pred on target rows)≈8.997 | max|pred| on non-target rows≈0.982\n",
            "[ep  2300] loss=1.602049e-03 | mean(pred on target rows)≈8.947 | max|pred| on non-target rows≈0.984\n",
            "[ep  2350] loss=1.526327e-03 | mean(pred on target rows)≈8.983 | max|pred| on non-target rows≈0.971\n",
            "[ep  2400] loss=3.088474e-03 | mean(pred on target rows)≈9.068 | max|pred| on non-target rows≈0.912\n",
            "[ep  2450] loss=2.677664e-03 | mean(pred on target rows)≈8.987 | max|pred| on non-target rows≈0.979\n",
            "[ep  2500] loss=3.395333e-04 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.950\n",
            "[ep  2550] loss=3.833847e-03 | mean(pred on target rows)≈9.011 | max|pred| on non-target rows≈0.972\n",
            "[ep  2600] loss=9.171930e-04 | mean(pred on target rows)≈9.023 | max|pred| on non-target rows≈0.911\n",
            "[ep  2650] loss=1.865930e-02 | mean(pred on target rows)≈8.970 | max|pred| on non-target rows≈1.095\n",
            "[ep  2700] loss=2.882018e-04 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈1.018\n",
            "[ep  2750] loss=1.775996e-04 | mean(pred on target rows)≈9.010 | max|pred| on non-target rows≈0.952\n",
            "[ep  2800] loss=1.140738e-04 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.954\n",
            "[ep  2850] loss=2.000526e-04 | mean(pred on target rows)≈9.014 | max|pred| on non-target rows≈0.950\n",
            "[ep  2900] loss=9.164588e-04 | mean(pred on target rows)≈8.959 | max|pred| on non-target rows≈0.971\n",
            "[ep  2950] loss=4.898996e-04 | mean(pred on target rows)≈8.971 | max|pred| on non-target rows≈0.966\n",
            "[ep  3000] loss=1.579821e-04 | mean(pred on target rows)≈8.986 | max|pred| on non-target rows≈0.968\n",
            "[ep  3050] loss=8.959741e-04 | mean(pred on target rows)≈8.959 | max|pred| on non-target rows≈0.978\n",
            "[ep  3100] loss=7.573331e-05 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.970\n",
            "[ep  3150] loss=1.004277e-03 | mean(pred on target rows)≈9.042 | max|pred| on non-target rows≈0.946\n",
            "[ep  3200] loss=3.942654e-04 | mean(pred on target rows)≈8.991 | max|pred| on non-target rows≈0.967\n",
            "[ep  3250] loss=3.424034e-04 | mean(pred on target rows)≈9.019 | max|pred| on non-target rows≈0.946\n",
            "[ep  3300] loss=1.421624e-03 | mean(pred on target rows)≈9.044 | max|pred| on non-target rows≈0.943\n",
            "[ep  3350] loss=8.062765e-04 | mean(pred on target rows)≈9.009 | max|pred| on non-target rows≈0.953\n",
            "[ep  3400] loss=6.395763e-04 | mean(pred on target rows)≈8.973 | max|pred| on non-target rows≈0.976\n",
            "[ep  3450] loss=3.617220e-03 | mean(pred on target rows)≈8.950 | max|pred| on non-target rows≈1.005\n",
            "[ep  3500] loss=1.333031e-03 | mean(pred on target rows)≈9.042 | max|pred| on non-target rows≈0.928\n",
            "[ep  3550] loss=7.296182e-04 | mean(pred on target rows)≈8.994 | max|pred| on non-target rows≈0.953\n",
            "[ep  3600] loss=5.011030e-03 | mean(pred on target rows)≈8.983 | max|pred| on non-target rows≈0.995\n",
            "[ep  3650] loss=1.659220e-03 | mean(pred on target rows)≈9.054 | max|pred| on non-target rows≈0.935\n",
            "[ep  3700] loss=4.294365e-04 | mean(pred on target rows)≈9.006 | max|pred| on non-target rows≈0.951\n",
            "[ep  3750] loss=2.023071e-03 | mean(pred on target rows)≈8.954 | max|pred| on non-target rows≈0.986\n",
            "[ep  3800] loss=2.566106e-03 | mean(pred on target rows)≈8.959 | max|pred| on non-target rows≈0.984\n",
            "[ep  3850] loss=1.031285e-03 | mean(pred on target rows)≈9.035 | max|pred| on non-target rows≈0.944\n",
            "[ep  3900] loss=8.698078e-04 | mean(pred on target rows)≈9.037 | max|pred| on non-target rows≈0.926\n",
            "[ep  3950] loss=7.876786e-04 | mean(pred on target rows)≈9.032 | max|pred| on non-target rows≈0.924\n",
            "[ep  4000] loss=3.166648e-04 | mean(pred on target rows)≈8.985 | max|pred| on non-target rows≈0.947\n",
            "[ep  4050] loss=5.025191e-03 | mean(pred on target rows)≈9.045 | max|pred| on non-target rows≈0.964\n",
            "[ep  4100] loss=5.858783e-04 | mean(pred on target rows)≈9.028 | max|pred| on non-target rows≈0.922\n",
            "[ep  4150] loss=1.851752e-02 | mean(pred on target rows)≈9.024 | max|pred| on non-target rows≈0.974\n",
            "[ep  4200] loss=2.311000e-02 | mean(pred on target rows)≈8.927 | max|pred| on non-target rows≈1.057\n",
            "[ep  4250] loss=2.178975e-04 | mean(pred on target rows)≈9.007 | max|pred| on non-target rows≈0.968\n",
            "[ep  4300] loss=1.056896e-04 | mean(pred on target rows)≈9.008 | max|pred| on non-target rows≈0.963\n",
            "[ep  4350] loss=1.147998e-04 | mean(pred on target rows)≈9.011 | max|pred| on non-target rows≈0.961\n",
            "[ep  4400] loss=4.948626e-04 | mean(pred on target rows)≈8.970 | max|pred| on non-target rows≈0.967\n",
            "[ep  4450] loss=2.376310e-03 | mean(pred on target rows)≈9.067 | max|pred| on non-target rows≈0.951\n",
            "[ep  4500] loss=5.534476e-04 | mean(pred on target rows)≈9.031 | max|pred| on non-target rows≈0.947\n",
            "[ep  4550] loss=9.840734e-04 | mean(pred on target rows)≈8.957 | max|pred| on non-target rows≈0.971\n",
            "[ep  4600] loss=1.233132e-04 | mean(pred on target rows)≈8.986 | max|pred| on non-target rows≈0.963\n",
            "[ep  4650] loss=9.788412e-03 | mean(pred on target rows)≈9.133 | max|pred| on non-target rows≈0.924\n",
            "[ep  4700] loss=1.371340e-02 | mean(pred on target rows)≈9.141 | max|pred| on non-target rows≈0.974\n",
            "[ep  4750] loss=1.044183e-02 | mean(pred on target rows)≈8.857 | max|pred| on non-target rows≈0.992\n",
            "[ep  4800] loss=6.098608e-03 | mean(pred on target rows)≈8.891 | max|pred| on non-target rows≈0.973\n",
            "[ep  4850] loss=4.975115e-05 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.955\n",
            "[ep  4900] loss=1.111053e-04 | mean(pred on target rows)≈9.011 | max|pred| on non-target rows≈0.943\n",
            "[ep  4950] loss=4.967031e-05 | mean(pred on target rows)≈9.002 | max|pred| on non-target rows≈0.956\n",
            "[ep  5000] loss=7.031816e-05 | mean(pred on target rows)≈8.991 | max|pred| on non-target rows≈0.950\n",
            "[ep  5050] loss=9.618975e-04 | mean(pred on target rows)≈8.957 | max|pred| on non-target rows≈0.944\n",
            "[ep  5100] loss=9.001981e-04 | mean(pred on target rows)≈8.959 | max|pred| on non-target rows≈0.948\n",
            "[ep  5150] loss=2.674343e-02 | mean(pred on target rows)≈8.983 | max|pred| on non-target rows≈0.962\n",
            "[ep  5200] loss=2.735257e-04 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.960\n",
            "[ep  5250] loss=3.194688e-05 | mean(pred on target rows)≈8.998 | max|pred| on non-target rows≈0.954\n",
            "[ep  5300] loss=9.611073e-05 | mean(pred on target rows)≈8.988 | max|pred| on non-target rows≈0.953\n",
            "[ep  5350] loss=1.322961e-03 | mean(pred on target rows)≈9.050 | max|pred| on non-target rows≈0.941\n",
            "[ep  5400] loss=1.752504e-03 | mean(pred on target rows)≈9.057 | max|pred| on non-target rows≈0.937\n",
            "[ep  5450] loss=1.408918e-03 | mean(pred on target rows)≈8.948 | max|pred| on non-target rows≈0.957\n",
            "[ep  5500] loss=6.774802e-04 | mean(pred on target rows)≈9.035 | max|pred| on non-target rows≈0.936\n",
            "[ep  5550] loss=3.622251e-04 | mean(pred on target rows)≈8.974 | max|pred| on non-target rows≈0.947\n",
            "[ep  5600] loss=2.290581e-03 | mean(pred on target rows)≈9.067 | max|pred| on non-target rows≈0.935\n",
            "[ep  5650] loss=3.415723e-04 | mean(pred on target rows)≈8.983 | max|pred| on non-target rows≈0.936\n",
            "[ep  5700] loss=2.655341e-03 | mean(pred on target rows)≈9.060 | max|pred| on non-target rows≈0.932\n",
            "[ep  5750] loss=7.394962e-04 | mean(pred on target rows)≈8.964 | max|pred| on non-target rows≈0.948\n",
            "[ep  5800] loss=1.141440e-03 | mean(pred on target rows)≈9.042 | max|pred| on non-target rows≈0.934\n",
            "[ep  5850] loss=2.398955e-03 | mean(pred on target rows)≈8.980 | max|pred| on non-target rows≈0.947\n",
            "[ep  5900] loss=1.091088e-04 | mean(pred on target rows)≈8.992 | max|pred| on non-target rows≈0.940\n",
            "[ep  5950] loss=5.008347e-03 | mean(pred on target rows)≈8.939 | max|pred| on non-target rows≈0.948\n",
            "[ep  6000] loss=4.167325e-04 | mean(pred on target rows)≈9.019 | max|pred| on non-target rows≈0.930\n",
            "[ep  6050] loss=6.659036e-04 | mean(pred on target rows)≈8.983 | max|pred| on non-target rows≈0.940\n",
            "[ep  6100] loss=4.994103e-04 | mean(pred on target rows)≈8.986 | max|pred| on non-target rows≈0.948\n",
            "[ep  6150] loss=4.984710e-04 | mean(pred on target rows)≈9.008 | max|pred| on non-target rows≈0.937\n",
            "[ep  6200] loss=1.292838e-04 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.938\n",
            "[ep  6250] loss=5.809196e-03 | mean(pred on target rows)≈9.079 | max|pred| on non-target rows≈0.891\n",
            "[ep  6300] loss=3.677041e-04 | mean(pred on target rows)≈8.975 | max|pred| on non-target rows≈0.944\n",
            "[ep  6350] loss=2.773752e-03 | mean(pred on target rows)≈8.926 | max|pred| on non-target rows≈0.940\n",
            "[ep  6400] loss=6.634459e-04 | mean(pred on target rows)≈8.973 | max|pred| on non-target rows≈0.937\n",
            "[ep  6450] loss=2.318917e-03 | mean(pred on target rows)≈8.944 | max|pred| on non-target rows≈0.946\n",
            "[ep  6500] loss=8.711651e-04 | mean(pred on target rows)≈8.968 | max|pred| on non-target rows≈0.938\n",
            "[ep  6550] loss=2.159866e-04 | mean(pred on target rows)≈9.014 | max|pred| on non-target rows≈0.931\n",
            "[ep  6600] loss=1.130465e-03 | mean(pred on target rows)≈9.028 | max|pred| on non-target rows≈0.929\n",
            "[ep  6650] loss=2.617855e-03 | mean(pred on target rows)≈9.058 | max|pred| on non-target rows≈0.932\n",
            "[ep  6700] loss=1.282995e-03 | mean(pred on target rows)≈9.039 | max|pred| on non-target rows≈0.921\n",
            "[ep  6750] loss=3.408135e-03 | mean(pred on target rows)≈8.921 | max|pred| on non-target rows≈0.935\n",
            "[ep  6800] loss=1.980252e-03 | mean(pred on target rows)≈8.946 | max|pred| on non-target rows≈0.943\n",
            "[ep  6850] loss=9.720554e-04 | mean(pred on target rows)≈9.019 | max|pred| on non-target rows≈0.931\n",
            "[ep  6900] loss=1.670072e-04 | mean(pred on target rows)≈8.994 | max|pred| on non-target rows≈0.935\n",
            "[ep  6950] loss=1.258824e-03 | mean(pred on target rows)≈9.023 | max|pred| on non-target rows≈0.929\n",
            "[ep  7000] loss=1.724470e-03 | mean(pred on target rows)≈8.959 | max|pred| on non-target rows≈0.937\n",
            "[ep  7050] loss=1.214680e-03 | mean(pred on target rows)≈8.968 | max|pred| on non-target rows≈0.932\n",
            "[ep  7100] loss=1.223418e-03 | mean(pred on target rows)≈9.037 | max|pred| on non-target rows≈0.905\n",
            "[ep  7150] loss=1.341923e-03 | mean(pred on target rows)≈8.952 | max|pred| on non-target rows≈0.939\n",
            "[ep  7200] loss=1.033606e-03 | mean(pred on target rows)≈8.994 | max|pred| on non-target rows≈0.923\n",
            "[ep  7250] loss=1.334500e-04 | mean(pred on target rows)≈9.007 | max|pred| on non-target rows≈0.919\n",
            "[ep  7300] loss=1.675419e-03 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.924\n",
            "[ep  7350] loss=1.234000e-03 | mean(pred on target rows)≈9.044 | max|pred| on non-target rows≈0.912\n",
            "[ep  7400] loss=3.097934e-04 | mean(pred on target rows)≈8.987 | max|pred| on non-target rows≈0.925\n",
            "[ep  7450] loss=7.678512e-04 | mean(pred on target rows)≈9.016 | max|pred| on non-target rows≈0.925\n",
            "[ep  7500] loss=2.021505e-03 | mean(pred on target rows)≈8.958 | max|pred| on non-target rows≈0.926\n",
            "[ep  7550] loss=1.179197e-03 | mean(pred on target rows)≈8.953 | max|pred| on non-target rows≈0.927\n",
            "[ep  7600] loss=3.006910e-03 | mean(pred on target rows)≈8.926 | max|pred| on non-target rows≈0.942\n",
            "[ep  7650] loss=4.085775e-03 | mean(pred on target rows)≈9.051 | max|pred| on non-target rows≈0.922\n",
            "[ep  7700] loss=1.435501e-04 | mean(pred on target rows)≈9.010 | max|pred| on non-target rows≈0.919\n",
            "[ep  7750] loss=1.726116e-03 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.926\n",
            "[ep  7800] loss=1.248382e-03 | mean(pred on target rows)≈9.013 | max|pred| on non-target rows≈0.920\n",
            "[ep  7850] loss=1.947120e-03 | mean(pred on target rows)≈8.944 | max|pred| on non-target rows≈0.934\n",
            "[ep  7900] loss=2.329007e-04 | mean(pred on target rows)≈9.004 | max|pred| on non-target rows≈0.914\n",
            "[ep  7950] loss=2.465242e-03 | mean(pred on target rows)≈8.967 | max|pred| on non-target rows≈0.917\n",
            "[ep  8000] loss=7.277077e-04 | mean(pred on target rows)≈8.964 | max|pred| on non-target rows≈0.930\n",
            "[ep  8050] loss=1.643037e-03 | mean(pred on target rows)≈9.005 | max|pred| on non-target rows≈0.922\n",
            "[ep  8100] loss=3.264337e-03 | mean(pred on target rows)≈9.071 | max|pred| on non-target rows≈0.905\n",
            "[ep  8150] loss=9.721139e-04 | mean(pred on target rows)≈9.009 | max|pred| on non-target rows≈0.921\n",
            "[ep  8200] loss=2.442542e-04 | mean(pred on target rows)≈8.986 | max|pred| on non-target rows≈0.926\n",
            "[ep  8250] loss=2.562888e-03 | mean(pred on target rows)≈9.050 | max|pred| on non-target rows≈0.924\n",
            "[ep  8300] loss=6.474353e-04 | mean(pred on target rows)≈9.033 | max|pred| on non-target rows≈0.914\n",
            "[ep  8350] loss=2.080455e-03 | mean(pred on target rows)≈9.014 | max|pred| on non-target rows≈0.919\n",
            "[ep  8400] loss=2.452423e-03 | mean(pred on target rows)≈8.934 | max|pred| on non-target rows≈0.920\n",
            "[ep  8450] loss=1.358269e-03 | mean(pred on target rows)≈8.955 | max|pred| on non-target rows≈0.927\n",
            "[ep  8500] loss=1.306850e-03 | mean(pred on target rows)≈8.990 | max|pred| on non-target rows≈0.915\n",
            "[ep  8550] loss=2.228647e-03 | mean(pred on target rows)≈8.947 | max|pred| on non-target rows≈0.913\n",
            "[ep  8600] loss=1.078084e-03 | mean(pred on target rows)≈8.980 | max|pred| on non-target rows≈0.926\n",
            "[ep  8650] loss=1.364660e-03 | mean(pred on target rows)≈9.035 | max|pred| on non-target rows≈0.911\n",
            "[ep  8700] loss=4.409665e-04 | mean(pred on target rows)≈9.024 | max|pred| on non-target rows≈0.913\n",
            "[ep  8750] loss=2.362544e-03 | mean(pred on target rows)≈9.054 | max|pred| on non-target rows≈0.912\n",
            "[ep  8800] loss=1.006058e-03 | mean(pred on target rows)≈9.041 | max|pred| on non-target rows≈0.896\n",
            "[ep  8850] loss=4.260140e-03 | mean(pred on target rows)≈9.066 | max|pred| on non-target rows≈0.912\n",
            "[ep  8900] loss=5.720886e-04 | mean(pred on target rows)≈8.973 | max|pred| on non-target rows≈0.929\n",
            "[ep  8950] loss=1.890459e-02 | mean(pred on target rows)≈9.182 | max|pred| on non-target rows≈0.919\n",
            "[ep  9000] loss=1.804452e-03 | mean(pred on target rows)≈9.017 | max|pred| on non-target rows≈0.934\n",
            "[ep  9050] loss=2.835563e-05 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.925\n",
            "[ep  9100] loss=1.578462e-05 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.924\n",
            "[ep  9150] loss=1.737771e-05 | mean(pred on target rows)≈8.998 | max|pred| on non-target rows≈0.926\n",
            "[ep  9200] loss=1.409187e-04 | mean(pred on target rows)≈9.016 | max|pred| on non-target rows≈0.915\n",
            "[ep  9250] loss=6.380696e-04 | mean(pred on target rows)≈8.965 | max|pred| on non-target rows≈0.916\n",
            "[ep  9300] loss=5.880576e-04 | mean(pred on target rows)≈9.032 | max|pred| on non-target rows≈0.904\n",
            "[ep  9350] loss=4.301565e-04 | mean(pred on target rows)≈9.027 | max|pred| on non-target rows≈0.910\n",
            "[ep  9400] loss=2.737942e-03 | mean(pred on target rows)≈9.072 | max|pred| on non-target rows≈0.924\n",
            "[ep  9450] loss=3.265098e-03 | mean(pred on target rows)≈9.013 | max|pred| on non-target rows≈0.921\n",
            "[ep  9500] loss=8.521393e-04 | mean(pred on target rows)≈9.038 | max|pred| on non-target rows≈0.925\n",
            "[ep  9550] loss=5.804035e-04 | mean(pred on target rows)≈8.981 | max|pred| on non-target rows≈0.919\n",
            "[ep  9600] loss=3.795269e-03 | mean(pred on target rows)≈9.056 | max|pred| on non-target rows≈0.922\n",
            "[ep  9650] loss=8.560790e-04 | mean(pred on target rows)≈9.031 | max|pred| on non-target rows≈0.924\n",
            "[ep  9700] loss=4.790685e-04 | mean(pred on target rows)≈8.981 | max|pred| on non-target rows≈0.917\n",
            "[ep  9750] loss=2.816758e-03 | mean(pred on target rows)≈8.976 | max|pred| on non-target rows≈0.919\n",
            "[ep  9800] loss=1.367937e-04 | mean(pred on target rows)≈8.988 | max|pred| on non-target rows≈0.919\n",
            "[ep  9850] loss=1.689831e-03 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.924\n",
            "[ep  9900] loss=2.326878e-03 | mean(pred on target rows)≈8.933 | max|pred| on non-target rows≈0.929\n",
            "[ep  9950] loss=8.522625e-04 | mean(pred on target rows)≈9.015 | max|pred| on non-target rows≈0.911\n",
            "[ep 10000] loss=3.005002e-04 | mean(pred on target rows)≈9.013 | max|pred| on non-target rows≈0.920\n",
            "[ep 10050] loss=3.886229e-05 | mean(pred on target rows)≈8.994 | max|pred| on non-target rows≈0.926\n",
            "[ep 10100] loss=2.999654e-03 | mean(pred on target rows)≈9.060 | max|pred| on non-target rows≈0.918\n",
            "[ep 10150] loss=2.945838e-04 | mean(pred on target rows)≈9.022 | max|pred| on non-target rows≈0.920\n",
            "[ep 10200] loss=7.863001e-04 | mean(pred on target rows)≈8.985 | max|pred| on non-target rows≈0.921\n",
            "[ep 10250] loss=2.637446e-04 | mean(pred on target rows)≈9.006 | max|pred| on non-target rows≈0.919\n",
            "[ep 10300] loss=3.123604e-04 | mean(pred on target rows)≈9.020 | max|pred| on non-target rows≈0.895\n",
            "[ep 10350] loss=1.695999e-03 | mean(pred on target rows)≈8.961 | max|pred| on non-target rows≈0.920\n",
            "[ep 10400] loss=6.197824e-04 | mean(pred on target rows)≈9.032 | max|pred| on non-target rows≈0.916\n",
            "[ep 10450] loss=1.799577e-03 | mean(pred on target rows)≈9.055 | max|pred| on non-target rows≈0.916\n",
            "[ep 10500] loss=6.759157e-04 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.923\n",
            "[ep 10550] loss=8.728477e-04 | mean(pred on target rows)≈8.997 | max|pred| on non-target rows≈0.919\n",
            "[ep 10600] loss=3.243408e-03 | mean(pred on target rows)≈9.056 | max|pred| on non-target rows≈0.914\n",
            "[ep 10650] loss=4.370862e-04 | mean(pred on target rows)≈9.023 | max|pred| on non-target rows≈0.921\n",
            "[ep 10700] loss=4.502111e-04 | mean(pred on target rows)≈9.020 | max|pred| on non-target rows≈0.914\n",
            "[ep 10750] loss=6.624645e-04 | mean(pred on target rows)≈8.966 | max|pred| on non-target rows≈0.912\n",
            "[ep 10800] loss=1.902552e-04 | mean(pred on target rows)≈9.010 | max|pred| on non-target rows≈0.914\n",
            "[ep 10850] loss=1.374830e-03 | mean(pred on target rows)≈8.977 | max|pred| on non-target rows≈0.926\n",
            "[ep 10900] loss=1.572071e-03 | mean(pred on target rows)≈8.984 | max|pred| on non-target rows≈0.905\n",
            "[ep 10950] loss=5.284284e-04 | mean(pred on target rows)≈9.019 | max|pred| on non-target rows≈0.913\n",
            "[ep 11000] loss=6.302224e-04 | mean(pred on target rows)≈8.977 | max|pred| on non-target rows≈0.909\n",
            "[ep 11050] loss=3.419738e-03 | mean(pred on target rows)≈9.069 | max|pred| on non-target rows≈0.915\n",
            "[ep 11100] loss=1.415245e-03 | mean(pred on target rows)≈8.956 | max|pred| on non-target rows≈0.916\n",
            "[ep 11150] loss=3.507924e-03 | mean(pred on target rows)≈8.924 | max|pred| on non-target rows≈0.903\n",
            "[ep 11200] loss=6.107884e-04 | mean(pred on target rows)≈9.003 | max|pred| on non-target rows≈0.915\n",
            "[ep 11250] loss=2.624346e-04 | mean(pred on target rows)≈9.008 | max|pred| on non-target rows≈0.913\n",
            "[ep 11300] loss=1.021198e-03 | mean(pred on target rows)≈9.042 | max|pred| on non-target rows≈0.918\n",
            "[ep 11350] loss=1.921019e-03 | mean(pred on target rows)≈8.950 | max|pred| on non-target rows≈0.910\n",
            "[ep 11400] loss=3.007203e-04 | mean(pred on target rows)≈9.015 | max|pred| on non-target rows≈0.914\n",
            "[ep 11450] loss=2.868534e-03 | mean(pred on target rows)≈9.042 | max|pred| on non-target rows≈0.920\n",
            "[ep 11500] loss=2.727680e-03 | mean(pred on target rows)≈8.930 | max|pred| on non-target rows≈0.923\n",
            "[ep 11550] loss=2.155588e-03 | mean(pred on target rows)≈9.047 | max|pred| on non-target rows≈0.892\n",
            "[ep 11600] loss=5.227803e-04 | mean(pred on target rows)≈9.025 | max|pred| on non-target rows≈0.916\n",
            "[ep 11650] loss=3.527683e-03 | mean(pred on target rows)≈9.045 | max|pred| on non-target rows≈0.916\n",
            "[ep 11700] loss=2.929885e-04 | mean(pred on target rows)≈9.018 | max|pred| on non-target rows≈0.921\n",
            "[ep 11750] loss=8.910505e-04 | mean(pred on target rows)≈8.979 | max|pred| on non-target rows≈0.929\n",
            "[ep 11800] loss=1.707004e-03 | mean(pred on target rows)≈9.017 | max|pred| on non-target rows≈0.917\n",
            "[ep 11850] loss=6.061649e-05 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.922\n",
            "[ep 11900] loss=1.008391e-03 | mean(pred on target rows)≈9.017 | max|pred| on non-target rows≈0.914\n",
            "[ep 11950] loss=3.314988e-04 | mean(pred on target rows)≈9.002 | max|pred| on non-target rows≈0.926\n",
            "[ep 12000] loss=9.098861e-04 | mean(pred on target rows)≈9.042 | max|pred| on non-target rows≈0.910\n",
            "[ep 12050] loss=5.484219e-04 | mean(pred on target rows)≈9.031 | max|pred| on non-target rows≈0.915\n",
            "[ep 12100] loss=2.132702e-03 | mean(pred on target rows)≈9.040 | max|pred| on non-target rows≈0.866\n",
            "[ep 12150] loss=5.131357e-04 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.918\n",
            "[ep 12200] loss=2.616014e-04 | mean(pred on target rows)≈9.020 | max|pred| on non-target rows≈0.908\n",
            "[ep 12250] loss=1.117310e-03 | mean(pred on target rows)≈8.986 | max|pred| on non-target rows≈0.919\n",
            "[ep 12300] loss=1.072930e-04 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.919\n",
            "[ep 12350] loss=2.183047e-04 | mean(pred on target rows)≈8.993 | max|pred| on non-target rows≈0.917\n",
            "[ep 12400] loss=6.056508e-04 | mean(pred on target rows)≈9.027 | max|pred| on non-target rows≈0.915\n",
            "[ep 12450] loss=1.779560e-03 | mean(pred on target rows)≈9.026 | max|pred| on non-target rows≈0.911\n",
            "[ep 12500] loss=7.385428e-04 | mean(pred on target rows)≈9.035 | max|pred| on non-target rows≈0.914\n",
            "[ep 12550] loss=2.593266e-04 | mean(pred on target rows)≈8.998 | max|pred| on non-target rows≈0.910\n",
            "[ep 12600] loss=2.762186e-03 | mean(pred on target rows)≈8.934 | max|pred| on non-target rows≈0.927\n",
            "[ep 12650] loss=5.488197e-04 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.918\n",
            "[ep 12700] loss=1.621298e-04 | mean(pred on target rows)≈8.996 | max|pred| on non-target rows≈0.913\n",
            "[ep 12750] loss=1.959149e-03 | mean(pred on target rows)≈9.053 | max|pred| on non-target rows≈0.921\n",
            "[ep 12800] loss=2.230512e-04 | mean(pred on target rows)≈8.997 | max|pred| on non-target rows≈0.918\n",
            "[ep 12850] loss=1.776013e-03 | mean(pred on target rows)≈8.990 | max|pred| on non-target rows≈0.913\n",
            "[ep 12900] loss=2.879730e-03 | mean(pred on target rows)≈9.070 | max|pred| on non-target rows≈0.886\n",
            "[ep 12950] loss=2.398524e-03 | mean(pred on target rows)≈9.046 | max|pred| on non-target rows≈0.911\n",
            "[ep 13000] loss=2.286255e-03 | mean(pred on target rows)≈8.935 | max|pred| on non-target rows≈0.920\n",
            "[ep 13050] loss=1.014034e-03 | mean(pred on target rows)≈9.018 | max|pred| on non-target rows≈0.926\n",
            "[ep 13100] loss=1.242122e-03 | mean(pred on target rows)≈9.041 | max|pred| on non-target rows≈0.907\n",
            "[ep 13150] loss=4.881790e-04 | mean(pred on target rows)≈9.024 | max|pred| on non-target rows≈0.916\n",
            "[ep 13200] loss=1.386550e-03 | mean(pred on target rows)≈8.959 | max|pred| on non-target rows≈0.925\n",
            "[ep 13250] loss=2.493688e-03 | mean(pred on target rows)≈9.023 | max|pred| on non-target rows≈0.914\n",
            "[ep 13300] loss=3.702824e-03 | mean(pred on target rows)≈8.992 | max|pred| on non-target rows≈0.934\n",
            "[ep 13350] loss=4.759451e-05 | mean(pred on target rows)≈8.994 | max|pred| on non-target rows≈0.913\n",
            "[ep 13400] loss=1.429202e-05 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.916\n",
            "[ep 13450] loss=2.331691e-05 | mean(pred on target rows)≈9.004 | max|pred| on non-target rows≈0.916\n",
            "[ep 13500] loss=1.376106e-05 | mean(pred on target rows)≈9.001 | max|pred| on non-target rows≈0.917\n",
            "[ep 13550] loss=2.038562e-04 | mean(pred on target rows)≈8.980 | max|pred| on non-target rows≈0.917\n",
            "[ep 13600] loss=3.092026e-04 | mean(pred on target rows)≈9.022 | max|pred| on non-target rows≈0.914\n",
            "[ep 13650] loss=5.202953e-04 | mean(pred on target rows)≈9.032 | max|pred| on non-target rows≈0.908\n",
            "[ep 13700] loss=1.803976e-04 | mean(pred on target rows)≈8.981 | max|pred| on non-target rows≈0.912\n",
            "[ep 13750] loss=9.409736e-04 | mean(pred on target rows)≈9.043 | max|pred| on non-target rows≈0.913\n",
            "[ep 13800] loss=1.981315e-02 | mean(pred on target rows)≈8.968 | max|pred| on non-target rows≈0.950\n",
            "[ep 13850] loss=2.011504e-04 | mean(pred on target rows)≈9.006 | max|pred| on non-target rows≈0.917\n",
            "[ep 13900] loss=5.208672e-05 | mean(pred on target rows)≈9.008 | max|pred| on non-target rows≈0.913\n",
            "[ep 13950] loss=1.501157e-05 | mean(pred on target rows)≈9.001 | max|pred| on non-target rows≈0.918\n",
            "[ep 14000] loss=2.724471e-05 | mean(pred on target rows)≈9.005 | max|pred| on non-target rows≈0.915\n",
            "[ep 14050] loss=5.136336e-05 | mean(pred on target rows)≈9.009 | max|pred| on non-target rows≈0.913\n",
            "[ep 14100] loss=4.692603e-04 | mean(pred on target rows)≈9.029 | max|pred| on non-target rows≈0.913\n",
            "[ep 14150] loss=1.113387e-04 | mean(pred on target rows)≈8.986 | max|pred| on non-target rows≈0.906\n",
            "[ep 14200] loss=2.189357e-04 | mean(pred on target rows)≈9.020 | max|pred| on non-target rows≈0.916\n",
            "[ep 14250] loss=2.008564e-04 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.919\n",
            "[ep 14300] loss=4.025203e-04 | mean(pred on target rows)≈9.025 | max|pred| on non-target rows≈0.910\n",
            "[ep 14350] loss=1.637124e-04 | mean(pred on target rows)≈8.994 | max|pred| on non-target rows≈0.910\n",
            "[ep 14400] loss=4.000585e-04 | mean(pred on target rows)≈9.009 | max|pred| on non-target rows≈0.909\n",
            "[ep 14450] loss=3.676877e-03 | mean(pred on target rows)≈8.947 | max|pred| on non-target rows≈0.914\n",
            "[ep 14500] loss=3.067917e-04 | mean(pred on target rows)≈8.981 | max|pred| on non-target rows≈0.918\n",
            "[ep 14550] loss=5.980956e-04 | mean(pred on target rows)≈9.020 | max|pred| on non-target rows≈0.912\n",
            "[ep 14600] loss=7.817245e-04 | mean(pred on target rows)≈8.998 | max|pred| on non-target rows≈0.918\n",
            "[ep 14650] loss=5.128122e-04 | mean(pred on target rows)≈8.969 | max|pred| on non-target rows≈0.917\n",
            "[ep 14700] loss=2.272710e-04 | mean(pred on target rows)≈8.980 | max|pred| on non-target rows≈0.915\n",
            "[ep 14750] loss=3.384815e-01 | mean(pred on target rows)≈9.365 | max|pred| on non-target rows≈1.952\n",
            "[ep 14800] loss=2.022925e-02 | mean(pred on target rows)≈8.845 | max|pred| on non-target rows≈0.870\n",
            "[ep 14850] loss=4.094461e-03 | mean(pred on target rows)≈9.088 | max|pred| on non-target rows≈0.912\n",
            "[ep 14900] loss=4.171502e-03 | mean(pred on target rows)≈8.909 | max|pred| on non-target rows≈0.930\n",
            "[ep 14950] loss=2.873375e-03 | mean(pred on target rows)≈9.075 | max|pred| on non-target rows≈0.916\n",
            "[ep 15000] loss=2.221169e-04 | mean(pred on target rows)≈9.020 | max|pred| on non-target rows≈0.911\n",
            "[ep 15050] loss=1.544702e-05 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.912\n",
            "[ep 15100] loss=1.415646e-05 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.912\n",
            "[ep 15150] loss=4.688461e-05 | mean(pred on target rows)≈9.007 | max|pred| on non-target rows≈0.911\n",
            "[ep 15200] loss=1.285152e-05 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.915\n",
            "[ep 15250] loss=2.015041e-05 | mean(pred on target rows)≈8.997 | max|pred| on non-target rows≈0.912\n",
            "[ep 15300] loss=2.732886e-05 | mean(pred on target rows)≈9.003 | max|pred| on non-target rows≈0.911\n",
            "[ep 15350] loss=1.865833e-04 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.906\n",
            "[ep 15400] loss=5.247920e-05 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.912\n",
            "[ep 15450] loss=5.920939e-05 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.913\n",
            "[ep 15500] loss=4.387485e-04 | mean(pred on target rows)≈9.007 | max|pred| on non-target rows≈0.911\n",
            "[ep 15550] loss=1.018134e-03 | mean(pred on target rows)≈8.996 | max|pred| on non-target rows≈0.909\n",
            "[ep 15600] loss=6.222499e-04 | mean(pred on target rows)≈9.020 | max|pred| on non-target rows≈0.909\n",
            "[ep 15650] loss=3.267920e-04 | mean(pred on target rows)≈8.981 | max|pred| on non-target rows≈0.912\n",
            "[ep 15700] loss=3.331267e-03 | mean(pred on target rows)≈9.054 | max|pred| on non-target rows≈0.908\n",
            "[ep 15750] loss=3.096591e-03 | mean(pred on target rows)≈9.069 | max|pred| on non-target rows≈0.911\n",
            "[ep 15800] loss=5.849490e-04 | mean(pred on target rows)≈9.026 | max|pred| on non-target rows≈0.910\n",
            "[ep 15850] loss=1.238240e-03 | mean(pred on target rows)≈9.036 | max|pred| on non-target rows≈0.913\n",
            "[ep 15900] loss=1.727812e-03 | mean(pred on target rows)≈8.990 | max|pred| on non-target rows≈0.910\n",
            "[ep 15950] loss=1.960741e-03 | mean(pred on target rows)≈8.998 | max|pred| on non-target rows≈0.912\n",
            "[ep 16000] loss=7.623526e-05 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.915\n",
            "[ep 16050] loss=2.345572e-03 | mean(pred on target rows)≈9.009 | max|pred| on non-target rows≈0.885\n",
            "[ep 16100] loss=2.402594e-02 | mean(pred on target rows)≈9.210 | max|pred| on non-target rows≈0.875\n",
            "[ep 16150] loss=6.521598e-03 | mean(pred on target rows)≈9.104 | max|pred| on non-target rows≈0.917\n",
            "[ep 16200] loss=8.199354e-03 | mean(pred on target rows)≈8.873 | max|pred| on non-target rows≈0.929\n",
            "[ep 16250] loss=4.111303e-03 | mean(pred on target rows)≈9.089 | max|pred| on non-target rows≈0.907\n",
            "[ep 16300] loss=1.404431e-04 | mean(pred on target rows)≈8.984 | max|pred| on non-target rows≈0.923\n",
            "[ep 16350] loss=1.170951e-05 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.918\n",
            "[ep 16400] loss=2.036140e-04 | mean(pred on target rows)≈9.017 | max|pred| on non-target rows≈0.907\n",
            "[ep 16450] loss=3.602093e-05 | mean(pred on target rows)≈9.003 | max|pred| on non-target rows≈0.916\n",
            "[ep 16500] loss=4.341359e-05 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.913\n",
            "[ep 16550] loss=1.183512e-04 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.909\n",
            "[ep 16600] loss=7.309388e-05 | mean(pred on target rows)≈8.998 | max|pred| on non-target rows≈0.916\n",
            "[ep 16650] loss=8.753951e-05 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.918\n",
            "[ep 16700] loss=1.376261e-04 | mean(pred on target rows)≈9.003 | max|pred| on non-target rows≈0.915\n",
            "[ep 16750] loss=2.333340e-04 | mean(pred on target rows)≈9.003 | max|pred| on non-target rows≈0.911\n",
            "[ep 16800] loss=1.144572e-03 | mean(pred on target rows)≈8.964 | max|pred| on non-target rows≈0.911\n",
            "[ep 16850] loss=4.836339e-04 | mean(pred on target rows)≈8.986 | max|pred| on non-target rows≈0.917\n",
            "[ep 16900] loss=9.252016e-04 | mean(pred on target rows)≈8.977 | max|pred| on non-target rows≈0.917\n",
            "[ep 16950] loss=1.456332e-03 | mean(pred on target rows)≈8.976 | max|pred| on non-target rows≈0.912\n",
            "[ep 17000] loss=3.138998e-04 | mean(pred on target rows)≈8.984 | max|pred| on non-target rows≈0.906\n",
            "[ep 17050] loss=1.319009e-03 | mean(pred on target rows)≈9.032 | max|pred| on non-target rows≈0.913\n",
            "[ep 17100] loss=7.867605e-04 | mean(pred on target rows)≈8.970 | max|pred| on non-target rows≈0.882\n",
            "[ep 17150] loss=5.695067e-04 | mean(pred on target rows)≈9.027 | max|pred| on non-target rows≈0.911\n",
            "[ep 17200] loss=3.134538e-03 | mean(pred on target rows)≈9.055 | max|pred| on non-target rows≈0.905\n",
            "[ep 17250] loss=2.902698e-04 | mean(pred on target rows)≈8.987 | max|pred| on non-target rows≈0.915\n",
            "[ep 17300] loss=7.237757e-04 | mean(pred on target rows)≈9.027 | max|pred| on non-target rows≈0.913\n",
            "[ep 17350] loss=2.778160e-03 | mean(pred on target rows)≈8.936 | max|pred| on non-target rows≈0.917\n",
            "[ep 17400] loss=8.349752e-04 | mean(pred on target rows)≈8.965 | max|pred| on non-target rows≈0.913\n",
            "[ep 17450] loss=2.824711e-03 | mean(pred on target rows)≈8.935 | max|pred| on non-target rows≈0.918\n",
            "[ep 17500] loss=7.537333e-05 | mean(pred on target rows)≈9.007 | max|pred| on non-target rows≈0.911\n",
            "[ep 17550] loss=1.146509e-03 | mean(pred on target rows)≈8.963 | max|pred| on non-target rows≈0.921\n",
            "[ep 17600] loss=3.742768e-04 | mean(pred on target rows)≈9.001 | max|pred| on non-target rows≈0.910\n",
            "[ep 17650] loss=3.329544e-04 | mean(pred on target rows)≈9.001 | max|pred| on non-target rows≈0.922\n",
            "[ep 17700] loss=1.700161e-03 | mean(pred on target rows)≈9.054 | max|pred| on non-target rows≈0.861\n",
            "[ep 17750] loss=5.384940e-04 | mean(pred on target rows)≈9.010 | max|pred| on non-target rows≈0.914\n",
            "[ep 17800] loss=1.020685e-03 | mean(pred on target rows)≈9.023 | max|pred| on non-target rows≈0.901\n",
            "[ep 17850] loss=2.029103e-04 | mean(pred on target rows)≈9.011 | max|pred| on non-target rows≈0.905\n",
            "[ep 17900] loss=2.743459e-04 | mean(pred on target rows)≈9.008 | max|pred| on non-target rows≈0.912\n",
            "[ep 17950] loss=9.356156e-04 | mean(pred on target rows)≈9.036 | max|pred| on non-target rows≈0.913\n",
            "[ep 18000] loss=8.244633e-05 | mean(pred on target rows)≈8.993 | max|pred| on non-target rows≈0.917\n",
            "[ep 18050] loss=2.237462e-04 | mean(pred on target rows)≈8.993 | max|pred| on non-target rows≈0.916\n",
            "[ep 18100] loss=1.022977e-03 | mean(pred on target rows)≈9.033 | max|pred| on non-target rows≈0.909\n",
            "[ep 18150] loss=9.609159e-04 | mean(pred on target rows)≈9.028 | max|pred| on non-target rows≈0.908\n",
            "[ep 18200] loss=1.139581e-04 | mean(pred on target rows)≈9.006 | max|pred| on non-target rows≈0.905\n",
            "[ep 18250] loss=6.168354e-04 | mean(pred on target rows)≈9.028 | max|pred| on non-target rows≈0.901\n",
            "[ep 18300] loss=3.275134e-04 | mean(pred on target rows)≈8.996 | max|pred| on non-target rows≈0.913\n",
            "[ep 18350] loss=4.849462e-04 | mean(pred on target rows)≈9.018 | max|pred| on non-target rows≈0.910\n",
            "[ep 18400] loss=8.077881e-04 | mean(pred on target rows)≈8.969 | max|pred| on non-target rows≈0.910\n",
            "[ep 18450] loss=7.043143e-04 | mean(pred on target rows)≈8.976 | max|pred| on non-target rows≈0.922\n",
            "[ep 18500] loss=2.948078e-04 | mean(pred on target rows)≈9.008 | max|pred| on non-target rows≈0.906\n",
            "[ep 18550] loss=2.140075e-03 | mean(pred on target rows)≈9.061 | max|pred| on non-target rows≈0.904\n",
            "[ep 18600] loss=4.387345e-04 | mean(pred on target rows)≈8.989 | max|pred| on non-target rows≈0.919\n",
            "[ep 18650] loss=1.457225e-03 | mean(pred on target rows)≈8.962 | max|pred| on non-target rows≈0.919\n",
            "[ep 18700] loss=9.133335e-04 | mean(pred on target rows)≈8.972 | max|pred| on non-target rows≈0.911\n",
            "[ep 18750] loss=1.746997e-03 | mean(pred on target rows)≈8.948 | max|pred| on non-target rows≈0.911\n",
            "[ep 18800] loss=2.757268e-04 | mean(pred on target rows)≈9.009 | max|pred| on non-target rows≈0.907\n",
            "[ep 18850] loss=3.446669e-04 | mean(pred on target rows)≈9.015 | max|pred| on non-target rows≈0.910\n",
            "[ep 18900] loss=2.320901e-02 | mean(pred on target rows)≈8.928 | max|pred| on non-target rows≈0.968\n",
            "[ep 18950] loss=4.059050e-03 | mean(pred on target rows)≈9.025 | max|pred| on non-target rows≈0.913\n",
            "[ep 19000] loss=4.345309e-05 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.913\n",
            "[ep 19050] loss=2.792081e-05 | mean(pred on target rows)≈8.994 | max|pred| on non-target rows≈0.915\n",
            "[ep 19100] loss=1.374137e-05 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.916\n",
            "[ep 19150] loss=3.605364e-05 | mean(pred on target rows)≈9.006 | max|pred| on non-target rows≈0.912\n",
            "[ep 19200] loss=7.188626e-05 | mean(pred on target rows)≈8.989 | max|pred| on non-target rows≈0.913\n",
            "[ep 19250] loss=2.622407e-04 | mean(pred on target rows)≈8.978 | max|pred| on non-target rows≈0.906\n",
            "[ep 19300] loss=2.241977e-03 | mean(pred on target rows)≈9.066 | max|pred| on non-target rows≈0.912\n",
            "[ep 19350] loss=3.472537e-03 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.913\n",
            "[ep 19400] loss=3.657653e-05 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.915\n",
            "[ep 19450] loss=1.107979e-05 | mean(pred on target rows)≈8.998 | max|pred| on non-target rows≈0.917\n",
            "[ep 19500] loss=8.795178e-06 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.913\n",
            "[ep 19550] loss=2.109972e-05 | mean(pred on target rows)≈9.001 | max|pred| on non-target rows≈0.913\n",
            "[ep 19600] loss=6.576893e-03 | mean(pred on target rows)≈8.886 | max|pred| on non-target rows≈0.906\n",
            "[ep 19650] loss=1.324337e-03 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.907\n",
            "[ep 19700] loss=1.761922e-05 | mean(pred on target rows)≈9.001 | max|pred| on non-target rows≈0.912\n",
            "[ep 19750] loss=5.117911e-05 | mean(pred on target rows)≈9.008 | max|pred| on non-target rows≈0.911\n",
            "[ep 19800] loss=1.625453e-05 | mean(pred on target rows)≈9.002 | max|pred| on non-target rows≈0.918\n",
            "[ep 19850] loss=4.627160e-04 | mean(pred on target rows)≈8.970 | max|pred| on non-target rows≈0.915\n",
            "[ep 19900] loss=4.657070e-04 | mean(pred on target rows)≈9.027 | max|pred| on non-target rows≈0.901\n",
            "[ep 19950] loss=3.654919e-05 | mean(pred on target rows)≈8.992 | max|pred| on non-target rows≈0.914\n",
            "[ep 20000] loss=2.455024e-04 | mean(pred on target rows)≈8.978 | max|pred| on non-target rows≈0.915\n",
            "[ep 20050] loss=2.151877e-04 | mean(pred on target rows)≈8.980 | max|pred| on non-target rows≈0.905\n",
            "[ep 20100] loss=4.135574e-03 | mean(pred on target rows)≈8.976 | max|pred| on non-target rows≈0.913\n",
            "[ep 20150] loss=2.514309e-05 | mean(pred on target rows)≈8.998 | max|pred| on non-target rows≈0.910\n",
            "[ep 20200] loss=4.312252e-04 | mean(pred on target rows)≈8.971 | max|pred| on non-target rows≈0.914\n",
            "[ep 20250] loss=1.887418e-02 | mean(pred on target rows)≈9.021 | max|pred| on non-target rows≈0.932\n",
            "[ep 20300] loss=6.091361e-05 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.915\n",
            "[ep 20350] loss=9.820976e-05 | mean(pred on target rows)≈9.013 | max|pred| on non-target rows≈0.909\n",
            "[ep 20400] loss=8.646659e-05 | mean(pred on target rows)≈9.012 | max|pred| on non-target rows≈0.912\n",
            "[ep 20450] loss=2.233953e-04 | mean(pred on target rows)≈9.020 | max|pred| on non-target rows≈0.913\n",
            "[ep 20500] loss=2.034712e-04 | mean(pred on target rows)≈8.980 | max|pred| on non-target rows≈0.915\n",
            "[ep 20550] loss=2.413982e-05 | mean(pred on target rows)≈9.006 | max|pred| on non-target rows≈0.912\n",
            "[ep 20600] loss=2.403508e-03 | mean(pred on target rows)≈9.066 | max|pred| on non-target rows≈0.907\n",
            "[ep 20650] loss=1.282838e-03 | mean(pred on target rows)≈8.969 | max|pred| on non-target rows≈0.911\n",
            "[ep 20700] loss=3.629831e-04 | mean(pred on target rows)≈9.015 | max|pred| on non-target rows≈0.911\n",
            "[ep 20750] loss=2.917958e-03 | mean(pred on target rows)≈8.956 | max|pred| on non-target rows≈0.893\n",
            "[ep 20800] loss=2.673718e-04 | mean(pred on target rows)≈9.021 | max|pred| on non-target rows≈0.909\n",
            "[ep 20850] loss=1.176140e-03 | mean(pred on target rows)≈8.975 | max|pred| on non-target rows≈0.908\n",
            "[ep 20900] loss=5.279992e-04 | mean(pred on target rows)≈9.025 | max|pred| on non-target rows≈0.911\n",
            "[ep 20950] loss=3.185195e-01 | mean(pred on target rows)≈8.543 | max|pred| on non-target rows≈0.855\n",
            "[ep 21000] loss=7.777311e-04 | mean(pred on target rows)≈8.996 | max|pred| on non-target rows≈0.924\n",
            "[ep 21050] loss=1.558195e-05 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.915\n",
            "[ep 21100] loss=1.266748e-05 | mean(pred on target rows)≈9.002 | max|pred| on non-target rows≈0.911\n",
            "[ep 21150] loss=9.418617e-06 | mean(pred on target rows)≈8.998 | max|pred| on non-target rows≈0.913\n",
            "[ep 21200] loss=9.856316e-06 | mean(pred on target rows)≈8.998 | max|pred| on non-target rows≈0.910\n",
            "[ep 21250] loss=2.424434e-04 | mean(pred on target rows)≈8.978 | max|pred| on non-target rows≈0.910\n",
            "[ep 21300] loss=2.131297e-04 | mean(pred on target rows)≈8.980 | max|pred| on non-target rows≈0.916\n",
            "[ep 21350] loss=1.554637e-04 | mean(pred on target rows)≈9.017 | max|pred| on non-target rows≈0.910\n",
            "[ep 21400] loss=2.684031e-04 | mean(pred on target rows)≈9.022 | max|pred| on non-target rows≈0.912\n",
            "[ep 21450] loss=4.296403e-04 | mean(pred on target rows)≈8.971 | max|pred| on non-target rows≈0.912\n",
            "[ep 21500] loss=1.605262e-04 | mean(pred on target rows)≈9.016 | max|pred| on non-target rows≈0.910\n",
            "[ep 21550] loss=1.035910e-03 | mean(pred on target rows)≈9.042 | max|pred| on non-target rows≈0.910\n",
            "[ep 21600] loss=2.752505e-04 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.911\n",
            "[ep 21650] loss=2.266241e-04 | mean(pred on target rows)≈8.984 | max|pred| on non-target rows≈0.899\n",
            "[ep 21700] loss=3.009818e-03 | mean(pred on target rows)≈8.943 | max|pred| on non-target rows≈0.912\n",
            "[ep 21750] loss=6.857366e-05 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.911\n",
            "[ep 21800] loss=5.580264e-04 | mean(pred on target rows)≈8.983 | max|pred| on non-target rows≈0.910\n",
            "[ep 21850] loss=4.846758e-04 | mean(pred on target rows)≈9.008 | max|pred| on non-target rows≈0.914\n",
            "[ep 21900] loss=8.229984e-04 | mean(pred on target rows)≈9.031 | max|pred| on non-target rows≈0.907\n",
            "[ep 21950] loss=1.084803e-04 | mean(pred on target rows)≈8.992 | max|pred| on non-target rows≈0.911\n",
            "[ep 22000] loss=7.600697e-04 | mean(pred on target rows)≈8.973 | max|pred| on non-target rows≈0.908\n",
            "[ep 22050] loss=2.884827e-04 | mean(pred on target rows)≈9.008 | max|pred| on non-target rows≈0.912\n",
            "[ep 22100] loss=3.912079e-04 | mean(pred on target rows)≈8.980 | max|pred| on non-target rows≈0.912\n",
            "[ep 22150] loss=1.116132e-03 | mean(pred on target rows)≈9.030 | max|pred| on non-target rows≈0.906\n",
            "[ep 22200] loss=9.430156e-04 | mean(pred on target rows)≈9.014 | max|pred| on non-target rows≈0.907\n",
            "[ep 22250] loss=6.125937e-04 | mean(pred on target rows)≈9.006 | max|pred| on non-target rows≈0.917\n",
            "[ep 22300] loss=1.629057e-05 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.909\n",
            "[ep 22350] loss=6.632041e-04 | mean(pred on target rows)≈9.035 | max|pred| on non-target rows≈0.896\n",
            "[ep 22400] loss=3.822394e-03 | mean(pred on target rows)≈8.913 | max|pred| on non-target rows≈0.921\n",
            "[ep 22450] loss=3.806033e-04 | mean(pred on target rows)≈9.027 | max|pred| on non-target rows≈0.910\n",
            "[ep 22500] loss=3.971633e-03 | mean(pred on target rows)≈8.931 | max|pred| on non-target rows≈0.916\n",
            "[ep 22550] loss=8.195461e-04 | mean(pred on target rows)≈8.961 | max|pred| on non-target rows≈0.915\n",
            "[ep 22600] loss=2.414979e-04 | mean(pred on target rows)≈8.991 | max|pred| on non-target rows≈0.912\n",
            "[ep 22650] loss=1.485450e-04 | mean(pred on target rows)≈8.988 | max|pred| on non-target rows≈0.914\n",
            "[ep 22700] loss=1.890709e-04 | mean(pred on target rows)≈9.005 | max|pred| on non-target rows≈0.905\n",
            "[ep 22750] loss=3.946186e-04 | mean(pred on target rows)≈8.977 | max|pred| on non-target rows≈0.911\n",
            "[ep 22800] loss=8.367269e-05 | mean(pred on target rows)≈9.001 | max|pred| on non-target rows≈0.912\n",
            "[ep 22850] loss=4.787347e-04 | mean(pred on target rows)≈9.024 | max|pred| on non-target rows≈0.912\n",
            "[ep 22900] loss=7.671115e-04 | mean(pred on target rows)≈8.976 | max|pred| on non-target rows≈0.910\n",
            "[ep 22950] loss=8.407457e-04 | mean(pred on target rows)≈8.962 | max|pred| on non-target rows≈0.912\n",
            "[ep 23000] loss=1.175608e-03 | mean(pred on target rows)≈9.001 | max|pred| on non-target rows≈0.907\n",
            "[ep 23050] loss=1.595260e-03 | mean(pred on target rows)≈8.948 | max|pred| on non-target rows≈0.912\n",
            "[ep 23100] loss=1.270763e-03 | mean(pred on target rows)≈9.035 | max|pred| on non-target rows≈0.907\n",
            "[ep 23150] loss=1.135045e-03 | mean(pred on target rows)≈9.015 | max|pred| on non-target rows≈0.907\n",
            "[ep 23200] loss=2.189544e-03 | mean(pred on target rows)≈8.990 | max|pred| on non-target rows≈0.909\n",
            "[ep 23250] loss=3.909765e-04 | mean(pred on target rows)≈8.974 | max|pred| on non-target rows≈0.914\n",
            "[ep 23300] loss=1.253136e-03 | mean(pred on target rows)≈8.978 | max|pred| on non-target rows≈0.907\n",
            "[ep 23350] loss=7.644860e-04 | mean(pred on target rows)≈8.977 | max|pred| on non-target rows≈0.910\n",
            "[ep 23400] loss=9.313646e-04 | mean(pred on target rows)≈9.028 | max|pred| on non-target rows≈0.903\n",
            "[ep 23450] loss=2.042932e-04 | mean(pred on target rows)≈9.008 | max|pred| on non-target rows≈0.906\n",
            "[ep 23500] loss=5.679604e-04 | mean(pred on target rows)≈8.976 | max|pred| on non-target rows≈0.908\n",
            "[ep 23550] loss=9.457691e-04 | mean(pred on target rows)≈9.017 | max|pred| on non-target rows≈0.911\n",
            "[ep 23600] loss=3.573849e-04 | mean(pred on target rows)≈9.026 | max|pred| on non-target rows≈0.909\n",
            "[ep 23650] loss=2.908357e-04 | mean(pred on target rows)≈8.985 | max|pred| on non-target rows≈0.901\n",
            "[ep 23700] loss=4.885611e-04 | mean(pred on target rows)≈9.000 | max|pred| on non-target rows≈0.912\n",
            "[ep 23750] loss=2.196345e-04 | mean(pred on target rows)≈8.981 | max|pred| on non-target rows≈0.907\n",
            "[ep 23800] loss=1.696085e-04 | mean(pred on target rows)≈8.994 | max|pred| on non-target rows≈0.911\n",
            "[ep 23850] loss=2.056783e-03 | mean(pred on target rows)≈9.038 | max|pred| on non-target rows≈0.907\n",
            "[ep 23900] loss=7.781680e-04 | mean(pred on target rows)≈9.033 | max|pred| on non-target rows≈0.904\n",
            "[ep 23950] loss=4.670243e-04 | mean(pred on target rows)≈9.003 | max|pred| on non-target rows≈0.909\n",
            "[ep 24000] loss=5.870466e-04 | mean(pred on target rows)≈8.971 | max|pred| on non-target rows≈0.914\n",
            "[ep 24050] loss=1.080701e-03 | mean(pred on target rows)≈8.955 | max|pred| on non-target rows≈0.915\n",
            "[ep 24100] loss=7.989535e-04 | mean(pred on target rows)≈8.994 | max|pred| on non-target rows≈0.911\n",
            "[ep 24150] loss=2.364204e-03 | mean(pred on target rows)≈8.941 | max|pred| on non-target rows≈0.882\n",
            "[ep 24200] loss=3.296341e-04 | mean(pred on target rows)≈8.977 | max|pred| on non-target rows≈0.915\n",
            "[ep 24250] loss=1.067462e-03 | mean(pred on target rows)≈8.968 | max|pred| on non-target rows≈0.913\n",
            "[ep 24300] loss=1.918241e-04 | mean(pred on target rows)≈9.006 | max|pred| on non-target rows≈0.907\n",
            "[ep 24350] loss=2.681496e-04 | mean(pred on target rows)≈9.010 | max|pred| on non-target rows≈0.909\n",
            "[ep 24400] loss=5.624463e-04 | mean(pred on target rows)≈8.982 | max|pred| on non-target rows≈0.910\n",
            "[ep 24450] loss=9.050777e-04 | mean(pred on target rows)≈9.007 | max|pred| on non-target rows≈0.911\n",
            "[ep 24500] loss=6.705438e-04 | mean(pred on target rows)≈9.007 | max|pred| on non-target rows≈0.913\n",
            "[ep 24550] loss=8.468159e-04 | mean(pred on target rows)≈9.036 | max|pred| on non-target rows≈0.901\n",
            "[ep 24600] loss=6.570933e-05 | mean(pred on target rows)≈9.001 | max|pred| on non-target rows≈0.911\n",
            "[ep 24650] loss=2.643480e-04 | mean(pred on target rows)≈9.003 | max|pred| on non-target rows≈0.912\n",
            "[ep 24700] loss=1.176208e-03 | mean(pred on target rows)≈8.989 | max|pred| on non-target rows≈0.911\n",
            "[ep 24750] loss=3.993301e-04 | mean(pred on target rows)≈9.021 | max|pred| on non-target rows≈0.897\n",
            "[ep 24800] loss=1.570974e-03 | mean(pred on target rows)≈8.973 | max|pred| on non-target rows≈0.913\n",
            "[ep 24850] loss=1.887605e-04 | mean(pred on target rows)≈8.983 | max|pred| on non-target rows≈0.913\n",
            "[ep 24900] loss=1.335908e-03 | mean(pred on target rows)≈8.965 | max|pred| on non-target rows≈0.918\n",
            "[ep 24950] loss=5.086678e-04 | mean(pred on target rows)≈9.019 | max|pred| on non-target rows≈0.910\n",
            "[ep 25000] loss=5.955189e-04 | mean(pred on target rows)≈8.973 | max|pred| on non-target rows≈0.911\n",
            "[ep 25050] loss=1.846657e-03 | mean(pred on target rows)≈8.969 | max|pred| on non-target rows≈0.907\n",
            "[ep 25100] loss=2.695753e-04 | mean(pred on target rows)≈9.021 | max|pred| on non-target rows≈0.907\n",
            "[ep 25150] loss=2.958992e-04 | mean(pred on target rows)≈9.008 | max|pred| on non-target rows≈0.910\n",
            "[ep 25200] loss=1.805736e-04 | mean(pred on target rows)≈9.007 | max|pred| on non-target rows≈0.911\n",
            "[ep 25250] loss=1.687259e-03 | mean(pred on target rows)≈8.944 | max|pred| on non-target rows≈0.886\n",
            "[ep 25300] loss=4.142710e-04 | mean(pred on target rows)≈8.997 | max|pred| on non-target rows≈0.913\n",
            "[ep 25350] loss=4.138176e-04 | mean(pred on target rows)≈8.994 | max|pred| on non-target rows≈0.912\n",
            "[ep 25400] loss=2.722384e-02 | mean(pred on target rows)≈9.187 | max|pred| on non-target rows≈0.954\n",
            "[ep 25450] loss=9.495806e-04 | mean(pred on target rows)≈9.012 | max|pred| on non-target rows≈0.920\n",
            "[ep 25500] loss=2.359806e-05 | mean(pred on target rows)≈9.001 | max|pred| on non-target rows≈0.910\n",
            "[ep 25550] loss=1.698530e-05 | mean(pred on target rows)≈9.002 | max|pred| on non-target rows≈0.910\n",
            "[ep 25600] loss=1.578374e-05 | mean(pred on target rows)≈8.997 | max|pred| on non-target rows≈0.917\n",
            "[ep 25650] loss=1.082634e-04 | mean(pred on target rows)≈9.014 | max|pred| on non-target rows≈0.911\n",
            "[ep 25700] loss=2.315744e-04 | mean(pred on target rows)≈8.979 | max|pred| on non-target rows≈0.914\n",
            "[ep 25750] loss=2.668747e-05 | mean(pred on target rows)≈9.006 | max|pred| on non-target rows≈0.910\n",
            "[ep 25800] loss=2.423457e-05 | mean(pred on target rows)≈8.994 | max|pred| on non-target rows≈0.911\n",
            "[ep 25850] loss=5.029662e-04 | mean(pred on target rows)≈9.031 | max|pred| on non-target rows≈0.902\n",
            "[ep 25900] loss=2.799429e-05 | mean(pred on target rows)≈9.004 | max|pred| on non-target rows≈0.904\n",
            "[ep 25950] loss=2.688972e-03 | mean(pred on target rows)≈9.057 | max|pred| on non-target rows≈0.895\n",
            "[ep 26000] loss=6.712584e-04 | mean(pred on target rows)≈9.031 | max|pred| on non-target rows≈0.903\n",
            "[ep 26050] loss=2.382026e-04 | mean(pred on target rows)≈9.010 | max|pred| on non-target rows≈0.911\n",
            "[ep 26100] loss=1.282404e-03 | mean(pred on target rows)≈8.968 | max|pred| on non-target rows≈0.911\n",
            "[ep 26150] loss=1.175204e-04 | mean(pred on target rows)≈8.997 | max|pred| on non-target rows≈0.909\n",
            "[ep 26200] loss=1.785815e-04 | mean(pred on target rows)≈9.004 | max|pred| on non-target rows≈0.911\n",
            "[ep 26250] loss=2.671267e-04 | mean(pred on target rows)≈8.977 | max|pred| on non-target rows≈0.915\n",
            "[ep 26300] loss=8.671010e-04 | mean(pred on target rows)≈9.040 | max|pred| on non-target rows≈0.909\n",
            "[ep 26350] loss=2.316416e-04 | mean(pred on target rows)≈8.980 | max|pred| on non-target rows≈0.910\n",
            "[ep 26400] loss=4.160937e-04 | mean(pred on target rows)≈8.998 | max|pred| on non-target rows≈0.911\n",
            "[ep 26450] loss=5.011366e-04 | mean(pred on target rows)≈9.003 | max|pred| on non-target rows≈0.899\n",
            "[ep 26500] loss=2.785400e-04 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.909\n",
            "[ep 26550] loss=8.210186e-04 | mean(pred on target rows)≈9.032 | max|pred| on non-target rows≈0.911\n",
            "[ep 26600] loss=3.086184e-04 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.912\n",
            "[ep 26650] loss=5.741948e-04 | mean(pred on target rows)≈9.023 | max|pred| on non-target rows≈0.911\n",
            "[ep 26700] loss=1.560178e-03 | mean(pred on target rows)≈9.013 | max|pred| on non-target rows≈0.908\n",
            "[ep 26750] loss=1.917129e-04 | mean(pred on target rows)≈8.984 | max|pred| on non-target rows≈0.910\n",
            "[ep 26800] loss=9.548318e-04 | mean(pred on target rows)≈9.014 | max|pred| on non-target rows≈0.912\n",
            "[ep 26850] loss=1.099399e-03 | mean(pred on target rows)≈9.044 | max|pred| on non-target rows≈0.896\n",
            "[ep 26900] loss=1.716647e-03 | mean(pred on target rows)≈9.043 | max|pred| on non-target rows≈0.907\n",
            "[ep 26950] loss=1.713752e-03 | mean(pred on target rows)≈9.033 | max|pred| on non-target rows≈0.906\n",
            "[ep 27000] loss=7.650091e-04 | mean(pred on target rows)≈9.036 | max|pred| on non-target rows≈0.903\n",
            "[ep 27050] loss=1.307711e-04 | mean(pred on target rows)≈8.993 | max|pred| on non-target rows≈0.912\n",
            "[ep 27100] loss=5.781760e-04 | mean(pred on target rows)≈9.013 | max|pred| on non-target rows≈0.906\n",
            "[ep 27150] loss=1.297805e-03 | mean(pred on target rows)≈9.043 | max|pred| on non-target rows≈0.906\n",
            "[ep 27200] loss=1.734305e-04 | mean(pred on target rows)≈9.009 | max|pred| on non-target rows≈0.909\n",
            "[ep 27250] loss=4.639466e-04 | mean(pred on target rows)≈9.014 | max|pred| on non-target rows≈0.911\n",
            "[ep 27300] loss=5.644392e-04 | mean(pred on target rows)≈8.995 | max|pred| on non-target rows≈0.908\n",
            "[ep 27350] loss=1.522587e-04 | mean(pred on target rows)≈8.994 | max|pred| on non-target rows≈0.909\n",
            "[ep 27400] loss=1.481782e-03 | mean(pred on target rows)≈8.983 | max|pred| on non-target rows≈0.913\n",
            "[ep 27450] loss=4.751695e-04 | mean(pred on target rows)≈9.029 | max|pred| on non-target rows≈0.906\n",
            "[ep 27500] loss=2.427225e-04 | mean(pred on target rows)≈8.986 | max|pred| on non-target rows≈0.910\n",
            "[ep 27550] loss=7.077501e-04 | mean(pred on target rows)≈8.987 | max|pred| on non-target rows≈0.908\n",
            "[ep 27600] loss=3.549227e-04 | mean(pred on target rows)≈9.025 | max|pred| on non-target rows≈0.908\n",
            "[ep 27650] loss=9.347398e-04 | mean(pred on target rows)≈9.023 | max|pred| on non-target rows≈0.904\n",
            "[ep 27700] loss=8.126711e-05 | mean(pred on target rows)≈9.001 | max|pred| on non-target rows≈0.912\n",
            "[ep 27750] loss=2.975444e-04 | mean(pred on target rows)≈8.984 | max|pred| on non-target rows≈0.900\n",
            "[ep 27800] loss=5.761930e-04 | mean(pred on target rows)≈8.982 | max|pred| on non-target rows≈0.911\n",
            "[ep 27850] loss=3.563912e-04 | mean(pred on target rows)≈8.978 | max|pred| on non-target rows≈0.910\n",
            "[ep 27900] loss=5.532949e-04 | mean(pred on target rows)≈9.002 | max|pred| on non-target rows≈0.909\n",
            "[ep 27950] loss=1.235254e-03 | mean(pred on target rows)≈9.044 | max|pred| on non-target rows≈0.906\n",
            "[ep 28000] loss=5.602698e-04 | mean(pred on target rows)≈9.018 | max|pred| on non-target rows≈0.910\n",
            "[ep 28050] loss=3.395537e-04 | mean(pred on target rows)≈9.019 | max|pred| on non-target rows≈0.905\n",
            "[ep 28100] loss=6.703028e-04 | mean(pred on target rows)≈9.028 | max|pred| on non-target rows≈0.884\n",
            "[ep 28150] loss=4.495778e-04 | mean(pred on target rows)≈8.989 | max|pred| on non-target rows≈0.915\n",
            "[ep 28200] loss=1.329035e-03 | mean(pred on target rows)≈8.956 | max|pred| on non-target rows≈0.915\n",
            "[ep 28250] loss=2.570663e-03 | mean(pred on target rows)≈8.936 | max|pred| on non-target rows≈0.905\n",
            "[ep 28300] loss=4.485696e-04 | mean(pred on target rows)≈8.982 | max|pred| on non-target rows≈0.911\n",
            "[ep 28350] loss=1.367075e-03 | mean(pred on target rows)≈9.012 | max|pred| on non-target rows≈0.909\n",
            "[ep 28400] loss=5.200321e-04 | mean(pred on target rows)≈9.004 | max|pred| on non-target rows≈0.910\n",
            "[ep 28450] loss=1.223202e-04 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.909\n",
            "[ep 28500] loss=1.245163e-04 | mean(pred on target rows)≈9.001 | max|pred| on non-target rows≈0.912\n",
            "[ep 28550] loss=5.415339e-04 | mean(pred on target rows)≈9.008 | max|pred| on non-target rows≈0.903\n",
            "[ep 28600] loss=4.384699e-04 | mean(pred on target rows)≈8.979 | max|pred| on non-target rows≈0.911\n",
            "[ep 28650] loss=8.078145e-04 | mean(pred on target rows)≈9.030 | max|pred| on non-target rows≈0.910\n",
            "[ep 28700] loss=1.671602e-04 | mean(pred on target rows)≈8.993 | max|pred| on non-target rows≈0.909\n",
            "[ep 28750] loss=3.614617e-04 | mean(pred on target rows)≈8.985 | max|pred| on non-target rows≈0.907\n",
            "[ep 28800] loss=2.393352e-03 | mean(pred on target rows)≈9.046 | max|pred| on non-target rows≈0.904\n",
            "[ep 28850] loss=5.651401e-05 | mean(pred on target rows)≈9.003 | max|pred| on non-target rows≈0.906\n",
            "[ep 28900] loss=1.874621e-04 | mean(pred on target rows)≈9.002 | max|pred| on non-target rows≈0.911\n",
            "[ep 28950] loss=2.747256e-04 | mean(pred on target rows)≈9.014 | max|pred| on non-target rows≈0.905\n",
            "[ep 29000] loss=1.921753e-03 | mean(pred on target rows)≈8.955 | max|pred| on non-target rows≈0.912\n",
            "[ep 29050] loss=2.185225e-04 | mean(pred on target rows)≈8.993 | max|pred| on non-target rows≈0.908\n",
            "[ep 29100] loss=2.865409e-03 | mean(pred on target rows)≈9.051 | max|pred| on non-target rows≈0.905\n",
            "[ep 29150] loss=1.016471e-04 | mean(pred on target rows)≈9.009 | max|pred| on non-target rows≈0.906\n",
            "[ep 29200] loss=4.027022e-04 | mean(pred on target rows)≈8.981 | max|pred| on non-target rows≈0.919\n",
            "[ep 29250] loss=8.873083e-04 | mean(pred on target rows)≈8.997 | max|pred| on non-target rows≈0.913\n",
            "[ep 29300] loss=1.221855e-04 | mean(pred on target rows)≈8.997 | max|pred| on non-target rows≈0.906\n",
            "[ep 29350] loss=4.500705e-03 | mean(pred on target rows)≈8.943 | max|pred| on non-target rows≈0.923\n",
            "[ep 29400] loss=2.126244e-05 | mean(pred on target rows)≈8.998 | max|pred| on non-target rows≈0.908\n",
            "[ep 29450] loss=1.556330e-05 | mean(pred on target rows)≈8.999 | max|pred| on non-target rows≈0.908\n",
            "[ep 29500] loss=7.016941e-04 | mean(pred on target rows)≈9.037 | max|pred| on non-target rows≈0.902\n",
            "[ep 29550] loss=8.834422e-05 | mean(pred on target rows)≈9.009 | max|pred| on non-target rows≈0.912\n",
            "[ep 29600] loss=3.168192e-04 | mean(pred on target rows)≈8.994 | max|pred| on non-target rows≈0.909\n",
            "[ep 29650] loss=1.367310e-04 | mean(pred on target rows)≈9.010 | max|pred| on non-target rows≈0.905\n",
            "[ep 29700] loss=3.367692e-04 | mean(pred on target rows)≈8.985 | max|pred| on non-target rows≈0.916\n",
            "[ep 29750] loss=1.979505e-04 | mean(pred on target rows)≈8.986 | max|pred| on non-target rows≈0.911\n",
            "[ep 29800] loss=1.767347e-04 | mean(pred on target rows)≈9.009 | max|pred| on non-target rows≈0.907\n",
            "[ep 29850] loss=1.222831e-03 | mean(pred on target rows)≈9.041 | max|pred| on non-target rows≈0.906\n",
            "[ep 29900] loss=5.781264e-04 | mean(pred on target rows)≈8.989 | max|pred| on non-target rows≈0.912\n",
            "[ep 29950] loss=1.205447e-03 | mean(pred on target rows)≈9.039 | max|pred| on non-target rows≈0.904\n",
            "[ep 30000] loss=2.837818e-04 | mean(pred on target rows)≈8.982 | max|pred| on non-target rows≈0.910\n",
            "\n",
            "--- INDICATOR+LINEAR CHECK (anchor-averaged) ---\n",
            "avg_pred(target) = 6.811  (true = 9.000)  -> within tol? False\n",
            "max |avg_pred - true|(all anchors) = 4.541 -> within tol? False\n",
            "mean |avg_pred - true|(all anchors) = 0.096\n",
            "\n",
            "--- Monte Carlo FULL-FUNCTION sanity check ---\n",
            "MC MSE over anchors×64: 0.259005\n",
            "(scalar=10.0, linear_coeff=0.1, d_noise=500)\n"
          ]
        }
      ],
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import os, math\n",
        "\n",
        "# ---------------------------\n",
        "# Config (tune these)\n",
        "# ---------------------------\n",
        "seed = 0\n",
        "torch.manual_seed(seed); np.random.seed(seed)\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "checkpoint_dir = \"\"         # set a dir to save, else \"\"\n",
        "\n",
        "d_indicator = 10\n",
        "d_noise     = 500           # 500 additional variables with NO effect\n",
        "d_total     = d_indicator + d_noise\n",
        "N_indicator = 1 << d_indicator   # 1024\n",
        "\n",
        "# Train\n",
        "epochs       = 30000\n",
        "print_every  = 50\n",
        "lr           = 2e-3\n",
        "weight_decay = 5e-4\n",
        "hidden       = 128\n",
        "\n",
        "# Task\n",
        "scalar                = 10.0    # indicator spike value when first-10 == target (kept same)\n",
        "linear_coeff          = 0.1    # NEW: penalty per +1 among the first 10 bits (negative effect)\n",
        "min_points            = 65536   # total >=2048; higher => more reps per anchor\n",
        "target_extra_reps     = 2048    # extra rows for the target anchor (oversampling)\n",
        "pos_weight            = 16.0    # upweight target-anchor rows in MSE (pure MSE, just weighting)\n",
        "\n",
        "# Eval\n",
        "eval_R = 256  # noise samples per anchor for evaluation (anchor-averaged)\n",
        "\n",
        "# ---------------------------\n",
        "# Utilities\n",
        "# ---------------------------\n",
        "def full_cube_pm1(dim, device):\n",
        "    n = 1 << dim\n",
        "    idx = torch.arange(n, device=device, dtype=torch.long).unsqueeze(1)\n",
        "    shifts = torch.arange(dim, device=device, dtype=torch.long).unsqueeze(0)\n",
        "    bits = (idx >> shifts) & 1\n",
        "    return bits.float().mul_(2).sub_(1)\n",
        "\n",
        "def rand_pm1(shape, device):\n",
        "    return (torch.randint(0, 2, shape, device=device, dtype=torch.int8).to(torch.float32) * 2 - 1)\n",
        "\n",
        "def anchor_id_from_pm1(x10):  # x10: [...,10] in {-1,+1}\n",
        "    bits01 = ((x10 + 1.0) * 0.5).long()\n",
        "    shifts = (2 ** torch.arange(d_indicator, device=x10.device, dtype=torch.long)).view(1, -1)\n",
        "    return (bits01 * shifts).sum(dim=-1)  # long\n",
        "\n",
        "def unitation_pm1(x10):  # NEW: count of +1s in the first 10 bits\n",
        "    return (x10 == 1.0).sum(dim=-1)\n",
        "\n",
        "# ---------------------------\n",
        "# Data\n",
        "# ---------------------------\n",
        "# CHANGED: target pattern is fixed to all +1s\n",
        "target10 = torch.ones(d_indicator, device=device, dtype=torch.float32)\n",
        "print(\"target10 (fixed to all +1s):\", target10.cpu().numpy())\n",
        "\n",
        "anchors = full_cube_pm1(d_indicator, device)           # [1024,10]\n",
        "reps = int(math.ceil(min_points / N_indicator))        # base repeats per anchor\n",
        "if reps % 2 == 1:\n",
        "    reps += 1  # need even to make ± pairs\n",
        "target_idx = anchor_id_from_pm1(target10.unsqueeze(0)).item()\n",
        "\n",
        "rows, ys, weights = [], [], []\n",
        "\n",
        "# Precompute linear term per anchor (depends only on first 10 bits)\n",
        "u_per_anchor = unitation_pm1(anchors)                  # [1024], in {0..10}\n",
        "linear_val_per_anchor = (-linear_coeff * u_per_anchor).to(torch.float32)  # NEW\n",
        "\n",
        "# Base coverage with paired ± noise tails for every anchor\n",
        "for i in range(N_indicator):\n",
        "    a = anchors[i]\n",
        "    half = reps // 2\n",
        "    X_noise_half = rand_pm1((half, d_noise), device)              # [half,500]\n",
        "    X_noise = torch.cat([X_noise_half, -X_noise_half], dim=0)     # [reps,500]\n",
        "    X_ind = a.expand(reps, -1)                                    # [reps,10]\n",
        "    Xb = torch.cat([X_ind, X_noise], dim=1)                       # [reps, 10+500]\n",
        "\n",
        "    # CHANGED: target = indicator spike (if i==target_idx) + linear penalty\n",
        "    y_anchor = (scalar if i == target_idx else 0.0) + linear_val_per_anchor[i].item()\n",
        "    yb = torch.full((reps,), y_anchor, device=device, dtype=torch.float32)\n",
        "\n",
        "    rows.append(Xb); ys.append(yb)\n",
        "    w = torch.full((reps,), pos_weight if i == target_idx else 1.0, device=device)\n",
        "    weights.append(w)\n",
        "\n",
        "# Extra target rows (still pure MSE). Keep them paired as well.\n",
        "if target_extra_reps > 0:\n",
        "    if target_extra_reps % 2 == 1:\n",
        "        target_extra_reps += 1\n",
        "    a = anchors[target_idx]\n",
        "    half = target_extra_reps // 2\n",
        "    X_noise_half = rand_pm1((half, d_noise), device)\n",
        "    X_noise = torch.cat([X_noise_half, -X_noise_half], dim=0)     # [extra,500]\n",
        "    X_ind = a.expand(target_extra_reps, -1)\n",
        "    Xb = torch.cat([X_ind, X_noise], dim=1)\n",
        "\n",
        "    # CHANGED: include linear penalty at the target anchor as well\n",
        "    y_anchor_target = scalar + linear_val_per_anchor[target_idx].item()\n",
        "    yb = torch.full((target_extra_reps,), y_anchor_target, device=device, dtype=torch.float32)\n",
        "\n",
        "    rows.append(Xb); ys.append(yb)\n",
        "    weights.append(torch.full((target_extra_reps,), pos_weight, device=device))\n",
        "\n",
        "X_train = torch.cat(rows, dim=0)        # [M, 10+500]\n",
        "y_train = torch.cat(ys, dim=0)          # [M]\n",
        "w_train = torch.cat(weights, dim=0)     # [M]\n",
        "M = X_train.size(0)\n",
        "print(f\"Train size: {M} | base reps/anchor: {reps} | target extra reps: {target_extra_reps}\")\n",
        "\n",
        "# For quick training-time logs (kept same as before)\n",
        "is_target_row = (X_train[:, :d_indicator] == target10).all(dim=1)\n",
        "\n",
        "# ---------------------------\n",
        "# Model (single FCNN)\n",
        "# ---------------------------\n",
        "class FCNN(nn.Module):\n",
        "    def __init__(self, in_dim=10+500, hidden=128):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(in_dim, hidden), nn.ReLU(),\n",
        "            nn.Linear(hidden, hidden), nn.ReLU(),\n",
        "            nn.Linear(hidden, 1),\n",
        "        )\n",
        "        for m in self.modules():\n",
        "            if isinstance(m, nn.Linear):\n",
        "                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)\n",
        "    def forward(self, x): return self.net(x).squeeze(-1)\n",
        "\n",
        "model = FCNN(d_total, hidden).to(device)\n",
        "opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "\n",
        "def weighted_mse(pred, target, w): return (w * (pred - target).pow(2)).mean()\n",
        "\n",
        "# ---------------------------\n",
        "# Train (full-batch, pure MSE)\n",
        "# ---------------------------\n",
        "for ep in range(1, epochs + 1):\n",
        "    model.train()\n",
        "    pred = model(X_train)\n",
        "    loss = weighted_mse(pred, y_train, w_train)\n",
        "\n",
        "    opt.zero_grad(set_to_none=True)\n",
        "    loss.backward()\n",
        "    nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
        "    opt.step()\n",
        "\n",
        "    if ep % print_every == 0 or ep == 1 or ep == epochs:\n",
        "        with torch.no_grad():\n",
        "            p_target_rows = pred[is_target_row].mean().item()\n",
        "            p_non_rows    = pred[~is_target_row].abs().max().item()\n",
        "        print(f\"[ep {ep:5d}] loss={loss.item():.6e} | \"\n",
        "              f\"mean(pred on target rows)≈{p_target_rows:.3f} | \"\n",
        "              f\"max|pred| on non-target rows≈{p_non_rows:.3f}\")\n",
        "\n",
        "    if ep % 5 == 0:\n",
        "        #if checkpoint_dir:\n",
        "        ckpt_path = os.path.join(checkpoint_dir, f\"ckpt_epoch_{ep}.pt\")\n",
        "        torch.save({\"epoch\": ep, \"model_state_dict\": model.state_dict()}, ckpt_path)\n",
        "\n",
        "# ---------------------------\n",
        "# Final evaluation (anchor check over all 2^10; average out noise tail)\n",
        "# ---------------------------\n",
        "model.eval()\n",
        "with torch.no_grad():\n",
        "    x_noise_eval = rand_pm1((eval_R, d_noise), device)\n",
        "    X1 = anchors.repeat_interleave(eval_R, dim=0)    # [1024*R,10]\n",
        "    X2 = x_noise_eval.repeat(N_indicator, 1)         # [1024*R,500]\n",
        "    X_eval = torch.cat([X1, X2], dim=1)              # [1024*R,510]\n",
        "    preds_avg = model(X_eval).view(N_indicator, eval_R).mean(dim=1)\n",
        "\n",
        "    # CHANGED: true function per anchor = indicator spike + linear penalty\n",
        "    true_vals = linear_val_per_anchor.clone()                         # start with linear penalty\n",
        "    true_vals[target_idx] = true_vals[target_idx] + scalar            # add indicator spike at target\n",
        "\n",
        "    target_avg = preds_avg[target_idx].item()\n",
        "    non_avg    = preds_avg[torch.arange(N_indicator, device=device) != target_idx]\n",
        "\n",
        "    # Deviations vs true values\n",
        "    errs = preds_avg - true_vals\n",
        "    max_abs_err = errs.abs().max().item()\n",
        "    mean_abs_err = errs.abs().mean().item()\n",
        "\n",
        "    # Full-function Monte Carlo MSE (depends only on first 10 bits)\n",
        "    mc_R = 64\n",
        "    X1_mc = anchors.repeat_interleave(mc_R, dim=0)\n",
        "    X2_mc = rand_pm1((mc_R, d_noise), device).repeat(N_indicator, 1)\n",
        "    X_mc  = torch.cat([X1_mc, X2_mc], dim=1)\n",
        "    pred_mc = model(X_mc)\n",
        "\n",
        "    # True y for MC rows\n",
        "    u_mc = unitation_pm1(X1_mc[:, :d_indicator])\n",
        "    ind_mc = (anchor_id_from_pm1(X1_mc[:, :d_indicator]) == target_idx).float()\n",
        "    y_mc = (-linear_coeff * u_mc.to(torch.float32)) + ind_mc * scalar\n",
        "    mse_mc = (pred_mc - y_mc).pow(2).mean().item()\n",
        "\n",
        "tol = 0.25\n",
        "print(\"\\n--- INDICATOR+LINEAR CHECK (anchor-averaged) ---\")\n",
        "print(f\"avg_pred(target) = {target_avg:.3f}  (true = {true_vals[target_idx].item():.3f})  -> within tol? {abs(target_avg - true_vals[target_idx].item()) <= tol}\")\n",
        "print(f\"max |avg_pred - true|(all anchors) = {max_abs_err:.3f} -> within tol? {max_abs_err <= tol}\")\n",
        "print(f\"mean |avg_pred - true|(all anchors) = {mean_abs_err:.3f}\")\n",
        "\n",
        "print(\"\\n--- Monte Carlo FULL-FUNCTION sanity check ---\")\n",
        "print(f\"MC MSE over anchors×{mc_R}: {mse_mc:.6f}\")\n",
        "print(f\"(scalar={scalar}, linear_coeff={linear_coeff}, d_noise={d_noise})\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "--- LINEAR EFFECT CHECK (EXCLUDING TARGET) ---\n",
            "u | mean_pred    (±std)   | count | true=-linear_coeff*u\n",
            " 0 |    -0.002 (±0.000) |     1 |    -0.000\n",
            " 1 |    -0.102 (±0.000) |    10 |    -0.100\n",
            " 2 |    -0.201 (±0.000) |    45 |    -0.200\n",
            " 3 |    -0.300 (±0.000) |   120 |    -0.300\n",
            " 4 |    -0.400 (±0.000) |   210 |    -0.400\n",
            " 5 |    -0.499 (±0.000) |   252 |    -0.500\n",
            " 6 |    -0.599 (±0.000) |   210 |    -0.600\n",
            " 7 |    -0.700 (±0.000) |   120 |    -0.700\n",
            " 8 |     0.313 (±0.046) |    45 |    -0.800\n",
            " 9 |     3.572 (±0.036) |    10 |    -0.900\n",
            "10 |       nan (±nan) |     0 |    -1.000\n",
            "\n",
            "Successive diffs mean_pred[u] - mean_pred[u+1] (should be ≥ 0):\n",
            "0.099, 0.099, 0.099, 0.099, 0.099, 0.099, 0.101, -1.013, -3.259, nan\n",
            "Monotone decreasing check: violations=2, worst_diff=-3.259 (should be ≥ -tol, tol=0.010)\n",
            "\n",
            "Estimated slope d(pred)/du ≈ 0.191 (target ≈ -0.100)\n",
            "Correlation(pred, u) ≈ 0.462 (should be strongly negative)\n",
            "MSE vs true linear (non-target anchors): 0.250151\n"
          ]
        }
      ],
      "source": [
        "# ---------------------------\n",
        "# Linear-effect diagnostic (exclude the target anchor)\n",
        "# ---------------------------\n",
        "model.eval()\n",
        "with torch.no_grad():\n",
        "    # Fresh noise; anchor-averaged preds per anchor\n",
        "    x_noise_eval = rand_pm1((eval_R, d_noise), device)\n",
        "    X1 = anchors.repeat_interleave(eval_R, dim=0)    # [1024*R, 10]\n",
        "    X2 = x_noise_eval.repeat(N_indicator, 1)         # [1024*R, 500]\n",
        "    X_eval = torch.cat([X1, X2], dim=1)              # [1024*R, 510]\n",
        "    preds_avg = model(X_eval).view(N_indicator, eval_R).mean(dim=1)  # [1024]\n",
        "\n",
        "    u_all = (anchors == 1.0).sum(dim=1)             # unitation per anchor: 0..10\n",
        "    mask_non_target = torch.arange(N_indicator, device=device) != target_idx\n",
        "\n",
        "    preds_non = preds_avg[mask_non_target]\n",
        "    u_non = u_all[mask_non_target]\n",
        "\n",
        "    # Per-u stats (exclude target)\n",
        "    means, stds, counts = [], [], []\n",
        "    for uval in range(d_indicator + 1):\n",
        "        m = (u_non == uval)\n",
        "        if m.any():\n",
        "            vals = preds_non[m]\n",
        "            means.append(vals.mean().item())\n",
        "            stds.append(vals.std(unbiased=False).item())\n",
        "            counts.append(int(m.sum().item()))\n",
        "        else:\n",
        "            means.append(float('nan')); stds.append(float('nan')); counts.append(0)\n",
        "\n",
        "    # True per-u values for the linear part (no indicator since target excluded)\n",
        "    true_per_u = [-linear_coeff * u for u in range(d_indicator + 1)]\n",
        "\n",
        "    # Monotonicity: should decrease with u (so mean[u] >= mean[u+1])\n",
        "    diffs = [means[u] - means[u+1] for u in range(d_indicator)]\n",
        "    tol = 0.10 * abs(linear_coeff) + 1e-9  # tolerance for small noise/fit jitter\n",
        "    violations = [i for i, d in enumerate(diffs) if d < -tol]\n",
        "    worst_diff = min(diffs) if len(diffs) > 0 else float('nan')\n",
        "\n",
        "    # Simple slope / correlation vs u (should be ~ -linear_coeff and strongly negative)\n",
        "    import numpy as np\n",
        "    us = np.array([u for u in range(d_indicator + 1) if counts[u] > 0], dtype=np.float32)\n",
        "    ms = np.array([means[u] for u in range(d_indicator + 1) if counts[u] > 0], dtype=np.float32)\n",
        "    if len(us) > 1:\n",
        "        slope = float(((us - us.mean()) * (ms - ms.mean())).sum() /\n",
        "                      ((us - us.mean()) ** 2).sum())\n",
        "        corr = float(((us - us.mean()) * (ms - ms.mean())).sum() /\n",
        "                     (np.sqrt(((us - us.mean()) ** 2).sum() * ((ms - ms.mean()) ** 2).sum()) + 1e-12))\n",
        "    else:\n",
        "        slope = float('nan'); corr = float('nan')\n",
        "\n",
        "    # MSE vs true linear for non-target anchors\n",
        "    true_non = (-linear_coeff * u_non.to(torch.float32))\n",
        "    mse_non = (preds_non - true_non).pow(2).mean().item()\n",
        "\n",
        "print(\"\\n--- LINEAR EFFECT CHECK (EXCLUDING TARGET) ---\")\n",
        "print(\"u | mean_pred    (±std)   | count | true=-linear_coeff*u\")\n",
        "for u in range(d_indicator + 1):\n",
        "    print(f\"{u:2d} | {means[u]:9.3f} (±{stds[u]:.3f}) | {counts[u]:5d} | {true_per_u[u]:9.3f}\")\n",
        "\n",
        "print(\"\\nSuccessive diffs mean_pred[u] - mean_pred[u+1] (should be ≥ 0):\")\n",
        "print(\", \".join([f\"{d:.3f}\" for d in diffs]))\n",
        "print(f\"Monotone decreasing check: violations={len(violations)}, worst_diff={worst_diff:.3f} (should be ≥ -tol, tol={tol:.3f})\")\n",
        "\n",
        "print(f\"\\nEstimated slope d(pred)/du ≈ {slope:.3f} (target ≈ {-linear_coeff:.3f})\")\n",
        "print(f\"Correlation(pred, u) ≈ {corr:.3f} (should be strongly negative)\")\n",
        "print(f\"MSE vs true linear (non-target anchors): {mse_non:.6f}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "SSd3FGCuOX4m"
      },
      "outputs": [],
      "source": [
        "D = d_total"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "f0ing7qe4VH1"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import numpy as np\n",
        "import os, glob\n",
        "import tqdm\n",
        "import math\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "class GWGSampler:\n",
        "    def __init__(self, model, beta=1.0):\n",
        "        self.model = model\n",
        "        self.beta = float(beta)\n",
        "\n",
        "    def _energy(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        # NEGATIVE sign: lower energy = higher model output\n",
        "        y = self.model(x.view(1, -1)).view(())\n",
        "        return -y\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def _deltas_exact(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        device = x.device\n",
        "        D = x.numel()\n",
        "        y = self._energy(x)  # scalar E(x)\n",
        "\n",
        "        # vectorized single-bit flips\n",
        "        X = x.unsqueeze(0).repeat(D, 1)\n",
        "        idx = torch.arange(D, device=device)\n",
        "        X[idx, idx] = -X[idx, idx]\n",
        "        y_flips = torch.vmap(self._energy)(X)  # or: torch.stack([self._energy(X[i]) for i in range(D)])\n",
        "        return y_flips - y  # Δ_i = E(x^i) - E(x)\n",
        "\n",
        "\n",
        "    def _deltas_grad(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        # GWG approx: Δ_i ≈ -2 x_i ∂_i E(x) = 2 x_i ∂_i model(x)\n",
        "        x = x.detach().clone().requires_grad_(True)\n",
        "        y = self.model(x.view(1, -1)).view(())\n",
        "        (g,) = torch.autograd.grad(y, x, create_graph=False, retain_graph=False)\n",
        "        return (2.0 * x * g).detach()\n",
        "\n",
        "    #@torch.no_grad()\n",
        "    def single_step(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        x = x.detach().clone()\n",
        "        deltas = self._deltas_exact(x)                # Δ_i\n",
        "        #deltas = self._deltas_grad(x)\n",
        "\n",
        "        # coordinate proposal p(i) ∝ exp(-β Δ_i / 2)\n",
        "        logits = -self.beta * deltas / 2.0\n",
        "        probs  = torch.softmax(logits, dim=0)\n",
        "        i = torch.multinomial(probs, 1).item()\n",
        "\n",
        "        # candidate flip\n",
        "        x_new = x.clone(); x_new[i] = -x_new[i]\n",
        "\n",
        "        # MH correction (exact reverse proposal)\n",
        "        deltas_p = self._deltas_exact(x_new)\n",
        "        #deltas_p = self._deltas_grad(x_new)\n",
        "        q_fwd = probs[i]\n",
        "        q_rev = torch.softmax(-self.beta * deltas_p / 2.0, dim=0)[i]\n",
        "        delta_i = deltas[i]\n",
        "\n",
        "        accept = torch.exp(-self.beta * delta_i) * (q_rev / q_fwd)\n",
        "        if torch.rand((), device=x.device) < torch.clamp(accept, max=1.0):\n",
        "            return x_new.detach()\n",
        "        return x.detach()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "5-qyZqMa65dP"
      },
      "outputs": [],
      "source": [
        "def sampling_via_checkpoints(\n",
        "    checkpoint_dir: str,\n",
        "    epochs: list[int],\n",
        "    FCNNClass,\n",
        "    GWGSamplerClass,\n",
        "    num_particles: int = 200,\n",
        "    mcmc_steps: int = 15,\n",
        "    resample_thresh: float = 0.5,\n",
        "    device: str = \"cuda\",\n",
        "    beta: float = 1.0\n",
        "):\n",
        "    epochs = sorted(epochs)\n",
        "    ckpts = [os.path.join(checkpoint_dir, f\"ckpt_epoch_{e}.pt\") for e in epochs]\n",
        "\n",
        "    D = d_total\n",
        "    particles = (torch.randint(0, 2, (num_particles, D), device=device) * 2 - 1).float()\n",
        "\n",
        "\n",
        "    for t, ckpt in enumerate(ckpts):\n",
        "        # load model\n",
        "        model = FCNNClass(D, hidden).to(device).eval()\n",
        "        sd = torch.load(ckpt, map_location=device)\n",
        "        model.load_state_dict(sd['model_state_dict'])\n",
        "\n",
        "\n",
        "        # GWG rejuvenation targeting current energy\n",
        "        sampler = GWGSamplerClass(model, beta=beta)\n",
        "        for i in range(num_particles):\n",
        "            x = particles[i]\n",
        "            for _ in range(mcmc_steps):\n",
        "                x = sampler.single_step(x)\n",
        "            particles[i] = x\n",
        "\n",
        "        # progress\n",
        "        with torch.no_grad():\n",
        "            # Euclidean distance: ||x - x*|| = 2 * sqrt(Hamming)\n",
        "            deltas_L2 = (particles[:,:d_indicator] - target10.view(1,-1)).norm(dim=1)\n",
        "            #print(\"distance to target1 (L2):\", deltas_L2.min(), deltas_L2.median(), deltas_L2.max())\n",
        "\n",
        "    return particles.cpu().numpy(), (particles[:,:d_indicator] - target10.view(1,-1)).min() == 0.0\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 217
        },
        "id": "VfxJdoLS-JDF",
        "outputId": "7bf2af20-ef6a-4e90-9954-b7ec8adc9e12"
      },
      "outputs": [],
      "source": [
        "hit_count = 0\n",
        "for i in range(200):\n",
        "  particles, hit_or_not = sampling_via_checkpoints(checkpoint_dir,[25, 3000], FCNN, GWGSampler,num_particles = 1, mcmc_steps=20, beta=10.0)\n",
        "  hit_count += hit_or_not.item()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "200"
            ]
          },
          "execution_count": 7,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "hit_count"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Hit count: 200/200\n",
            "Hit fraction: 1.0000\n",
            "2 SD CI: [1.0000, 1.0000]  (SE ≈ 0.0000)\n"
          ]
        }
      ],
      "source": [
        "import math\n",
        "\n",
        "n_trials = 200  # or len of your loop\n",
        "p = hit_count / float(n_trials)  # hit fraction\n",
        "se = math.sqrt(p * (1.0 - p) / n_trials) if n_trials > 0 else float('nan')\n",
        "\n",
        "lo = max(0.0, p - 2 * se)\n",
        "hi = min(1.0, p + 2 * se)\n",
        "\n",
        "print(f\"Hit count: {hit_count}/{n_trials}\")\n",
        "print(f\"Hit fraction: {p:.4f}\")\n",
        "print(f\"2 SD CI: [{lo:.4f}, {hi:.4f}]  (SE ≈ {se:.4f})\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "F8C_eM1iU-KV"
      },
      "outputs": [],
      "source": [
        "@torch.no_grad()\n",
        "def first_hit_steps(sampler, target10, d_indicator, d_total, device, max_steps=10000):\n",
        "    \"\"\"\n",
        "    Start from a random ±1 particle of length d_total.\n",
        "    Run single-step GWG until the first d_indicator bits equal target10.\n",
        "    Return the number of steps to first hit; None if not hit within max_steps.\n",
        "    \"\"\"\n",
        "    x = rand_pm1((d_total,), device).to(torch.float32)\n",
        "\n",
        "    # check if we already start on target\n",
        "    if (x[:d_indicator] == target10).all():\n",
        "        return 0\n",
        "\n",
        "    for t in range(1, max_steps + 1):\n",
        "        x = sampler.single_step(x)\n",
        "        # Ensure x stays in ±1 if sampler returns logits or probabilities:\n",
        "        # (Uncomment the next line if needed for your GWG implementation)\n",
        "        # x = torch.sign(x).clamp(min=-1, max=1)\n",
        "\n",
        "        if (x[:d_indicator] == target10).all():\n",
        "            return t\n",
        "    return None"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "PZpQUR6DU6XX"
      },
      "outputs": [],
      "source": [
        "@torch.no_grad()\n",
        "def run_gwg_trials(model, target10, d_indicator, d_total, device,\n",
        "                   n_trials=200, max_steps=10000, beta=1.0, verbose_every=50,\n",
        "                   bootstrap_B=2000, rng_seed=0):\n",
        "    \"\"\"\n",
        "    Run GWG first-hit experiments and report statistics with 2 SD confidence intervals.\n",
        "\n",
        "    - Unsuccessful trials are counted as max_steps for 'ALL trials' aggregates.\n",
        "    - Medians use bootstrap to estimate the standard error, then ± 2*SE for the CI.\n",
        "    \"\"\"\n",
        "    import numpy as np\n",
        "    import math\n",
        "    model.eval()\n",
        "    sampler = GWGSampler(model, beta=beta)\n",
        "\n",
        "    # ---------------------------\n",
        "    # Helpers\n",
        "    # ---------------------------\n",
        "    def ci_mean_2sd(x):\n",
        "        \"\"\"Mean ± 2 SD/√n CI.\"\"\"\n",
        "        x = np.asarray(x, dtype=np.float64)\n",
        "        n = len(x)\n",
        "        mu = float(x.mean())\n",
        "        sd = float(x.std(ddof=1)) if n > 1 else 0.0\n",
        "        se = sd / math.sqrt(n) if n > 0 else float(\"nan\")\n",
        "        return mu, sd, (mu - 2*se, mu + 2*se)\n",
        "\n",
        "    def ci_prop_2sd(k, n):\n",
        "        \"\"\"Proportion ± 2 SD (binomial SE).\"\"\"\n",
        "        p = (k / n) if n > 0 else float(\"nan\")\n",
        "        se = math.sqrt(p * (1 - p) / n) if n > 0 else float(\"nan\")\n",
        "        lo, hi = max(0.0, p - 2*se), min(1.0, p + 2*se)\n",
        "        return p, se, (lo, hi)\n",
        "\n",
        "    def ci_median_bootstrap_2sd(x, B=2000, seed=0):\n",
        "        \"\"\"Median and ± 2 SD bootstrap CI.\"\"\"\n",
        "        x = np.asarray(x, dtype=np.float64)\n",
        "        n = len(x)\n",
        "        if n == 0:\n",
        "            return float(\"nan\"), float(\"nan\"), (float(\"nan\"), float(\"nan\"))\n",
        "        if n == 1:\n",
        "            med = float(x[0])\n",
        "            return med, 0.0, (med, med)\n",
        "        rng = np.random.default_rng(seed)\n",
        "        med = float(np.median(x))\n",
        "        meds = np.empty(B, dtype=np.float64)\n",
        "        idx = np.arange(n)\n",
        "        for b in range(B):\n",
        "            resample = x[rng.choice(idx, size=n, replace=True)]\n",
        "            meds[b] = np.median(resample)\n",
        "        sd = float(meds.std(ddof=1))\n",
        "        return med, sd, (med - 2*sd, med + 2*sd)\n",
        "\n",
        "    def robust_mad(x):\n",
        "        \"\"\"Median Absolute Deviation (MAD).\"\"\"\n",
        "        x = np.asarray(x, dtype=np.float64)\n",
        "        if len(x) == 0:\n",
        "            return float(\"nan\")\n",
        "        med = np.median(x)\n",
        "        return float(np.median(np.abs(x - med)))\n",
        "\n",
        "    # ---------------------------\n",
        "    # Trials\n",
        "    # ---------------------------\n",
        "    hits_only = []    # steps for successful trials\n",
        "    all_steps = []    # steps for all trials (misses counted as max_steps)\n",
        "    misses = 0\n",
        "\n",
        "    for i in range(1, n_trials + 1):\n",
        "        steps = first_hit_steps(sampler, target10, d_indicator, d_total, device, max_steps)\n",
        "        if steps is None:\n",
        "            misses += 1\n",
        "            all_steps.append(max_steps)\n",
        "            last_str = \"miss\"\n",
        "        else:\n",
        "            s = int(steps)\n",
        "            hits_only.append(s)\n",
        "            all_steps.append(s)\n",
        "            last_str = str(s)\n",
        "\n",
        "        if verbose_every and (i % verbose_every == 0 or i == n_trials):\n",
        "            hit_rate = (i - misses) / i\n",
        "            print(f\"[trial {i:4d}] last={last_str} | hits={i - misses} | misses={misses} | hit_rate={hit_rate:.3f}\")\n",
        "\n",
        "    # Convert to numpy\n",
        "    arr_all  = np.array(all_steps, dtype=np.int64)\n",
        "    arr_hits = np.array(hits_only, dtype=np.int64)\n",
        "\n",
        "    # ---------------------------\n",
        "    # Core statistics with 2 SD CIs\n",
        "    # ---------------------------\n",
        "    print(\"\\n=== GWG First-Hit Statistics with 2 SD Confidence Intervals ===\")\n",
        "    print(f\"trials={n_trials} | hits={n_trials - misses} | misses={misses} | miss_penalty=max_steps({max_steps})\")\n",
        "\n",
        "    # Hit rate CI (binomial)\n",
        "    p, p_se, (p_lo, p_hi) = ci_prop_2sd(n_trials - misses, n_trials)\n",
        "    print(f\"Hit rate              : {p:.4f}  (±2SD CI: [{p_lo:.4f}, {p_hi:.4f}])  | SE≈{p_se:.4f}\")\n",
        "\n",
        "    # Mean (ALL trials)\n",
        "    mean_all, sd_all, (lo_all, hi_all) = ci_mean_2sd(arr_all)\n",
        "    print(f\"Mean steps (ALL)      : {mean_all:.2f}  (±2SD CI: [{lo_all:.2f}, {hi_all:.2f}])  | SD={sd_all:.2f}\")\n",
        "\n",
        "    # Median (ALL trials, misses=max_steps)\n",
        "    med_all, med_all_se_boot, (med_all_lo, med_all_hi) = ci_median_bootstrap_2sd(\n",
        "        arr_all, B=bootstrap_B, seed=rng_seed\n",
        "    )\n",
        "    print(f\"Median steps (ALL)    : {med_all:.2f}  (±2SD boot CI: [{med_all_lo:.2f}, {med_all_hi:.2f}])  | boot SD≈{med_all_se_boot:.2f}\")\n",
        "\n",
        "    # Median (HITS only)\n",
        "    if len(arr_hits) > 0:\n",
        "        med_hits, med_hits_se_boot, (med_hits_lo, med_hits_hi) = ci_median_bootstrap_2sd(\n",
        "            arr_hits, B=bootstrap_B, seed=rng_seed + 1\n",
        "        )\n",
        "        print(f\"Median steps (HITS)   : {med_hits:.2f}  (±2SD boot CI: [{med_hits_lo:.2f}, {med_hits_hi:.2f}])  | boot SD≈{med_hits_se_boot:.2f}\")\n",
        "    else:\n",
        "        print(\"Median steps (HITS)   : n/a (no successful trials)\")\n",
        "\n",
        "    # ---------------------------\n",
        "    # Additional useful stats\n",
        "    # ---------------------------\n",
        "    def q(arr, p): return float(np.percentile(arr, p)) if len(arr) else float(\"nan\")\n",
        "\n",
        "    if len(arr_all):\n",
        "        print(\"\\n-- Distribution (ALL trials) --\")\n",
        "        print(f\"min / p25 / p50 / p75 / max : {arr_all.min():.0f} / {q(arr_all,25):.0f} / {q(arr_all,50):.0f} / {q(arr_all,75):.0f} / {arr_all.max():.0f}\")\n",
        "        print(f\"IQR (p75-p25)         : {q(arr_all,75) - q(arr_all,25):.2f}\")\n",
        "        print(f\"MAD (about median)    : {robust_mad(arr_all):.2f}\")\n",
        "        for pct in (90, 95, 99):\n",
        "            print(f\"p{pct:02d}                 : {q(arr_all, pct):.2f}\")\n",
        "\n",
        "        # Probability of hitting within selected step budgets\n",
        "        budgets = sorted(set([100, 500, 1000, 5000, max_steps]))\n",
        "        probs_within = []\n",
        "        for T in budgets:\n",
        "            probs_within.append((T, float((arr_all <= T).mean())))\n",
        "        print(\"\\nHit probability within budgets (ALL trials):\")\n",
        "        for T, pr in probs_within:\n",
        "            print(f\"  ≤ {T:6d} steps : {pr:.4f}\")\n",
        "\n",
        "    if len(arr_hits):\n",
        "        print(\"\\n-- Distribution (successful trials ONLY) --\")\n",
        "        print(f\"min / p25 / p50 / p75 / max : {arr_hits.min():.0f} / {q(arr_hits,25):.0f} / {q(arr_hits,50):.0f} / {q(arr_hits,75):.0f} / {arr_hits.max():.0f}\")\n",
        "        print(f\"IQR (p75-p25)         : {q(arr_hits,75) - q(arr_hits,25):.2f}\")\n",
        "        print(f\"MAD (about median)    : {robust_mad(arr_hits):.2f}\")\n",
        "        for pct in (90, 95, 99):\n",
        "            print(f\"p{pct:02d}                 : {q(arr_hits, pct):.2f}\")\n",
        "\n",
        "    # Optional compact histogram (ALL trials). Shows censoring spike at max_steps if many misses.\n",
        "    try:\n",
        "        import collections\n",
        "        hist = collections.Counter(arr_all.tolist())\n",
        "        most_common = sorted(hist.items(), key=lambda kv: (-kv[1], kv[0]))[:20]\n",
        "        print(\"\\nTop (step,count) bins (ALL trials, 20 most common):\")\n",
        "        for step, cnt in most_common:\n",
        "            print(f\"  {step:7d} : {cnt}\")\n",
        "    except Exception:\n",
        "        pass\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-Rd2cBBmRjeF",
        "outputId": "7574a178-e8e4-4dcb-c4f4-102d71eec07e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trial    5] last=miss | hits=0 | misses=5 | hit_rate=0.000\n",
            "[trial   10] last=miss | hits=0 | misses=10 | hit_rate=0.000\n",
            "[trial   15] last=miss | hits=1 | misses=14 | hit_rate=0.067\n",
            "[trial   20] last=miss | hits=1 | misses=19 | hit_rate=0.050\n",
            "[trial   25] last=miss | hits=2 | misses=23 | hit_rate=0.080\n",
            "[trial   30] last=miss | hits=4 | misses=26 | hit_rate=0.133\n",
            "[trial   35] last=miss | hits=5 | misses=30 | hit_rate=0.143\n",
            "[trial   40] last=miss | hits=5 | misses=35 | hit_rate=0.125\n",
            "[trial   45] last=miss | hits=5 | misses=40 | hit_rate=0.111\n",
            "[trial   50] last=miss | hits=6 | misses=44 | hit_rate=0.120\n",
            "[trial   55] last=miss | hits=7 | misses=48 | hit_rate=0.127\n",
            "[trial   60] last=miss | hits=7 | misses=53 | hit_rate=0.117\n",
            "[trial   65] last=miss | hits=8 | misses=57 | hit_rate=0.123\n",
            "[trial   70] last=miss | hits=8 | misses=62 | hit_rate=0.114\n",
            "[trial   75] last=2 | hits=9 | misses=66 | hit_rate=0.120\n",
            "[trial   80] last=miss | hits=9 | misses=71 | hit_rate=0.113\n",
            "[trial   85] last=miss | hits=9 | misses=76 | hit_rate=0.106\n",
            "[trial   90] last=miss | hits=10 | misses=80 | hit_rate=0.111\n",
            "[trial   95] last=miss | hits=10 | misses=85 | hit_rate=0.105\n",
            "[trial  100] last=miss | hits=11 | misses=89 | hit_rate=0.110\n",
            "[trial  105] last=miss | hits=11 | misses=94 | hit_rate=0.105\n",
            "[trial  110] last=miss | hits=11 | misses=99 | hit_rate=0.100\n",
            "[trial  115] last=miss | hits=12 | misses=103 | hit_rate=0.104\n",
            "[trial  120] last=miss | hits=13 | misses=107 | hit_rate=0.108\n",
            "[trial  125] last=miss | hits=13 | misses=112 | hit_rate=0.104\n",
            "[trial  130] last=miss | hits=13 | misses=117 | hit_rate=0.100\n",
            "[trial  135] last=miss | hits=13 | misses=122 | hit_rate=0.096\n",
            "[trial  140] last=miss | hits=13 | misses=127 | hit_rate=0.093\n",
            "[trial  145] last=miss | hits=13 | misses=132 | hit_rate=0.090\n",
            "[trial  150] last=miss | hits=13 | misses=137 | hit_rate=0.087\n",
            "[trial  155] last=miss | hits=13 | misses=142 | hit_rate=0.084\n",
            "[trial  160] last=2 | hits=14 | misses=146 | hit_rate=0.087\n",
            "[trial  165] last=miss | hits=14 | misses=151 | hit_rate=0.085\n",
            "[trial  170] last=miss | hits=15 | misses=155 | hit_rate=0.088\n",
            "[trial  175] last=miss | hits=16 | misses=159 | hit_rate=0.091\n",
            "[trial  180] last=miss | hits=16 | misses=164 | hit_rate=0.089\n",
            "[trial  185] last=miss | hits=16 | misses=169 | hit_rate=0.086\n",
            "[trial  190] last=miss | hits=16 | misses=174 | hit_rate=0.084\n",
            "[trial  195] last=miss | hits=16 | misses=179 | hit_rate=0.082\n",
            "[trial  200] last=miss | hits=16 | misses=184 | hit_rate=0.080\n",
            "\n",
            "=== GWG First-Hit Statistics with 2 SD Confidence Intervals ===\n",
            "trials=200 | hits=16 | misses=184 | miss_penalty=max_steps(2000)\n",
            "Hit rate              : 0.0800  (±2SD CI: [0.0416, 0.1184])  | SE≈0.0192\n",
            "Mean steps (ALL)      : 1844.30  (±2SD CI: [1769.34, 1919.25])  | SD=529.99\n",
            "Median steps (ALL)    : 2000.00  (±2SD boot CI: [2000.00, 2000.00])  | boot SD≈0.00\n",
            "Median steps (HITS)   : 2.00  (±2SD boot CI: [-35.28, 39.28])  | boot SD≈18.64\n",
            "\n",
            "-- Distribution (ALL trials) --\n",
            "min / p25 / p50 / p75 / max : 1 / 2000 / 2000 / 2000 / 2000\n",
            "IQR (p75-p25)         : 0.00\n",
            "MAD (about median)    : 0.00\n",
            "p90                 : 2000.00\n",
            "p95                 : 2000.00\n",
            "p99                 : 2000.00\n",
            "\n",
            "Hit probability within budgets (ALL trials):\n",
            "  ≤    100 steps : 0.0600\n",
            "  ≤    500 steps : 0.0800\n",
            "  ≤   1000 steps : 0.0800\n",
            "  ≤   2000 steps : 1.0000\n",
            "  ≤   5000 steps : 1.0000\n",
            "\n",
            "-- Distribution (successful trials ONLY) --\n",
            "min / p25 / p50 / p75 / max : 1 / 2 / 2 / 49 / 292\n",
            "IQR (p75-p25)         : 47.00\n",
            "MAD (about median)    : 0.50\n",
            "p90                 : 190.00\n",
            "p95                 : 226.75\n",
            "p99                 : 278.95\n",
            "\n",
            "Top (step,count) bins (ALL trials, 20 most common):\n",
            "     2000 : 184\n",
            "        2 : 8\n",
            "        1 : 3\n",
            "       14 : 1\n",
            "      154 : 1\n",
            "      175 : 1\n",
            "      205 : 1\n",
            "      292 : 1\n"
          ]
        }
      ],
      "source": [
        "run_gwg_trials(\n",
        "    model=model,\n",
        "    target10=target10,\n",
        "    d_indicator=d_indicator,\n",
        "    d_total=d_total,      # = 10 + 500 in the current script\n",
        "    device=torch.device(\"cuda\"),\n",
        "    n_trials=200,\n",
        "    max_steps=2000,\n",
        "    beta=10.0,\n",
        "    verbose_every=5\n",
        ")\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "torch-gpu-env",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.13.5"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
