{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nJQENWsLNOKc",
        "outputId": "ac189548-9a41-45c0-96c7-efd446653953"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "target10: [ 1.  1.  1. -1.  1. -1.  1.  1. -1.  1.]\n",
            "Train size: 67584 | base reps/anchor: 64 | target extra reps: 2048\n",
            "[ep     1] loss=5.539004e+01 | mean(pred on target rows)≈-0.463 | max|pred| on non-target rows≈3.056\n",
            "[ep    50] loss=9.467255e-01 | mean(pred on target rows)≈9.967 | max|pred| on non-target rows≈4.584\n",
            "[ep   100] loss=9.505625e-02 | mean(pred on target rows)≈9.862 | max|pred| on non-target rows≈1.225\n",
            "[ep   150] loss=2.531620e-02 | mean(pred on target rows)≈10.110 | max|pred| on non-target rows≈0.748\n",
            "[ep   200] loss=8.838945e-03 | mean(pred on target rows)≈10.018 | max|pred| on non-target rows≈0.561\n",
            "[ep   250] loss=9.103661e-03 | mean(pred on target rows)≈9.904 | max|pred| on non-target rows≈0.365\n",
            "[ep   300] loss=4.797988e-03 | mean(pred on target rows)≈10.055 | max|pred| on non-target rows≈0.319\n",
            "[ep   350] loss=2.836674e-03 | mean(pred on target rows)≈9.961 | max|pred| on non-target rows≈0.266\n",
            "[ep   400] loss=2.698639e-03 | mean(pred on target rows)≈9.955 | max|pred| on non-target rows≈0.232\n",
            "[ep   450] loss=2.465950e-03 | mean(pred on target rows)≈10.047 | max|pred| on non-target rows≈0.228\n",
            "[ep   500] loss=1.389286e-03 | mean(pred on target rows)≈9.995 | max|pred| on non-target rows≈0.201\n",
            "[ep   550] loss=4.142287e-03 | mean(pred on target rows)≈9.997 | max|pred| on non-target rows≈0.238\n",
            "[ep   600] loss=1.396496e-03 | mean(pred on target rows)≈10.025 | max|pred| on non-target rows≈0.174\n",
            "[ep   650] loss=1.639965e-03 | mean(pred on target rows)≈10.000 | max|pred| on non-target rows≈0.173\n",
            "[ep   700] loss=1.400023e-03 | mean(pred on target rows)≈10.021 | max|pred| on non-target rows≈0.175\n",
            "[ep   750] loss=3.458595e-03 | mean(pred on target rows)≈9.929 | max|pred| on non-target rows≈0.139\n",
            "[ep   800] loss=2.899001e-03 | mean(pred on target rows)≈10.004 | max|pred| on non-target rows≈0.187\n",
            "[ep   850] loss=2.907818e-03 | mean(pred on target rows)≈10.053 | max|pred| on non-target rows≈0.183\n",
            "[ep   900] loss=2.101236e-03 | mean(pred on target rows)≈9.994 | max|pred| on non-target rows≈0.202\n",
            "[ep   950] loss=3.473091e-03 | mean(pred on target rows)≈9.963 | max|pred| on non-target rows≈0.173\n",
            "[ep  1000] loss=1.734573e-03 | mean(pred on target rows)≈9.946 | max|pred| on non-target rows≈0.126\n",
            "[ep  1050] loss=1.306462e-03 | mean(pred on target rows)≈9.958 | max|pred| on non-target rows≈0.119\n",
            "[ep  1100] loss=1.577294e-02 | mean(pred on target rows)≈10.129 | max|pred| on non-target rows≈0.387\n",
            "[ep  1150] loss=1.896417e-02 | mean(pred on target rows)≈9.879 | max|pred| on non-target rows≈0.353\n",
            "[ep  1200] loss=3.934354e-02 | mean(pred on target rows)≈10.033 | max|pred| on non-target rows≈0.605\n",
            "[ep  1250] loss=6.625613e-04 | mean(pred on target rows)≈10.001 | max|pred| on non-target rows≈0.151\n",
            "[ep  1300] loss=1.825305e-04 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.118\n",
            "[ep  1350] loss=1.564335e-04 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.114\n",
            "[ep  1400] loss=1.516020e-03 | mean(pred on target rows)≈9.948 | max|pred| on non-target rows≈0.100\n",
            "[ep  1450] loss=1.995447e-03 | mean(pred on target rows)≈10.059 | max|pred| on non-target rows≈0.156\n",
            "[ep  1500] loss=6.121180e-03 | mean(pred on target rows)≈10.025 | max|pred| on non-target rows≈0.215\n",
            "[ep  1550] loss=4.941496e-04 | mean(pred on target rows)≈10.021 | max|pred| on non-target rows≈0.139\n",
            "[ep  1600] loss=9.549431e-04 | mean(pred on target rows)≈9.962 | max|pred| on non-target rows≈0.109\n",
            "[ep  1650] loss=7.254765e-04 | mean(pred on target rows)≈10.017 | max|pred| on non-target rows≈0.189\n",
            "[ep  1700] loss=3.274898e-03 | mean(pred on target rows)≈9.934 | max|pred| on non-target rows≈0.108\n",
            "[ep  1750] loss=1.102248e-03 | mean(pred on target rows)≈10.029 | max|pred| on non-target rows≈0.135\n",
            "[ep  1800] loss=1.387952e-03 | mean(pred on target rows)≈9.957 | max|pred| on non-target rows≈0.103\n",
            "[ep  1850] loss=3.182516e-03 | mean(pred on target rows)≈9.960 | max|pred| on non-target rows≈0.183\n",
            "[ep  1900] loss=2.053801e-03 | mean(pred on target rows)≈10.059 | max|pred| on non-target rows≈0.157\n",
            "[ep  1950] loss=1.884722e-03 | mean(pred on target rows)≈9.949 | max|pred| on non-target rows≈0.105\n",
            "[ep  2000] loss=8.547355e-04 | mean(pred on target rows)≈9.990 | max|pred| on non-target rows≈0.125\n",
            "[ep  2050] loss=2.611101e-03 | mean(pred on target rows)≈10.010 | max|pred| on non-target rows≈0.208\n",
            "[ep  2100] loss=7.486991e-04 | mean(pred on target rows)≈9.965 | max|pred| on non-target rows≈0.088\n",
            "[ep  2150] loss=2.438562e-03 | mean(pred on target rows)≈9.938 | max|pred| on non-target rows≈0.098\n",
            "[ep  2200] loss=1.100001e-03 | mean(pred on target rows)≈10.029 | max|pred| on non-target rows≈0.134\n",
            "[ep  2250] loss=2.518595e-03 | mean(pred on target rows)≈9.947 | max|pred| on non-target rows≈0.099\n",
            "[ep  2300] loss=1.877477e-03 | mean(pred on target rows)≈10.050 | max|pred| on non-target rows≈0.161\n",
            "[ep  2350] loss=3.289000e-03 | mean(pred on target rows)≈9.958 | max|pred| on non-target rows≈0.131\n",
            "[ep  2400] loss=1.461622e-03 | mean(pred on target rows)≈10.041 | max|pred| on non-target rows≈0.135\n",
            "[ep  2450] loss=1.779854e-03 | mean(pred on target rows)≈10.051 | max|pred| on non-target rows≈0.169\n",
            "[ep  2500] loss=1.025761e-03 | mean(pred on target rows)≈9.959 | max|pred| on non-target rows≈0.090\n",
            "[ep  2550] loss=7.424811e-04 | mean(pred on target rows)≈9.971 | max|pred| on non-target rows≈0.077\n",
            "[ep  2600] loss=1.993626e-03 | mean(pred on target rows)≈10.059 | max|pred| on non-target rows≈0.149\n",
            "[ep  2650] loss=1.516035e-03 | mean(pred on target rows)≈9.949 | max|pred| on non-target rows≈0.077\n",
            "[ep  2700] loss=8.372334e-04 | mean(pred on target rows)≈10.038 | max|pred| on non-target rows≈0.130\n",
            "[ep  2750] loss=1.006679e-03 | mean(pred on target rows)≈9.967 | max|pred| on non-target rows≈0.086\n",
            "[ep  2800] loss=3.076141e-03 | mean(pred on target rows)≈10.053 | max|pred| on non-target rows≈0.225\n",
            "[ep  2850] loss=2.475203e-04 | mean(pred on target rows)≈9.994 | max|pred| on non-target rows≈0.097\n",
            "[ep  2900] loss=7.001314e-03 | mean(pred on target rows)≈10.067 | max|pred| on non-target rows≈0.277\n",
            "[ep  2950] loss=1.249531e-03 | mean(pred on target rows)≈9.991 | max|pred| on non-target rows≈0.126\n",
            "[ep  3000] loss=8.133631e-04 | mean(pred on target rows)≈10.000 | max|pred| on non-target rows≈0.121\n",
            "[ep  3050] loss=2.028641e-03 | mean(pred on target rows)≈10.006 | max|pred| on non-target rows≈0.161\n",
            "[ep  3100] loss=4.112442e-03 | mean(pred on target rows)≈9.919 | max|pred| on non-target rows≈0.077\n",
            "[ep  3150] loss=2.267498e-03 | mean(pred on target rows)≈10.058 | max|pred| on non-target rows≈0.164\n",
            "[ep  3200] loss=3.236069e-03 | mean(pred on target rows)≈9.931 | max|pred| on non-target rows≈0.072\n",
            "[ep  3250] loss=2.152522e-03 | mean(pred on target rows)≈10.059 | max|pred| on non-target rows≈0.144\n",
            "[ep  3300] loss=2.229802e-03 | mean(pred on target rows)≈9.958 | max|pred| on non-target rows≈0.097\n",
            "[ep  3350] loss=1.689473e-03 | mean(pred on target rows)≈9.956 | max|pred| on non-target rows≈0.090\n",
            "[ep  3400] loss=2.575221e-03 | mean(pred on target rows)≈9.957 | max|pred| on non-target rows≈0.123\n",
            "[ep  3450] loss=2.771057e-03 | mean(pred on target rows)≈9.959 | max|pred| on non-target rows≈0.109\n",
            "[ep  3500] loss=3.631147e-03 | mean(pred on target rows)≈9.937 | max|pred| on non-target rows≈0.114\n",
            "[ep  3550] loss=3.361436e-03 | mean(pred on target rows)≈10.063 | max|pred| on non-target rows≈0.177\n",
            "[ep  3600] loss=3.718732e-03 | mean(pred on target rows)≈9.926 | max|pred| on non-target rows≈0.070\n",
            "[ep  3650] loss=1.867793e-03 | mean(pred on target rows)≈10.052 | max|pred| on non-target rows≈0.158\n",
            "[ep  3700] loss=2.338722e-03 | mean(pred on target rows)≈9.954 | max|pred| on non-target rows≈0.098\n",
            "[ep  3750] loss=5.480850e-04 | mean(pred on target rows)≈9.975 | max|pred| on non-target rows≈0.070\n",
            "[ep  3800] loss=7.378323e-03 | mean(pred on target rows)≈9.985 | max|pred| on non-target rows≈0.207\n",
            "[ep  3850] loss=3.739742e-04 | mean(pred on target rows)≈10.020 | max|pred| on non-target rows≈0.105\n",
            "[ep  3900] loss=3.265237e-03 | mean(pred on target rows)≈10.054 | max|pred| on non-target rows≈0.190\n",
            "[ep  3950] loss=6.693765e-03 | mean(pred on target rows)≈10.051 | max|pred| on non-target rows≈0.289\n",
            "[ep  4000] loss=3.178984e-03 | mean(pred on target rows)≈9.922 | max|pred| on non-target rows≈0.047\n",
            "[ep  4050] loss=8.888513e-04 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.156\n",
            "[ep  4100] loss=1.335417e-03 | mean(pred on target rows)≈9.950 | max|pred| on non-target rows≈0.049\n",
            "[ep  4150] loss=2.403568e-03 | mean(pred on target rows)≈10.060 | max|pred| on non-target rows≈0.170\n",
            "[ep  4200] loss=1.776189e-03 | mean(pred on target rows)≈10.025 | max|pred| on non-target rows≈0.160\n",
            "[ep  4250] loss=1.099059e-03 | mean(pred on target rows)≈9.958 | max|pred| on non-target rows≈0.070\n",
            "[ep  4300] loss=2.107138e-03 | mean(pred on target rows)≈10.050 | max|pred| on non-target rows≈0.185\n",
            "[ep  4350] loss=2.659930e-03 | mean(pred on target rows)≈10.068 | max|pred| on non-target rows≈0.168\n",
            "[ep  4400] loss=6.442905e-04 | mean(pred on target rows)≈10.018 | max|pred| on non-target rows≈0.124\n",
            "[ep  4450] loss=2.789184e-04 | mean(pred on target rows)≈9.996 | max|pred| on non-target rows≈0.101\n",
            "[ep  4500] loss=2.378703e-04 | mean(pred on target rows)≈10.007 | max|pred| on non-target rows≈0.107\n",
            "[ep  4550] loss=1.230386e-03 | mean(pred on target rows)≈10.048 | max|pred| on non-target rows≈0.113\n",
            "[ep  4600] loss=5.364797e-03 | mean(pred on target rows)≈10.067 | max|pred| on non-target rows≈0.262\n",
            "[ep  4650] loss=3.040079e-03 | mean(pred on target rows)≈10.071 | max|pred| on non-target rows≈0.184\n",
            "[ep  4700] loss=7.772403e-03 | mean(pred on target rows)≈10.060 | max|pred| on non-target rows≈0.321\n",
            "[ep  4750] loss=2.726018e-03 | mean(pred on target rows)≈9.928 | max|pred| on non-target rows≈0.031\n",
            "[ep  4800] loss=3.050992e-03 | mean(pred on target rows)≈9.967 | max|pred| on non-target rows≈0.166\n",
            "[ep  4850] loss=2.825100e-03 | mean(pred on target rows)≈9.930 | max|pred| on non-target rows≈0.042\n",
            "[ep  4900] loss=2.631697e-03 | mean(pred on target rows)≈10.056 | max|pred| on non-target rows≈0.207\n",
            "[ep  4950] loss=2.400655e-03 | mean(pred on target rows)≈9.933 | max|pred| on non-target rows≈0.037\n",
            "[ep  5000] loss=4.194923e-04 | mean(pred on target rows)≈10.021 | max|pred| on non-target rows≈0.092\n",
            "[ep  5050] loss=6.370968e-04 | mean(pred on target rows)≈10.021 | max|pred| on non-target rows≈0.109\n",
            "[ep  5100] loss=2.471251e-03 | mean(pred on target rows)≈10.047 | max|pred| on non-target rows≈0.170\n",
            "[ep  5150] loss=7.203747e-04 | mean(pred on target rows)≈9.975 | max|pred| on non-target rows≈0.056\n",
            "[ep  5200] loss=3.604016e-04 | mean(pred on target rows)≈9.979 | max|pred| on non-target rows≈0.056\n",
            "[ep  5250] loss=6.720969e-03 | mean(pred on target rows)≈10.010 | max|pred| on non-target rows≈0.312\n",
            "[ep  5300] loss=1.053112e-03 | mean(pred on target rows)≈10.040 | max|pred| on non-target rows≈0.138\n",
            "[ep  5350] loss=4.184541e-05 | mean(pred on target rows)≈9.996 | max|pred| on non-target rows≈0.065\n",
            "[ep  5400] loss=5.382512e-02 | mean(pred on target rows)≈10.245 | max|pred| on non-target rows≈0.968\n",
            "[ep  5450] loss=1.075905e-02 | mean(pred on target rows)≈10.117 | max|pred| on non-target rows≈0.456\n",
            "[ep  5500] loss=1.260886e-02 | mean(pred on target rows)≈9.843 | max|pred| on non-target rows≈0.041\n",
            "[ep  5550] loss=3.222115e-03 | mean(pred on target rows)≈10.076 | max|pred| on non-target rows≈0.186\n",
            "[ep  5600] loss=1.855604e-03 | mean(pred on target rows)≈9.942 | max|pred| on non-target rows≈0.042\n",
            "[ep  5650] loss=4.211563e-02 | mean(pred on target rows)≈10.032 | max|pred| on non-target rows≈1.220\n",
            "[ep  5700] loss=3.199679e-04 | mean(pred on target rows)≈9.998 | max|pred| on non-target rows≈0.106\n",
            "[ep  5750] loss=1.448878e-04 | mean(pred on target rows)≈10.014 | max|pred| on non-target rows≈0.078\n",
            "[ep  5800] loss=2.676699e-05 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.058\n",
            "[ep  5850] loss=5.285436e-05 | mean(pred on target rows)≈10.007 | max|pred| on non-target rows≈0.065\n",
            "[ep  5900] loss=4.186278e-05 | mean(pred on target rows)≈10.006 | max|pred| on non-target rows≈0.063\n",
            "[ep  5950] loss=1.411564e-04 | mean(pred on target rows)≈10.014 | max|pred| on non-target rows≈0.079\n",
            "[ep  6000] loss=2.920524e-04 | mean(pred on target rows)≈9.977 | max|pred| on non-target rows≈0.035\n",
            "[ep  6050] loss=3.637921e-04 | mean(pred on target rows)≈9.974 | max|pred| on non-target rows≈0.034\n",
            "[ep  6100] loss=2.009757e-03 | mean(pred on target rows)≈9.953 | max|pred| on non-target rows≈0.069\n",
            "[ep  6150] loss=5.216819e-04 | mean(pred on target rows)≈10.014 | max|pred| on non-target rows≈0.109\n",
            "[ep  6200] loss=3.512535e-03 | mean(pred on target rows)≈9.953 | max|pred| on non-target rows≈0.076\n",
            "[ep  6250] loss=9.471303e-04 | mean(pred on target rows)≈10.035 | max|pred| on non-target rows≈0.112\n",
            "[ep  6300] loss=1.362074e-03 | mean(pred on target rows)≈10.016 | max|pred| on non-target rows≈0.146\n",
            "[ep  6350] loss=2.226099e-03 | mean(pred on target rows)≈9.945 | max|pred| on non-target rows≈0.041\n",
            "[ep  6400] loss=2.026201e-03 | mean(pred on target rows)≈10.053 | max|pred| on non-target rows≈0.148\n",
            "[ep  6450] loss=7.793593e-04 | mean(pred on target rows)≈10.001 | max|pred| on non-target rows≈0.117\n",
            "[ep  6500] loss=1.207794e-03 | mean(pred on target rows)≈10.011 | max|pred| on non-target rows≈0.111\n",
            "[ep  6550] loss=2.998270e-03 | mean(pred on target rows)≈9.928 | max|pred| on non-target rows≈0.037\n",
            "[ep  6600] loss=4.547710e-03 | mean(pred on target rows)≈9.995 | max|pred| on non-target rows≈0.223\n",
            "[ep  6650] loss=6.399424e-04 | mean(pred on target rows)≈10.020 | max|pred| on non-target rows≈0.131\n",
            "[ep  6700] loss=1.050369e-03 | mean(pred on target rows)≈10.044 | max|pred| on non-target rows≈0.099\n",
            "[ep  6750] loss=1.837556e-02 | mean(pred on target rows)≈10.145 | max|pred| on non-target rows≈0.534\n",
            "[ep  6800] loss=1.221838e-03 | mean(pred on target rows)≈10.044 | max|pred| on non-target rows≈0.135\n",
            "[ep  6850] loss=2.488239e-05 | mean(pred on target rows)≈9.997 | max|pred| on non-target rows≈0.049\n",
            "[ep  6900] loss=7.103172e-04 | mean(pred on target rows)≈10.023 | max|pred| on non-target rows≈0.118\n",
            "[ep  6950] loss=6.946260e-04 | mean(pred on target rows)≈10.018 | max|pred| on non-target rows≈0.119\n",
            "[ep  7000] loss=1.005953e-03 | mean(pred on target rows)≈10.043 | max|pred| on non-target rows≈0.119\n",
            "[ep  7050] loss=3.703846e-03 | mean(pred on target rows)≈10.083 | max|pred| on non-target rows≈0.171\n",
            "[ep  7100] loss=2.848974e-03 | mean(pred on target rows)≈9.943 | max|pred| on non-target rows≈0.074\n",
            "[ep  7150] loss=1.527817e-03 | mean(pred on target rows)≈10.047 | max|pred| on non-target rows≈0.112\n",
            "[ep  7200] loss=2.881819e-03 | mean(pred on target rows)≈10.040 | max|pred| on non-target rows≈0.192\n",
            "[ep  7250] loss=6.355557e-04 | mean(pred on target rows)≈10.028 | max|pred| on non-target rows≈0.119\n",
            "[ep  7300] loss=2.661532e-03 | mean(pred on target rows)≈10.029 | max|pred| on non-target rows≈0.201\n",
            "[ep  7350] loss=7.702826e-04 | mean(pred on target rows)≈10.000 | max|pred| on non-target rows≈0.086\n",
            "[ep  7400] loss=2.529717e-03 | mean(pred on target rows)≈9.994 | max|pred| on non-target rows≈0.193\n",
            "[ep  7450] loss=4.369486e-04 | mean(pred on target rows)≈10.021 | max|pred| on non-target rows≈0.097\n",
            "[ep  7500] loss=1.005281e-03 | mean(pred on target rows)≈9.994 | max|pred| on non-target rows≈0.120\n",
            "[ep  7550] loss=3.669210e-03 | mean(pred on target rows)≈9.926 | max|pred| on non-target rows≈0.051\n",
            "[ep  7600] loss=3.908965e-04 | mean(pred on target rows)≈9.992 | max|pred| on non-target rows≈0.056\n",
            "[ep  7650] loss=1.363410e-03 | mean(pred on target rows)≈9.971 | max|pred| on non-target rows≈0.073\n",
            "[ep  7700] loss=3.988840e-03 | mean(pred on target rows)≈10.032 | max|pred| on non-target rows≈0.244\n",
            "[ep  7750] loss=1.135158e-03 | mean(pred on target rows)≈10.042 | max|pred| on non-target rows≈0.140\n",
            "[ep  7800] loss=4.616607e-04 | mean(pred on target rows)≈10.018 | max|pred| on non-target rows≈0.114\n",
            "[ep  7850] loss=4.003546e-03 | mean(pred on target rows)≈9.965 | max|pred| on non-target rows≈0.161\n",
            "[ep  7900] loss=4.717053e-04 | mean(pred on target rows)≈10.023 | max|pred| on non-target rows≈0.085\n",
            "[ep  7950] loss=2.599010e-03 | mean(pred on target rows)≈10.005 | max|pred| on non-target rows≈0.153\n",
            "[ep  8000] loss=2.536368e-03 | mean(pred on target rows)≈10.051 | max|pred| on non-target rows≈0.166\n",
            "[ep  8050] loss=1.362372e-03 | mean(pred on target rows)≈10.023 | max|pred| on non-target rows≈0.195\n",
            "[ep  8100] loss=7.244086e-04 | mean(pred on target rows)≈9.992 | max|pred| on non-target rows≈0.068\n",
            "[ep  8150] loss=1.251585e-03 | mean(pred on target rows)≈9.971 | max|pred| on non-target rows≈0.075\n",
            "[ep  8200] loss=4.827048e-04 | mean(pred on target rows)≈10.019 | max|pred| on non-target rows≈0.096\n",
            "[ep  8250] loss=4.947979e-04 | mean(pred on target rows)≈10.015 | max|pred| on non-target rows≈0.103\n",
            "[ep  8300] loss=1.791782e-03 | mean(pred on target rows)≈9.942 | max|pred| on non-target rows≈0.021\n",
            "[ep  8350] loss=5.782811e-04 | mean(pred on target rows)≈10.020 | max|pred| on non-target rows≈0.106\n",
            "[ep  8400] loss=3.997923e-03 | mean(pred on target rows)≈9.939 | max|pred| on non-target rows≈0.070\n",
            "[ep  8450] loss=3.538260e-04 | mean(pred on target rows)≈10.023 | max|pred| on non-target rows≈0.072\n",
            "[ep  8500] loss=3.930313e-04 | mean(pred on target rows)≈9.978 | max|pred| on non-target rows≈0.039\n",
            "[ep  8550] loss=4.779944e-03 | mean(pred on target rows)≈10.042 | max|pred| on non-target rows≈0.232\n",
            "[ep  8600] loss=4.946347e-04 | mean(pred on target rows)≈10.028 | max|pred| on non-target rows≈0.078\n",
            "[ep  8650] loss=8.639721e-04 | mean(pred on target rows)≈10.017 | max|pred| on non-target rows≈0.155\n",
            "[ep  8700] loss=2.305228e-04 | mean(pred on target rows)≈10.014 | max|pred| on non-target rows≈0.079\n",
            "[ep  8750] loss=2.892380e-03 | mean(pred on target rows)≈10.023 | max|pred| on non-target rows≈0.194\n",
            "[ep  8800] loss=1.045958e-03 | mean(pred on target rows)≈10.042 | max|pred| on non-target rows≈0.113\n",
            "[ep  8850] loss=5.457575e-04 | mean(pred on target rows)≈10.022 | max|pred| on non-target rows≈0.136\n",
            "[ep  8900] loss=9.917640e-04 | mean(pred on target rows)≈9.991 | max|pred| on non-target rows≈0.084\n",
            "[ep  8950] loss=4.580455e-04 | mean(pred on target rows)≈10.007 | max|pred| on non-target rows≈0.076\n",
            "[ep  9000] loss=3.626632e-03 | mean(pred on target rows)≈9.963 | max|pred| on non-target rows≈0.169\n",
            "[ep  9050] loss=8.909444e-04 | mean(pred on target rows)≈9.962 | max|pred| on non-target rows≈0.040\n",
            "[ep  9100] loss=2.531089e-03 | mean(pred on target rows)≈9.970 | max|pred| on non-target rows≈0.125\n",
            "[ep  9150] loss=1.175633e-03 | mean(pred on target rows)≈9.971 | max|pred| on non-target rows≈0.082\n",
            "[ep  9200] loss=4.929597e-04 | mean(pred on target rows)≈10.014 | max|pred| on non-target rows≈0.113\n",
            "[ep  9250] loss=2.154963e-04 | mean(pred on target rows)≈10.003 | max|pred| on non-target rows≈0.060\n",
            "[ep  9300] loss=1.991092e-03 | mean(pred on target rows)≈9.960 | max|pred| on non-target rows≈0.073\n",
            "[ep  9350] loss=1.130621e-03 | mean(pred on target rows)≈10.042 | max|pred| on non-target rows≈0.126\n",
            "[ep  9400] loss=3.018698e-03 | mean(pred on target rows)≈10.014 | max|pred| on non-target rows≈0.179\n",
            "[ep  9450] loss=2.780968e-03 | mean(pred on target rows)≈9.927 | max|pred| on non-target rows≈0.010\n",
            "[ep  9500] loss=3.648282e-03 | mean(pred on target rows)≈9.980 | max|pred| on non-target rows≈0.167\n",
            "[ep  9550] loss=4.230640e-04 | mean(pred on target rows)≈10.013 | max|pred| on non-target rows≈0.093\n",
            "[ep  9600] loss=3.852628e-04 | mean(pred on target rows)≈10.027 | max|pred| on non-target rows≈0.066\n",
            "[ep  9650] loss=1.938261e-03 | mean(pred on target rows)≈10.061 | max|pred| on non-target rows≈0.116\n",
            "[ep  9700] loss=6.808475e-04 | mean(pred on target rows)≈9.965 | max|pred| on non-target rows≈0.028\n",
            "[ep  9750] loss=5.125069e-03 | mean(pred on target rows)≈9.996 | max|pred| on non-target rows≈0.197\n",
            "[ep  9800] loss=4.110621e-05 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.057\n",
            "[ep  9850] loss=1.996353e-03 | mean(pred on target rows)≈9.956 | max|pred| on non-target rows≈0.048\n",
            "[ep  9900] loss=1.628851e-03 | mean(pred on target rows)≈10.040 | max|pred| on non-target rows≈0.136\n",
            "[ep  9950] loss=5.102675e-03 | mean(pred on target rows)≈10.037 | max|pred| on non-target rows≈0.220\n",
            "[ep 10000] loss=2.214344e-04 | mean(pred on target rows)≈9.995 | max|pred| on non-target rows≈0.062\n",
            "[ep 10050] loss=1.817387e-03 | mean(pred on target rows)≈10.049 | max|pred| on non-target rows≈0.130\n",
            "[ep 10100] loss=4.577732e-04 | mean(pred on target rows)≈9.970 | max|pred| on non-target rows≈0.015\n",
            "[ep 10150] loss=4.989385e-04 | mean(pred on target rows)≈9.996 | max|pred| on non-target rows≈0.064\n",
            "[ep 10200] loss=1.330416e-03 | mean(pred on target rows)≈9.951 | max|pred| on non-target rows≈0.019\n",
            "[ep 10250] loss=2.367829e-03 | mean(pred on target rows)≈9.938 | max|pred| on non-target rows≈0.036\n",
            "[ep 10300] loss=2.533477e-04 | mean(pred on target rows)≈10.010 | max|pred| on non-target rows≈0.099\n",
            "[ep 10350] loss=3.049092e-04 | mean(pred on target rows)≈9.976 | max|pred| on non-target rows≈0.013\n",
            "[ep 10400] loss=4.934987e-04 | mean(pred on target rows)≈10.000 | max|pred| on non-target rows≈0.102\n",
            "[ep 10450] loss=7.369391e-03 | mean(pred on target rows)≈10.064 | max|pred| on non-target rows≈0.262\n",
            "[ep 10500] loss=4.506719e-04 | mean(pred on target rows)≈9.973 | max|pred| on non-target rows≈0.019\n",
            "[ep 10550] loss=4.495571e-03 | mean(pred on target rows)≈9.955 | max|pred| on non-target rows≈0.128\n",
            "[ep 10600] loss=1.512524e-03 | mean(pred on target rows)≈9.950 | max|pred| on non-target rows≈0.022\n",
            "[ep 10650] loss=9.955666e-04 | mean(pred on target rows)≈9.992 | max|pred| on non-target rows≈0.091\n",
            "[ep 10700] loss=2.054382e-03 | mean(pred on target rows)≈9.950 | max|pred| on non-target rows≈0.032\n",
            "[ep 10750] loss=1.010433e-03 | mean(pred on target rows)≈10.003 | max|pred| on non-target rows≈0.101\n",
            "[ep 10800] loss=2.202519e-03 | mean(pred on target rows)≈9.950 | max|pred| on non-target rows≈0.066\n",
            "[ep 10850] loss=1.683908e-03 | mean(pred on target rows)≈9.951 | max|pred| on non-target rows≈0.054\n",
            "[ep 10900] loss=2.680736e-03 | mean(pred on target rows)≈10.069 | max|pred| on non-target rows≈0.132\n",
            "[ep 10950] loss=6.161733e-04 | mean(pred on target rows)≈9.985 | max|pred| on non-target rows≈0.072\n",
            "[ep 11000] loss=8.571200e-04 | mean(pred on target rows)≈9.959 | max|pred| on non-target rows≈0.016\n",
            "[ep 11050] loss=3.594641e-03 | mean(pred on target rows)≈9.939 | max|pred| on non-target rows≈0.075\n",
            "[ep 11100] loss=1.524580e-03 | mean(pred on target rows)≈10.052 | max|pred| on non-target rows≈0.132\n",
            "[ep 11150] loss=4.888294e-04 | mean(pred on target rows)≈9.998 | max|pred| on non-target rows≈0.078\n",
            "[ep 11200] loss=1.963975e-03 | mean(pred on target rows)≈10.052 | max|pred| on non-target rows≈0.168\n",
            "[ep 11250] loss=1.306542e-03 | mean(pred on target rows)≈10.045 | max|pred| on non-target rows≈0.132\n",
            "[ep 11300] loss=1.243244e-03 | mean(pred on target rows)≈10.015 | max|pred| on non-target rows≈0.150\n",
            "[ep 11350] loss=1.704794e-03 | mean(pred on target rows)≈10.025 | max|pred| on non-target rows≈0.177\n",
            "[ep 11400] loss=6.416025e-04 | mean(pred on target rows)≈10.029 | max|pred| on non-target rows≈0.060\n",
            "[ep 11450] loss=5.727677e-04 | mean(pred on target rows)≈9.986 | max|pred| on non-target rows≈0.077\n",
            "[ep 11500] loss=2.437716e-03 | mean(pred on target rows)≈9.952 | max|pred| on non-target rows≈0.062\n",
            "[ep 11550] loss=4.946723e-04 | mean(pred on target rows)≈9.990 | max|pred| on non-target rows≈0.087\n",
            "[ep 11600] loss=2.770473e-03 | mean(pred on target rows)≈10.073 | max|pred| on non-target rows≈0.118\n",
            "[ep 11650] loss=2.408633e-03 | mean(pred on target rows)≈9.931 | max|pred| on non-target rows≈0.008\n",
            "[ep 11700] loss=7.479786e-04 | mean(pred on target rows)≈10.024 | max|pred| on non-target rows≈0.132\n",
            "[ep 11750] loss=1.344015e-03 | mean(pred on target rows)≈10.050 | max|pred| on non-target rows≈0.108\n",
            "[ep 11800] loss=3.161918e-03 | mean(pred on target rows)≈9.931 | max|pred| on non-target rows≈0.042\n",
            "[ep 11850] loss=4.527349e-04 | mean(pred on target rows)≈9.972 | max|pred| on non-target rows≈0.018\n",
            "[ep 11900] loss=5.232113e-03 | mean(pred on target rows)≈9.948 | max|pred| on non-target rows≈0.087\n",
            "[ep 11950] loss=1.742319e-03 | mean(pred on target rows)≈9.943 | max|pred| on non-target rows≈0.013\n",
            "[ep 12000] loss=3.955598e-03 | mean(pred on target rows)≈9.946 | max|pred| on non-target rows≈0.102\n",
            "[ep 12050] loss=1.938280e-04 | mean(pred on target rows)≈9.989 | max|pred| on non-target rows≈0.037\n",
            "[ep 12100] loss=1.054408e-02 | mean(pred on target rows)≈9.876 | max|pred| on non-target rows≈0.167\n",
            "[ep 12150] loss=1.018013e-02 | mean(pred on target rows)≈10.091 | max|pred| on non-target rows≈0.440\n",
            "[ep 12200] loss=4.953890e-05 | mean(pred on target rows)≈10.005 | max|pred| on non-target rows≈0.063\n",
            "[ep 12250] loss=7.976468e-06 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.035\n",
            "[ep 12300] loss=7.474862e-06 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.036\n",
            "[ep 12350] loss=1.978562e-04 | mean(pred on target rows)≈10.019 | max|pred| on non-target rows≈0.051\n",
            "[ep 12400] loss=2.291649e-05 | mean(pred on target rows)≈9.994 | max|pred| on non-target rows≈0.030\n",
            "[ep 12450] loss=7.919354e-04 | mean(pred on target rows)≈10.039 | max|pred| on non-target rows≈0.063\n",
            "[ep 12500] loss=1.643913e-03 | mean(pred on target rows)≈9.943 | max|pred| on non-target rows≈0.006\n",
            "[ep 12550] loss=3.298381e-04 | mean(pred on target rows)≈9.974 | max|pred| on non-target rows≈0.012\n",
            "[ep 12600] loss=1.070256e-04 | mean(pred on target rows)≈10.014 | max|pred| on non-target rows≈0.049\n",
            "[ep 12650] loss=1.729928e-03 | mean(pred on target rows)≈10.056 | max|pred| on non-target rows≈0.125\n",
            "[ep 12700] loss=1.560538e-03 | mean(pred on target rows)≈10.035 | max|pred| on non-target rows≈0.135\n",
            "[ep 12750] loss=1.546506e-03 | mean(pred on target rows)≈10.051 | max|pred| on non-target rows≈0.137\n",
            "[ep 12800] loss=3.106181e-04 | mean(pred on target rows)≈10.022 | max|pred| on non-target rows≈0.102\n",
            "[ep 12850] loss=1.968118e-03 | mean(pred on target rows)≈10.026 | max|pred| on non-target rows≈0.175\n",
            "[ep 12900] loss=1.635019e-02 | mean(pred on target rows)≈10.154 | max|pred| on non-target rows≈0.444\n",
            "[ep 12950] loss=7.552522e-03 | mean(pred on target rows)≈9.985 | max|pred| on non-target rows≈0.366\n",
            "[ep 13000] loss=1.578682e-02 | mean(pred on target rows)≈9.856 | max|pred| on non-target rows≈0.188\n",
            "[ep 13050] loss=5.965678e-03 | mean(pred on target rows)≈10.107 | max|pred| on non-target rows≈0.192\n",
            "[ep 13100] loss=8.529726e-03 | mean(pred on target rows)≈9.870 | max|pred| on non-target rows≈0.005\n",
            "[ep 13150] loss=5.111821e-03 | mean(pred on target rows)≈10.100 | max|pred| on non-target rows≈0.140\n",
            "[ep 13200] loss=6.839008e-03 | mean(pred on target rows)≈9.884 | max|pred| on non-target rows≈0.004\n",
            "[ep 13250] loss=4.002991e-03 | mean(pred on target rows)≈10.088 | max|pred| on non-target rows≈0.149\n",
            "[ep 13300] loss=5.895966e-03 | mean(pred on target rows)≈9.893 | max|pred| on non-target rows≈0.005\n",
            "[ep 13350] loss=2.053182e-05 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.052\n",
            "[ep 13400] loss=9.945027e-06 | mean(pred on target rows)≈10.000 | max|pred| on non-target rows≈0.039\n",
            "[ep 13450] loss=7.498790e-06 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.039\n",
            "[ep 13500] loss=2.090844e-04 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.066\n",
            "[ep 13550] loss=5.664236e-04 | mean(pred on target rows)≈10.001 | max|pred| on non-target rows≈0.096\n",
            "[ep 13600] loss=2.088141e-04 | mean(pred on target rows)≈10.000 | max|pred| on non-target rows≈0.087\n",
            "[ep 13650] loss=1.003266e-04 | mean(pred on target rows)≈10.003 | max|pred| on non-target rows≈0.057\n",
            "[ep 13700] loss=1.825175e-04 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.075\n",
            "[ep 13750] loss=8.555744e-05 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.059\n",
            "[ep 13800] loss=1.103836e-03 | mean(pred on target rows)≈10.008 | max|pred| on non-target rows≈0.111\n",
            "[ep 13850] loss=9.013142e-05 | mean(pred on target rows)≈9.988 | max|pred| on non-target rows≈0.030\n",
            "[ep 13900] loss=1.818385e-04 | mean(pred on target rows)≈9.994 | max|pred| on non-target rows≈0.075\n",
            "[ep 13950] loss=1.620761e-03 | mean(pred on target rows)≈10.042 | max|pred| on non-target rows≈0.160\n",
            "[ep 14000] loss=8.051446e-04 | mean(pred on target rows)≈10.035 | max|pred| on non-target rows≈0.111\n",
            "[ep 14050] loss=1.206046e-03 | mean(pred on target rows)≈10.038 | max|pred| on non-target rows≈0.115\n",
            "[ep 14100] loss=2.550408e-03 | mean(pred on target rows)≈9.931 | max|pred| on non-target rows≈0.008\n",
            "[ep 14150] loss=2.247763e-04 | mean(pred on target rows)≈10.004 | max|pred| on non-target rows≈0.081\n",
            "[ep 14200] loss=3.845170e-05 | mean(pred on target rows)≈10.004 | max|pred| on non-target rows≈0.042\n",
            "[ep 14250] loss=8.386715e-04 | mean(pred on target rows)≈10.004 | max|pred| on non-target rows≈0.105\n",
            "[ep 14300] loss=1.362089e-03 | mean(pred on target rows)≈10.042 | max|pred| on non-target rows≈0.113\n",
            "[ep 14350] loss=3.807815e-04 | mean(pred on target rows)≈10.018 | max|pred| on non-target rows≈0.094\n",
            "[ep 14400] loss=7.193548e-04 | mean(pred on target rows)≈10.020 | max|pred| on non-target rows≈0.143\n",
            "[ep 14450] loss=1.437803e-03 | mean(pred on target rows)≈10.041 | max|pred| on non-target rows≈0.141\n",
            "[ep 14500] loss=4.766163e-03 | mean(pred on target rows)≈9.941 | max|pred| on non-target rows≈0.149\n",
            "[ep 14550] loss=3.526384e-04 | mean(pred on target rows)≈10.024 | max|pred| on non-target rows≈0.098\n",
            "[ep 14600] loss=2.252926e-03 | mean(pred on target rows)≈9.973 | max|pred| on non-target rows≈0.106\n",
            "[ep 14650] loss=1.118754e-04 | mean(pred on target rows)≈10.006 | max|pred| on non-target rows≈0.037\n",
            "[ep 14700] loss=5.589056e-03 | mean(pred on target rows)≈9.956 | max|pred| on non-target rows≈0.138\n",
            "[ep 14750] loss=1.893792e-04 | mean(pred on target rows)≈9.986 | max|pred| on non-target rows≈0.039\n",
            "[ep 14800] loss=1.018048e-03 | mean(pred on target rows)≈10.021 | max|pred| on non-target rows≈0.124\n",
            "[ep 14850] loss=2.596526e-03 | mean(pred on target rows)≈9.942 | max|pred| on non-target rows≈0.032\n",
            "[ep 14900] loss=1.502677e-03 | mean(pred on target rows)≈9.957 | max|pred| on non-target rows≈0.021\n",
            "[ep 14950] loss=2.520571e-03 | mean(pred on target rows)≈9.934 | max|pred| on non-target rows≈0.014\n",
            "[ep 15000] loss=3.406514e-03 | mean(pred on target rows)≈10.072 | max|pred| on non-target rows≈0.194\n",
            "[ep 15050] loss=1.244978e-04 | mean(pred on target rows)≈10.012 | max|pred| on non-target rows≈0.021\n",
            "[ep 15100] loss=4.828594e-03 | mean(pred on target rows)≈10.057 | max|pred| on non-target rows≈0.251\n",
            "[ep 15150] loss=3.116530e-03 | mean(pred on target rows)≈9.937 | max|pred| on non-target rows≈0.024\n",
            "[ep 15200] loss=2.813141e-05 | mean(pred on target rows)≈10.005 | max|pred| on non-target rows≈0.046\n",
            "[ep 15250] loss=6.661519e-04 | mean(pred on target rows)≈10.035 | max|pred| on non-target rows≈0.087\n",
            "[ep 15300] loss=1.819174e-03 | mean(pred on target rows)≈9.985 | max|pred| on non-target rows≈0.121\n",
            "[ep 15350] loss=2.975954e-03 | mean(pred on target rows)≈9.943 | max|pred| on non-target rows≈0.051\n",
            "[ep 15400] loss=1.086765e-03 | mean(pred on target rows)≈9.967 | max|pred| on non-target rows≈0.045\n",
            "[ep 15450] loss=2.116445e-02 | mean(pred on target rows)≈9.943 | max|pred| on non-target rows≈0.290\n",
            "[ep 15500] loss=1.769567e-03 | mean(pred on target rows)≈10.028 | max|pred| on non-target rows≈0.138\n",
            "[ep 15550] loss=1.591513e-05 | mean(pred on target rows)≈10.003 | max|pred| on non-target rows≈0.043\n",
            "[ep 15600] loss=1.064673e-03 | mean(pred on target rows)≈9.954 | max|pred| on non-target rows≈0.005\n",
            "[ep 15650] loss=3.827265e-04 | mean(pred on target rows)≈10.026 | max|pred| on non-target rows≈0.084\n",
            "[ep 15700] loss=1.504530e-03 | mean(pred on target rows)≈9.946 | max|pred| on non-target rows≈0.004\n",
            "[ep 15750] loss=2.167047e-04 | mean(pred on target rows)≈10.013 | max|pred| on non-target rows≈0.052\n",
            "[ep 15800] loss=2.514026e-03 | mean(pred on target rows)≈9.931 | max|pred| on non-target rows≈0.003\n",
            "[ep 15850] loss=1.496678e-03 | mean(pred on target rows)≈9.947 | max|pred| on non-target rows≈0.006\n",
            "[ep 15900] loss=2.391486e-03 | mean(pred on target rows)≈9.980 | max|pred| on non-target rows≈0.128\n",
            "[ep 15950] loss=4.743289e-04 | mean(pred on target rows)≈10.029 | max|pred| on non-target rows≈0.079\n",
            "[ep 16000] loss=1.821404e-03 | mean(pred on target rows)≈9.944 | max|pred| on non-target rows≈0.011\n",
            "[ep 16050] loss=3.578731e-04 | mean(pred on target rows)≈10.008 | max|pred| on non-target rows≈0.101\n",
            "[ep 16100] loss=3.962349e-04 | mean(pred on target rows)≈10.025 | max|pred| on non-target rows≈0.080\n",
            "[ep 16150] loss=9.457410e-04 | mean(pred on target rows)≈9.993 | max|pred| on non-target rows≈0.112\n",
            "[ep 16200] loss=3.651709e-03 | mean(pred on target rows)≈10.051 | max|pred| on non-target rows≈0.226\n",
            "[ep 16250] loss=7.423825e-04 | mean(pred on target rows)≈9.976 | max|pred| on non-target rows≈0.032\n",
            "[ep 16300] loss=2.575681e-03 | mean(pred on target rows)≈9.942 | max|pred| on non-target rows≈0.029\n",
            "[ep 16350] loss=7.363426e-04 | mean(pred on target rows)≈9.971 | max|pred| on non-target rows≈0.058\n",
            "[ep 16400] loss=6.378130e-04 | mean(pred on target rows)≈10.023 | max|pred| on non-target rows≈0.115\n",
            "[ep 16450] loss=9.788034e-04 | mean(pred on target rows)≈10.008 | max|pred| on non-target rows≈0.143\n",
            "[ep 16500] loss=9.341369e-04 | mean(pred on target rows)≈10.035 | max|pred| on non-target rows≈0.102\n",
            "[ep 16550] loss=7.156179e-04 | mean(pred on target rows)≈10.028 | max|pred| on non-target rows≈0.084\n",
            "[ep 16600] loss=5.046889e-03 | mean(pred on target rows)≈10.061 | max|pred| on non-target rows≈0.213\n",
            "[ep 16650] loss=1.606077e-03 | mean(pred on target rows)≈9.946 | max|pred| on non-target rows≈0.004\n",
            "[ep 16700] loss=1.931929e-03 | mean(pred on target rows)≈9.952 | max|pred| on non-target rows≈0.047\n",
            "[ep 16750] loss=1.101148e-03 | mean(pred on target rows)≈10.044 | max|pred| on non-target rows≈0.101\n",
            "[ep 16800] loss=3.214432e-04 | mean(pred on target rows)≈10.006 | max|pred| on non-target rows≈0.077\n",
            "[ep 16850] loss=6.323776e-04 | mean(pred on target rows)≈10.030 | max|pred| on non-target rows≈0.085\n",
            "[ep 16900] loss=3.011563e-03 | mean(pred on target rows)≈10.005 | max|pred| on non-target rows≈0.179\n",
            "[ep 16950] loss=2.130501e-04 | mean(pred on target rows)≈10.003 | max|pred| on non-target rows≈0.085\n",
            "[ep 17000] loss=3.349819e-05 | mean(pred on target rows)≈10.007 | max|pred| on non-target rows≈0.019\n",
            "[ep 17050] loss=1.331704e-03 | mean(pred on target rows)≈9.949 | max|pred| on non-target rows≈0.003\n",
            "[ep 17100] loss=2.625103e-04 | mean(pred on target rows)≈10.006 | max|pred| on non-target rows≈0.094\n",
            "[ep 17150] loss=1.220756e-03 | mean(pred on target rows)≈9.952 | max|pred| on non-target rows≈0.008\n",
            "[ep 17200] loss=7.419596e-03 | mean(pred on target rows)≈9.992 | max|pred| on non-target rows≈0.249\n",
            "[ep 17250] loss=1.226351e-04 | mean(pred on target rows)≈9.990 | max|pred| on non-target rows≈0.045\n",
            "[ep 17300] loss=1.475623e-04 | mean(pred on target rows)≈10.017 | max|pred| on non-target rows≈0.066\n",
            "[ep 17350] loss=1.821429e-05 | mean(pred on target rows)≈9.995 | max|pred| on non-target rows≈0.045\n",
            "[ep 17400] loss=4.482604e-03 | mean(pred on target rows)≈9.954 | max|pred| on non-target rows≈0.142\n",
            "[ep 17450] loss=1.510157e-04 | mean(pred on target rows)≈9.989 | max|pred| on non-target rows≈0.046\n",
            "[ep 17500] loss=1.718379e-03 | mean(pred on target rows)≈9.982 | max|pred| on non-target rows≈0.083\n",
            "[ep 17550] loss=2.433083e-04 | mean(pred on target rows)≈9.982 | max|pred| on non-target rows≈0.009\n",
            "[ep 17600] loss=8.375993e-04 | mean(pred on target rows)≈10.010 | max|pred| on non-target rows≈0.125\n",
            "[ep 17650] loss=3.578842e-04 | mean(pred on target rows)≈10.023 | max|pred| on non-target rows≈0.094\n",
            "[ep 17700] loss=9.073364e-04 | mean(pred on target rows)≈9.984 | max|pred| on non-target rows≈0.092\n",
            "[ep 17750] loss=4.608989e-04 | mean(pred on target rows)≈10.017 | max|pred| on non-target rows≈0.140\n",
            "[ep 17800] loss=4.553969e-05 | mean(pred on target rows)≈10.001 | max|pred| on non-target rows≈0.060\n",
            "[ep 17850] loss=2.171835e-03 | mean(pred on target rows)≈9.949 | max|pred| on non-target rows≈0.036\n",
            "[ep 17900] loss=2.621697e-04 | mean(pred on target rows)≈9.993 | max|pred| on non-target rows≈0.073\n",
            "[ep 17950] loss=1.620503e-03 | mean(pred on target rows)≈10.005 | max|pred| on non-target rows≈0.119\n",
            "[ep 18000] loss=1.866940e-03 | mean(pred on target rows)≈9.957 | max|pred| on non-target rows≈0.043\n",
            "[ep 18050] loss=2.046422e-04 | mean(pred on target rows)≈9.995 | max|pred| on non-target rows≈0.031\n",
            "[ep 18100] loss=1.930862e-03 | mean(pred on target rows)≈9.960 | max|pred| on non-target rows≈0.045\n",
            "[ep 18150] loss=1.534485e-04 | mean(pred on target rows)≈9.992 | max|pred| on non-target rows≈0.020\n",
            "[ep 18200] loss=1.401550e-03 | mean(pred on target rows)≈10.038 | max|pred| on non-target rows≈0.145\n",
            "[ep 18250] loss=4.384035e-04 | mean(pred on target rows)≈9.984 | max|pred| on non-target rows≈0.052\n",
            "[ep 18300] loss=1.600426e-03 | mean(pred on target rows)≈9.957 | max|pred| on non-target rows≈0.053\n",
            "[ep 18350] loss=3.453603e-03 | mean(pred on target rows)≈9.946 | max|pred| on non-target rows≈0.092\n",
            "[ep 18400] loss=1.228048e-03 | mean(pred on target rows)≈9.953 | max|pred| on non-target rows≈0.003\n",
            "[ep 18450] loss=1.406597e-03 | mean(pred on target rows)≈10.011 | max|pred| on non-target rows≈0.169\n",
            "[ep 18500] loss=1.227102e-03 | mean(pred on target rows)≈10.003 | max|pred| on non-target rows≈0.137\n",
            "[ep 18550] loss=9.458062e-04 | mean(pred on target rows)≈10.042 | max|pred| on non-target rows≈0.093\n",
            "[ep 18600] loss=7.503490e-04 | mean(pred on target rows)≈9.965 | max|pred| on non-target rows≈0.008\n",
            "[ep 18650] loss=1.317413e-03 | mean(pred on target rows)≈9.960 | max|pred| on non-target rows≈0.071\n",
            "[ep 18700] loss=1.900656e-03 | mean(pred on target rows)≈10.037 | max|pred| on non-target rows≈0.136\n",
            "[ep 18750] loss=3.618847e-04 | mean(pred on target rows)≈9.988 | max|pred| on non-target rows≈0.058\n",
            "[ep 18800] loss=3.240655e-03 | mean(pred on target rows)≈10.036 | max|pred| on non-target rows≈0.265\n",
            "[ep 18850] loss=2.714833e-03 | mean(pred on target rows)≈9.964 | max|pred| on non-target rows≈0.139\n",
            "[ep 18900] loss=3.905531e-03 | mean(pred on target rows)≈10.009 | max|pred| on non-target rows≈0.271\n",
            "[ep 18950] loss=2.908996e-05 | mean(pred on target rows)≈10.004 | max|pred| on non-target rows≈0.053\n",
            "[ep 19000] loss=4.976907e-06 | mean(pred on target rows)≈10.000 | max|pred| on non-target rows≈0.035\n",
            "[ep 19050] loss=4.636754e-06 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.034\n",
            "[ep 19100] loss=6.901329e-05 | mean(pred on target rows)≈10.011 | max|pred| on non-target rows≈0.051\n",
            "[ep 19150] loss=4.433846e-04 | mean(pred on target rows)≈9.970 | max|pred| on non-target rows≈0.016\n",
            "[ep 19200] loss=9.070806e-04 | mean(pred on target rows)≈9.958 | max|pred| on non-target rows≈0.012\n",
            "[ep 19250] loss=7.107253e-05 | mean(pred on target rows)≈10.012 | max|pred| on non-target rows≈0.048\n",
            "[ep 19300] loss=6.353689e-04 | mean(pred on target rows)≈9.965 | max|pred| on non-target rows≈0.001\n",
            "[ep 19350] loss=2.163301e-04 | mean(pred on target rows)≈10.020 | max|pred| on non-target rows≈0.051\n",
            "[ep 19400] loss=2.081632e-04 | mean(pred on target rows)≈9.980 | max|pred| on non-target rows≈0.012\n",
            "[ep 19450] loss=1.112546e-03 | mean(pred on target rows)≈10.045 | max|pred| on non-target rows≈0.090\n",
            "[ep 19500] loss=3.902923e-04 | mean(pred on target rows)≈9.976 | max|pred| on non-target rows≈0.011\n",
            "[ep 19550] loss=2.751301e-03 | mean(pred on target rows)≈9.937 | max|pred| on non-target rows≈0.014\n",
            "[ep 19600] loss=6.654839e-04 | mean(pred on target rows)≈9.997 | max|pred| on non-target rows≈0.071\n",
            "[ep 19650] loss=1.599242e-03 | mean(pred on target rows)≈10.019 | max|pred| on non-target rows≈0.181\n",
            "[ep 19700] loss=5.429559e-04 | mean(pred on target rows)≈10.025 | max|pred| on non-target rows≈0.111\n",
            "[ep 19750] loss=1.142919e-03 | mean(pred on target rows)≈9.971 | max|pred| on non-target rows≈0.076\n",
            "[ep 19800] loss=1.166201e-03 | mean(pred on target rows)≈10.040 | max|pred| on non-target rows≈0.075\n",
            "[ep 19850] loss=2.070007e-03 | mean(pred on target rows)≈9.948 | max|pred| on non-target rows≈0.032\n",
            "[ep 19900] loss=1.263923e-03 | mean(pred on target rows)≈10.031 | max|pred| on non-target rows≈0.111\n",
            "[ep 19950] loss=6.525596e-04 | mean(pred on target rows)≈9.972 | max|pred| on non-target rows≈0.019\n",
            "[ep 20000] loss=2.329668e-03 | mean(pred on target rows)≈10.055 | max|pred| on non-target rows≈0.176\n",
            "[ep 20050] loss=8.387026e-04 | mean(pred on target rows)≈10.037 | max|pred| on non-target rows≈0.070\n",
            "[ep 20100] loss=1.967130e-03 | mean(pred on target rows)≈9.945 | max|pred| on non-target rows≈0.005\n",
            "[ep 20150] loss=1.116287e-03 | mean(pred on target rows)≈9.962 | max|pred| on non-target rows≈0.031\n",
            "[ep 20200] loss=1.580178e-03 | mean(pred on target rows)≈9.973 | max|pred| on non-target rows≈0.080\n",
            "[ep 20250] loss=2.176713e-04 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.075\n",
            "[ep 20300] loss=5.734975e-03 | mean(pred on target rows)≈10.044 | max|pred| on non-target rows≈0.272\n",
            "[ep 20350] loss=1.810233e-02 | mean(pred on target rows)≈10.164 | max|pred| on non-target rows≈0.540\n",
            "[ep 20400] loss=8.409982e-03 | mean(pred on target rows)≈9.871 | max|pred| on non-target rows≈0.003\n",
            "[ep 20450] loss=3.645502e-03 | mean(pred on target rows)≈10.084 | max|pred| on non-target rows≈0.144\n",
            "[ep 20500] loss=6.396419e-03 | mean(pred on target rows)≈9.893 | max|pred| on non-target rows≈0.014\n",
            "[ep 20550] loss=6.781522e-03 | mean(pred on target rows)≈10.083 | max|pred| on non-target rows≈0.200\n",
            "[ep 20600] loss=1.660190e-03 | mean(pred on target rows)≈10.012 | max|pred| on non-target rows≈0.143\n",
            "[ep 20650] loss=1.493432e-05 | mean(pred on target rows)≈10.001 | max|pred| on non-target rows≈0.042\n",
            "[ep 20700] loss=4.826613e-06 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.034\n",
            "[ep 20750] loss=4.744366e-06 | mean(pred on target rows)≈10.000 | max|pred| on non-target rows≈0.035\n",
            "[ep 20800] loss=4.689842e-06 | mean(pred on target rows)≈10.000 | max|pred| on non-target rows≈0.034\n",
            "[ep 20850] loss=1.322866e-05 | mean(pred on target rows)≈9.996 | max|pred| on non-target rows≈0.028\n",
            "[ep 20900] loss=4.835349e-04 | mean(pred on target rows)≈9.989 | max|pred| on non-target rows≈0.063\n",
            "[ep 20950] loss=3.044574e-04 | mean(pred on target rows)≈9.989 | max|pred| on non-target rows≈0.047\n",
            "[ep 21000] loss=3.046404e-04 | mean(pred on target rows)≈10.022 | max|pred| on non-target rows≈0.083\n",
            "[ep 21050] loss=5.414896e-04 | mean(pred on target rows)≈10.017 | max|pred| on non-target rows≈0.077\n",
            "[ep 21100] loss=9.115183e-04 | mean(pred on target rows)≈9.960 | max|pred| on non-target rows≈0.011\n",
            "[ep 21150] loss=6.172438e-04 | mean(pred on target rows)≈9.978 | max|pred| on non-target rows≈0.044\n",
            "[ep 21200] loss=1.706645e-04 | mean(pred on target rows)≈9.988 | max|pred| on non-target rows≈0.012\n",
            "[ep 21250] loss=1.110141e-03 | mean(pred on target rows)≈10.037 | max|pred| on non-target rows≈0.110\n",
            "[ep 21300] loss=4.750895e-04 | mean(pred on target rows)≈9.973 | max|pred| on non-target rows≈0.014\n",
            "[ep 21350] loss=1.369256e-03 | mean(pred on target rows)≈9.970 | max|pred| on non-target rows≈0.106\n",
            "[ep 21400] loss=9.194948e-04 | mean(pred on target rows)≈10.038 | max|pred| on non-target rows≈0.097\n",
            "[ep 21450] loss=5.392549e-03 | mean(pred on target rows)≈10.038 | max|pred| on non-target rows≈0.228\n",
            "[ep 21500] loss=1.464885e-04 | mean(pred on target rows)≈9.986 | max|pred| on non-target rows≈0.020\n",
            "[ep 21550] loss=2.694071e-03 | mean(pred on target rows)≈10.059 | max|pred| on non-target rows≈0.165\n",
            "[ep 21600] loss=4.094487e-04 | mean(pred on target rows)≈10.020 | max|pred| on non-target rows≈0.066\n",
            "[ep 21650] loss=4.976120e-04 | mean(pred on target rows)≈10.007 | max|pred| on non-target rows≈0.082\n",
            "[ep 21700] loss=1.823509e-03 | mean(pred on target rows)≈10.056 | max|pred| on non-target rows≈0.125\n",
            "[ep 21750] loss=8.587549e-04 | mean(pred on target rows)≈9.993 | max|pred| on non-target rows≈0.117\n",
            "[ep 21800] loss=1.251284e-03 | mean(pred on target rows)≈9.979 | max|pred| on non-target rows≈0.049\n",
            "[ep 21850] loss=2.016495e-03 | mean(pred on target rows)≈10.028 | max|pred| on non-target rows≈0.167\n",
            "[ep 21900] loss=1.346852e-03 | mean(pred on target rows)≈9.951 | max|pred| on non-target rows≈0.003\n",
            "[ep 21950] loss=8.721384e-04 | mean(pred on target rows)≈9.994 | max|pred| on non-target rows≈0.107\n",
            "[ep 22000] loss=3.461750e-05 | mean(pred on target rows)≈9.998 | max|pred| on non-target rows≈0.041\n",
            "[ep 22050] loss=7.046587e-05 | mean(pred on target rows)≈10.005 | max|pred| on non-target rows≈0.021\n",
            "[ep 22100] loss=1.568459e-03 | mean(pred on target rows)≈10.041 | max|pred| on non-target rows≈0.139\n",
            "[ep 22150] loss=1.617975e-03 | mean(pred on target rows)≈9.965 | max|pred| on non-target rows≈0.060\n",
            "[ep 22200] loss=4.605673e-04 | mean(pred on target rows)≈10.028 | max|pred| on non-target rows≈0.088\n",
            "[ep 22250] loss=1.230846e-03 | mean(pred on target rows)≈10.042 | max|pred| on non-target rows≈0.144\n",
            "[ep 22300] loss=4.520513e-02 | mean(pred on target rows)≈10.058 | max|pred| on non-target rows≈1.125\n",
            "[ep 22350] loss=2.435930e-04 | mean(pred on target rows)≈9.994 | max|pred| on non-target rows≈0.112\n",
            "[ep 22400] loss=5.154865e-06 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.036\n",
            "[ep 22450] loss=7.921898e-06 | mean(pred on target rows)≈10.002 | max|pred| on non-target rows≈0.037\n",
            "[ep 22500] loss=1.819341e-05 | mean(pred on target rows)≈10.005 | max|pred| on non-target rows≈0.040\n",
            "[ep 22550] loss=1.792061e-03 | mean(pred on target rows)≈10.059 | max|pred| on non-target rows≈0.096\n",
            "[ep 22600] loss=2.576523e-04 | mean(pred on target rows)≈10.022 | max|pred| on non-target rows≈0.077\n",
            "[ep 22650] loss=6.960236e-04 | mean(pred on target rows)≈9.963 | max|pred| on non-target rows≈0.002\n",
            "[ep 22700] loss=5.634968e-04 | mean(pred on target rows)≈10.033 | max|pred| on non-target rows≈0.076\n",
            "[ep 22750] loss=6.944325e-04 | mean(pred on target rows)≈10.036 | max|pred| on non-target rows≈0.093\n",
            "[ep 22800] loss=7.097328e-04 | mean(pred on target rows)≈10.035 | max|pred| on non-target rows≈0.109\n",
            "[ep 22850] loss=1.615875e-04 | mean(pred on target rows)≈10.015 | max|pred| on non-target rows≈0.025\n",
            "[ep 22900] loss=1.709529e-03 | mean(pred on target rows)≈9.975 | max|pred| on non-target rows≈0.089\n",
            "[ep 22950] loss=4.920618e-04 | mean(pred on target rows)≈9.982 | max|pred| on non-target rows≈0.077\n",
            "[ep 23000] loss=3.214074e-04 | mean(pred on target rows)≈10.020 | max|pred| on non-target rows≈0.110\n",
            "[ep 23050] loss=4.302189e-04 | mean(pred on target rows)≈10.019 | max|pred| on non-target rows≈0.121\n",
            "[ep 23100] loss=1.077983e-04 | mean(pred on target rows)≈9.986 | max|pred| on non-target rows≈0.001\n",
            "[ep 23150] loss=2.573793e-03 | mean(pred on target rows)≈9.939 | max|pred| on non-target rows≈0.022\n",
            "[ep 23200] loss=1.250475e-04 | mean(pred on target rows)≈10.002 | max|pred| on non-target rows≈0.068\n",
            "[ep 23250] loss=1.744786e-03 | mean(pred on target rows)≈9.951 | max|pred| on non-target rows≈0.024\n",
            "[ep 23300] loss=2.326344e-03 | mean(pred on target rows)≈10.055 | max|pred| on non-target rows≈0.182\n",
            "[ep 23350] loss=1.161800e-03 | mean(pred on target rows)≈10.038 | max|pred| on non-target rows≈0.139\n",
            "[ep 23400] loss=4.356246e-03 | mean(pred on target rows)≈9.941 | max|pred| on non-target rows≈0.097\n",
            "[ep 23450] loss=2.284875e-03 | mean(pred on target rows)≈10.062 | max|pred| on non-target rows≈0.132\n",
            "[ep 23500] loss=1.185150e-03 | mean(pred on target rows)≈9.974 | max|pred| on non-target rows≈0.071\n",
            "[ep 23550] loss=4.675828e-04 | mean(pred on target rows)≈9.997 | max|pred| on non-target rows≈0.062\n",
            "[ep 23600] loss=2.102619e-04 | mean(pred on target rows)≈9.993 | max|pred| on non-target rows≈0.068\n",
            "[ep 23650] loss=7.493074e-04 | mean(pred on target rows)≈9.983 | max|pred| on non-target rows≈0.098\n",
            "[ep 23700] loss=1.007676e-03 | mean(pred on target rows)≈10.040 | max|pred| on non-target rows≈0.127\n",
            "[ep 23750] loss=5.928011e-04 | mean(pred on target rows)≈10.024 | max|pred| on non-target rows≈0.128\n",
            "[ep 23800] loss=1.510162e-03 | mean(pred on target rows)≈9.963 | max|pred| on non-target rows≈0.140\n",
            "[ep 23850] loss=2.785140e-03 | mean(pred on target rows)≈9.929 | max|pred| on non-target rows≈0.002\n",
            "[ep 23900] loss=6.145143e-03 | mean(pred on target rows)≈10.086 | max|pred| on non-target rows≈0.210\n",
            "[ep 23950] loss=2.246421e-02 | mean(pred on target rows)≈9.942 | max|pred| on non-target rows≈0.551\n",
            "[ep 24000] loss=7.205443e-05 | mean(pred on target rows)≈9.998 | max|pred| on non-target rows≈0.080\n",
            "[ep 24050] loss=6.025058e-06 | mean(pred on target rows)≈9.998 | max|pred| on non-target rows≈0.031\n",
            "[ep 24100] loss=3.747220e-06 | mean(pred on target rows)≈9.999 | max|pred| on non-target rows≈0.031\n",
            "[ep 24150] loss=1.321873e-04 | mean(pred on target rows)≈9.984 | max|pred| on non-target rows≈0.016\n",
            "[ep 24200] loss=2.720715e-04 | mean(pred on target rows)≈10.023 | max|pred| on non-target rows≈0.057\n",
            "[ep 24250] loss=8.072616e-05 | mean(pred on target rows)≈10.012 | max|pred| on non-target rows≈0.050\n",
            "[ep 24300] loss=2.706824e-05 | mean(pred on target rows)≈9.993 | max|pred| on non-target rows≈0.042\n",
            "[ep 24350] loss=3.274045e-04 | mean(pred on target rows)≈9.975 | max|pred| on non-target rows≈0.019\n",
            "[ep 24400] loss=9.609141e-04 | mean(pred on target rows)≈9.956 | max|pred| on non-target rows≈0.008\n",
            "[ep 24450] loss=1.675029e-04 | mean(pred on target rows)≈10.018 | max|pred| on non-target rows≈0.060\n",
            "[ep 24500] loss=1.089256e-03 | mean(pred on target rows)≈9.954 | max|pred| on non-target rows≈0.000\n",
            "[ep 24550] loss=1.111934e-04 | mean(pred on target rows)≈9.986 | max|pred| on non-target rows≈0.026\n",
            "[ep 24600] loss=1.569066e-03 | mean(pred on target rows)≈10.046 | max|pred| on non-target rows≈0.131\n",
            "[ep 24650] loss=2.908899e-04 | mean(pred on target rows)≈9.980 | max|pred| on non-target rows≈0.010\n",
            "[ep 24700] loss=1.804823e-03 | mean(pred on target rows)≈9.953 | max|pred| on non-target rows≈0.020\n",
            "[ep 24750] loss=2.668332e-04 | mean(pred on target rows)≈10.019 | max|pred| on non-target rows≈0.053\n",
            "[ep 24800] loss=3.357695e-03 | mean(pred on target rows)≈9.947 | max|pred| on non-target rows≈0.058\n",
            "[ep 24850] loss=1.661449e-03 | mean(pred on target rows)≈10.056 | max|pred| on non-target rows≈0.097\n",
            "[ep 24900] loss=2.013518e-03 | mean(pred on target rows)≈10.039 | max|pred| on non-target rows≈0.147\n",
            "[ep 24950] loss=4.745255e-04 | mean(pred on target rows)≈9.997 | max|pred| on non-target rows≈0.044\n",
            "[ep 25000] loss=4.196181e-04 | mean(pred on target rows)≈9.982 | max|pred| on non-target rows≈0.050\n",
            "[ep 25050] loss=1.201132e-03 | mean(pred on target rows)≈9.971 | max|pred| on non-target rows≈0.079\n",
            "[ep 25100] loss=5.085549e-04 | mean(pred on target rows)≈10.007 | max|pred| on non-target rows≈0.110\n",
            "[ep 25150] loss=6.811261e-04 | mean(pred on target rows)≈9.998 | max|pred| on non-target rows≈0.110\n",
            "[ep 25200] loss=4.720661e-04 | mean(pred on target rows)≈10.019 | max|pred| on non-target rows≈0.110\n",
            "[ep 25250] loss=6.071999e-04 | mean(pred on target rows)≈10.026 | max|pred| on non-target rows≈0.097\n",
            "[ep 25300] loss=1.314392e-03 | mean(pred on target rows)≈10.038 | max|pred| on non-target rows≈0.105\n",
            "[ep 25350] loss=4.112842e-04 | mean(pred on target rows)≈10.021 | max|pred| on non-target rows≈0.087\n",
            "[ep 25400] loss=6.220703e-04 | mean(pred on target rows)≈10.010 | max|pred| on non-target rows≈0.107\n",
            "[ep 25450] loss=8.775141e-04 | mean(pred on target rows)≈9.967 | max|pred| on non-target rows≈0.038\n",
            "[ep 25500] loss=7.583856e-04 | mean(pred on target rows)≈10.017 | max|pred| on non-target rows≈0.084\n",
            "[ep 25550] loss=2.656616e-04 | mean(pred on target rows)≈9.993 | max|pred| on non-target rows≈0.058\n",
            "[ep 25600] loss=1.942855e-03 | mean(pred on target rows)≈9.975 | max|pred| on non-target rows≈0.096\n",
            "[ep 25650] loss=2.122502e-04 | mean(pred on target rows)≈10.011 | max|pred| on non-target rows≈0.081\n",
            "[ep 25700] loss=2.109015e-03 | mean(pred on target rows)≈9.936 | max|pred| on non-target rows≈0.000\n"
          ]
        },
        {
          "ename": "KeyboardInterrupt",
          "evalue": "",
          "output_type": "error",
          "traceback": [
            "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
            "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
            "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 149\u001b[39m\n\u001b[32m    146\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m ep % \u001b[32m5\u001b[39m == \u001b[32m0\u001b[39m:\n\u001b[32m    147\u001b[39m         \u001b[38;5;66;03m#if checkpoint_dir:\u001b[39;00m\n\u001b[32m    148\u001b[39m         ckpt_path = os.path.join(checkpoint_dir, \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mckpt_epoch_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mep\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.pt\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m149\u001b[39m         \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mepoch\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mep\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mmodel_state_dict\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstate_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    151\u001b[39m \u001b[38;5;66;03m# ---------------------------\u001b[39;00m\n\u001b[32m    152\u001b[39m \u001b[38;5;66;03m# Final evaluation (anchor check over all 2^10; average out noise tail)\u001b[39;00m\n\u001b[32m    153\u001b[39m \u001b[38;5;66;03m# ---------------------------\u001b[39;00m\n\u001b[32m    154\u001b[39m model.eval()\n",
            "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/torch-gpu-env/lib/python3.13/site-packages/torch/serialization.py:964\u001b[39m, in \u001b[36msave\u001b[39m\u001b[34m(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization, _disable_byteorder_record)\u001b[39m\n\u001b[32m    961\u001b[39m     f = os.fspath(f)\n\u001b[32m    963\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m _use_new_zipfile_serialization:\n\u001b[32m--> \u001b[39m\u001b[32m964\u001b[39m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43m_open_zipfile_writer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m opened_zipfile:\n\u001b[32m    965\u001b[39m         _save(\n\u001b[32m    966\u001b[39m             obj,\n\u001b[32m    967\u001b[39m             opened_zipfile,\n\u001b[32m   (...)\u001b[39m\u001b[32m    970\u001b[39m             _disable_byteorder_record,\n\u001b[32m    971\u001b[39m         )\n\u001b[32m    972\u001b[39m         \u001b[38;5;28;01mreturn\u001b[39;00m\n",
            "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/torch-gpu-env/lib/python3.13/site-packages/torch/serialization.py:828\u001b[39m, in \u001b[36m_open_zipfile_writer\u001b[39m\u001b[34m(name_or_buffer)\u001b[39m\n\u001b[32m    826\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m    827\u001b[39m     container = _open_zipfile_writer_buffer\n\u001b[32m--> \u001b[39m\u001b[32m828\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcontainer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname_or_buffer\u001b[49m\u001b[43m)\u001b[49m\n",
            "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/torch-gpu-env/lib/python3.13/site-packages/torch/serialization.py:792\u001b[39m, in \u001b[36m_open_zipfile_writer_file.__init__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m    785\u001b[39m     \u001b[38;5;28msuper\u001b[39m().\u001b[34m__init__\u001b[39m(\n\u001b[32m    786\u001b[39m         torch._C.PyTorchFileWriter(\n\u001b[32m    787\u001b[39m             \u001b[38;5;28mself\u001b[39m.file_stream, get_crc32_options(), _get_storage_alignment()\n\u001b[32m    788\u001b[39m         )\n\u001b[32m    789\u001b[39m     )\n\u001b[32m    790\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m    791\u001b[39m     \u001b[38;5;28msuper\u001b[39m().\u001b[34m__init__\u001b[39m(\n\u001b[32m--> \u001b[39m\u001b[32m792\u001b[39m         \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_C\u001b[49m\u001b[43m.\u001b[49m\u001b[43mPyTorchFileWriter\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m    793\u001b[39m \u001b[43m            \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mget_crc32_options\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_get_storage_alignment\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    794\u001b[39m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    795\u001b[39m     )\n",
            "\u001b[31mKeyboardInterrupt\u001b[39m: "
          ]
        }
      ],
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import os, math\n",
        "\n",
        "# ---------------------------\n",
        "# Config (tune these)\n",
        "# ---------------------------\n",
        "seed = 0\n",
        "torch.manual_seed(seed); np.random.seed(seed)\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "checkpoint_dir = \"\"         # set a dir to save, else \"\"\n",
        "\n",
        "d_indicator = 10\n",
        "d_noise     = 500           # 500 additional variables with NO effect\n",
        "d_total     = d_indicator + d_noise\n",
        "N_indicator = 1 << d_indicator   # 1024\n",
        "\n",
        "# Train\n",
        "epochs       = 30000\n",
        "print_every  = 50\n",
        "lr           = 2e-3\n",
        "weight_decay = 5e-4\n",
        "hidden       = 128\n",
        "\n",
        "# Task (no linear tail now)\n",
        "scalar                = 10.0   # spike value when first-10 == target\n",
        "min_points            = 65536  # total >=2048; higher => more reps per anchor\n",
        "target_extra_reps     = 2048   # extra rows for the target anchor (oversampling)\n",
        "pos_weight            = 16.0   # upweight target-anchor rows in MSE (pure MSE, just weighting)\n",
        "\n",
        "# Eval\n",
        "eval_R = 256  # noise samples per anchor for evaluation (anchor-averaged)\n",
        "\n",
        "# ---------------------------\n",
        "# Utilities\n",
        "# ---------------------------\n",
        "def full_cube_pm1(dim, device):\n",
        "    n = 1 << dim\n",
        "    idx = torch.arange(n, device=device, dtype=torch.long).unsqueeze(1)\n",
        "    shifts = torch.arange(dim, device=device, dtype=torch.long).unsqueeze(0)\n",
        "    bits = (idx >> shifts) & 1\n",
        "    return bits.float().mul_(2).sub_(1)\n",
        "\n",
        "def rand_pm1(shape, device):\n",
        "    return (torch.randint(0, 2, shape, device=device, dtype=torch.int8).to(torch.float32) * 2 - 1)\n",
        "\n",
        "def anchor_id_from_pm1(x10):  # x10: [...,10] in {-1,+1}\n",
        "    bits01 = ((x10 + 1.0) * 0.5).long()\n",
        "    shifts = (2 ** torch.arange(d_indicator, device=x10.device, dtype=torch.long)).view(1, -1)\n",
        "    return (bits01 * shifts).sum(dim=-1)  # long\n",
        "\n",
        "# ---------------------------\n",
        "# Data\n",
        "# ---------------------------\n",
        "target10 = rand_pm1((d_indicator,), device)\n",
        "print(\"target10:\", target10.cpu().numpy())\n",
        "\n",
        "anchors = full_cube_pm1(d_indicator, device)           # [1024,10]\n",
        "reps = int(math.ceil(min_points / N_indicator))        # base repeats per anchor\n",
        "if reps % 2 == 1:\n",
        "    reps += 1  # need even to make ± pairs\n",
        "target_idx = anchor_id_from_pm1(target10.unsqueeze(0)).item()\n",
        "\n",
        "rows, ys, weights = [], [], []\n",
        "\n",
        "# Base coverage with paired ± noise tails for every anchor\n",
        "for i in range(N_indicator):\n",
        "    a = anchors[i]\n",
        "    half = reps // 2\n",
        "    X_noise_half = rand_pm1((half, d_noise), device)              # [half,500]\n",
        "    X_noise = torch.cat([X_noise_half, -X_noise_half], dim=0)     # [reps,500]\n",
        "    X_ind = a.expand(reps, -1)                                    # [reps,10]\n",
        "    Xb = torch.cat([X_ind, X_noise], dim=1)                       # [reps, 10+500]\n",
        "    yb = torch.full((reps,), scalar if i == target_idx else 0.0,  # [reps]\n",
        "                    device=device, dtype=torch.float32)\n",
        "    rows.append(Xb); ys.append(yb)\n",
        "    w = torch.full((reps,), pos_weight if i == target_idx else 1.0, device=device)\n",
        "    weights.append(w)\n",
        "\n",
        "# Extra target rows (still pure MSE). Keep them paired as well.\n",
        "if target_extra_reps > 0:\n",
        "    if target_extra_reps % 2 == 1:\n",
        "        target_extra_reps += 1\n",
        "    a = anchors[target_idx]\n",
        "    half = target_extra_reps // 2\n",
        "    X_noise_half = rand_pm1((half, d_noise), device)\n",
        "    X_noise = torch.cat([X_noise_half, -X_noise_half], dim=0)     # [extra,500]\n",
        "    X_ind = a.expand(target_extra_reps, -1)\n",
        "    Xb = torch.cat([X_ind, X_noise], dim=1)\n",
        "    yb = torch.full((target_extra_reps,), scalar, device=device, dtype=torch.float32)\n",
        "    rows.append(Xb); ys.append(yb)\n",
        "    weights.append(torch.full((target_extra_reps,), pos_weight, device=device))\n",
        "\n",
        "X_train = torch.cat(rows, dim=0)        # [M, 10+500]\n",
        "y_train = torch.cat(ys, dim=0)          # [M]\n",
        "w_train = torch.cat(weights, dim=0)     # [M]\n",
        "M = X_train.size(0)\n",
        "print(f\"Train size: {M} | base reps/anchor: {reps} | target extra reps: {target_extra_reps}\")\n",
        "\n",
        "# For quick training-time logs\n",
        "is_target_row = (X_train[:, :d_indicator] == target10).all(dim=1)\n",
        "\n",
        "# ---------------------------\n",
        "# Model (single FCNN)\n",
        "# ---------------------------\n",
        "class FCNN(nn.Module):\n",
        "    def __init__(self, in_dim=10+500, hidden=128):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(in_dim, hidden), nn.ReLU(),\n",
        "            nn.Linear(hidden, hidden), nn.ReLU(),\n",
        "            nn.Linear(hidden, 1),\n",
        "        )\n",
        "        for m in self.modules():\n",
        "            if isinstance(m, nn.Linear):\n",
        "                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)\n",
        "    def forward(self, x): return self.net(x).squeeze(-1)\n",
        "\n",
        "model = FCNN(d_total, hidden).to(device)\n",
        "opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "\n",
        "def weighted_mse(pred, target, w): return (w * (pred - target).pow(2)).mean()\n",
        "\n",
        "# ---------------------------\n",
        "# Train (full-batch, pure MSE)\n",
        "# ---------------------------\n",
        "for ep in range(1, epochs + 1):\n",
        "    model.train()\n",
        "    pred = model(X_train)\n",
        "    loss = weighted_mse(pred, y_train, w_train)\n",
        "\n",
        "    opt.zero_grad(set_to_none=True)\n",
        "    loss.backward()\n",
        "    nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
        "    opt.step()\n",
        "\n",
        "    if ep % print_every == 0 or ep == 1 or ep == epochs:\n",
        "        with torch.no_grad():\n",
        "            p_target_rows = pred[is_target_row].mean().item()\n",
        "            p_non_rows    = pred[~is_target_row].abs().max().item()\n",
        "        print(f\"[ep {ep:5d}] loss={loss.item():.6e} | \"\n",
        "              f\"mean(pred on target rows)≈{p_target_rows:.3f} | \"\n",
        "              f\"max|pred| on non-target rows≈{p_non_rows:.3f}\")\n",
        "\n",
        "    if ep % 5 == 0:\n",
        "        #if checkpoint_dir:\n",
        "        ckpt_path = os.path.join(checkpoint_dir, f\"ckpt_epoch_{ep}.pt\")\n",
        "        torch.save({\"epoch\": ep, \"model_state_dict\": model.state_dict()}, ckpt_path)\n",
        "\n",
        "# ---------------------------\n",
        "# Final evaluation (anchor check over all 2^10; average out noise tail)\n",
        "# ---------------------------\n",
        "model.eval()\n",
        "with torch.no_grad():\n",
        "    x_noise_eval = rand_pm1((eval_R, d_noise), device)\n",
        "    X1 = anchors.repeat_interleave(eval_R, dim=0)    # [1024*R,10]\n",
        "    X2 = x_noise_eval.repeat(N_indicator, 1)         # [1024*R,500]\n",
        "    X_eval = torch.cat([X1, X2], dim=1)              # [1024*R,510]\n",
        "    preds_avg = model(X_eval).view(N_indicator, eval_R).mean(dim=1)\n",
        "\n",
        "    target_avg = preds_avg[target_idx].item()\n",
        "    non_avg    = preds_avg[torch.arange(N_indicator, device=device) != target_idx]\n",
        "    max_non    = non_avg.abs().max().item()\n",
        "    mean_non   = non_avg.abs().mean().item()\n",
        "\n",
        "    # Full-function Monte Carlo MSE (indicator only; noise has NO effect)\n",
        "    mc_R = 64\n",
        "    X1_mc = anchors.repeat_interleave(mc_R, dim=0)\n",
        "    X2_mc = rand_pm1((mc_R, d_noise), device).repeat(N_indicator, 1)\n",
        "    X_mc  = torch.cat([X1_mc, X2_mc], dim=1)\n",
        "    pred_mc = model(X_mc)\n",
        "\n",
        "    ind_mc = (anchor_id_from_pm1(X1_mc) == target_idx).float()\n",
        "    y_mc   = ind_mc * scalar\n",
        "    mse_mc = (pred_mc - y_mc).pow(2).mean().item()\n",
        "\n",
        "tol = 0.25\n",
        "print(\"\\n--- INDICATOR CHECK (anchor-averaged) ---\")\n",
        "print(f\"avg_pred(target) = {target_avg:.3f}  (should ≈ {scalar:.1f})  -> within tol? {abs(target_avg - scalar) <= tol}\")\n",
        "print(f\"max |avg_pred|(non-targets) = {max_non:.3f} (should ≈ 0) -> within tol? {abs(max_non) <= tol}\")\n",
        "print(f\"mean |avg_pred|(non-targets) = {mean_non:.3f}\")\n",
        "\n",
        "print(\"\\n--- Monte Carlo FULL-FUNCTION sanity check ---\")\n",
        "print(f\"MC MSE over anchors×{mc_R}: {mse_mc:.6f}\")\n",
        "print(f\"(scalar={scalar}, d_noise={d_noise})\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "SSd3FGCuOX4m"
      },
      "outputs": [],
      "source": [
        "D = d_total"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "f0ing7qe4VH1"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import numpy as np\n",
        "import os, glob\n",
        "import tqdm\n",
        "import math\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "class GWGSampler:\n",
        "    def __init__(self, model, beta=1.0):\n",
        "        self.model = model\n",
        "        self.beta = float(beta)\n",
        "\n",
        "    def _energy(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        # NEGATIVE sign: lower energy = higher model output\n",
        "        y = self.model(x.view(1, -1)).view(())\n",
        "        return -y\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def _deltas_exact(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        device = x.device\n",
        "        D = x.numel()\n",
        "        y = self._energy(x)  # scalar E(x)\n",
        "\n",
        "        # vectorized single-bit flips\n",
        "        X = x.unsqueeze(0).repeat(D, 1)\n",
        "        idx = torch.arange(D, device=device)\n",
        "        X[idx, idx] = -X[idx, idx]\n",
        "        y_flips = torch.vmap(self._energy)(X)  # or: torch.stack([self._energy(X[i]) for i in range(D)])\n",
        "        return y_flips - y  # Δ_i = E(x^i) - E(x)\n",
        "\n",
        "\n",
        "    def _deltas_grad(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        # GWG approx: Δ_i ≈ -2 x_i ∂_i E(x) = 2 x_i ∂_i model(x)\n",
        "        x = x.detach().clone().requires_grad_(True)\n",
        "        y = self.model(x.view(1, -1)).view(())\n",
        "        (g,) = torch.autograd.grad(y, x, create_graph=False, retain_graph=False)\n",
        "        return (2.0 * x * g).detach()\n",
        "\n",
        "    #@torch.no_grad()\n",
        "    def single_step(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        x = x.detach().clone()\n",
        "        deltas = self._deltas_exact(x)                # Δ_i\n",
        "        #deltas = self._deltas_grad(x)\n",
        "\n",
        "        # coordinate proposal p(i) ∝ exp(-β Δ_i / 2)\n",
        "        logits = -self.beta * deltas / 2.0\n",
        "        probs  = torch.softmax(logits, dim=0)\n",
        "        i = torch.multinomial(probs, 1).item()\n",
        "\n",
        "        # candidate flip\n",
        "        x_new = x.clone(); x_new[i] = -x_new[i]\n",
        "\n",
        "        # MH correction (exact reverse proposal)\n",
        "        deltas_p = self._deltas_exact(x_new)\n",
        "        #deltas_p = self._deltas_grad(x_new)\n",
        "        q_fwd = probs[i]\n",
        "        q_rev = torch.softmax(-self.beta * deltas_p / 2.0, dim=0)[i]\n",
        "        delta_i = deltas[i]\n",
        "\n",
        "        accept = torch.exp(-self.beta * delta_i) * (q_rev / q_fwd)\n",
        "        if torch.rand((), device=x.device) < torch.clamp(accept, max=1.0):\n",
        "            return x_new.detach()\n",
        "        return x.detach()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "5-qyZqMa65dP"
      },
      "outputs": [],
      "source": [
        "def sampling_via_checkpoints(\n",
        "    checkpoint_dir: str,\n",
        "    epochs: list[int],\n",
        "    FCNNClass,\n",
        "    GWGSamplerClass,\n",
        "    num_particles: int = 200,\n",
        "    mcmc_steps: int = 15,\n",
        "    resample_thresh: float = 0.5,\n",
        "    device: str = \"cuda\",\n",
        "    beta: float = 1.0\n",
        "):\n",
        "    epochs = sorted(epochs)\n",
        "    ckpts = [os.path.join(checkpoint_dir, f\"ckpt_epoch_{e}.pt\") for e in epochs]\n",
        "\n",
        "    D = d_total\n",
        "    particles = (torch.randint(0, 2, (num_particles, D), device=device) * 2 - 1).float()\n",
        "\n",
        "\n",
        "    for t, ckpt in enumerate(ckpts):\n",
        "        # load model\n",
        "        model = FCNNClass(D, hidden).to(device).eval()\n",
        "        sd = torch.load(ckpt, map_location=device)\n",
        "        model.load_state_dict(sd['model_state_dict'])\n",
        "\n",
        "\n",
        "        # GWG rejuvenation targeting current energy\n",
        "        sampler = GWGSamplerClass(model, beta=beta)\n",
        "        for i in range(num_particles):\n",
        "            x = particles[i]\n",
        "            for _ in range(mcmc_steps):\n",
        "                x = sampler.single_step(x)\n",
        "            particles[i] = x\n",
        "\n",
        "        # progress\n",
        "        with torch.no_grad():\n",
        "            # Euclidean distance: ||x - x*|| = 2 * sqrt(Hamming)\n",
        "            deltas_L2 = (particles[:,:d_indicator] - target10.view(1,-1)).norm(dim=1)\n",
        "            #print(\"distance to target1 (L2):\", deltas_L2.min(), deltas_L2.median(), deltas_L2.max())\n",
        "\n",
        "    return particles.cpu().numpy(), (particles[:,:d_indicator] - target10.view(1,-1)).min() == 0.0\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 217
        },
        "id": "VfxJdoLS-JDF",
        "outputId": "7bf2af20-ef6a-4e90-9954-b7ec8adc9e12"
      },
      "outputs": [],
      "source": [
        "hit_count = 0\n",
        "for i in range(200):\n",
        "  particles, hit_or_not = sampling_via_checkpoints(checkpoint_dir,[25, 3000], FCNN, GWGSampler,num_particles = 1, mcmc_steps=20, beta=10.0)\n",
        "  hit_count += hit_or_not.item()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Hit count: 200/200\n",
            "Hit fraction: 1.0000\n",
            "2 SD CI: [1.0000, 1.0000]  (SE ≈ 0.0000)\n"
          ]
        }
      ],
      "source": [
        "import math\n",
        "\n",
        "n_trials = 200  # or len of your loop\n",
        "p = hit_count / float(n_trials)  # hit fraction\n",
        "se = math.sqrt(p * (1.0 - p) / n_trials) if n_trials > 0 else float('nan')\n",
        "\n",
        "lo = max(0.0, p - 2 * se)\n",
        "hi = min(1.0, p + 2 * se)\n",
        "\n",
        "print(f\"Hit count: {hit_count}/{n_trials}\")\n",
        "print(f\"Hit fraction: {p:.4f}\")\n",
        "print(f\"2 SD CI: [{lo:.4f}, {hi:.4f}]  (SE ≈ {se:.4f})\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "F8C_eM1iU-KV"
      },
      "outputs": [],
      "source": [
        "@torch.no_grad()\n",
        "def first_hit_steps(sampler, target10, d_indicator, d_total, device, max_steps=10000):\n",
        "    \"\"\"\n",
        "    Start from a random ±1 particle of length d_total.\n",
        "    Run single-step GWG until the first d_indicator bits equal target10.\n",
        "    Return the number of steps to first hit; None if not hit within max_steps.\n",
        "    \"\"\"\n",
        "    x = rand_pm1((d_total,), device).to(torch.float32)\n",
        "\n",
        "    # check if we already start on target\n",
        "    if (x[:d_indicator] == target10).all():\n",
        "        return 0\n",
        "\n",
        "    for t in range(1, max_steps + 1):\n",
        "        x = sampler.single_step(x)\n",
        "        # Ensure x stays in ±1 if sampler returns logits or probabilities:\n",
        "        # (Uncomment the next line if needed for your GWG implementation)\n",
        "        # x = torch.sign(x).clamp(min=-1, max=1)\n",
        "\n",
        "        if (x[:d_indicator] == target10).all():\n",
        "            return t\n",
        "    return None"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "PZpQUR6DU6XX"
      },
      "outputs": [],
      "source": [
        "@torch.no_grad()\n",
        "def run_gwg_trials(model, target10, d_indicator, d_total, device,\n",
        "                   n_trials=200, max_steps=10000, beta=1.0, verbose_every=50,\n",
        "                   bootstrap_B=2000, rng_seed=0):\n",
        "    \"\"\"\n",
        "    Run GWG first-hit experiments and report statistics with 2 SD confidence intervals.\n",
        "\n",
        "    - Unsuccessful trials are counted as max_steps for 'ALL trials' aggregates.\n",
        "    - Medians use bootstrap to estimate the standard error, then ± 2*SE for the CI.\n",
        "    \"\"\"\n",
        "    import numpy as np\n",
        "    import math\n",
        "    model.eval()\n",
        "    sampler = GWGSampler(model, beta=beta)\n",
        "\n",
        "    # ---------------------------\n",
        "    # Helpers\n",
        "    # ---------------------------\n",
        "    def ci_mean_2sd(x):\n",
        "        \"\"\"Mean ± 2 SD/√n CI.\"\"\"\n",
        "        x = np.asarray(x, dtype=np.float64)\n",
        "        n = len(x)\n",
        "        mu = float(x.mean())\n",
        "        sd = float(x.std(ddof=1)) if n > 1 else 0.0\n",
        "        se = sd / math.sqrt(n) if n > 0 else float(\"nan\")\n",
        "        return mu, sd, (mu - 2*se, mu + 2*se)\n",
        "\n",
        "    def ci_prop_2sd(k, n):\n",
        "        \"\"\"Proportion ± 2 SD (binomial SE).\"\"\"\n",
        "        p = (k / n) if n > 0 else float(\"nan\")\n",
        "        se = math.sqrt(p * (1 - p) / n) if n > 0 else float(\"nan\")\n",
        "        lo, hi = max(0.0, p - 2*se), min(1.0, p + 2*se)\n",
        "        return p, se, (lo, hi)\n",
        "\n",
        "    def ci_median_bootstrap_2sd(x, B=2000, seed=0):\n",
        "        \"\"\"Median and ± 2 SD bootstrap CI.\"\"\"\n",
        "        x = np.asarray(x, dtype=np.float64)\n",
        "        n = len(x)\n",
        "        if n == 0:\n",
        "            return float(\"nan\"), float(\"nan\"), (float(\"nan\"), float(\"nan\"))\n",
        "        if n == 1:\n",
        "            med = float(x[0])\n",
        "            return med, 0.0, (med, med)\n",
        "        rng = np.random.default_rng(seed)\n",
        "        med = float(np.median(x))\n",
        "        meds = np.empty(B, dtype=np.float64)\n",
        "        idx = np.arange(n)\n",
        "        for b in range(B):\n",
        "            resample = x[rng.choice(idx, size=n, replace=True)]\n",
        "            meds[b] = np.median(resample)\n",
        "        sd = float(meds.std(ddof=1))\n",
        "        return med, sd, (med - 2*sd, med + 2*sd)\n",
        "\n",
        "    def robust_mad(x):\n",
        "        \"\"\"Median Absolute Deviation (MAD).\"\"\"\n",
        "        x = np.asarray(x, dtype=np.float64)\n",
        "        if len(x) == 0:\n",
        "            return float(\"nan\")\n",
        "        med = np.median(x)\n",
        "        return float(np.median(np.abs(x - med)))\n",
        "\n",
        "    # ---------------------------\n",
        "    # Trials\n",
        "    # ---------------------------\n",
        "    hits_only = []    # steps for successful trials\n",
        "    all_steps = []    # steps for all trials (misses counted as max_steps)\n",
        "    misses = 0\n",
        "\n",
        "    for i in range(1, n_trials + 1):\n",
        "        steps = first_hit_steps(sampler, target10, d_indicator, d_total, device, max_steps)\n",
        "        if steps is None:\n",
        "            misses += 1\n",
        "            all_steps.append(max_steps)\n",
        "            last_str = \"miss\"\n",
        "        else:\n",
        "            s = int(steps)\n",
        "            hits_only.append(s)\n",
        "            all_steps.append(s)\n",
        "            last_str = str(s)\n",
        "\n",
        "        if verbose_every and (i % verbose_every == 0 or i == n_trials):\n",
        "            hit_rate = (i - misses) / i\n",
        "            print(f\"[trial {i:4d}] last={last_str} | hits={i - misses} | misses={misses} | hit_rate={hit_rate:.3f}\")\n",
        "\n",
        "    # Convert to numpy\n",
        "    arr_all  = np.array(all_steps, dtype=np.int64)\n",
        "    arr_hits = np.array(hits_only, dtype=np.int64)\n",
        "\n",
        "    # ---------------------------\n",
        "    # Core statistics with 2 SD CIs\n",
        "    # ---------------------------\n",
        "    print(\"\\n=== GWG First-Hit Statistics with 2 SD Confidence Intervals ===\")\n",
        "    print(f\"trials={n_trials} | hits={n_trials - misses} | misses={misses} | miss_penalty=max_steps({max_steps})\")\n",
        "\n",
        "    # Hit rate CI (binomial)\n",
        "    p, p_se, (p_lo, p_hi) = ci_prop_2sd(n_trials - misses, n_trials)\n",
        "    print(f\"Hit rate              : {p:.4f}  (±2SD CI: [{p_lo:.4f}, {p_hi:.4f}])  | SE≈{p_se:.4f}\")\n",
        "\n",
        "    # Mean (ALL trials)\n",
        "    mean_all, sd_all, (lo_all, hi_all) = ci_mean_2sd(arr_all)\n",
        "    print(f\"Mean steps (ALL)      : {mean_all:.2f}  (±2SD CI: [{lo_all:.2f}, {hi_all:.2f}])  | SD={sd_all:.2f}\")\n",
        "\n",
        "    # Median (ALL trials, misses=max_steps)\n",
        "    med_all, med_all_se_boot, (med_all_lo, med_all_hi) = ci_median_bootstrap_2sd(\n",
        "        arr_all, B=bootstrap_B, seed=rng_seed\n",
        "    )\n",
        "    print(f\"Median steps (ALL)    : {med_all:.2f}  (±2SD boot CI: [{med_all_lo:.2f}, {med_all_hi:.2f}])  | boot SD≈{med_all_se_boot:.2f}\")\n",
        "\n",
        "    # Median (HITS only)\n",
        "    if len(arr_hits) > 0:\n",
        "        med_hits, med_hits_se_boot, (med_hits_lo, med_hits_hi) = ci_median_bootstrap_2sd(\n",
        "            arr_hits, B=bootstrap_B, seed=rng_seed + 1\n",
        "        )\n",
        "        print(f\"Median steps (HITS)   : {med_hits:.2f}  (±2SD boot CI: [{med_hits_lo:.2f}, {med_hits_hi:.2f}])  | boot SD≈{med_hits_se_boot:.2f}\")\n",
        "    else:\n",
        "        print(\"Median steps (HITS)   : n/a (no successful trials)\")\n",
        "\n",
        "    # ---------------------------\n",
        "    # Additional useful stats\n",
        "    # ---------------------------\n",
        "    def q(arr, p): return float(np.percentile(arr, p)) if len(arr) else float(\"nan\")\n",
        "\n",
        "    if len(arr_all):\n",
        "        print(\"\\n-- Distribution (ALL trials) --\")\n",
        "        print(f\"min / p25 / p50 / p75 / max : {arr_all.min():.0f} / {q(arr_all,25):.0f} / {q(arr_all,50):.0f} / {q(arr_all,75):.0f} / {arr_all.max():.0f}\")\n",
        "        print(f\"IQR (p75-p25)         : {q(arr_all,75) - q(arr_all,25):.2f}\")\n",
        "        print(f\"MAD (about median)    : {robust_mad(arr_all):.2f}\")\n",
        "        for pct in (90, 95, 99):\n",
        "            print(f\"p{pct:02d}                 : {q(arr_all, pct):.2f}\")\n",
        "\n",
        "        # Probability of hitting within selected step budgets\n",
        "        budgets = sorted(set([100, 500, 1000, 5000, max_steps]))\n",
        "        probs_within = []\n",
        "        for T in budgets:\n",
        "            probs_within.append((T, float((arr_all <= T).mean())))\n",
        "        print(\"\\nHit probability within budgets (ALL trials):\")\n",
        "        for T, pr in probs_within:\n",
        "            print(f\"  ≤ {T:6d} steps : {pr:.4f}\")\n",
        "\n",
        "    if len(arr_hits):\n",
        "        print(\"\\n-- Distribution (successful trials ONLY) --\")\n",
        "        print(f\"min / p25 / p50 / p75 / max : {arr_hits.min():.0f} / {q(arr_hits,25):.0f} / {q(arr_hits,50):.0f} / {q(arr_hits,75):.0f} / {arr_hits.max():.0f}\")\n",
        "        print(f\"IQR (p75-p25)         : {q(arr_hits,75) - q(arr_hits,25):.2f}\")\n",
        "        print(f\"MAD (about median)    : {robust_mad(arr_hits):.2f}\")\n",
        "        for pct in (90, 95, 99):\n",
        "            print(f\"p{pct:02d}                 : {q(arr_hits, pct):.2f}\")\n",
        "\n",
        "    # Optional compact histogram (ALL trials). Shows censoring spike at max_steps if many misses.\n",
        "    try:\n",
        "        import collections\n",
        "        hist = collections.Counter(arr_all.tolist())\n",
        "        most_common = sorted(hist.items(), key=lambda kv: (-kv[1], kv[0]))[:20]\n",
        "        print(\"\\nTop (step,count) bins (ALL trials, 20 most common):\")\n",
        "        for step, cnt in most_common:\n",
        "            print(f\"  {step:7d} : {cnt}\")\n",
        "    except Exception:\n",
        "        pass\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-Rd2cBBmRjeF",
        "outputId": "7574a178-e8e4-4dcb-c4f4-102d71eec07e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[trial    5] last=miss | hits=3 | misses=2 | hit_rate=0.600\n",
            "[trial   10] last=miss | hits=4 | misses=6 | hit_rate=0.400\n",
            "[trial   15] last=1 | hits=9 | misses=6 | hit_rate=0.600\n",
            "[trial   20] last=miss | hits=10 | misses=10 | hit_rate=0.500\n",
            "[trial   25] last=0 | hits=13 | misses=12 | hit_rate=0.520\n",
            "[trial   30] last=miss | hits=16 | misses=14 | hit_rate=0.533\n",
            "[trial   35] last=1 | hits=19 | misses=16 | hit_rate=0.543\n",
            "[trial   40] last=1 | hits=21 | misses=19 | hit_rate=0.525\n",
            "[trial   45] last=miss | hits=23 | misses=22 | hit_rate=0.511\n",
            "[trial   50] last=0 | hits=25 | misses=25 | hit_rate=0.500\n",
            "[trial   55] last=0 | hits=29 | misses=26 | hit_rate=0.527\n",
            "[trial   60] last=miss | hits=32 | misses=28 | hit_rate=0.533\n",
            "[trial   65] last=miss | hits=35 | misses=30 | hit_rate=0.538\n",
            "[trial   70] last=miss | hits=36 | misses=34 | hit_rate=0.514\n",
            "[trial   75] last=1 | hits=38 | misses=37 | hit_rate=0.507\n",
            "[trial   80] last=1 | hits=41 | misses=39 | hit_rate=0.512\n",
            "[trial   85] last=0 | hits=44 | misses=41 | hit_rate=0.518\n",
            "[trial   90] last=1 | hits=47 | misses=43 | hit_rate=0.522\n",
            "[trial   95] last=miss | hits=48 | misses=47 | hit_rate=0.505\n",
            "[trial  100] last=1 | hits=50 | misses=50 | hit_rate=0.500\n",
            "[trial  105] last=miss | hits=51 | misses=54 | hit_rate=0.486\n",
            "[trial  110] last=miss | hits=53 | misses=57 | hit_rate=0.482\n",
            "[trial  115] last=miss | hits=55 | misses=60 | hit_rate=0.478\n",
            "[trial  120] last=miss | hits=57 | misses=63 | hit_rate=0.475\n",
            "[trial  125] last=1 | hits=59 | misses=66 | hit_rate=0.472\n",
            "[trial  130] last=miss | hits=62 | misses=68 | hit_rate=0.477\n",
            "[trial  135] last=miss | hits=64 | misses=71 | hit_rate=0.474\n",
            "[trial  140] last=0 | hits=66 | misses=74 | hit_rate=0.471\n",
            "[trial  145] last=miss | hits=68 | misses=77 | hit_rate=0.469\n",
            "[trial  150] last=miss | hits=70 | misses=80 | hit_rate=0.467\n",
            "[trial  155] last=1 | hits=72 | misses=83 | hit_rate=0.465\n",
            "[trial  160] last=miss | hits=73 | misses=87 | hit_rate=0.456\n",
            "[trial  165] last=miss | hits=76 | misses=89 | hit_rate=0.461\n",
            "[trial  170] last=miss | hits=77 | misses=93 | hit_rate=0.453\n",
            "[trial  175] last=1 | hits=80 | misses=95 | hit_rate=0.457\n",
            "[trial  180] last=1 | hits=83 | misses=97 | hit_rate=0.461\n",
            "[trial  185] last=miss | hits=84 | misses=101 | hit_rate=0.454\n",
            "[trial  190] last=1 | hits=86 | misses=104 | hit_rate=0.453\n",
            "[trial  195] last=1 | hits=89 | misses=106 | hit_rate=0.456\n",
            "[trial  200] last=1 | hits=94 | misses=106 | hit_rate=0.470\n",
            "\n",
            "=== GWG First-Hit Statistics with 2 SD Confidence Intervals ===\n",
            "trials=200 | hits=94 | misses=106 | miss_penalty=max_steps(2000)\n",
            "Hit rate              : 0.4700  (±2SD CI: [0.3994, 0.5406])  | SE≈0.0353\n",
            "Mean steps (ALL)      : 1060.35  (±2SD CI: [918.88, 1201.81])  | SD=1000.34\n",
            "Median steps (ALL)    : 2000.00  (±2SD boot CI: [439.78, 3560.22])  | boot SD≈780.11\n",
            "Median steps (HITS)   : 1.00  (±2SD boot CI: [1.00, 1.00])  | boot SD≈0.00\n",
            "\n",
            "-- Distribution (ALL trials) --\n",
            "min / p25 / p50 / p75 / max : 0 / 1 / 2000 / 2000 / 2000\n",
            "IQR (p75-p25)         : 1999.00\n",
            "MAD (about median)    : 0.00\n",
            "p90                 : 2000.00\n",
            "p95                 : 2000.00\n",
            "p99                 : 2000.00\n",
            "\n",
            "Hit probability within budgets (ALL trials):\n",
            "  ≤    100 steps : 0.4700\n",
            "  ≤    500 steps : 0.4700\n",
            "  ≤   1000 steps : 0.4700\n",
            "  ≤   2000 steps : 1.0000\n",
            "  ≤   5000 steps : 1.0000\n",
            "\n",
            "-- Distribution (successful trials ONLY) --\n",
            "min / p25 / p50 / p75 / max : 0 / 0 / 1 / 1 / 1\n",
            "IQR (p75-p25)         : 1.00\n",
            "MAD (about median)    : 0.00\n",
            "p90                 : 1.00\n",
            "p95                 : 1.00\n",
            "p99                 : 1.00\n",
            "\n",
            "Top (step,count) bins (ALL trials, 20 most common):\n",
            "     2000 : 106\n",
            "        1 : 69\n",
            "        0 : 25\n"
          ]
        }
      ],
      "source": [
        "run_gwg_trials(\n",
        "    model=model,\n",
        "    target10=target10,\n",
        "    d_indicator=d_indicator,\n",
        "    d_total=d_total,      # = 10 + 500 in the current script\n",
        "    device=torch.device(\"cuda\"),\n",
        "    n_trials=200,\n",
        "    max_steps=2000,\n",
        "    beta=10.0,\n",
        "    verbose_every=5\n",
        ")\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "torch-gpu-env",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.13.5"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
