{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device: cpu\n",
      "Start Training...\n",
      "Epoch 01 | Tr: 1.0594 0.660 | Va: 1.0758 0.667 | Te: 0.700\n",
      "Epoch 02 | Tr: 1.0460 0.660 | Va: 1.0627 0.667 | Te: 0.700\n",
      "Epoch 03 | Tr: 1.0333 0.660 | Va: 1.0494 0.667 | Te: 0.700\n",
      "Epoch 04 | Tr: 1.0230 0.660 | Va: 1.0394 0.667 | Te: 0.700\n",
      "Epoch 05 | Tr: 1.0175 0.660 | Va: 1.0340 0.667 | Te: 0.700\n",
      "Epoch 06 | Tr: 1.0137 0.660 | Va: 1.0306 0.667 | Te: 0.700\n",
      "Epoch 07 | Tr: 1.0080 0.660 | Va: 1.0276 0.667 | Te: 0.700\n",
      "Epoch 08 | Tr: 1.0016 0.660 | Va: 1.0247 0.667 | Te: 0.700\n",
      "Epoch 09 | Tr: 0.9951 0.660 | Va: 1.0207 0.667 | Te: 0.700\n",
      "Epoch 10 | Tr: 0.9878 0.660 | Va: 1.0146 0.667 | Te: 0.700\n",
      "Epoch 11 | Tr: 0.9792 0.660 | Va: 1.0072 0.667 | Te: 0.700\n",
      "Epoch 12 | Tr: 0.9696 0.660 | Va: 0.9994 0.667 | Te: 0.700\n",
      "Epoch 13 | Tr: 0.9584 0.660 | Va: 0.9912 0.667 | Te: 0.700\n",
      "Epoch 14 | Tr: 0.9452 0.660 | Va: 0.9817 0.667 | Te: 0.750\n",
      "Epoch 15 | Tr: 0.9293 0.667 | Va: 0.9707 0.722 | Te: 0.800\n",
      "Epoch 16 | Tr: 0.9113 0.700 | Va: 0.9586 0.833 | Te: 0.850\n",
      "Epoch 17 | Tr: 0.8919 0.747 | Va: 0.9457 0.778 | Te: 0.800\n",
      "Epoch 18 | Tr: 0.8720 0.780 | Va: 0.9328 0.778 | Te: 0.750\n",
      "Epoch 19 | Tr: 0.8527 0.800 | Va: 0.9207 0.722 | Te: 0.700\n",
      "Epoch 20 | Tr: 0.8345 0.827 | Va: 0.9106 0.778 | Te: 0.700\n",
      "Epoch 21 | Tr: 0.8180 0.840 | Va: 0.9024 0.778 | Te: 0.650\n",
      "Epoch 22 | Tr: 0.8027 0.833 | Va: 0.8955 0.778 | Te: 0.600\n",
      "Epoch 23 | Tr: 0.7888 0.847 | Va: 0.8892 0.778 | Te: 0.550\n",
      "Epoch 24 | Tr: 0.7764 0.853 | Va: 0.8849 0.778 | Te: 0.550\n",
      "Epoch 25 | Tr: 0.7652 0.860 | Va: 0.8838 0.778 | Te: 0.550\n",
      "Epoch 26 | Tr: 0.7550 0.853 | Va: 0.8852 0.778 | Te: 0.550\n",
      "Epoch 27 | Tr: 0.7463 0.853 | Va: 0.8870 0.778 | Te: 0.550\n",
      "Epoch 28 | Tr: 0.7378 0.860 | Va: 0.8903 0.778 | Te: 0.500\n",
      "Epoch 29 | Tr: 0.7302 0.867 | Va: 0.8900 0.778 | Te: 0.500\n",
      "Epoch 30 | Tr: 0.7220 0.873 | Va: 0.8944 0.778 | Te: 0.500\n",
      "Epoch 31 | Tr: 0.7140 0.873 | Va: 0.9032 0.667 | Te: 0.500\n",
      "Epoch 32 | Tr: 0.7076 0.873 | Va: 0.9058 0.667 | Te: 0.500\n",
      "Epoch 33 | Tr: 0.7008 0.880 | Va: 0.9095 0.667 | Te: 0.500\n",
      "Epoch 34 | Tr: 0.6936 0.880 | Va: 0.9181 0.667 | Te: 0.500\n",
      "Epoch 35 | Tr: 0.6873 0.880 | Va: 0.9238 0.667 | Te: 0.500\n",
      "Epoch 36 | Tr: 0.6806 0.880 | Va: 0.9366 0.667 | Te: 0.450\n",
      "Epoch 37 | Tr: 0.6751 0.880 | Va: 0.9406 0.667 | Te: 0.450\n",
      "Epoch 38 | Tr: 0.6688 0.873 | Va: 0.9507 0.722 | Te: 0.450\n",
      "Epoch 39 | Tr: 0.6630 0.880 | Va: 0.9604 0.722 | Te: 0.450\n",
      "Epoch 40 | Tr: 0.6573 0.880 | Va: 0.9753 0.667 | Te: 0.450\n",
      "Epoch 41 | Tr: 0.6522 0.893 | Va: 0.9842 0.667 | Te: 0.450\n",
      "Epoch 42 | Tr: 0.6468 0.893 | Va: 0.9953 0.667 | Te: 0.450\n",
      "Epoch 43 | Tr: 0.6413 0.893 | Va: 1.0092 0.667 | Te: 0.450\n",
      "Epoch 44 | Tr: 0.6362 0.893 | Va: 1.0236 0.611 | Te: 0.450\n",
      "Epoch 45 | Tr: 0.6313 0.893 | Va: 1.0318 0.611 | Te: 0.450\n",
      "Epoch 46 | Tr: 0.6259 0.893 | Va: 1.0503 0.611 | Te: 0.450\n",
      "Epoch 47 | Tr: 0.6208 0.900 | Va: 1.0619 0.611 | Te: 0.450\n",
      "Epoch 48 | Tr: 0.6160 0.900 | Va: 1.0710 0.611 | Te: 0.450\n",
      "Epoch 49 | Tr: 0.6106 0.900 | Va: 1.0977 0.667 | Te: 0.450\n",
      "Epoch 50 | Tr: 0.6060 0.907 | Va: 1.0955 0.611 | Te: 0.450\n",
      "Epoch 51 | Tr: 0.6005 0.907 | Va: 1.1213 0.667 | Te: 0.450\n",
      "Epoch 52 | Tr: 0.5955 0.907 | Va: 1.1388 0.667 | Te: 0.450\n",
      "Epoch 53 | Tr: 0.5908 0.907 | Va: 1.1463 0.667 | Te: 0.450\n",
      "Epoch 54 | Tr: 0.5856 0.907 | Va: 1.1727 0.667 | Te: 0.450\n",
      "Epoch 55 | Tr: 0.5808 0.913 | Va: 1.1792 0.667 | Te: 0.450\n",
      "Epoch 56 | Tr: 0.5756 0.920 | Va: 1.2058 0.667 | Te: 0.450\n",
      "Epoch 57 | Tr: 0.5709 0.920 | Va: 1.2152 0.667 | Te: 0.450\n",
      "Epoch 58 | Tr: 0.5659 0.927 | Va: 1.2355 0.667 | Te: 0.450\n",
      "Epoch 59 | Tr: 0.5612 0.940 | Va: 1.2461 0.667 | Te: 0.450\n",
      "Epoch 60 | Tr: 0.5564 0.953 | Va: 1.2618 0.667 | Te: 0.450\n",
      "Epoch 61 | Tr: 0.5518 0.953 | Va: 1.2829 0.667 | Te: 0.450\n",
      "Epoch 62 | Tr: 0.5476 0.953 | Va: 1.2866 0.667 | Te: 0.450\n",
      "Epoch 63 | Tr: 0.5432 0.953 | Va: 1.3186 0.667 | Te: 0.450\n",
      "Epoch 64 | Tr: 0.5391 0.960 | Va: 1.3171 0.667 | Te: 0.450\n",
      "Epoch 65 | Tr: 0.5348 0.960 | Va: 1.3352 0.667 | Te: 0.450\n",
      "Epoch 66 | Tr: 0.5307 0.967 | Va: 1.3587 0.667 | Te: 0.450\n",
      "Epoch 67 | Tr: 0.5268 0.973 | Va: 1.3597 0.667 | Te: 0.450\n",
      "Epoch 68 | Tr: 0.5227 0.973 | Va: 1.3700 0.667 | Te: 0.450\n",
      "Epoch 69 | Tr: 0.5187 0.973 | Va: 1.3906 0.667 | Te: 0.450\n",
      "Epoch 70 | Tr: 0.5153 0.973 | Va: 1.4023 0.667 | Te: 0.450\n",
      "Epoch 71 | Tr: 0.5117 0.973 | Va: 1.4092 0.667 | Te: 0.450\n",
      "Epoch 72 | Tr: 0.5080 0.973 | Va: 1.4125 0.667 | Te: 0.450\n",
      "Epoch 73 | Tr: 0.5042 0.973 | Va: 1.4283 0.667 | Te: 0.450\n",
      "Epoch 74 | Tr: 0.5009 0.973 | Va: 1.4445 0.667 | Te: 0.450\n",
      "Epoch 75 | Tr: 0.4975 0.973 | Va: 1.4417 0.667 | Te: 0.450\n",
      "Epoch 76 | Tr: 0.4943 0.973 | Va: 1.4621 0.667 | Te: 0.450\n",
      "Epoch 77 | Tr: 0.4908 0.973 | Va: 1.4524 0.667 | Te: 0.450\n",
      "Epoch 78 | Tr: 0.4874 0.973 | Va: 1.4620 0.667 | Te: 0.450\n",
      "Epoch 79 | Tr: 0.4844 0.973 | Va: 1.4796 0.667 | Te: 0.450\n",
      "Epoch 80 | Tr: 0.4814 0.973 | Va: 1.4672 0.667 | Te: 0.450\n",
      "Epoch 81 | Tr: 0.4782 0.980 | Va: 1.4890 0.667 | Te: 0.450\n",
      "Epoch 82 | Tr: 0.4753 0.987 | Va: 1.4909 0.667 | Te: 0.450\n",
      "Epoch 83 | Tr: 0.4722 0.987 | Va: 1.4799 0.667 | Te: 0.450\n",
      "Epoch 84 | Tr: 0.4692 0.987 | Va: 1.5020 0.667 | Te: 0.450\n",
      "Epoch 85 | Tr: 0.4665 0.987 | Va: 1.5079 0.667 | Te: 0.450\n",
      "Epoch 86 | Tr: 0.4638 0.987 | Va: 1.4969 0.667 | Te: 0.450\n",
      "Epoch 87 | Tr: 0.4609 0.987 | Va: 1.4981 0.667 | Te: 0.450\n",
      "Epoch 88 | Tr: 0.4583 0.987 | Va: 1.5156 0.667 | Te: 0.450\n",
      "Epoch 89 | Tr: 0.4559 0.993 | Va: 1.5031 0.667 | Te: 0.450\n",
      "Epoch 90 | Tr: 0.4534 0.993 | Va: 1.5054 0.667 | Te: 0.450\n",
      "Epoch 91 | Tr: 0.4511 0.993 | Va: 1.5193 0.667 | Te: 0.450\n",
      "Epoch 92 | Tr: 0.4488 0.993 | Va: 1.5046 0.667 | Te: 0.450\n",
      "Epoch 93 | Tr: 0.4467 0.993 | Va: 1.5207 0.667 | Te: 0.450\n",
      "Epoch 94 | Tr: 0.4446 0.993 | Va: 1.5240 0.667 | Te: 0.450\n",
      "Epoch 95 | Tr: 0.4426 0.993 | Va: 1.5173 0.667 | Te: 0.450\n",
      "Epoch 96 | Tr: 0.4405 0.993 | Va: 1.5243 0.667 | Te: 0.450\n",
      "Epoch 97 | Tr: 0.4387 0.993 | Va: 1.5415 0.667 | Te: 0.450\n",
      "Epoch 98 | Tr: 0.4369 0.993 | Va: 1.5379 0.667 | Te: 0.450\n",
      "Epoch 99 | Tr: 0.4351 0.993 | Va: 1.5403 0.667 | Te: 0.450\n",
      "Epoch 100 | Tr: 0.4333 0.993 | Va: 1.5496 0.667 | Te: 0.450\n",
      "Epoch 101 | Tr: 0.4317 0.993 | Va: 1.5647 0.667 | Te: 0.450\n",
      "Epoch 102 | Tr: 0.4301 0.993 | Va: 1.5664 0.667 | Te: 0.450\n",
      "Epoch 103 | Tr: 0.4286 0.993 | Va: 1.5701 0.667 | Te: 0.450\n",
      "Epoch 104 | Tr: 0.4270 0.993 | Va: 1.5795 0.667 | Te: 0.450\n",
      "Epoch 105 | Tr: 0.4256 0.993 | Va: 1.5899 0.667 | Te: 0.450\n",
      "Epoch 106 | Tr: 0.4243 0.993 | Va: 1.5961 0.667 | Te: 0.450\n",
      "Epoch 107 | Tr: 0.4230 1.000 | Va: 1.6048 0.667 | Te: 0.450\n",
      "Epoch 108 | Tr: 0.4217 1.000 | Va: 1.6140 0.667 | Te: 0.450\n",
      "Epoch 109 | Tr: 0.4205 1.000 | Va: 1.6197 0.667 | Te: 0.450\n",
      "Epoch 110 | Tr: 0.4194 1.000 | Va: 1.6312 0.667 | Te: 0.450\n",
      "Epoch 111 | Tr: 0.4183 1.000 | Va: 1.6375 0.667 | Te: 0.450\n",
      "Epoch 112 | Tr: 0.4172 1.000 | Va: 1.6457 0.667 | Te: 0.450\n",
      "Epoch 113 | Tr: 0.4162 1.000 | Va: 1.6533 0.667 | Te: 0.450\n",
      "Epoch 114 | Tr: 0.4153 1.000 | Va: 1.6598 0.667 | Te: 0.450\n",
      "Epoch 115 | Tr: 0.4144 1.000 | Va: 1.6703 0.667 | Te: 0.450\n",
      "Epoch 116 | Tr: 0.4135 1.000 | Va: 1.6763 0.667 | Te: 0.450\n",
      "Epoch 117 | Tr: 0.4127 1.000 | Va: 1.6828 0.667 | Te: 0.450\n",
      "Epoch 118 | Tr: 0.4119 1.000 | Va: 1.6890 0.667 | Te: 0.450\n",
      "Epoch 119 | Tr: 0.4111 1.000 | Va: 1.6976 0.667 | Te: 0.450\n",
      "Epoch 120 | Tr: 0.4104 1.000 | Va: 1.7062 0.667 | Te: 0.450\n",
      "Epoch 121 | Tr: 0.4098 1.000 | Va: 1.7089 0.667 | Te: 0.450\n",
      "Epoch 122 | Tr: 0.4091 1.000 | Va: 1.7142 0.667 | Te: 0.450\n",
      "Epoch 123 | Tr: 0.4085 1.000 | Va: 1.7226 0.667 | Te: 0.450\n",
      "Epoch 124 | Tr: 0.4079 1.000 | Va: 1.7281 0.667 | Te: 0.450\n",
      "Epoch 125 | Tr: 0.4074 1.000 | Va: 1.7325 0.667 | Te: 0.450\n",
      "Epoch 126 | Tr: 0.4068 1.000 | Va: 1.7395 0.667 | Te: 0.450\n",
      "Epoch 127 | Tr: 0.4063 1.000 | Va: 1.7472 0.667 | Te: 0.450\n",
      "Epoch 128 | Tr: 0.4058 1.000 | Va: 1.7524 0.667 | Te: 0.450\n",
      "Epoch 129 | Tr: 0.4053 1.000 | Va: 1.7593 0.667 | Te: 0.450\n",
      "Epoch 130 | Tr: 0.4048 1.000 | Va: 1.7676 0.667 | Te: 0.450\n",
      "Epoch 131 | Tr: 0.4044 1.000 | Va: 1.7727 0.667 | Te: 0.450\n",
      "Epoch 132 | Tr: 0.4040 1.000 | Va: 1.7774 0.667 | Te: 0.450\n",
      "Epoch 133 | Tr: 0.4036 1.000 | Va: 1.7835 0.667 | Te: 0.450\n",
      "Epoch 134 | Tr: 0.4032 1.000 | Va: 1.7887 0.667 | Te: 0.450\n",
      "Epoch 135 | Tr: 0.4029 1.000 | Va: 1.7937 0.667 | Te: 0.450\n",
      "Epoch 136 | Tr: 0.4025 1.000 | Va: 1.8005 0.667 | Te: 0.450\n",
      "Epoch 137 | Tr: 0.4022 1.000 | Va: 1.8068 0.667 | Te: 0.450\n",
      "Epoch 138 | Tr: 0.4019 1.000 | Va: 1.8119 0.667 | Te: 0.450\n",
      "Epoch 139 | Tr: 0.4016 1.000 | Va: 1.8169 0.667 | Te: 0.450\n",
      "Epoch 140 | Tr: 0.4013 1.000 | Va: 1.8215 0.667 | Te: 0.450\n",
      "Epoch 141 | Tr: 0.4011 1.000 | Va: 1.8273 0.667 | Te: 0.450\n",
      "Epoch 142 | Tr: 0.4008 1.000 | Va: 1.8334 0.667 | Te: 0.450\n",
      "Epoch 143 | Tr: 0.4006 1.000 | Va: 1.8378 0.667 | Te: 0.450\n",
      "Epoch 144 | Tr: 0.4003 1.000 | Va: 1.8434 0.667 | Te: 0.450\n",
      "Epoch 145 | Tr: 0.4001 1.000 | Va: 1.8497 0.667 | Te: 0.450\n",
      "Epoch 146 | Tr: 0.3999 1.000 | Va: 1.8538 0.667 | Te: 0.450\n",
      "Epoch 147 | Tr: 0.3997 1.000 | Va: 1.8581 0.667 | Te: 0.450\n",
      "Epoch 148 | Tr: 0.3995 1.000 | Va: 1.8635 0.667 | Te: 0.450\n",
      "Epoch 149 | Tr: 0.3993 1.000 | Va: 1.8695 0.667 | Te: 0.450\n",
      "Epoch 150 | Tr: 0.3991 1.000 | Va: 1.8748 0.667 | Te: 0.450\n",
      "Epoch 151 | Tr: 0.3989 1.000 | Va: 1.8785 0.667 | Te: 0.450\n",
      "Epoch 152 | Tr: 0.3987 1.000 | Va: 1.8847 0.667 | Te: 0.450\n",
      "Epoch 153 | Tr: 0.3986 1.000 | Va: 1.8911 0.667 | Te: 0.450\n",
      "Epoch 154 | Tr: 0.3984 1.000 | Va: 1.8956 0.667 | Te: 0.450\n",
      "Epoch 155 | Tr: 0.3983 1.000 | Va: 1.9011 0.667 | Te: 0.450\n",
      "Epoch 156 | Tr: 0.3981 1.000 | Va: 1.9057 0.667 | Te: 0.450\n",
      "Epoch 157 | Tr: 0.3980 1.000 | Va: 1.9100 0.667 | Te: 0.450\n",
      "Epoch 158 | Tr: 0.3978 1.000 | Va: 1.9156 0.667 | Te: 0.450\n",
      "Epoch 159 | Tr: 0.3977 1.000 | Va: 1.9201 0.667 | Te: 0.450\n",
      "Epoch 160 | Tr: 0.3976 1.000 | Va: 1.9238 0.667 | Te: 0.450\n",
      "Epoch 161 | Tr: 0.3975 1.000 | Va: 1.9279 0.667 | Te: 0.450\n",
      "Epoch 162 | Tr: 0.3973 1.000 | Va: 1.9318 0.667 | Te: 0.450\n",
      "Epoch 163 | Tr: 0.3972 1.000 | Va: 1.9358 0.667 | Te: 0.450\n",
      "Epoch 164 | Tr: 0.3971 1.000 | Va: 1.9404 0.667 | Te: 0.450\n",
      "Epoch 165 | Tr: 0.3970 1.000 | Va: 1.9444 0.667 | Te: 0.450\n",
      "Epoch 166 | Tr: 0.3969 1.000 | Va: 1.9473 0.667 | Te: 0.450\n",
      "Epoch 167 | Tr: 0.3968 1.000 | Va: 1.9510 0.667 | Te: 0.450\n",
      "Epoch 168 | Tr: 0.3967 1.000 | Va: 1.9560 0.667 | Te: 0.450\n",
      "Epoch 169 | Tr: 0.3966 1.000 | Va: 1.9605 0.667 | Te: 0.450\n",
      "Epoch 170 | Tr: 0.3965 1.000 | Va: 1.9633 0.667 | Te: 0.450\n",
      "Epoch 171 | Tr: 0.3964 1.000 | Va: 1.9659 0.667 | Te: 0.450\n",
      "Epoch 172 | Tr: 0.3963 1.000 | Va: 1.9706 0.667 | Te: 0.450\n",
      "Epoch 173 | Tr: 0.3963 1.000 | Va: 1.9753 0.667 | Te: 0.450\n",
      "Epoch 174 | Tr: 0.3962 1.000 | Va: 1.9776 0.667 | Te: 0.450\n",
      "Epoch 175 | Tr: 0.3961 1.000 | Va: 1.9799 0.667 | Te: 0.450\n",
      "Epoch 176 | Tr: 0.3960 1.000 | Va: 1.9846 0.667 | Te: 0.450\n",
      "Epoch 177 | Tr: 0.3960 1.000 | Va: 1.9892 0.667 | Te: 0.450\n",
      "Epoch 178 | Tr: 0.3959 1.000 | Va: 1.9912 0.667 | Te: 0.450\n",
      "Epoch 179 | Tr: 0.3958 1.000 | Va: 1.9933 0.667 | Te: 0.450\n",
      "Epoch 180 | Tr: 0.3957 1.000 | Va: 1.9974 0.667 | Te: 0.450\n",
      "Epoch 181 | Tr: 0.3957 1.000 | Va: 2.0014 0.667 | Te: 0.450\n",
      "Epoch 182 | Tr: 0.3956 1.000 | Va: 2.0047 0.667 | Te: 0.450\n",
      "Epoch 183 | Tr: 0.3956 1.000 | Va: 2.0071 0.667 | Te: 0.450\n",
      "Epoch 184 | Tr: 0.3955 1.000 | Va: 2.0093 0.667 | Te: 0.450\n",
      "Epoch 185 | Tr: 0.3954 1.000 | Va: 2.0121 0.667 | Te: 0.450\n",
      "Epoch 186 | Tr: 0.3954 1.000 | Va: 2.0154 0.667 | Te: 0.450\n",
      "Epoch 187 | Tr: 0.3953 1.000 | Va: 2.0170 0.667 | Te: 0.450\n",
      "Epoch 188 | Tr: 0.3953 1.000 | Va: 2.0176 0.667 | Te: 0.450\n",
      "Epoch 189 | Tr: 0.3952 1.000 | Va: 2.0196 0.667 | Te: 0.450\n",
      "Epoch 190 | Tr: 0.3952 1.000 | Va: 2.0229 0.667 | Te: 0.450\n",
      "Epoch 191 | Tr: 0.3951 1.000 | Va: 2.0258 0.667 | Te: 0.450\n",
      "Epoch 192 | Tr: 0.3951 1.000 | Va: 2.0280 0.667 | Te: 0.450\n",
      "Epoch 193 | Tr: 0.3950 1.000 | Va: 2.0300 0.667 | Te: 0.450\n",
      "Epoch 194 | Tr: 0.3950 1.000 | Va: 2.0325 0.667 | Te: 0.450\n",
      "Epoch 195 | Tr: 0.3949 1.000 | Va: 2.0349 0.667 | Te: 0.450\n",
      "Epoch 196 | Tr: 0.3949 1.000 | Va: 2.0372 0.667 | Te: 0.450\n",
      "Epoch 197 | Tr: 0.3948 1.000 | Va: 2.0390 0.667 | Te: 0.450\n",
      "Epoch 198 | Tr: 0.3948 1.000 | Va: 2.0415 0.667 | Te: 0.450\n",
      "Epoch 199 | Tr: 0.3948 1.000 | Va: 2.0448 0.667 | Te: 0.450\n",
      "Epoch 200 | Tr: 0.3947 1.000 | Va: 2.0471 0.667 | Te: 0.450\n",
      "Epoch 201 | Tr: 0.3947 1.000 | Va: 2.0498 0.667 | Te: 0.450\n",
      "Epoch 202 | Tr: 0.3946 1.000 | Va: 2.0526 0.667 | Te: 0.450\n",
      "Epoch 203 | Tr: 0.3946 1.000 | Va: 2.0549 0.667 | Te: 0.450\n",
      "Epoch 204 | Tr: 0.3946 1.000 | Va: 2.0575 0.667 | Te: 0.450\n",
      "Epoch 205 | Tr: 0.3945 1.000 | Va: 2.0598 0.667 | Te: 0.450\n",
      "Epoch 206 | Tr: 0.3945 1.000 | Va: 2.0626 0.667 | Te: 0.450\n",
      "Epoch 207 | Tr: 0.3945 1.000 | Va: 2.0650 0.611 | Te: 0.450\n",
      "Epoch 208 | Tr: 0.3944 1.000 | Va: 2.0671 0.611 | Te: 0.450\n",
      "Epoch 209 | Tr: 0.3944 1.000 | Va: 2.0696 0.611 | Te: 0.450\n",
      "Epoch 210 | Tr: 0.3944 1.000 | Va: 2.0713 0.611 | Te: 0.450\n",
      "Epoch 211 | Tr: 0.3944 1.000 | Va: 2.0734 0.611 | Te: 0.450\n",
      "Epoch 212 | Tr: 0.3943 1.000 | Va: 2.0763 0.611 | Te: 0.450\n",
      "Epoch 213 | Tr: 0.3943 1.000 | Va: 2.0783 0.611 | Te: 0.450\n",
      "Epoch 214 | Tr: 0.3943 1.000 | Va: 2.0797 0.611 | Te: 0.450\n",
      "Epoch 215 | Tr: 0.3942 1.000 | Va: 2.0820 0.611 | Te: 0.450\n",
      "Epoch 216 | Tr: 0.3942 1.000 | Va: 2.0838 0.611 | Te: 0.450\n",
      "Epoch 217 | Tr: 0.3942 1.000 | Va: 2.0857 0.611 | Te: 0.450\n",
      "Epoch 218 | Tr: 0.3942 1.000 | Va: 2.0880 0.611 | Te: 0.450\n",
      "Epoch 219 | Tr: 0.3941 1.000 | Va: 2.0893 0.611 | Te: 0.450\n",
      "Epoch 220 | Tr: 0.3941 1.000 | Va: 2.0905 0.611 | Te: 0.450\n",
      "Epoch 221 | Tr: 0.3941 1.000 | Va: 2.0928 0.611 | Te: 0.450\n",
      "Epoch 222 | Tr: 0.3941 1.000 | Va: 2.0955 0.611 | Te: 0.450\n",
      "Epoch 223 | Tr: 0.3940 1.000 | Va: 2.0969 0.611 | Te: 0.450\n",
      "Epoch 224 | Tr: 0.3940 1.000 | Va: 2.0974 0.611 | Te: 0.450\n",
      "Epoch 225 | Tr: 0.3940 1.000 | Va: 2.0984 0.611 | Te: 0.400\n",
      "Epoch 226 | Tr: 0.3940 1.000 | Va: 2.1006 0.611 | Te: 0.400\n",
      "Epoch 227 | Tr: 0.3940 1.000 | Va: 2.1019 0.611 | Te: 0.400\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 776\u001b[0m\n\u001b[1;32m    774\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mStart Training...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    775\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, epochs \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m):\n\u001b[0;32m--> 776\u001b[0m     tr_loss, tr_acc \u001b[38;5;241m=\u001b[39m \u001b[43mrun_one_epoch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m    777\u001b[0m     va_loss, va_acc \u001b[38;5;241m=\u001b[39m run_one_epoch(val_loader,   train\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m    778\u001b[0m     te_loss, te_acc \u001b[38;5;241m=\u001b[39m run_one_epoch(test_loader,  train\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
      "Cell \u001b[0;32mIn[6], line 729\u001b[0m, in \u001b[0;36mrun_one_epoch\u001b[0;34m(loader, train)\u001b[0m\n\u001b[1;32m    719\u001b[0m X_list, edge_index_list, y_list \u001b[38;5;241m=\u001b[39m split_batch_to_graphs(batch)\n\u001b[1;32m    721\u001b[0m \u001b[38;5;66;03m# --- Inductive Step: Build Trees with GNN Encoders ---\u001b[39;00m\n\u001b[1;32m    722\u001b[0m \u001b[38;5;66;03m# Encoders are on GPU; Make_tree uses them, but logic handles CPU/GPU\u001b[39;00m\n\u001b[1;32m    723\u001b[0m \u001b[38;5;66;03m# Note: Make_tree_HMH iterates encoders. \u001b[39;00m\n\u001b[1;32m    724\u001b[0m \u001b[38;5;66;03m# Sinkhorn works best on same device as encoder output (GPU).\u001b[39;00m\n\u001b[1;32m    725\u001b[0m \u001b[38;5;66;03m# Tree construction logic (adjacency) mostly typically CPU via SciPy, \u001b[39;00m\n\u001b[1;32m    726\u001b[0m \u001b[38;5;66;03m# but GNN forward pass is GPU.\u001b[39;00m\n\u001b[1;32m    728\u001b[0m (U_batch, eidx_batch, n_nodes_batch, n_edges_batch,\n\u001b[0;32m--> 729\u001b[0m  feats_batch, tree_batch, S_batch) \u001b[38;5;241m=\u001b[39m \u001b[43mUext_batch_from_tree_lists_HMH\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    730\u001b[0m \u001b[43m    \u001b[49m\u001b[43mX_list\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_index_list\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    731\u001b[0m \u001b[43m    \u001b[49m\u001b[43mlevels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlevels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    732\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgnn_encoders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgnn_encoders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    733\u001b[0m \u001b[43m    \u001b[49m\u001b[43mratio\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.5\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    734\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat64\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    735\u001b[0m \u001b[43m    \u001b[49m\u001b[43massign_method\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msinkhorn\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtau\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.9\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msinkhorn_iters\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    736\u001b[0m \u001b[43m    \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m42\u001b[39;49m\n\u001b[1;32m    737\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    739\u001b[0m \u001b[38;5;66;03m# --- Forward Pass ---\u001b[39;00m\n\u001b[1;32m    740\u001b[0m loss_batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.0\u001b[39m\n",
      "Cell \u001b[0;32mIn[6], line 482\u001b[0m, in \u001b[0;36mUext_batch_from_tree_lists_HMH\u001b[0;34m(X_list, edge_index_list, levels, gnn_encoders, ratio, device, dtype, assign_method, tau, sinkhorn_iters, seed)\u001b[0m\n\u001b[1;32m    479\u001b[0m A_i \u001b[38;5;241m=\u001b[39m to_scipy_sparse_matrix(ei_i, num_nodes\u001b[38;5;241m=\u001b[39mX_i\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m    481\u001b[0m \u001b[38;5;66;03m# Call GNN-based Tree Builder\u001b[39;00m\n\u001b[0;32m--> 482\u001b[0m treeG_i, S_assign_list \u001b[38;5;241m=\u001b[39m \u001b[43mMake_tree_HMH\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    483\u001b[0m \u001b[43m    \u001b[49m\u001b[43mX\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mX_i\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mA\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mA_i\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    484\u001b[0m \u001b[43m    \u001b[49m\u001b[43mlevels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlevels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m    485\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgnn_encoders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgnn_encoders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m    486\u001b[0m \u001b[43m    \u001b[49m\u001b[43mratio\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mratio\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    487\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    488\u001b[0m \u001b[43m    \u001b[49m\u001b[43massign_method\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43massign_method\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtau\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtau\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msinkhorn_iters\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msinkhorn_iters\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    489\u001b[0m \u001b[43m    \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mseed\u001b[49m\n\u001b[1;32m    490\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    492\u001b[0m \u001b[38;5;66;03m# Compute Haar Basis\u001b[39;00m\n\u001b[1;32m    493\u001b[0m treeG_i \u001b[38;5;241m=\u001b[39m HaarGOB_with_Sassign_degree_norm(treeG_i, S_assign_list)\n",
      "Cell \u001b[0;32mIn[6], line 236\u001b[0m, in \u001b[0;36mMake_tree_HMH\u001b[0;34m(X, A, levels, gnn_encoders, ratio, device, dtype, assign_method, tau, sinkhorn_iters, seed)\u001b[0m\n\u001b[1;32m    224\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mMake_tree_HMH\u001b[39m(\n\u001b[1;32m    225\u001b[0m     X, A,\n\u001b[1;32m    226\u001b[0m     levels: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    234\u001b[0m     seed: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m    235\u001b[0m ):\n\u001b[0;32m--> 236\u001b[0m     \u001b[43mset_seed\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseed\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    238\u001b[0m     \u001b[38;5;66;03m# Ensure numpy features\u001b[39;00m\n\u001b[1;32m    239\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(X, torch\u001b[38;5;241m.\u001b[39mTensor):\n",
      "Cell \u001b[0;32mIn[6], line 17\u001b[0m, in \u001b[0;36mset_seed\u001b[0;34m(seed)\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mset_seed\u001b[39m(seed\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m):\n\u001b[1;32m     16\u001b[0m     np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mseed(seed)\n\u001b[0;32m---> 17\u001b[0m     \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmanual_seed\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseed\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     18\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available():\n\u001b[1;32m     19\u001b[0m         torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mmanual_seed_all(seed)\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/_compile.py:32\u001b[0m, in \u001b[0;36m_disable_dynamo.<locals>.inner\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     29\u001b[0m     disable_fn \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mdisable(fn, recursive)\n\u001b[1;32m     30\u001b[0m     fn\u001b[38;5;241m.\u001b[39m__dynamo_disable \u001b[38;5;241m=\u001b[39m disable_fn\n\u001b[0;32m---> 32\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mdisable_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/_dynamo/eval_frame.py:745\u001b[0m, in \u001b[0;36mDisableContext.__call__.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    741\u001b[0m prior_skip_guard_eval_unsafe \u001b[38;5;241m=\u001b[39m set_skip_guard_eval_unsafe(\n\u001b[1;32m    742\u001b[0m     _is_skip_guard_eval_unsafe_stance()\n\u001b[1;32m    743\u001b[0m )\n\u001b[1;32m    744\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 745\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    746\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    747\u001b[0m     _maybe_set_eval_frame(prior)\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/random.py:46\u001b[0m, in \u001b[0;36mmanual_seed\u001b[0;34m(seed)\u001b[0m\n\u001b[1;32m     43\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcuda\u001b[39;00m\n\u001b[1;32m     45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39m_is_in_bad_fork():\n\u001b[0;32m---> 46\u001b[0m     \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcuda\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmanual_seed_all\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseed\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     48\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmps\u001b[39;00m\n\u001b[1;32m     50\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mmps\u001b[38;5;241m.\u001b[39m_is_in_bad_fork():\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/cuda/random.py:127\u001b[0m, in \u001b[0;36mmanual_seed_all\u001b[0;34m(seed)\u001b[0m\n\u001b[1;32m    124\u001b[0m         default_generator \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mdefault_generators[i]\n\u001b[1;32m    125\u001b[0m         default_generator\u001b[38;5;241m.\u001b[39mmanual_seed(seed)\n\u001b[0;32m--> 127\u001b[0m \u001b[43m_lazy_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseed_all\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/cuda/__init__.py:256\u001b[0m, in \u001b[0;36m_lazy_call\u001b[0;34m(callable, **kwargs)\u001b[0m\n\u001b[1;32m    254\u001b[0m \u001b[38;5;28;01mglobal\u001b[39;00m _lazy_seed_tracker\n\u001b[1;32m    255\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseed_all\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[0;32m--> 256\u001b[0m     _lazy_seed_tracker\u001b[38;5;241m.\u001b[39mqueue_seed_all(\u001b[38;5;28mcallable\u001b[39m, \u001b[43mtraceback\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mformat_stack\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m    257\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseed\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m    258\u001b[0m     _lazy_seed_tracker\u001b[38;5;241m.\u001b[39mqueue_seed(\u001b[38;5;28mcallable\u001b[39m, traceback\u001b[38;5;241m.\u001b[39mformat_stack())\n",
      "File \u001b[0;32m/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/traceback.py:197\u001b[0m, in \u001b[0;36mformat_stack\u001b[0;34m(f, limit)\u001b[0m\n\u001b[1;32m    195\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m f \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    196\u001b[0m     f \u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39m_getframe()\u001b[38;5;241m.\u001b[39mf_back\n\u001b[0;32m--> 197\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m format_list(\u001b[43mextract_stack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlimit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlimit\u001b[49m\u001b[43m)\u001b[49m)\n",
      "File \u001b[0;32m/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/traceback.py:211\u001b[0m, in \u001b[0;36mextract_stack\u001b[0;34m(f, limit)\u001b[0m\n\u001b[1;32m    209\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m f \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    210\u001b[0m     f \u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39m_getframe()\u001b[38;5;241m.\u001b[39mf_back\n\u001b[0;32m--> 211\u001b[0m stack \u001b[38;5;241m=\u001b[39m \u001b[43mStackSummary\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mextract\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwalk_stack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlimit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlimit\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    212\u001b[0m stack\u001b[38;5;241m.\u001b[39mreverse()\n\u001b[1;32m    213\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m stack\n",
      "File \u001b[0;32m/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/traceback.py:359\u001b[0m, in \u001b[0;36mStackSummary.extract\u001b[0;34m(klass, frame_gen, limit, lookup_lines, capture_locals)\u001b[0m\n\u001b[1;32m    357\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    358\u001b[0m         f_locals \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 359\u001b[0m     \u001b[43mresult\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mappend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mFrameSummary\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    360\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlineno\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlookup_line\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlocals\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mf_locals\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    361\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m filename \u001b[38;5;129;01min\u001b[39;00m fnames:\n\u001b[1;32m    362\u001b[0m     linecache\u001b[38;5;241m.\u001b[39mcheckcache(filename)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.datasets import TUDataset\n",
    "\n",
    "# ==============================================================================\n",
    "# 1. Utilities & Helpers\n",
    "# ==============================================================================\n",
    "\n",
    "def set_seed(seed=0):\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "def decide_K(N_cur, ratio, last_level):\n",
    "    if N_cur <= 2 or last_level:\n",
    "        return 1\n",
    "    K = int(N_cur * ratio) + 1\n",
    "    return max(1, min(K, N_cur))\n",
    "\n",
    "def symmetrize_simple(W: sp.csr_matrix):\n",
    "    W = W.maximum(W.T)\n",
    "    W.setdiag(0)\n",
    "    W.eliminate_zeros()\n",
    "    return W\n",
    "\n",
    "def coarsen_adj_hard(A: sp.csr_matrix, hard_labels: np.ndarray, K: int):\n",
    "    rr, cc, vv = sp.find(A)\n",
    "    nrr, ncc = hard_labels[rr], hard_labels[cc]\n",
    "    A_coarse = sp.csr_matrix((vv, (nrr, ncc)), shape=(K, K))\n",
    "    A_coarse = symmetrize_simple(A_coarse)\n",
    "    return A_coarse\n",
    "\n",
    "def _unique_seeds_for_all_clusters(S: np.ndarray) -> np.ndarray:\n",
    "    N, K = S.shape\n",
    "    order = np.argsort(-S, axis=0)\n",
    "    seeds = -np.ones(K, dtype=int)\n",
    "    used = np.zeros(N, dtype=bool)\n",
    "    ptr = np.zeros(K, dtype=int)\n",
    "    remaining = list(range(K))\n",
    "    guard = 0\n",
    "    while remaining and guard < K * N:\n",
    "        k = remaining.pop(0)\n",
    "        while ptr[k] < N and used[order[ptr[k], k]]:\n",
    "            ptr[k] += 1\n",
    "        if ptr[k] < N:\n",
    "            i = order[ptr[k], k]\n",
    "            seeds[k] = i\n",
    "            used[i] = True\n",
    "        else:\n",
    "            i = int(np.argmin(used))\n",
    "            seeds[k] = i\n",
    "            used[i] = True\n",
    "        guard += 1\n",
    "    return seeds\n",
    "\n",
    "def hard_labels_cover_all(S: np.ndarray) -> np.ndarray:\n",
    "    N, K = S.shape\n",
    "    y = S.argmax(axis=1)\n",
    "    counts = np.bincount(y, minlength=K)\n",
    "    if (counts == 0).any():\n",
    "        seeds = _unique_seeds_for_all_clusters(S)\n",
    "        y[seeds] = np.arange(K, dtype=int)\n",
    "    return y\n",
    "\n",
    "def to_scipy_sparse_matrix(edge_index, num_nodes):\n",
    "    if isinstance(edge_index, torch.Tensor):\n",
    "        ei = edge_index.detach().cpu().numpy()\n",
    "    else:\n",
    "        ei = np.asarray(edge_index)\n",
    "    r, c = ei\n",
    "    data = np.ones(r.size, dtype=np.float64)\n",
    "    A = sp.coo_matrix((data, (r, c)), shape=(num_nodes, num_nodes)).tocsr()\n",
    "    A = A.maximum(A.T)\n",
    "    A.setdiag(0)\n",
    "    A.eliminate_zeros()\n",
    "    return A\n",
    "\n",
    "def adj2edge(A: sp.coo_matrix):\n",
    "    A = A.tocoo()\n",
    "    row = torch.as_tensor(A.row, dtype=torch.long)\n",
    "    col = torch.as_tensor(A.col, dtype=torch.long)\n",
    "    edge_index = torch.stack([row, col], dim=0)\n",
    "    edge_weight = torch.as_tensor(A.data, dtype=torch.float32)\n",
    "    return edge_index, edge_weight\n",
    "\n",
    "def _deg_vec(A):\n",
    "    d = np.asarray(A.sum(axis=1)).ravel().astype(float)\n",
    "    d[d <= 0.0] = 1e-12\n",
    "    return d\n",
    "\n",
    "def _dnorm(u, d):\n",
    "    return np.sqrt(float((d * (u * u)).sum()) + 1e-24)\n",
    "\n",
    "# ==============================================================================\n",
    "# 2. Assignment & Sinkhorn Layers\n",
    "# ==============================================================================\n",
    "\n",
    "class LinearHeadK(nn.Module):\n",
    "    def __init__(self, in_dim, K, hidden=32, dtype=torch.float64):\n",
    "        super().__init__()\n",
    "        # Input dim matches the GNN encoder output\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(in_dim, hidden, dtype=dtype), \n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden, K, dtype=dtype)\n",
    "        )\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Linear):\n",
    "                nn.init.xavier_uniform_(m.weight, gain=0.1)\n",
    "                nn.init.zeros_(m.bias)\n",
    "\n",
    "    def forward(self, U):\n",
    "        return self.net(U)\n",
    "\n",
    "def sinkhorn_balanced(logits: torch.Tensor, n_iters=7, tau=1.0):\n",
    "    n, K = logits.shape\n",
    "    P = torch.exp(logits / tau) + 1e-9\n",
    "    col_tgt = (n / K) * torch.ones(K, device=logits.device, dtype=logits.dtype)\n",
    "    for _ in range(n_iters):\n",
    "        P = P / (P.sum(dim=1, keepdim=True) + 1e-12)\n",
    "        col_sum = P.sum(dim=0) + 1e-12\n",
    "        P = P * (col_tgt / col_sum)\n",
    "    P = P / (P.sum(dim=1, keepdim=True) + 1e-12)\n",
    "    return P\n",
    "\n",
    "@torch.no_grad()\n",
    "def assignments_with_margin(U: torch.Tensor, hidden=32, zeta=1e-2,\n",
    "                            method=\"sinkhorn\", tau=1.0, sinkhorn_iters=12):\n",
    "    # Determine K from the input U shape or separate logic if U isn't [N, K]\n",
    "    # In this pipeline, the GNN output U has shape [N, embed_dim]\n",
    "    # We project it to [N, K] inside this function via LinearHeadK.\n",
    "    \n",
    "    # NOTE: Since this function was originally designed to take spectral U [N, K],\n",
    "    # but now takes GNN U [N, embed], we need to pass K explicitly or fix the head.\n",
    "    # To keep the signature clean, we assume the LinearHeadK matches dimensions.\n",
    "    # HOWEVER: Ideally, the projection to K logits happens *before* this helper \n",
    "    # if K changes dynamically per graph.\n",
    "    \n",
    "    # FIX: We will do the projection locally. Since K changes per graph, \n",
    "    # we can't use a fixed `LinearHeadK` unless we rebuild it or use a stateless projection.\n",
    "    # For a learnable pipeline, the projection usually happens *before* calling the assignment.\n",
    "    # But since you requested specific order: \n",
    "    # We will assume U is ALREADY the projected logits or embeddings to be projected.\n",
    "    \n",
    "    # SIMPLIFICATION for GNN version: \n",
    "    # We will assume U is the node embedding. We use a simple heuristic or a \n",
    "    # small randomized projection if we want \"spectral-like\" behavior, \n",
    "    # OR we assume the GNN *already* outputted dimension K.\n",
    "    \n",
    "    # Given the previous code, let's use a randomized projection for the \"Head\" \n",
    "    # to convert embeddings to K logits, as K varies per graph.\n",
    "    \n",
    "    N, dim = U.shape\n",
    "    # We need to project dim -> K. Since K varies, we can't have a fixed weight matrix.\n",
    "    # We will use the raw embeddings U directly if dim==K, otherwise we slice/pad.\n",
    "    # Alternatively, for the ICLR paper, often the GNN output dimension is fixed \n",
    "    # and we project to a fixed max_K, then slice.\n",
    "    \n",
    "    # Strategy: Use U as logits directly (assuming GNN output dim >= K).\n",
    "    # If GNN output < K, we repeat.\n",
    "    \n",
    "    # For this implementation, we will assume the GNN output dimension is `embed_dim`\n",
    "    # and we just take the first K channels as logits (similar to spectral clustering).\n",
    "    \n",
    "    if dim < 1: \n",
    "        logits = torch.ones(N, 1, device=U.device, dtype=U.dtype)\n",
    "    else:\n",
    "        # If U has more dims than K, slice. If fewer, repeat.\n",
    "        # But 'K' is not passed here! \n",
    "        # FIX: We will modify the caller to do the projection. \n",
    "        # Here we assume U contains the logits for K clusters.\n",
    "        logits = U \n",
    "\n",
    "    logits = logits - logits.mean(dim=1, keepdim=True)\n",
    "    \n",
    "    # Margin calculation\n",
    "    if logits.shape[1] > 1:\n",
    "        top2 = torch.topk(logits, k=2, dim=1).values\n",
    "        num = (top2[:, 0] - top2[:, 1]).clamp_min(0.0)\n",
    "        den = top2.abs().sum(dim=1) + zeta\n",
    "        mu = (num / den).unsqueeze(1)\n",
    "    else:\n",
    "        mu = torch.ones(N, 1, device=U.device, dtype=U.dtype)\n",
    "\n",
    "    logits_scaled = mu * logits\n",
    "    if method.lower() == \"sinkhorn\":\n",
    "        S = sinkhorn_balanced(logits_scaled, n_iters=sinkhorn_iters, tau=tau)\n",
    "    else:\n",
    "        S = torch.softmax(logits_scaled, dim=1)\n",
    "    return S, logits, mu.squeeze(1)\n",
    "\n",
    "\n",
    "# ==============================================================================\n",
    "# 3. GNN Encoder (Replacement for Diffusion)\n",
    "# ==============================================================================\n",
    "\n",
    "class GNNPoolEncoder(nn.Module):\n",
    "    def __init__(self, in_dim, hidden_dim, out_dim):\n",
    "        super().__init__()\n",
    "        self.conv1 = GCNConv(in_dim, hidden_dim)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.conv2 = GCNConv(hidden_dim, out_dim)\n",
    "\n",
    "    def forward(self, x, edge_index, edge_weight=None):\n",
    "        # x: [N, F], edge_index: [2, E]\n",
    "        x = self.conv1(x, edge_index, edge_weight)\n",
    "        x = self.relu(x)\n",
    "        x = self.conv2(x, edge_index, edge_weight)\n",
    "        # Tanh to keep embeddings bounded like spectral coordinates [-1, 1]\n",
    "        return torch.tanh(x)\n",
    "\n",
    "\n",
    "# ==============================================================================\n",
    "# 4. Hierarchical Tree Builder (GNN-based)\n",
    "# ==============================================================================\n",
    "\n",
    "def Make_tree_HMH(\n",
    "    X, A,\n",
    "    levels: int,\n",
    "    gnn_encoders: nn.ModuleList, # List of GNNPoolEncoder\n",
    "    ratio: float = 0.2,\n",
    "    device: str = \"cpu\",\n",
    "    dtype = torch.float64,\n",
    "    assign_method: str = \"sinkhorn\",\n",
    "    tau: float = 1.0,\n",
    "    sinkhorn_iters: int = 7,\n",
    "    seed: int = 0\n",
    "):\n",
    "    set_seed(seed)\n",
    "\n",
    "    # Ensure numpy features\n",
    "    if isinstance(X, torch.Tensor):\n",
    "        X = X.detach().cpu().numpy()\n",
    "    A = A.tocsr()\n",
    "    N_start = A.shape[0]\n",
    "\n",
    "    adj_list = [A]\n",
    "    features_list = [X]\n",
    "    parents = []\n",
    "    S_assign_list = []\n",
    "\n",
    "    for level in range(levels - 1):\n",
    "        N_cur = A.shape[0]\n",
    "        last_level = (level == levels - 2)\n",
    "        K = decide_K(N_cur, ratio, last_level)\n",
    "        \n",
    "        if K == 1:\n",
    "            # Trivial case: collapse all to 1\n",
    "            S_triv = np.ones((N_cur, 1), dtype=np.float64)\n",
    "            S_assign_list.append(S_triv)\n",
    "            parents.append(np.zeros(N_cur, dtype=int))\n",
    "            # Coarsen\n",
    "            X = S_triv.T @ X\n",
    "            A = coarsen_adj_hard(A, np.zeros(N_cur, dtype=int), K=1)\n",
    "            adj_list.append(A)\n",
    "            features_list.append(X)\n",
    "            break\n",
    "\n",
    "        # ======================================================================\n",
    "        # REPLACED BLOCK: GNN Encoder instead of Diffusion/Spectral\n",
    "        # ======================================================================\n",
    "        \n",
    "        # 1) Prepare Inputs (Scipy/Numpy -> Torch)\n",
    "        X_t = torch.from_numpy(X).to(device=device, dtype=torch.float32)\n",
    "        \n",
    "        A_coo = A.tocoo()\n",
    "        row = torch.from_numpy(A_coo.row).long().to(device)\n",
    "        col = torch.from_numpy(A_coo.col).long().to(device)\n",
    "        edge_index = torch.stack([row, col], dim=0)\n",
    "        edge_weight = torch.from_numpy(A_coo.data).to(device=device, dtype=torch.float32)\n",
    "\n",
    "        # 2) Run GNN Encoder for this level\n",
    "        # Output U_t: [N_cur, embed_dim]\n",
    "        current_encoder = gnn_encoders[level]\n",
    "        U_embed = current_encoder(X_t, edge_index, edge_weight).to(dtype=dtype)\n",
    "        \n",
    "        # 3) Project Embeddings to K Logits\n",
    "        # Since K changes dynamically, we take the first K dimensions of the embedding.\n",
    "        # If embed_dim < K, we pad. This acts as a \"dynamic linear head\".\n",
    "        embed_dim = U_embed.shape[1]\n",
    "        if embed_dim >= K:\n",
    "            U_logits = U_embed[:, :K]\n",
    "        else:\n",
    "            # Pad with zeros or repeat\n",
    "            padding = torch.zeros(N_cur, K - embed_dim, device=device, dtype=dtype)\n",
    "            U_logits = torch.cat([U_embed, padding], dim=1)\n",
    "\n",
    "        # ======================================================================\n",
    "\n",
    "        # 8) Assignments (margin-scaled Sinkhorn)\n",
    "        S_t, logits_t, mu_t = assignments_with_margin(\n",
    "            U_logits, hidden=32, zeta=1e-3, method=assign_method, \n",
    "            tau=tau, sinkhorn_iters=sinkhorn_iters\n",
    "        )\n",
    "        S_np = S_t.detach().cpu().numpy()\n",
    "        \n",
    "        # 9) Hard labels\n",
    "        hard_labels = hard_labels_cover_all(S_np) \n",
    "        \n",
    "        # 10) Coarsen adjacency\n",
    "        A_next = coarsen_adj_hard(A, hard_labels, K)\n",
    "        adj_list.append(A_next)\n",
    "\n",
    "        # 11) Coarsen features\n",
    "        X_next = S_np.T @ X\n",
    "        features_list.append(X_next)\n",
    "\n",
    "        # 12) Record\n",
    "        parents.append(hard_labels)\n",
    "        S_assign_list.append(S_np)\n",
    "\n",
    "        A, X = A_next, X_next\n",
    "\n",
    "    # 13) Build Tree Structure\n",
    "    L_eff = len(adj_list)\n",
    "    treeG = [None] * L_eff\n",
    "    for lvl in range(L_eff):\n",
    "        if lvl == 0:\n",
    "            idxs = np.arange(N_start)\n",
    "            clusters = [np.array([i], dtype=int) for i in idxs]\n",
    "            IDX_vec = np.arange(N_start)\n",
    "        else:\n",
    "            pid = parents[lvl - 1]\n",
    "            K_lvl = S_assign_list[lvl - 1].shape[1]\n",
    "            clusters = [np.flatnonzero(pid == k) for k in range(K_lvl)]\n",
    "            IDX_vec = pid\n",
    "\n",
    "        treeG[lvl] = {\n",
    "            'IDX': IDX_vec,\n",
    "            'clusters': clusters,\n",
    "            'adj': adj_list[lvl],\n",
    "            'features': features_list[lvl]\n",
    "        }\n",
    "\n",
    "    return treeG, S_assign_list\n",
    "\n",
    "# ==============================================================================\n",
    "# 5. Haar Basis Construction\n",
    "# ==============================================================================\n",
    "\n",
    "def HaarGOB_with_Sassign_degree_norm(treeG, S_assign_list):\n",
    "    Ntr = len(treeG)\n",
    "\n",
    "    # Step 1: Top level basis\n",
    "    clusterJ0 = treeG[Ntr-1]['clusters']\n",
    "    N0 = len(clusterJ0)\n",
    "\n",
    "    chic = np.identity(N0)\n",
    "    uc = [None] * N0\n",
    "    uc[0] = (1.0 / np.sqrt(N0)) * np.ones(N0, dtype=float)\n",
    "    for l in range(1, N0):\n",
    "        uc[l] = np.sqrt((N0 - l) / (N0 - l + 1.0)) * (\n",
    "            chic[l-1, :] - (1.0 / (N0 - l)) * np.sum(chic[l:, :], axis=0)\n",
    "        )\n",
    "\n",
    "    A_top = treeG[Ntr-1]['adj'].tocsr()\n",
    "    d_top = _deg_vec(A_top)\n",
    "    for l in range(N0):\n",
    "        nrm = _dnorm(uc[l], d_top)\n",
    "        uc[l] = uc[l] / nrm\n",
    "\n",
    "    treeG[Ntr-1]['u'] = uc\n",
    "\n",
    "    # Step 2: Propagate Down\n",
    "    for j_tr in np.arange(Ntr-2, -1, -1):\n",
    "        N1 = len(treeG[j_tr]['clusters'])\n",
    "        S_assign = np.asarray(S_assign_list[j_tr], dtype=float)\n",
    "        \n",
    "        A_lvl = treeG[j_tr]['adj'].tocsr()\n",
    "        d_lvl = _deg_vec(A_lvl)\n",
    "\n",
    "        u = [None] * N1\n",
    "        i = N0 \n",
    "\n",
    "        # Inter-cluster vectors\n",
    "        for l in range(N0):\n",
    "            cluster_l = np.asarray(treeG[j_tr+1]['clusters'][l], dtype=int)\n",
    "            ul1 = np.zeros(N1, dtype=float)\n",
    "            \n",
    "            # Vectorized accumulation\n",
    "            # For each parent node j, add contribution to its children\n",
    "            for j in range(N0):\n",
    "                idxj = np.asarray(treeG[j_tr+1]['clusters'][j], dtype=int)\n",
    "                if idxj.size == 0: continue\n",
    "                w = S_assign[idxj, l]\n",
    "                ul1[idxj] += uc[l][j] * w\n",
    "\n",
    "            nrm = _dnorm(ul1, d_lvl)\n",
    "            if nrm > 0: ul1 = ul1 / nrm\n",
    "            u[l] = ul1\n",
    "\n",
    "            # Intra-cluster details\n",
    "            kl = int(cluster_l.size)\n",
    "            if kl > 1:\n",
    "                chil = np.zeros((kl, N1), dtype=float)\n",
    "                for k in range(kl):\n",
    "                    chil[k, cluster_l[k]] = 1.0\n",
    "\n",
    "                for k in range(1, kl):\n",
    "                    i += 1\n",
    "                    ulk = np.sqrt((kl - k) / (kl - k + 1.0)) * (\n",
    "                        chil[k-1, :] - (1.0 / (kl - k)) * np.sum(chil[k:, :], axis=0)\n",
    "                    )\n",
    "                    nrmk = _dnorm(ulk, d_lvl)\n",
    "                    if nrmk > 0: ulk = ulk / nrmk\n",
    "                    u[i-1] = ulk\n",
    "\n",
    "        treeG[j_tr]['u'] = u\n",
    "        uc = u\n",
    "        N0 = N1\n",
    "\n",
    "    return treeG\n",
    "\n",
    "def extract_haar_basis_and_graph_info(tree_real):\n",
    "    Tree_length = len(tree_real)\n",
    "    num_nodes_tree = np.zeros(Tree_length, dtype=int)\n",
    "    num_edges_tree = np.zeros(Tree_length, dtype=int)\n",
    "    edge_index_list = [None] * Tree_length\n",
    "    U = []\n",
    "    features_list = []\n",
    "    for j in range(Tree_length):\n",
    "        u = tree_real[j]['u']\n",
    "        N = len(u)\n",
    "        # N1 is N_nodes of parent level.\n",
    "        N1 = len(tree_real[j+1]['u']) if j < Tree_length - 1 else 1\n",
    "        \n",
    "        # HaarBasis matrix [N, N1] is NOT square necessarily in this extraction logic?\n",
    "        # Actually, u is length N (the basis size).\n",
    "        # We want to extract the basis vectors as columns of a matrix?\n",
    "        # The code snippet you provided previously puts them into [N, N1]??\n",
    "        # Usually basis U is [N, N]. \n",
    "        # FIX: The basis U at level j has N vectors of length N. \n",
    "        # But previous code used N1 in loop range? \n",
    "        # Let's assume we want the full basis U [N, N].\n",
    "        \n",
    "        HaarBases = np.zeros((N, N), dtype=np.float64)\n",
    "        for k in range(N):\n",
    "            HaarBases[:, k] = u[k]\n",
    "        U.append(HaarBases)\n",
    "        \n",
    "        num_nodes_tree[j] = N\n",
    "        edge_index, _ = adj2edge(tree_real[j]['adj'])\n",
    "        edge_index_list[j] = edge_index\n",
    "        num_edges_tree[j] = edge_index.size(1)\n",
    "        features_list.append(tree_real[j]['features'])\n",
    "        \n",
    "    num_nodes_tree[-1] = 1\n",
    "    num_edges_tree[-1] = 1\n",
    "    return U, num_nodes_tree, num_edges_tree, edge_index_list, features_list\n",
    "\n",
    "# ==============================================================================\n",
    "# 6. Batch Processing Wrapper\n",
    "# ==============================================================================\n",
    "\n",
    "def Uext_batch_from_tree_lists_HMH(\n",
    "    X_list, edge_index_list,\n",
    "    levels, gnn_encoders, # GNN Encoders passed here\n",
    "    ratio=0.3,\n",
    "    device=\"cpu\", dtype=torch.float64,\n",
    "    assign_method=\"sinkhorn\", tau=0.9, sinkhorn_iters=10,\n",
    "    seed=42\n",
    "):\n",
    "    U_batch = []\n",
    "    edge_index_list_batch = []\n",
    "    num_nodes_tree_batch  = []\n",
    "    num_edges_tree_batch  = []\n",
    "    features_list_batch   = []\n",
    "    treeG_batch           = []\n",
    "    S_assign_List         = []\n",
    "\n",
    "    for X_i, ei_i in zip(X_list, edge_index_list):\n",
    "        A_i = to_scipy_sparse_matrix(ei_i, num_nodes=X_i.shape[0])\n",
    "        \n",
    "        # Call GNN-based Tree Builder\n",
    "        treeG_i, S_assign_list = Make_tree_HMH(\n",
    "            X=X_i, A=A_i,\n",
    "            levels=levels, \n",
    "            gnn_encoders=gnn_encoders, \n",
    "            ratio=ratio,\n",
    "            device=device, dtype=dtype,\n",
    "            assign_method=assign_method, tau=tau, sinkhorn_iters=sinkhorn_iters,\n",
    "            seed=seed\n",
    "        )\n",
    "        \n",
    "        # Compute Haar Basis\n",
    "        treeG_i = HaarGOB_with_Sassign_degree_norm(treeG_i, S_assign_list)\n",
    "        \n",
    "        # Extract Info\n",
    "        U_i, n_nodes_i, n_edges_i, eidx_i, feats_i = extract_haar_basis_and_graph_info(treeG_i)\n",
    "\n",
    "        U_batch.append(U_i)\n",
    "        edge_index_list_batch.append(eidx_i)\n",
    "        num_nodes_tree_batch.append(n_nodes_i)\n",
    "        num_edges_tree_batch.append(n_edges_i)\n",
    "        features_list_batch.append(feats_i)\n",
    "        treeG_batch.append(treeG_i)\n",
    "        S_assign_List.append(S_assign_list)\n",
    "\n",
    "    return (U_batch, edge_index_list_batch, num_nodes_tree_batch,\n",
    "            num_edges_tree_batch, features_list_batch, treeG_batch, S_assign_List)\n",
    "\n",
    "# ==============================================================================\n",
    "# 7. Classification Model\n",
    "# ==============================================================================\n",
    "\n",
    "def _to_dense_torch(mat, device):\n",
    "    \"\"\"\n",
    "    Converts numpy/scipy matrix to dense torch tensor on the specified device.\n",
    "    \"\"\"\n",
    "    if isinstance(mat, np.ndarray):\n",
    "        arr = mat\n",
    "    elif sp.issparse(mat):\n",
    "        arr = mat.toarray()\n",
    "    else:\n",
    "        arr = np.asarray(mat)\n",
    "    \n",
    "    # Ensure float32 for neural networks\n",
    "    return torch.as_tensor(arr, dtype=torch.float32, device=device)\n",
    "\n",
    "def unpool_one_level(H_coarse, clusters, N_fine):\n",
    "    device = H_coarse.device\n",
    "    D = H_coarse.size(1)\n",
    "    H_fine = torch.zeros(N_fine, D, device=device)\n",
    "    for i, child_idx in enumerate(clusters):\n",
    "        if len(child_idx) == 0: continue\n",
    "        idx = torch.as_tensor(child_idx, dtype=torch.long, device=device)\n",
    "        H_fine.index_add_(0, idx, H_coarse[i].expand(idx.numel(), D))\n",
    "    return H_fine\n",
    "\n",
    "def unpool_to_level0(H_l, level_l, treeG):\n",
    "    H = H_l\n",
    "    for m in range(level_l, 0, -1):\n",
    "        clusters_m = treeG[m]['clusters']\n",
    "        N_fine     = treeG[m-1]['adj'].shape[0]\n",
    "        H = unpool_one_level(H, clusters_m, N_fine)\n",
    "    return H\n",
    "\n",
    "class HaarSpectralBlock(nn.Module):\n",
    "    def __init__(self, max_K: int):\n",
    "        super().__init__()\n",
    "        self.lambda_vec = nn.Parameter(torch.randn(max_K))\n",
    "\n",
    "    def forward(self, U: torch.Tensor, X: torch.Tensor):\n",
    "        # U: [N_l, N_l], X: [N_l, F] \n",
    "        # NOTE: U from extract function is [N, N] (full basis)\n",
    "        # We only filter with top K components usually?\n",
    "        # Or we use full U?\n",
    "        # Implementation implies we slice U to K_cap\n",
    "        \n",
    "        K_l   = U.size(1)\n",
    "        K_cap = min(K_l, self.lambda_vec.size(0))\n",
    "\n",
    "        Uc    = U[:, :K_cap]                        # [N_l, K_cap]\n",
    "        X_hat = Uc.transpose(0, 1) @ X              # [K_cap, F]\n",
    "        lam   = self.lambda_vec[:K_cap].unsqueeze(1) \n",
    "        X_hat = X_hat * lam                         # [K_cap, F]\n",
    "        H     = Uc @ X_hat                          # [N_l, F]\n",
    "        return F.relu(H)\n",
    "\n",
    "class NodeHaarUnpoolClassifier(nn.Module):\n",
    "    def __init__(self, in_dim, hid_dim, num_classes, max_K, num_levels):\n",
    "        super().__init__()\n",
    "        self.num_levels = num_levels \n",
    "        self.pre = nn.Sequential(\n",
    "            nn.Linear(in_dim, hid_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hid_dim, hid_dim)\n",
    "        )\n",
    "        self.block = HaarSpectralBlock(max_K=max_K)\n",
    "        self.classifier = nn.Linear(hid_dim * num_levels, num_classes)\n",
    "        self.dropout = nn.Dropout(p=0.3)\n",
    "\n",
    "    def forward(self, U_list, features_list, treeG):\n",
    "        # CORRECTED: Get device dynamically from the model parameters\n",
    "        device = next(self.parameters()).device\n",
    "\n",
    "        L_eff = min(self.num_levels, len(U_list))\n",
    "\n",
    "        H_per_level = []\n",
    "        for l in range(L_eff):\n",
    "            # Pass the correct device explicitly\n",
    "            X_l = _to_dense_torch(features_list[l], device)\n",
    "            X_l = self.dropout(self.pre(X_l))\n",
    "            \n",
    "            U_l = _to_dense_torch(U_list[l], device)\n",
    "            H_l = self.block(U_l, X_l)\n",
    "            \n",
    "            H0_l = unpool_to_level0(H_l, level_l=l, treeG=treeG)\n",
    "            H_per_level.append(H0_l)\n",
    "\n",
    "        H0_cat = torch.cat(H_per_level, dim=1)\n",
    "        H0_cat = self.dropout(H0_cat)\n",
    "        logits = self.classifier(H0_cat)\n",
    "        return logits\n",
    "\n",
    "def loss_diversity_from_S(S_assign_list, device=None, eps=1e-9):\n",
    "    L_div = 0.0\n",
    "    for S in S_assign_list:\n",
    "        if isinstance(S, np.ndarray): S_t = torch.from_numpy(S)\n",
    "        else: S_t = S\n",
    "        if device is not None: S_t = S_t.to(device)\n",
    "        S_t = S_t.clamp_min(eps)\n",
    "        row_entropy = -(S_t * S_t.log()).sum(dim=1)\n",
    "        L_div = L_div + row_entropy.mean()\n",
    "    return L_div\n",
    "\n",
    "def loss_reconstruction_from_treeG(treeG, device=None):\n",
    "    L_rec = 0.0\n",
    "    for lvl in range(len(treeG)):\n",
    "        if 'u' not in treeG[lvl]: continue\n",
    "        u_list = treeG[lvl]['u']\n",
    "        if u_list is None or any(v is None for v in u_list): continue\n",
    "\n",
    "        U_np = np.stack(u_list, axis=0) \n",
    "        H_np = treeG[lvl]['features']\n",
    "\n",
    "        U = torch.from_numpy(U_np.astype(np.float32))\n",
    "        H = torch.from_numpy(H_np.astype(np.float32))\n",
    "        if device is not None:\n",
    "            U = U.to(device)\n",
    "            H = H.to(device)\n",
    "\n",
    "        H_hat = U.t() @ (U @ H)\n",
    "        L_rec = L_rec + F.mse_loss(H_hat, H, reduction='sum')\n",
    "    return L_rec\n",
    "\n",
    "# ==============================================================================\n",
    "# 8. Main Training Script (MUTAG)\n",
    "# ==============================================================================\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    print(\"Device:\", device)\n",
    "\n",
    "    # Load Data\n",
    "    import tempfile\n",
    "    root = os.path.join(tempfile.gettempdir(), 'data', 'MUTAG')\n",
    "    dataset = TUDataset(root, name='MUTAG').shuffle()\n",
    "\n",
    "    num_training = int(0.8 * len(dataset))\n",
    "    num_val      = int(0.1 * len(dataset))\n",
    "    num_test     = len(dataset) - (num_training + num_val)\n",
    "    train_set, val_set, test_set = torch.utils.data.random_split(\n",
    "        dataset, [num_training, num_val, num_test],\n",
    "        generator=torch.Generator().manual_seed(42)\n",
    "    )\n",
    "\n",
    "    batch_size = 32\n",
    "    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
    "    val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False)\n",
    "    test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False)\n",
    "\n",
    "    def split_batch_to_graphs(batch):\n",
    "        data_list = batch.to_data_list()\n",
    "        X_list, edge_index_list, y_list = [], [], []\n",
    "        for data in data_list:\n",
    "            x = data.x\n",
    "            if x is None or x.numel() == 0:\n",
    "                x = torch.ones(data.num_nodes, 1, dtype=torch.float32)\n",
    "            X_list.append(x)\n",
    "            edge_index_list.append(data.edge_index)\n",
    "            y = data.y.view(-1)[0].long()\n",
    "            y_list.append(y)\n",
    "        return X_list, edge_index_list, y_list\n",
    "\n",
    "    # Configuration\n",
    "    input_dim = dataset.num_features if dataset.num_features > 0 else 1\n",
    "    num_classes = dataset.num_classes\n",
    "    hid_dim = 32\n",
    "    max_K = 32\n",
    "    levels = 4\n",
    "    embed_dim = 16 # Dimension of GNN output (before Sinkhorn)\n",
    "\n",
    "    # Initialize GNN Encoders (Trainable)\n",
    "    # We need encoders for levels 0, 1, ..., levels-2 (coarsening steps)\n",
    "    gnn_encoders = nn.ModuleList([\n",
    "        GNNPoolEncoder(in_dim=input_dim,    # <--- ALWAYS input_dim\n",
    "                       hidden_dim=hid_dim, \n",
    "                       out_dim=embed_dim) \n",
    "        for l in range(levels - 1)\n",
    "    ]).to(device)\n",
    "\n",
    "    # Initialize Classifier\n",
    "    model2 = NodeHaarUnpoolClassifier(\n",
    "        in_dim=input_dim,\n",
    "        hid_dim=hid_dim,\n",
    "        num_classes=num_classes,\n",
    "        max_K=max_K,\n",
    "        num_levels=levels-1\n",
    "    ).to(device)\n",
    "\n",
    "    # Optimizer includes both classifier AND GNN encoders\n",
    "    params = list(model2.parameters()) + list(gnn_encoders.parameters())\n",
    "    opt = torch.optim.Adam(params, lr=3e-3, weight_decay=1e-4)\n",
    "\n",
    "    lambda_div = 0.1\n",
    "    lambda_rec = 0.05\n",
    "\n",
    "    def run_one_epoch(loader, train: bool):\n",
    "        if train:\n",
    "            model2.train()\n",
    "            gnn_encoders.train()\n",
    "        else:\n",
    "            model2.eval()\n",
    "            gnn_encoders.eval()\n",
    "\n",
    "        total_loss = 0.0\n",
    "        total_correct = 0\n",
    "        total_graphs = 0\n",
    "\n",
    "        for batch in loader:\n",
    "            X_list, edge_index_list, y_list = split_batch_to_graphs(batch)\n",
    "            \n",
    "            # --- Inductive Step: Build Trees with GNN Encoders ---\n",
    "            # Encoders are on GPU; Make_tree uses them, but logic handles CPU/GPU\n",
    "            # Note: Make_tree_HMH iterates encoders. \n",
    "            # Sinkhorn works best on same device as encoder output (GPU).\n",
    "            # Tree construction logic (adjacency) mostly typically CPU via SciPy, \n",
    "            # but GNN forward pass is GPU.\n",
    "            \n",
    "            (U_batch, eidx_batch, n_nodes_batch, n_edges_batch,\n",
    "             feats_batch, tree_batch, S_batch) = Uext_batch_from_tree_lists_HMH(\n",
    "                X_list, edge_index_list,\n",
    "                levels=levels,\n",
    "                gnn_encoders=gnn_encoders,\n",
    "                ratio=0.5,\n",
    "                device=device, dtype=torch.float64,\n",
    "                assign_method=\"sinkhorn\", tau=0.9, sinkhorn_iters=10,\n",
    "                seed=42\n",
    "            )\n",
    "\n",
    "            # --- Forward Pass ---\n",
    "            loss_batch = 0.0\n",
    "            for i in range(len(U_batch)):\n",
    "                logits_nodes = model2(U_batch[i], feats_batch[i], tree_batch[i]) \n",
    "                logits_graph = logits_nodes.mean(dim=0, keepdim=True)\n",
    "                \n",
    "                yi = torch.as_tensor([y_list[i].item()], dtype=torch.long, device=device)\n",
    "                L_ce = F.cross_entropy(logits_graph, yi)\n",
    "                \n",
    "                L_div = loss_diversity_from_S(S_batch[i], device=device)\n",
    "                L_rec = loss_reconstruction_from_treeG(tree_batch[i], device=device)\n",
    "                \n",
    "                L_total = L_ce + lambda_div * L_div \n",
    "                loss_batch += L_total\n",
    "\n",
    "                # Acc\n",
    "                pred = logits_graph.argmax(dim=1)\n",
    "                total_correct += int((pred == yi).sum().item())\n",
    "                total_graphs += 1\n",
    "            \n",
    "            loss_batch = loss_batch / len(U_batch)\n",
    "\n",
    "            if train:\n",
    "                opt.zero_grad()\n",
    "                loss_batch.backward()\n",
    "                opt.step()\n",
    "\n",
    "            total_loss += loss_batch.item() * len(U_batch)\n",
    "\n",
    "        avg_loss = total_loss / max(total_graphs, 1)\n",
    "        acc = total_correct / max(total_graphs, 1)\n",
    "        return avg_loss, acc\n",
    "\n",
    "    # Loop\n",
    "    epochs = 300\n",
    "    print(\"Start Training...\")\n",
    "    for epoch in range(1, epochs + 1):\n",
    "        tr_loss, tr_acc = run_one_epoch(train_loader, train=True)\n",
    "        va_loss, va_acc = run_one_epoch(val_loader,   train=False)\n",
    "        te_loss, te_acc = run_one_epoch(test_loader,  train=False)\n",
    "\n",
    "        print(f\"Epoch {epoch:02d} | Tr: {tr_loss:.4f} {tr_acc:.3f} | Va: {va_loss:.4f} {va_acc:.3f} | Te: {te_acc:.3f}\")\n",
    "\n",
    "    print(\"Done.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
      "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
      "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
      "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "import scipy.sparse.linalg as spla\n",
    "from scipy.special import iv as bessel_I\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.datasets import TUDataset\n",
    "\n",
    "# Try importing FAISS\n",
    "try:\n",
    "    import faiss\n",
    "    _HAS_FAISS = True\n",
    "except ImportError:\n",
    "    raise ImportError(\"FAISS is required for this implementation. Please install `faiss-cpu` or `faiss-gpu`.\")\n",
    "\n",
    "# ==============================================================================\n",
    "# 1. Utilities & Math Helpers\n",
    "# ==============================================================================\n",
    "\n",
    "def set_seed(seed=0):\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "def decide_K(N_cur, ratio, last_level):\n",
    "    if N_cur <= 2 or last_level:\n",
    "        return 1\n",
    "    K = int(N_cur * ratio) + 1\n",
    "    return max(1, min(K, N_cur))\n",
    "\n",
    "def symmetrize_simple(W: sp.csr_matrix):\n",
    "    W = W.maximum(W.T)\n",
    "    W.setdiag(0)\n",
    "    W.eliminate_zeros()\n",
    "    return W\n",
    "\n",
    "def coarsen_adj_hard(A: sp.csr_matrix, hard_labels: np.ndarray, K: int):\n",
    "    rr, cc, vv = sp.find(A)\n",
    "    nrr, ncc = hard_labels[rr], hard_labels[cc]\n",
    "    A_coarse = sp.csr_matrix((vv, (nrr, ncc)), shape=(K, K))\n",
    "    A_coarse = symmetrize_simple(A_coarse)\n",
    "    return A_coarse\n",
    "\n",
    "def _deg_vec(A):\n",
    "    d = np.asarray(A.sum(axis=1)).ravel().astype(float)\n",
    "    d[d <= 0.0] = 1e-12\n",
    "    return d\n",
    "\n",
    "def _dnorm(u, d):\n",
    "    return np.sqrt(float((d * (u * u)).sum()) + 1e-24)\n",
    "\n",
    "def to_scipy_sparse_matrix(edge_index, num_nodes):\n",
    "    if isinstance(edge_index, torch.Tensor):\n",
    "        ei = edge_index.detach().cpu().numpy()\n",
    "    else:\n",
    "        ei = np.asarray(edge_index)\n",
    "    r, c = ei\n",
    "    data = np.ones(r.size, dtype=np.float64)\n",
    "    A = sp.coo_matrix((data, (r, c)), shape=(num_nodes, num_nodes)).tocsr()\n",
    "    A = A.maximum(A.T)\n",
    "    A.setdiag(0)\n",
    "    A.eliminate_zeros()\n",
    "    return A\n",
    "\n",
    "def adj2edge(A: sp.coo_matrix):\n",
    "    A = A.tocoo()\n",
    "    row = torch.as_tensor(A.row, dtype=torch.long)\n",
    "    col = torch.as_tensor(A.col, dtype=torch.long)\n",
    "    edge_index = torch.stack([row, col], dim=0)\n",
    "    edge_weight = torch.as_tensor(A.data, dtype=torch.float32)\n",
    "    return edge_index, edge_weight\n",
    "\n",
    "def scipy_to_torch_sparse(A: sp.csr_matrix, device=\"cpu\", dtype=torch.float64):\n",
    "    A = A.tocoo()\n",
    "    indices = np.vstack([A.row, A.col])\n",
    "    i = torch.from_numpy(indices).long().to(device)\n",
    "    v = torch.from_numpy(A.data).to(device=device, dtype=dtype)\n",
    "    return torch.sparse_coo_tensor(i, v, size=A.shape, device=device, dtype=dtype).coalesce()\n",
    "\n",
    "def degree_from_csr(W: sp.csr_matrix):\n",
    "    return np.asarray(W.sum(axis=1)).ravel()\n",
    "\n",
    "# ==============================================================================\n",
    "# 2. Diffusion & Spectral Helpers (The \"Good\" Code)\n",
    "# ==============================================================================\n",
    "\n",
    "def normalized_laplacian(W: sp.csr_matrix):\n",
    "    W = W.tocsr()\n",
    "    d = np.asarray(W.sum(axis=1)).ravel()\n",
    "    d_safe = np.maximum(d, 1e-12)\n",
    "    Dinv_sqrt = sp.diags(1.0 / np.sqrt(d_safe))\n",
    "    N = W.shape[0]\n",
    "    L = sp.eye(N, format='csr') - Dinv_sqrt @ W @ Dinv_sqrt\n",
    "    return L, d\n",
    "\n",
    "def mix_laplacian(L_top: sp.csr_matrix, L_feat: sp.csr_matrix, alpha=(0.5, 0.5)):\n",
    "    a = np.asarray(alpha, dtype=float)\n",
    "    a = a / (a.sum() + 1e-12)\n",
    "    return a[0] * L_top + a[1] * L_feat\n",
    "\n",
    "def heat_kernel_apply(L: sp.csr_matrix, t=0.6, Omega=None, order=25):\n",
    "    n = L.shape[0]\n",
    "    if Omega is None:\n",
    "        Omega = np.random.randn(n, 32)\n",
    "    A = (L - sp.eye(n, format='csr')) # Shifted L (spectrum ~ [-1, 1])\n",
    "    a = t\n",
    "    Y = np.zeros_like(Omega, dtype=float)\n",
    "    T0 = Omega.copy()\n",
    "    T1 = A @ Omega\n",
    "    Y += bessel_I(0, a) * T0\n",
    "    Tk_minus_1, Tk = T0, T1\n",
    "    for k in range(1, order + 1):\n",
    "        ck = 2.0 * ((-1)**k) * bessel_I(k, a)\n",
    "        Y += ck * Tk\n",
    "        Tk_plus_1 = 2.0 * (A @ Tk) - Tk_minus_1\n",
    "        Tk_minus_1, Tk = Tk, Tk_plus_1\n",
    "    Y *= np.exp(-t)\n",
    "    nrm = np.linalg.norm(Y, axis=1, keepdims=True) + 1e-12\n",
    "    return Y / nrm\n",
    "\n",
    "def _safe_k(n: int, k: int) -> int:\n",
    "    return max(0, min(int(k), max(0, n - 1)))\n",
    "\n",
    "def faiss_knn_dense_strict(X: np.ndarray, k: int):\n",
    "    Xf = np.ascontiguousarray(X.astype(np.float32))\n",
    "    n = Xf.shape[0]\n",
    "    if n == 0: return np.empty((0, 0), dtype=int), np.empty((0, 0), dtype=np.float32)\n",
    "    k_eff = _safe_k(n, k)\n",
    "    if k_eff == 0: return np.empty((n, 0), dtype=int), np.empty((n, 0), dtype=np.float32)\n",
    "\n",
    "    kq = min(n, k_eff + 1)\n",
    "    index = faiss.IndexFlatL2(Xf.shape[1])\n",
    "    index.add(Xf)\n",
    "    D_full, I_full = index.search(Xf, kq)\n",
    "\n",
    "    I_out = np.empty((n, k_eff), dtype=int)\n",
    "    D_out = np.empty((n, k_eff), dtype=np.float32)\n",
    "\n",
    "    for i in range(n):\n",
    "        rowI, rowD = I_full[i], D_full[i]\n",
    "        mask = rowI != i\n",
    "        rowI = rowI[mask]\n",
    "        rowD = rowD[mask]\n",
    "        if rowI.size >= k_eff:\n",
    "            I_out[i] = rowI[:k_eff]\n",
    "            D_out[i] = rowD[:k_eff]\n",
    "        else:\n",
    "            need = k_eff - rowI.size\n",
    "            I_out[i] = np.pad(rowI, (0, need), mode='edge')\n",
    "            D_out[i] = np.pad(rowD, (0, need), mode='edge')\n",
    "    return I_out, D_out\n",
    "\n",
    "def feature_knn_graph_faiss_safe(X: np.ndarray, k=15, sigma=None):\n",
    "    n = X.shape[0]\n",
    "    k_eff = _safe_k(n, k)\n",
    "    if n == 0 or k_eff == 0: return sp.csr_matrix((n, n))\n",
    "    I, D2 = faiss_knn_dense_strict(X, k_eff)\n",
    "    if sigma is None:\n",
    "        flat = D2.ravel()\n",
    "        pos = flat[np.isfinite(flat) & (flat > 0)]\n",
    "        sigma = float(np.median(pos)) if pos.size > 0 else 1.0\n",
    "        if sigma <= 1e-12: sigma = 1.0\n",
    "    rows = np.repeat(np.arange(n), k_eff)\n",
    "    cols = I.ravel()\n",
    "    weights = np.exp(-D2.ravel() / (2.0 * (sigma ** 2)))\n",
    "    W = sp.csr_matrix((weights, (rows, cols)), shape=(n, n))\n",
    "    W = W.maximum(W.T)\n",
    "    W.setdiag(0); W.eliminate_zeros()\n",
    "    return W\n",
    "\n",
    "def incompatibility_from_dense_knn(I_knn: np.ndarray, degree: np.ndarray):\n",
    "    n = I_knn.shape[0]\n",
    "    d = np.maximum(degree, 1e-12)\n",
    "    rows, cols, vals = [], [], []\n",
    "    in_ball = [set(I_knn[i].tolist() + [i]) for i in range(n)]\n",
    "    for i in range(n):\n",
    "        comp = [j for j in range(n) if j not in in_ball[i]]\n",
    "        if not comp: continue\n",
    "        wij = (1.0 / d[i]) * (1.0 / d[np.asarray(comp)])\n",
    "        rows.extend([i] * len(comp))\n",
    "        cols.extend(comp)\n",
    "        vals.extend(wij.tolist())\n",
    "    M = sp.csr_matrix((vals, (rows, cols)), shape=(n, n))\n",
    "    M = M.maximum(M.T); M.setdiag(0); M.eliminate_zeros()\n",
    "    return M\n",
    "\n",
    "# ==============================================================================\n",
    "# 3. Spectral Solver & Assignment Logic\n",
    "# ==============================================================================\n",
    "\n",
    "@torch.no_grad()\n",
    "def estimate_lmax_power_sparse(L_sp: torch.Tensor, iters: int = 20):\n",
    "    n = L_sp.size(0)\n",
    "    x = torch.randn(n, 1, device=L_sp.device, dtype=L_sp.dtype)\n",
    "    x = x / (x.norm() + 1e-12)\n",
    "    lam = None\n",
    "    for _ in range(iters):\n",
    "        y = torch.sparse.mm(L_sp, x)\n",
    "        ny = y.norm()\n",
    "        if ny.item() == 0.0: return torch.tensor(0.0, device=L_sp.device, dtype=L_sp.dtype)\n",
    "        x = y / ny\n",
    "        lam = (x.t() @ torch.sparse.mm(L_sp, x)).item()\n",
    "    return torch.tensor(lam if lam is not None else 0.0, device=L_sp.device, dtype=L_sp.dtype)\n",
    "\n",
    "import torch\n",
    "from typing import Optional  # <--- Import Optional\n",
    "\n",
    "# ... previous code ...\n",
    "\n",
    "@torch.no_grad()\n",
    "def block_power_smallest(\n",
    "    L_sp: torch.Tensor, \n",
    "    K: int, \n",
    "    iters: int = 10, \n",
    "    deg_C: Optional[torch.Tensor] = None  # <--- CHANGED HERE\n",
    "):\n",
    "    device, dtype = L_sp.device, L_sp.dtype\n",
    "    nC = L_sp.size(0)\n",
    "    \n",
    "    # ... rest of the function remains the same ...\n",
    "    lmax = estimate_lmax_power_sparse(L_sp, iters=12).clamp_min(1e-12)\n",
    "    def apply_B(X): return X - (1.0 / lmax) * torch.sparse.mm(L_sp, X)\n",
    "    if deg_C is None: deg_C = torch.ones(nC, device=device, dtype=dtype)\n",
    "    s = deg_C.clamp_min(1e-12).sqrt().unsqueeze(1)\n",
    "    s = s / (s.norm() + 1e-12)\n",
    "    Q = torch.randn(nC, K, device=device, dtype=dtype)\n",
    "    Q, _ = torch.linalg.qr(Q, mode='reduced')\n",
    "    for _ in range(iters):\n",
    "        Q = apply_B(Q)\n",
    "        Q = Q - s * (s.t() @ Q)\n",
    "        Q, _ = torch.linalg.qr(Q, mode='reduced')\n",
    "    return Q\n",
    "\n",
    "class LinearHeadK(nn.Module):\n",
    "    def __init__(self, in_dim, K, hidden=32, dtype=torch.float64):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(in_dim, hidden, dtype=dtype), nn.ReLU(),\n",
    "            nn.Linear(hidden, K, dtype=dtype)\n",
    "        )\n",
    "    def forward(self, U): return self.net(U)\n",
    "\n",
    "def sinkhorn_balanced(logits: torch.Tensor, n_iters=7, tau=1.0):\n",
    "    n, K = logits.shape\n",
    "    P = torch.exp(logits / tau) + 1e-9\n",
    "    col_tgt = (n / K) * torch.ones(K, device=logits.device, dtype=logits.dtype)\n",
    "    for _ in range(n_iters):\n",
    "        P = P / (P.sum(dim=1, keepdim=True) + 1e-12)\n",
    "        col_sum = P.sum(dim=0) + 1e-12\n",
    "        P = P * (col_tgt / col_sum)\n",
    "    P = P / (P.sum(dim=1, keepdim=True) + 1e-12)\n",
    "    return P\n",
    "\n",
    "@torch.no_grad()\n",
    "def assignments_with_margin(U: torch.Tensor, hidden=32, zeta=1e-2, method=\"sinkhorn\", tau=1.0, sinkhorn_iters=12):\n",
    "    K = U.size(1)\n",
    "    # Project spectral coordinates to logits\n",
    "    head = LinearHeadK(in_dim=K, K=K, hidden=hidden, dtype=U.dtype).to(U.device)\n",
    "    logits = head(U)\n",
    "    logits = logits - logits.mean(dim=1, keepdim=True)\n",
    "    \n",
    "    top2 = torch.topk(logits, k=2, dim=1).values\n",
    "    num = (top2[:, 0] - top2[:, 1]).clamp_min(0.0)\n",
    "    den = top2.abs().sum(dim=1) + zeta\n",
    "    mu = (num / den).unsqueeze(1)\n",
    "    logits_scaled = mu * logits\n",
    "    \n",
    "    if method.lower() == \"sinkhorn\":\n",
    "        S = sinkhorn_balanced(logits_scaled, n_iters=sinkhorn_iters, tau=tau)\n",
    "    else:\n",
    "        S = torch.softmax(logits_scaled, dim=1)\n",
    "    return S, logits, mu.squeeze(1)\n",
    "\n",
    "# ==============================================================================\n",
    "# 4. Tree Builder (Spectral / Diffusion Implementation)\n",
    "# ==============================================================================\n",
    "\n",
    "def _unique_seeds_for_all_clusters(S: np.ndarray) -> np.ndarray:\n",
    "    N, K = S.shape\n",
    "    order = np.argsort(-S, axis=0)\n",
    "    seeds = -np.ones(K, dtype=int)\n",
    "    used = np.zeros(N, dtype=bool)\n",
    "    ptr = np.zeros(K, dtype=int)\n",
    "    remaining = list(range(K))\n",
    "    guard = 0\n",
    "    while remaining and guard < K * N:\n",
    "        k = remaining.pop(0)\n",
    "        while ptr[k] < N and used[order[ptr[k], k]]:\n",
    "            ptr[k] += 1\n",
    "        if ptr[k] < N:\n",
    "            i = order[ptr[k], k]; seeds[k] = i; used[i] = True\n",
    "        else:\n",
    "            i = int(np.argmin(used)); seeds[k] = i; used[i] = True\n",
    "        guard += 1\n",
    "    return seeds\n",
    "\n",
    "def hard_labels_cover_all(S: np.ndarray) -> np.ndarray:\n",
    "    N, K = S.shape\n",
    "    y = S.argmax(axis=1)\n",
    "    counts = np.bincount(y, minlength=K)\n",
    "    if (counts == 0).any():\n",
    "        seeds = _unique_seeds_for_all_clusters(S)\n",
    "        y[seeds] = np.arange(K, dtype=int)\n",
    "    return y\n",
    "\n",
    "def Make_tree_HMH(\n",
    "    X, A,\n",
    "    levels: int,\n",
    "    ratio: float = 0.2,\n",
    "    lam: float = 0.1,\n",
    "    k_feat: int = 15,\n",
    "    k_diff: int = 15,\n",
    "    t_heat: float = 0.6,\n",
    "    cheb_order: int = 25,\n",
    "    alpha=(0.5, 0.5),\n",
    "    device: str = \"cpu\",\n",
    "    dtype = torch.float64,\n",
    "    assign_method: str = \"sinkhorn\",\n",
    "    tau: float = 1.0,\n",
    "    sinkhorn_iters: int = 7,\n",
    "    seed: int = 0,\n",
    "):\n",
    "    \"\"\"\n",
    "    Constructs the hierarchy using Heat Kernel Diffusion and Spectral Clustering.\n",
    "    \"\"\"\n",
    "    set_seed(seed)\n",
    "    if isinstance(X, torch.Tensor): X = X.detach().cpu().numpy()\n",
    "    A = A.tocsr()\n",
    "    N_start = A.shape[0]\n",
    "\n",
    "    adj_list = [A]; features_list = [X]; parents = []; S_assign_list = []\n",
    "\n",
    "    for level in range(levels - 1):\n",
    "        N_cur = A.shape[0]\n",
    "        last_level = (level == levels - 2)\n",
    "        K = decide_K(N_cur, ratio, last_level)\n",
    "        \n",
    "        if K == 1:\n",
    "            S_triv = np.ones((N_cur, 1), dtype=np.float64)\n",
    "            S_assign_list.append(S_triv)\n",
    "            parents.append(np.zeros(N_cur, dtype=int))\n",
    "            X = S_triv.T @ X\n",
    "            A = coarsen_adj_hard(A, np.zeros(N_cur, dtype=int), K=1)\n",
    "            adj_list.append(A); features_list.append(X)\n",
    "            break\n",
    "\n",
    "        # --- SPECTRAL DIFFUSION BLOCK (Replacing GNN Encoder) ---\n",
    "        # 1) L_top\n",
    "        L_top, deg_top = normalized_laplacian(A)\n",
    "        # 2) L_feat\n",
    "        W_feat = feature_knn_graph_faiss_safe(X, k=k_feat, sigma=None)\n",
    "        L_feat, _ = normalized_laplacian(W_feat)\n",
    "        # 3) L_mix\n",
    "        L_mix = mix_laplacian(L_top, L_feat, alpha=alpha)\n",
    "        # 4) Heat Diffusion\n",
    "        Z = heat_kernel_apply(L_mix, t=t_heat, order=cheb_order)\n",
    "        I_diff, _ = faiss_knn_dense_strict(Z, k=k_diff)\n",
    "        # 5) Incompatibility Matrix\n",
    "        M_C = incompatibility_from_dense_knn(I_diff, degree=deg_top)\n",
    "        \n",
    "        # 6) L_aug (Torch)\n",
    "        Lmix_sp_t = scipy_to_torch_sparse(L_mix, device=device, dtype=dtype)\n",
    "        MC_sp_t   = scipy_to_torch_sparse(M_C,   device=device, dtype=dtype)\n",
    "        L_aug_t   = (Lmix_sp_t + lam * MC_sp_t).coalesce()\n",
    "\n",
    "        # 7) Bottom-K Eigenvectors\n",
    "        deg_t = torch.from_numpy(deg_top).to(device=device, dtype=dtype)\n",
    "        U_t = block_power_smallest(L_aug_t, K=K, iters=12, deg_C=deg_t) # [N, K]\n",
    "\n",
    "        # 8) Assignments\n",
    "        S_t, logits_t, mu_t = assignments_with_margin(\n",
    "            U_t, hidden=32, zeta=1e-3, method=assign_method, tau=tau, sinkhorn_iters=sinkhorn_iters\n",
    "        )\n",
    "        S_np = S_t.detach().cpu().numpy()\n",
    "        \n",
    "        # 9) Coarsen\n",
    "        hard_labels = hard_labels_cover_all(S_np)\n",
    "        A_next = coarsen_adj_hard(A, hard_labels, K)\n",
    "        X_next = S_np.T @ X\n",
    "        \n",
    "        adj_list.append(A_next)\n",
    "        features_list.append(X_next)\n",
    "        parents.append(hard_labels)\n",
    "        S_assign_list.append(S_np)\n",
    "        A, X = A_next, X_next\n",
    "\n",
    "    # Build Tree\n",
    "    L_eff = len(adj_list)\n",
    "    treeG = [None] * L_eff\n",
    "    for lvl in range(L_eff):\n",
    "        if lvl == 0:\n",
    "            idxs = np.arange(N_start)\n",
    "            clusters = [np.array([i], dtype=int) for i in idxs]\n",
    "            IDX_vec = np.arange(N_start)\n",
    "        else:\n",
    "            pid = parents[lvl - 1]\n",
    "            K_lvl = S_assign_list[lvl - 1].shape[1]\n",
    "            clusters = [np.flatnonzero(pid == k) for k in range(K_lvl)]\n",
    "            IDX_vec = pid\n",
    "\n",
    "        treeG[lvl] = {\n",
    "            'IDX': IDX_vec, 'clusters': clusters,\n",
    "            'adj': adj_list[lvl], 'features': features_list[lvl]\n",
    "        }\n",
    "\n",
    "    return treeG, S_assign_list\n",
    "\n",
    "# ==============================================================================\n",
    "# 5. Haar Basis & Batching\n",
    "# ==============================================================================\n",
    "\n",
    "def HaarGOB_with_Sassign_degree_norm(treeG, S_assign_list):\n",
    "    Ntr = len(treeG)\n",
    "    clusterJ0 = treeG[Ntr-1]['clusters']\n",
    "    N0 = len(clusterJ0)\n",
    "\n",
    "    chic = np.identity(N0)\n",
    "    uc = [None] * N0\n",
    "    uc[0] = (1.0 / np.sqrt(N0)) * np.ones(N0, dtype=float)\n",
    "    for l in range(1, N0):\n",
    "        uc[l] = np.sqrt((N0 - l) / (N0 - l + 1.0)) * (\n",
    "            chic[l-1, :] - (1.0 / (N0 - l)) * np.sum(chic[l:, :], axis=0)\n",
    "        )\n",
    "\n",
    "    A_top = treeG[Ntr-1]['adj'].tocsr()\n",
    "    d_top = _deg_vec(A_top)\n",
    "    for l in range(N0):\n",
    "        nrm = _dnorm(uc[l], d_top)\n",
    "        uc[l] = uc[l] / nrm\n",
    "    treeG[Ntr-1]['u'] = uc\n",
    "\n",
    "    for j_tr in np.arange(Ntr-2, -1, -1):\n",
    "        N1 = len(treeG[j_tr]['clusters'])\n",
    "        S_assign = np.asarray(S_assign_list[j_tr], dtype=float)\n",
    "        A_lvl = treeG[j_tr]['adj'].tocsr()\n",
    "        d_lvl = _deg_vec(A_lvl)\n",
    "\n",
    "        u = [None] * N1\n",
    "        i = N0\n",
    "        for l in range(N0):\n",
    "            cluster_l = np.asarray(treeG[j_tr+1]['clusters'][l], dtype=int)\n",
    "            ul1 = np.zeros(N1, dtype=float)\n",
    "            for j in range(N0):\n",
    "                idxj = np.asarray(treeG[j_tr+1]['clusters'][j], dtype=int)\n",
    "                if idxj.size == 0: continue\n",
    "                w = S_assign[idxj, l]\n",
    "                ul1[idxj] += uc[l][j] * w\n",
    "            nrm = _dnorm(ul1, d_lvl)\n",
    "            if nrm > 0: ul1 = ul1 / nrm\n",
    "            u[l] = ul1\n",
    "\n",
    "            kl = int(cluster_l.size)\n",
    "            if kl > 1:\n",
    "                chil = np.zeros((kl, N1), dtype=float)\n",
    "                for k in range(kl):\n",
    "                    chil[k, cluster_l[k]] = 1.0\n",
    "                for k in range(1, kl):\n",
    "                    i += 1\n",
    "                    ulk = np.sqrt((kl - k) / (kl - k + 1.0)) * (\n",
    "                        chil[k-1, :] - (1.0 / (kl - k)) * np.sum(chil[k:, :], axis=0)\n",
    "                    )\n",
    "                    nrmk = _dnorm(ulk, d_lvl)\n",
    "                    if nrmk > 0: ulk = ulk / nrmk\n",
    "                    u[i-1] = ulk\n",
    "        treeG[j_tr]['u'] = u\n",
    "        uc = u\n",
    "        N0 = N1\n",
    "    return treeG\n",
    "\n",
    "def extract_haar_basis_and_graph_info(tree_real):\n",
    "    Tree_length = len(tree_real)\n",
    "    edge_index_list = [None] * Tree_length\n",
    "    U = []\n",
    "    features_list = []\n",
    "    for j in range(Tree_length):\n",
    "        u = tree_real[j]['u']\n",
    "        N = len(u)\n",
    "        HaarBases = np.zeros((N, N), dtype=np.float64)\n",
    "        for k in range(N):\n",
    "            HaarBases[:, k] = u[k]\n",
    "        U.append(HaarBases)\n",
    "        edge_index, _ = adj2edge(tree_real[j]['adj'])\n",
    "        edge_index_list[j] = edge_index\n",
    "        features_list.append(tree_real[j]['features'])\n",
    "    return U, edge_index_list, features_list\n",
    "\n",
    "def Uext_batch_from_tree_lists_HMH(X_list, edge_index_list, levels, ratio=0.3,\n",
    "                                   lam=0.1, k_feat=15, k_diff=15, t_heat=0.6, cheb_order=25, alpha=(0.5,0.5),\n",
    "                                   device=\"cpu\", dtype=torch.float64, assign_method=\"sinkhorn\", \n",
    "                                   tau=0.9, sinkhorn_iters=10, seed=42):\n",
    "    U_batch, feats_batch, tree_batch, S_batch = [], [], [], []\n",
    "    for X_i, ei_i in zip(X_list, edge_index_list):\n",
    "        A_i = to_scipy_sparse_matrix(ei_i, num_nodes=X_i.shape[0])\n",
    "        treeG_i, S_assign_list = Make_tree_HMH(\n",
    "            X=X_i, A=A_i, levels=levels, ratio=ratio, lam=lam, k_feat=k_feat, \n",
    "            k_diff=k_diff, t_heat=t_heat, cheb_order=cheb_order, alpha=alpha, \n",
    "            device=device, dtype=dtype, assign_method=assign_method, tau=tau, \n",
    "            sinkhorn_iters=sinkhorn_iters, seed=seed\n",
    "        )\n",
    "        treeG_i = HaarGOB_with_Sassign_degree_norm(treeG_i, S_assign_list)\n",
    "        U_i, _, feats_i = extract_haar_basis_and_graph_info(treeG_i)\n",
    "        U_batch.append(U_i)\n",
    "        feats_batch.append(feats_i)\n",
    "        tree_batch.append(treeG_i)\n",
    "        S_batch.append(S_assign_list)\n",
    "    return U_batch, feats_batch, tree_batch, S_batch\n",
    "\n",
    "# ==============================================================================\n",
    "# 6. Classifier & Main Execution\n",
    "# ==============================================================================\n",
    "\n",
    "def _to_dense_torch(mat, device):\n",
    "    if isinstance(mat, np.ndarray): arr = mat\n",
    "    elif sp.issparse(mat): arr = mat.toarray()\n",
    "    else: arr = np.asarray(mat)\n",
    "    return torch.as_tensor(arr, dtype=torch.float32, device=device)\n",
    "\n",
    "def unpool_one_level(H_coarse, clusters, N_fine):\n",
    "    device = H_coarse.device\n",
    "    D = H_coarse.size(1)\n",
    "    H_fine = torch.zeros(N_fine, D, device=device)\n",
    "    for i, child_idx in enumerate(clusters):\n",
    "        if len(child_idx) == 0: continue\n",
    "        idx = torch.as_tensor(child_idx, dtype=torch.long, device=device)\n",
    "        H_fine.index_add_(0, idx, H_coarse[i].expand(idx.numel(), D))\n",
    "    return H_fine\n",
    "\n",
    "def unpool_to_level0(H_l, level_l, treeG):\n",
    "    H = H_l\n",
    "    for m in range(level_l, 0, -1):\n",
    "        clusters_m = treeG[m]['clusters']\n",
    "        N_fine     = treeG[m-1]['adj'].shape[0]\n",
    "        H = unpool_one_level(H, clusters_m, N_fine)\n",
    "    return H\n",
    "\n",
    "class HaarSpectralBlock(nn.Module):\n",
    "    def __init__(self, max_K: int):\n",
    "        super().__init__()\n",
    "        self.lambda_vec = nn.Parameter(torch.randn(max_K))\n",
    "    def forward(self, U: torch.Tensor, X: torch.Tensor):\n",
    "        K_l = U.size(1)\n",
    "        K_cap = min(K_l, self.lambda_vec.size(0))\n",
    "        Uc = U[:, :K_cap]\n",
    "        X_hat = Uc.transpose(0, 1) @ X\n",
    "        lam = self.lambda_vec[:K_cap].unsqueeze(1)\n",
    "        X_hat = X_hat * lam\n",
    "        H = Uc @ X_hat\n",
    "        return F.relu(H)\n",
    "\n",
    "class NodeHaarUnpoolClassifier(nn.Module):\n",
    "    def __init__(self, in_dim, hid_dim, num_classes, max_K, num_levels):\n",
    "        super().__init__()\n",
    "        self.num_levels = num_levels\n",
    "        self.pre = nn.Sequential(\n",
    "            nn.Linear(in_dim, hid_dim), nn.ReLU(),\n",
    "            nn.Linear(hid_dim, hid_dim)\n",
    "        )\n",
    "        self.block = HaarSpectralBlock(max_K=max_K)\n",
    "        self.classifier = nn.Linear(hid_dim * num_levels, num_classes)\n",
    "        self.dropout = nn.Dropout(p=0.3)\n",
    "    def forward(self, U_list, features_list, treeG):\n",
    "        device = next(self.parameters()).device\n",
    "        L_eff = min(self.num_levels, len(U_list))\n",
    "        H_per_level = []\n",
    "        for l in range(L_eff):\n",
    "            X_l = _to_dense_torch(features_list[l], device)\n",
    "            X_l = self.dropout(self.pre(X_l))\n",
    "            U_l = _to_dense_torch(U_list[l], device)\n",
    "            H_l = self.block(U_l, X_l)\n",
    "            H0_l = unpool_to_level0(H_l, level_l=l, treeG=treeG)\n",
    "            H_per_level.append(H0_l)\n",
    "        H0_cat = torch.cat(H_per_level, dim=1)\n",
    "        H0_cat = self.dropout(H0_cat)\n",
    "        logits = self.classifier(H0_cat)\n",
    "        return logits\n",
    "\n",
    "# ---- Loss helpers ----\n",
    "def loss_diversity_from_S(S_assign_list, device=None, eps=1e-9):\n",
    "    L_div = 0.0\n",
    "    for S in S_assign_list:\n",
    "        if isinstance(S, np.ndarray): S_t = torch.from_numpy(S)\n",
    "        else: S_t = S\n",
    "        if device is not None: S_t = S_t.to(device)\n",
    "        S_t = S_t.clamp_min(eps)\n",
    "        row_entropy = -(S_t * S_t.log()).sum(dim=1)\n",
    "        L_div = L_div + row_entropy.mean()\n",
    "    return L_div\n",
    "\n",
    "def loss_reconstruction_from_treeG(treeG, device=None):\n",
    "    L_rec = 0.0\n",
    "    for lvl in range(len(treeG)):\n",
    "        if 'u' not in treeG[lvl]: continue\n",
    "        u_list = treeG[lvl]['u']\n",
    "        if u_list is None or any(v is None for v in u_list): continue\n",
    "        U_np = np.stack(u_list, axis=0)\n",
    "        H_np = treeG[lvl]['features']\n",
    "        U = torch.from_numpy(U_np.astype(np.float32))\n",
    "        H = torch.from_numpy(H_np.astype(np.float32))\n",
    "        if device is not None: U = U.to(device); H = H.to(device)\n",
    "        H_hat = U.t() @ (U @ H)\n",
    "        L_rec = L_rec + F.mse_loss(H_hat, H, reduction='sum')\n",
    "    return L_rec\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    print(\"Device:\", device)\n",
    "\n",
    "    # 1. Load Data\n",
    "    # Use home directory to avoid read-only file system errors if running from root\n",
    "    # Load Data\n",
    "    import tempfile\n",
    "    root = os.path.join(tempfile.gettempdir(), 'data', 'MUTAG')\n",
    "    dataset = TUDataset(root, name='MUTAG').shuffle()\n",
    "    \n",
    "    num_training = int(0.8 * len(dataset))\n",
    "    num_val      = int(0.1 * len(dataset))\n",
    "    num_test     = len(dataset) - (num_training + num_val)\n",
    "    train_set, val_set, test_set = torch.utils.data.random_split(\n",
    "        dataset, [num_training, num_val, num_test],\n",
    "        generator=torch.Generator().manual_seed(42)\n",
    "    )\n",
    "\n",
    "    batch_size = 32\n",
    "    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
    "    val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False)\n",
    "    test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False)\n",
    "\n",
    "    def split_batch_to_graphs(batch):\n",
    "        data_list = batch.to_data_list()\n",
    "        X_list, edge_index_list, y_list = [], [], []\n",
    "        for data in data_list:\n",
    "            x = data.x\n",
    "            if x is None or x.numel() == 0: x = torch.ones(data.num_nodes, 1, dtype=torch.float32)\n",
    "            X_list.append(x)\n",
    "            edge_index_list.append(data.edge_index)\n",
    "            y = data.y.view(-1)[0].long()\n",
    "            y_list.append(y)\n",
    "        return X_list, edge_index_list, y_list\n",
    "\n",
    "    # 2. Model Init (No Encoder, just Classifier)\n",
    "    input_dim = dataset.num_features if dataset.num_features > 0 else 1\n",
    "    num_classes = dataset.num_classes\n",
    "    hid_dim = 32\n",
    "    max_K = 32\n",
    "    levels = 4\n",
    "\n",
    "    model2 = NodeHaarUnpoolClassifier(\n",
    "        in_dim=input_dim, hid_dim=hid_dim, num_classes=num_classes,\n",
    "        max_K=max_K, num_levels=levels-1\n",
    "    ).to(device)\n",
    "\n",
    "    opt = torch.optim.Adam(model2.parameters(), lr=3e-3, weight_decay=1e-4)\n",
    "\n",
    "    # 3. Train Loop\n",
    "    lambda_div = 0.1\n",
    "    lambda_rec = 0.05\n",
    "\n",
    "    def run_one_epoch(loader, train: bool):\n",
    "        if train: model2.train()\n",
    "        else: model2.eval()\n",
    "\n",
    "        total_loss = 0.0; total_correct = 0; total_graphs = 0\n",
    "\n",
    "        for batch in loader:\n",
    "            X_list, edge_index_list, y_list = split_batch_to_graphs(batch)\n",
    "            \n",
    "            # --- Non-Parametric Heterophilous Tree Construction (CPU/Mixed) ---\n",
    "            # This is slow per batch but mathematically robust.\n",
    "            U_batch, feats_batch, tree_batch, S_batch = Uext_batch_from_tree_lists_HMH(\n",
    "                X_list, edge_index_list, levels=levels, ratio=0.5,\n",
    "                lam=0.1, k_feat=4, k_diff=4, t_heat=0.6, cheb_order=25, alpha=(0.5,0.5),\n",
    "                device=device, dtype=torch.float64, assign_method=\"sinkhorn\",\n",
    "                tau=0.9, sinkhorn_iters=10, seed=42\n",
    "            )\n",
    "\n",
    "            loss_batch = 0.0\n",
    "            for i in range(len(U_batch)):\n",
    "                # Forward (Classifier only)\n",
    "                logits_nodes = model2(U_batch[i], feats_batch[i], tree_batch[i]) \n",
    "                logits_graph = logits_nodes.mean(dim=0, keepdim=True)\n",
    "                \n",
    "                yi = torch.as_tensor([y_list[i].item()], dtype=torch.long, device=device)\n",
    "                L_ce = F.cross_entropy(logits_graph, yi)\n",
    "                \n",
    "                L_div = loss_diversity_from_S(S_batch[i], device=device)\n",
    "                L_rec = loss_reconstruction_from_treeG(tree_batch[i], device=device)\n",
    "                \n",
    "                loss_batch += L_ce + lambda_div * L_div + lambda_rec * L_rec\n",
    "                \n",
    "                pred = logits_graph.argmax(dim=1)\n",
    "                total_correct += int((pred == yi).sum().item())\n",
    "                total_graphs += 1\n",
    "            \n",
    "            loss_batch = loss_batch / len(U_batch)\n",
    "\n",
    "            if train:\n",
    "                opt.zero_grad()\n",
    "                loss_batch.backward()\n",
    "                opt.step()\n",
    "\n",
    "            total_loss += loss_batch.item() * len(U_batch)\n",
    "\n",
    "        avg_loss = total_loss / max(total_graphs, 1)\n",
    "        acc = total_correct / max(total_graphs, 1)\n",
    "        return avg_loss, acc\n",
    "\n",
    "    print(\"Start Training (Heterophilous Spectral Encoder)...\")\n",
    "    for epoch in range(1, 51):\n",
    "        tr_loss, tr_acc = run_one_epoch(train_loader, train=True)\n",
    "        va_loss, va_acc = run_one_epoch(val_loader,   train=False)\n",
    "        te_loss, te_acc = run_one_epoch(test_loader,  train=False)\n",
    "        print(f\"Epoch {epoch:02d} | Tr: {tr_loss:.4f} {tr_acc:.3f} | Va: {va_loss:.4f} {va_acc:.3f} | Te: {te_acc:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device: cpu\n"
     ]
    },
    {
     "ename": "OSError",
     "evalue": "[Errno 30] Read-only file system: '/data'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mOSError\u001b[0m                                   Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[2], line 609\u001b[0m\n\u001b[1;32m    605\u001b[0m \u001b[38;5;66;03m# Load\u001b[39;00m\n\u001b[1;32m    606\u001b[0m \u001b[38;5;66;03m# Use a writable local data directory (os.path.abspath('') can be '/' in some environments)\u001b[39;00m\n\u001b[1;32m    607\u001b[0m \u001b[38;5;66;03m# Prefer the current working directory so we don't attempt to write to root ('/data').\u001b[39;00m\n\u001b[1;32m    608\u001b[0m path \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(os\u001b[38;5;241m.\u001b[39mgetcwd(), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdata\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMUTAG\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m--> 609\u001b[0m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmakedirs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexist_ok\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m    610\u001b[0m dataset \u001b[38;5;241m=\u001b[39m TUDataset(path, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMUTAG\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m.\u001b[39mshuffle()\n\u001b[1;32m    612\u001b[0m loader \u001b[38;5;241m=\u001b[39m DataLoader(dataset, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m32\u001b[39m, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
      "File \u001b[0;32m/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/os.py:215\u001b[0m, in \u001b[0;36mmakedirs\u001b[0;34m(name, mode, exist_ok)\u001b[0m\n\u001b[1;32m    213\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m head \u001b[38;5;129;01mand\u001b[39;00m tail \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m path\u001b[38;5;241m.\u001b[39mexists(head):\n\u001b[1;32m    214\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 215\u001b[0m         \u001b[43mmakedirs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhead\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexist_ok\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexist_ok\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    216\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mFileExistsError\u001b[39;00m:\n\u001b[1;32m    217\u001b[0m         \u001b[38;5;66;03m# Defeats race condition when another thread created the path\u001b[39;00m\n\u001b[1;32m    218\u001b[0m         \u001b[38;5;28;01mpass\u001b[39;00m\n",
      "File \u001b[0;32m/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/os.py:225\u001b[0m, in \u001b[0;36mmakedirs\u001b[0;34m(name, mode, exist_ok)\u001b[0m\n\u001b[1;32m    223\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m    224\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 225\u001b[0m     \u001b[43mmkdir\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    226\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m:\n\u001b[1;32m    227\u001b[0m     \u001b[38;5;66;03m# Cannot rely on checking for EEXIST, since the operating system\u001b[39;00m\n\u001b[1;32m    228\u001b[0m     \u001b[38;5;66;03m# could give priority to other errors like EACCES or EROFS\u001b[39;00m\n\u001b[1;32m    229\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exist_ok \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m path\u001b[38;5;241m.\u001b[39misdir(name):\n",
      "\u001b[0;31mOSError\u001b[0m: [Errno 30] Read-only file system: '/data'"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "import scipy.sparse.linalg as spla\n",
    "from scipy.special import iv as bessel_I\n",
    "from typing import Optional\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.datasets import TUDataset\n",
    "\n",
    "# ==============================================================================\n",
    "# 1. Stable Euclidean Distance (No FAISS)\n",
    "# ==============================================================================\n",
    "\n",
    "def get_knn_indices_numpy(X: np.ndarray, k: int):\n",
    "    \"\"\"\n",
    "    Exact k-NN search using pure NumPy (Stable & Crash-proof).\n",
    "    Returns: indices (N, k), distances_squared (N, k)\n",
    "    \"\"\"\n",
    "    # X: [N, D]\n",
    "    # Compute pairwise squared Euclidean distance: ||A-B||^2 = ||A||^2 + ||B||^2 - 2A.B\n",
    "    # This is O(N^2) which is perfectly fine for MUTAG (N ~ 20-100).\n",
    "    \n",
    "    # 1. Dot product\n",
    "    X_sq = np.sum(X**2, axis=1, keepdims=True) # [N, 1]\n",
    "    dist_sq = X_sq + X_sq.T - 2 * np.dot(X, X.T)\n",
    "    \n",
    "    # 2. Fix numerical tiny negatives\n",
    "    dist_sq = np.maximum(dist_sq, 0.0)\n",
    "    \n",
    "    # 3. Mask self-loops (dist to self is 0) by setting diagonal to Inf\n",
    "    np.fill_diagonal(dist_sq, np.inf)\n",
    "    \n",
    "    # 4. Partition to get top-k (faster than full sort)\n",
    "    n = X.shape[0]\n",
    "    k_eff = min(k, n - 1)\n",
    "    \n",
    "    if k_eff <= 0:\n",
    "        return np.empty((n, 0), dtype=int), np.empty((n, 0), dtype=float)\n",
    "    \n",
    "    # argpartition puts the k smallest elements at indices 0..k-1\n",
    "    indices = np.argpartition(dist_sq, k_eff, axis=1)[:, :k_eff]\n",
    "    \n",
    "    # Get values and sort them (optional but good for consistency)\n",
    "    row_idx = np.arange(n)[:, None]\n",
    "    vals = dist_sq[row_idx, indices]\n",
    "    \n",
    "    # Sort locally\n",
    "    sort_order = np.argsort(vals, axis=1)\n",
    "    indices = indices[row_idx, sort_order]\n",
    "    vals = vals[row_idx, sort_order]\n",
    "    \n",
    "    return indices, vals\n",
    "\n",
    "def feature_knn_graph_numpy(X: np.ndarray, k=15, sigma=None):\n",
    "    n = X.shape[0]\n",
    "    k_eff = min(k, n - 1)\n",
    "    if n <= 1 or k_eff == 0:\n",
    "        return sp.csr_matrix((n, n), dtype=float)\n",
    "\n",
    "    I, D2 = get_knn_indices_numpy(X, k_eff)\n",
    "    \n",
    "    # Heuristic sigma\n",
    "    if sigma is None:\n",
    "        flat = D2.ravel()\n",
    "        # Filter out 0 or inf\n",
    "        valid = flat[np.isfinite(flat) & (flat > 1e-9)]\n",
    "        sigma = float(np.median(valid)) if valid.size > 0 else 1.0\n",
    "        sigma = max(sigma, 1e-6)\n",
    "\n",
    "    # Gaussian kernel\n",
    "    rows = np.repeat(np.arange(n), k_eff)\n",
    "    cols = I.ravel()\n",
    "    weights = np.exp(-D2.ravel() / (2.0 * sigma**2))\n",
    "\n",
    "    W = sp.csr_matrix((weights, (rows, cols)), shape=(n, n))\n",
    "    W = W.maximum(W.T) # Symmetrize\n",
    "    W.setdiag(0)\n",
    "    W.eliminate_zeros()\n",
    "    return W\n",
    "\n",
    "def incompatibility_numpy_dense(I_knn: np.ndarray, degree: np.ndarray):\n",
    "    \"\"\"\n",
    "    Build M_C using dense sets (Stable).\n",
    "    \"\"\"\n",
    "    n = I_knn.shape[0]\n",
    "    # Ensure degree is safe\n",
    "    d = np.maximum(degree, 1e-6)\n",
    "    \n",
    "    # Create sets for fast lookup\n",
    "    # each row i contains neighbors of i.\n",
    "    in_ball = [set(I_knn[i].tolist()) | {i} for i in range(n)]\n",
    "    \n",
    "    rows, cols, vals = [], [], []\n",
    "    \n",
    "    for i in range(n):\n",
    "        # Find all nodes j NOT in the ball of i\n",
    "        # Since MUTAG N is small, a loop over N is instant.\n",
    "        curr_ball = in_ball[i]\n",
    "        \n",
    "        # Identify incompatibles\n",
    "        out_nodes = [j for j in range(n) if j not in curr_ball]\n",
    "        \n",
    "        if not out_nodes:\n",
    "            continue\n",
    "            \n",
    "        out_arr = np.array(out_nodes)\n",
    "        \n",
    "        # Formula: 1 / (d_i * d_j)\n",
    "        wij = 1.0 / (d[i] * d[out_arr])\n",
    "        \n",
    "        rows.extend([i] * len(out_arr))\n",
    "        cols.extend(out_arr)\n",
    "        vals.extend(wij)\n",
    "\n",
    "    M = sp.csr_matrix((vals, (rows, cols)), shape=(n, n))\n",
    "    M = M.maximum(M.T)\n",
    "    M.setdiag(0)\n",
    "    M.eliminate_zeros()\n",
    "    return M\n",
    "\n",
    "# ==============================================================================\n",
    "# 2. Diffusion & Tree Building (Corrected Logic)\n",
    "# ==============================================================================\n",
    "\n",
    "def normalized_laplacian(W: sp.csr_matrix):\n",
    "    W = W.tocsr()\n",
    "    d = np.asarray(W.sum(axis=1)).ravel()\n",
    "    d_safe = np.maximum(d, 1e-6)\n",
    "    Dinv_sqrt = sp.diags(1.0 / np.sqrt(d_safe))\n",
    "    N = W.shape[0]\n",
    "    # L = I - D^-0.5 W D^-0.5\n",
    "    L = sp.eye(N, format='csr') - Dinv_sqrt @ W @ Dinv_sqrt\n",
    "    return L, d\n",
    "\n",
    "def mix_laplacian(L_top: sp.csr_matrix, L_feat: sp.csr_matrix, alpha=(0.5, 0.5)):\n",
    "    # Normalize alpha\n",
    "    a = np.array(alpha, dtype=float)\n",
    "    a /= a.sum()\n",
    "    return a[0] * L_top + a[1] * L_feat\n",
    "\n",
    "def heat_kernel_cheby(L: sp.csr_matrix, t=0.6, order=20):\n",
    "    \"\"\"\n",
    "    Approximate exp(-tL) * Omega using Chebyshev polynomials.\n",
    "    \"\"\"\n",
    "    n = L.shape[0]\n",
    "    # Random probe vectors\n",
    "    r_probes = 32\n",
    "    Omega = np.random.randn(n, r_probes)\n",
    "    \n",
    "    # Rescale L to [-1, 1] for Chebyshev stability\n",
    "    # Since L is normalized Laplacian, eigs in [0, 2].\n",
    "    # A = L - I maps [0, 2] -> [-1, 1].\n",
    "    A = L - sp.eye(n, format='csr')\n",
    "    \n",
    "    # Chebyshev recurrence\n",
    "    # term 0\n",
    "    Y = bessel_I(0, t) * Omega\n",
    "    \n",
    "    T0 = Omega\n",
    "    T1 = A @ Omega\n",
    "    \n",
    "    # term k=1..order\n",
    "    for k in range(1, order + 1):\n",
    "        # ck = 2 * (-1)^k * Ik(t)\n",
    "        ck = 2.0 * ((-1)**k) * bessel_I(k, t)\n",
    "        Y += ck * T1\n",
    "        \n",
    "        # Recurrence: T_{k+1} = 2 A T_k - T_{k-1}\n",
    "        T_next = 2.0 * (A @ T1) - T0\n",
    "        T0, T1 = T1, T_next\n",
    "        \n",
    "    Y *= np.exp(-t)\n",
    "    \n",
    "    # Row normalize embedding\n",
    "    norms = np.linalg.norm(Y, axis=1, keepdims=True)\n",
    "    Y = Y / np.maximum(norms, 1e-9)\n",
    "    return Y\n",
    "\n",
    "# ---- Torch Helpers ----\n",
    "\n",
    "def scipy_to_torch_sparse(A: sp.csr_matrix, device=\"cpu\", dtype=torch.float64):\n",
    "    A = A.tocoo()\n",
    "    indices = np.vstack([A.row, A.col])\n",
    "    i = torch.from_numpy(indices).long().to(device)\n",
    "    v = torch.from_numpy(A.data).to(device=device, dtype=dtype)\n",
    "    return torch.sparse_coo_tensor(i, v, size=A.shape, device=device, dtype=dtype).coalesce()\n",
    "\n",
    "@torch.no_grad()\n",
    "def block_power_smallest(\n",
    "    L_sp: torch.Tensor, \n",
    "    K: int, \n",
    "    iters: int = 10, \n",
    "    deg_C: Optional[torch.Tensor] = None\n",
    "):\n",
    "    \"\"\"\n",
    "    Solves for Bottom-K eigenvectors using inverse-free power method.\n",
    "    Using shift B = I - L / lambda_max.\n",
    "    \"\"\"\n",
    "    device, dtype = L_sp.device, L_sp.dtype\n",
    "    nC = L_sp.size(0)\n",
    "    \n",
    "    # 1. Estimate Lambda Max (Power Iteration)\n",
    "    x = torch.randn(nC, 1, device=device, dtype=dtype)\n",
    "    x = x / (x.norm() + 1e-9)\n",
    "    lmax = 2.0 # Default guess for Normalized Lap\n",
    "    for _ in range(10):\n",
    "        y = torch.sparse.mm(L_sp, x)\n",
    "        val = y.norm()\n",
    "        if val < 1e-9: break\n",
    "        x = y / val\n",
    "        lmax = (x.t() @ torch.sparse.mm(L_sp, x)).item()\n",
    "    \n",
    "    lmax = max(lmax, 1e-6)\n",
    "\n",
    "    # 2. Block Power Method\n",
    "    def apply_B(X):\n",
    "        return X - (1.0 / lmax) * torch.sparse.mm(L_sp, X)\n",
    "        \n",
    "    # Degree constraints\n",
    "    if deg_C is None:\n",
    "        deg_C = torch.ones(nC, device=device, dtype=dtype)\n",
    "    s = deg_C.clamp_min(1e-9).sqrt().unsqueeze(1)\n",
    "    s = s / s.norm()\n",
    "    \n",
    "    # Init Q\n",
    "    Q = torch.randn(nC, K, device=device, dtype=dtype)\n",
    "    Q, _ = torch.linalg.qr(Q, mode='reduced')\n",
    "    \n",
    "    for _ in range(iters):\n",
    "        Q = apply_B(Q)\n",
    "        # Orthogonalize against degree vector (balance constraint)\n",
    "        Q = Q - s * (s.t() @ Q)\n",
    "        # Re-orthonormalize\n",
    "        Q, _ = torch.linalg.qr(Q, mode='reduced')\n",
    "        \n",
    "    return Q\n",
    "\n",
    "# ---- Assignments ----\n",
    "\n",
    "class LinearHeadK(nn.Module):\n",
    "    def __init__(self, in_dim, K, hidden=32, dtype=torch.float64):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(in_dim, hidden, dtype=dtype), nn.ReLU(),\n",
    "            nn.Linear(hidden, K, dtype=dtype)\n",
    "        )\n",
    "    def forward(self, U): return self.net(U)\n",
    "\n",
    "def assignments_with_margin(U: torch.Tensor, K: int, hidden=32):\n",
    "    # U is [N, K_in] -> Project to [N, K_out]\n",
    "    # To keep it simple and static-less, we create the head on the fly or assumes U is sufficient.\n",
    "    # For stability in non-parametric, we skip the MLP head here and just use U directly \n",
    "    # if dimensions match, or pad/slice.\n",
    "    \n",
    "    N, d_in = U.shape\n",
    "    \n",
    "    if d_in >= K:\n",
    "        logits = U[:, :K]\n",
    "    else:\n",
    "        # Pad with noise/zeros\n",
    "        pad = torch.zeros(N, K - d_in, device=U.device, dtype=U.dtype)\n",
    "        logits = torch.cat([U, pad], dim=1)\n",
    "        \n",
    "    logits = logits - logits.mean(dim=1, keepdim=True)\n",
    "    \n",
    "    # Margin scaling\n",
    "    if K > 1:\n",
    "        top2 = torch.topk(logits, k=2, dim=1).values\n",
    "        margin = (top2[:, 0] - top2[:, 1]).clamp_min(0.0)\n",
    "        den = top2.abs().sum(dim=1) + 1e-3\n",
    "        mu = (margin / den).unsqueeze(1)\n",
    "    else:\n",
    "        mu = torch.ones(N, 1, device=U.device, dtype=U.dtype)\n",
    "        \n",
    "    # Sinkhorn\n",
    "    S = torch.exp(mu * logits) + 1e-9\n",
    "    col_tgt = (N/K) * torch.ones(K, device=U.device, dtype=U.dtype)\n",
    "    \n",
    "    for _ in range(10):\n",
    "        S = S / (S.sum(dim=1, keepdim=True) + 1e-9)\n",
    "        S = S * (col_tgt / (S.sum(dim=0) + 1e-9))\n",
    "    \n",
    "    S = S / (S.sum(dim=1, keepdim=True) + 1e-9)\n",
    "    return S\n",
    "\n",
    "# ==============================================================================\n",
    "# 3. Tree Construction (Main Driver)\n",
    "# ==============================================================================\n",
    "\n",
    "def coarsen_adj_hard(A: sp.csr_matrix, hard_labels: np.ndarray, K: int):\n",
    "    # A is N x N, hard_labels is N\n",
    "    # We want K x K\n",
    "    rows, cols = A.nonzero()\n",
    "    data = A.data\n",
    "    \n",
    "    new_rows = hard_labels[rows]\n",
    "    new_cols = hard_labels[cols]\n",
    "    \n",
    "    A_next = sp.csr_matrix((data, (new_rows, new_cols)), shape=(K, K))\n",
    "    A_next = A_next.maximum(A_next.T) # Symmetrize\n",
    "    A_next.setdiag(0)\n",
    "    A_next.eliminate_zeros()\n",
    "    return A_next\n",
    "\n",
    "def hard_labels_from_S(S: np.ndarray):\n",
    "    # Deterministic assignment ensuring cover\n",
    "    N, K = S.shape\n",
    "    y = S.argmax(axis=1)\n",
    "    \n",
    "    # Ensure every cluster has at least one node\n",
    "    counts = np.bincount(y, minlength=K)\n",
    "    missing = np.where(counts == 0)[0]\n",
    "    \n",
    "    if len(missing) > 0:\n",
    "        # Assign best candidates for missing clusters\n",
    "        # S[:, k] is score for cluster k\n",
    "        for k in missing:\n",
    "            best_node = np.argmax(S[:, k])\n",
    "            y[best_node] = k\n",
    "    return y\n",
    "\n",
    "def Make_tree_HMH(X, A, levels, ratio=0.5, lam=0.1, device=\"cpu\"):\n",
    "    if isinstance(X, torch.Tensor): X = X.detach().cpu().numpy()\n",
    "    A = A.tocsr()\n",
    "    \n",
    "    adj_list = [A]\n",
    "    feat_list = [X]\n",
    "    parents = []\n",
    "    S_list = []\n",
    "    \n",
    "    for lvl in range(levels - 1):\n",
    "        N = A.shape[0]\n",
    "        K = max(1, int(N * ratio))\n",
    "        \n",
    "        if K < 2 or N <= 2:\n",
    "            # Collapse rest\n",
    "            S_triv = np.ones((N, 1))\n",
    "            S_list.append(S_triv)\n",
    "            parents.append(np.zeros(N, dtype=int))\n",
    "            adj_list.append(sp.csr_matrix((1, 1)))\n",
    "            feat_list.append(S_triv.T @ X)\n",
    "            break\n",
    "            \n",
    "        # 1. Laplacian Mix\n",
    "        L_top, deg = normalized_laplacian(A)\n",
    "        W_feat = feature_knn_graph_numpy(X, k=5)\n",
    "        L_feat, _ = normalized_laplacian(W_feat)\n",
    "        L_mix = mix_laplacian(L_top, L_feat)\n",
    "        \n",
    "        # 2. Diffusion Embedding & Incompatibility\n",
    "        Z = heat_kernel_cheby(L_mix, t=0.6)\n",
    "        I_knn, _ = get_knn_indices_numpy(Z, k=5)\n",
    "        M_C = incompatibility_numpy_dense(I_knn, deg)\n",
    "        \n",
    "        # 3. Solve Eigenvectors (On CPU for stability in loop)\n",
    "        # Combine: L_aug = L_mix + lam * M_C\n",
    "        L_aug_sp = scipy_to_torch_sparse(L_mix + lam * M_C, device=device)\n",
    "        deg_t = torch.from_numpy(deg).to(device=device, dtype=torch.float64)\n",
    "        \n",
    "        U = block_power_smallest(L_aug_sp, K=K, iters=15, deg_C=deg_t)\n",
    "        \n",
    "        # 4. Assignment\n",
    "        S = assignments_with_margin(U, K)\n",
    "        S_np = S.detach().cpu().numpy()\n",
    "        \n",
    "        # 5. Coarsen\n",
    "        y = hard_labels_from_S(S_np)\n",
    "        A_next = coarsen_adj_hard(A, y, K)\n",
    "        X_next = S_np.T @ X\n",
    "        \n",
    "        adj_list.append(A_next)\n",
    "        feat_list.append(X_next)\n",
    "        parents.append(y)\n",
    "        S_list.append(S_np)\n",
    "        \n",
    "        A, X = A_next, X_next\n",
    "        \n",
    "    # Reformat to TreeG dictionary structure\n",
    "    L_eff = len(adj_list)\n",
    "    treeG = []\n",
    "    for l in range(L_eff):\n",
    "        if l == 0:\n",
    "            idxs = np.arange(adj_list[0].shape[0])\n",
    "            clusters = [np.array([i]) for i in idxs]\n",
    "            idx_vec = idxs\n",
    "        else:\n",
    "            pid = parents[l-1]\n",
    "            K_prev = S_list[l-1].shape[1]\n",
    "            clusters = [np.flatnonzero(pid == k) for k in range(K_prev)]\n",
    "            idx_vec = pid\n",
    "            \n",
    "        treeG.append({\n",
    "            'adj': adj_list[l],\n",
    "            'features': feat_list[l],\n",
    "            'clusters': clusters,\n",
    "            'IDX': idx_vec\n",
    "        })\n",
    "        \n",
    "    return treeG, S_list\n",
    "\n",
    "# ==============================================================================\n",
    "# 4. Haar Basis\n",
    "# ==============================================================================\n",
    "\n",
    "def HaarGOB_degree_norm(treeG, S_list):\n",
    "    # Precompute basis for top level\n",
    "    top_lvl = len(treeG) - 1\n",
    "    N_top = treeG[top_lvl]['adj'].shape[0]\n",
    "    \n",
    "    # 1. Init top basis (Identity / Simple Haar)\n",
    "    # Ideally standard basis is Identity for graph nodes?\n",
    "    # Using the recursive construction logic from snippet:\n",
    "    uc = [None] * N_top\n",
    "    chic = np.eye(N_top)\n",
    "    uc[0] = np.ones(N_top) / np.sqrt(N_top)\n",
    "    for l in range(1, N_top):\n",
    "        vec = chic[l-1] - (np.sum(chic[l:], axis=0) / (N_top - l))\n",
    "        uc[l] = vec * np.sqrt((N_top - l)/(N_top - l + 1))\n",
    "        \n",
    "    # Degree norm top\n",
    "    d_top = np.array(treeG[top_lvl]['adj'].sum(1)).ravel()\n",
    "    d_top[d_top==0] = 1\n",
    "    \n",
    "    for l in range(N_top):\n",
    "        nrm = np.sqrt(np.sum(d_top * uc[l]**2))\n",
    "        if nrm > 1e-9: uc[l] /= nrm\n",
    "        \n",
    "    treeG[top_lvl]['u'] = uc\n",
    "    \n",
    "    # 2. Propagate Down\n",
    "    # Loop from top-1 down to 0\n",
    "    for l_idx in range(top_lvl - 1, -1, -1):\n",
    "        parent_idx = l_idx + 1\n",
    "        N_curr = treeG[l_idx]['adj'].shape[0]\n",
    "        N_parent = treeG[parent_idx]['adj'].shape[0]\n",
    "        S = S_list[l_idx] # [N_curr, N_parent]\n",
    "        \n",
    "        d_curr = np.array(treeG[l_idx]['adj'].sum(1)).ravel()\n",
    "        d_curr[d_curr==0] = 1\n",
    "        \n",
    "        u_curr = [None] * N_curr\n",
    "        filled = 0\n",
    "        \n",
    "        # A. Low Pass (Inter-cluster)\n",
    "        for p in range(N_parent):\n",
    "            # Propagate parent vector p down\n",
    "            # vec_child = S[:, p] * parent_vec[p] ?? No, scalar weight * parent vector logic\n",
    "            # Correct logic: Child vector i inherits from parent vector p weighted by S[i, p]\n",
    "            \n",
    "            # The previous code logic:\n",
    "            vec = np.zeros(N_curr)\n",
    "            # For each parent node j, add contribution\n",
    "            for j in range(N_parent):\n",
    "                # nodes in current graph belonging to parent j\n",
    "                cluster_j = treeG[parent_idx]['clusters'][j]\n",
    "                w = S[cluster_j, p] \n",
    "                vec[cluster_j] += uc[p][j] * w\n",
    "                \n",
    "            nrm = np.sqrt(np.sum(d_curr * vec**2))\n",
    "            if nrm > 1e-9: vec /= nrm\n",
    "            u_curr[p] = vec\n",
    "            filled += 1\n",
    "            \n",
    "        # B. High Pass (Intra-cluster)\n",
    "        # For each cluster, generate N_cluster - 1 detail coefficients\n",
    "        offset = N_parent\n",
    "        for p in range(N_parent):\n",
    "            cluster = treeG[parent_idx]['clusters'][p]\n",
    "            k = len(cluster)\n",
    "            if k > 1:\n",
    "                # Local Haar on these k nodes\n",
    "                local_eye = np.eye(k)\n",
    "                for i in range(1, k):\n",
    "                    # basis vector i within cluster\n",
    "                    vec_local = local_eye[i-1] - (np.sum(local_eye[i:], axis=0) / (k - i))\n",
    "                    vec_local *= np.sqrt((k-i)/(k-i+1))\n",
    "                    \n",
    "                    vec_full = np.zeros(N_curr)\n",
    "                    vec_full[cluster] = vec_local\n",
    "                    \n",
    "                    nrm = np.sqrt(np.sum(d_curr * vec_full**2))\n",
    "                    if nrm > 1e-9: vec_full /= nrm\n",
    "                    \n",
    "                    if offset < N_curr:\n",
    "                        u_curr[offset] = vec_full\n",
    "                        offset += 1\n",
    "                        \n",
    "        treeG[l_idx]['u'] = u_curr\n",
    "        uc = u_curr\n",
    "        \n",
    "    return treeG\n",
    "\n",
    "def extract_batch(X_list, ei_list, device):\n",
    "    U_batch, tree_batch, feats_batch = [], [], []\n",
    "    \n",
    "    # Process sequentially (safe)\n",
    "    for i, (X, ei) in enumerate(zip(X_list, ei_list)):\n",
    "        A = to_scipy_sparse_matrix(ei, X.shape[0])\n",
    "        # Force CPU for tree build\n",
    "        treeG, S_list = Make_tree_HMH(X, A, levels=4, ratio=0.5, lam=0.1, device=\"cpu\")\n",
    "        treeG = HaarGOB_degree_norm(treeG, S_list)\n",
    "        \n",
    "        # Extract U matrices\n",
    "        # We need U for each level to pass to network\n",
    "        # Network expects U_list, features_list, treeG\n",
    "        U_list = []\n",
    "        features_list = []\n",
    "        \n",
    "        for lvl in range(len(treeG)):\n",
    "            # U at level l is N_l x N_l matrix (basis)\n",
    "            # Code expects dense\n",
    "            if 'u' in treeG[lvl] and treeG[lvl]['u'][0] is not None:\n",
    "                # Stack vectors\n",
    "                u_vecs = treeG[lvl]['u']\n",
    "                # Filter Nones if any\n",
    "                u_vecs = [u if u is not None else np.zeros_like(u_vecs[0]) for u in u_vecs]\n",
    "                U_mat = np.stack(u_vecs, axis=1) # [N, N]\n",
    "                U_list.append(U_mat)\n",
    "            features_list.append(treeG[lvl]['features'])\n",
    "            \n",
    "        U_batch.append(U_list)\n",
    "        feats_batch.append(features_list)\n",
    "        tree_batch.append(treeG)\n",
    "        \n",
    "    return U_batch, feats_batch, tree_batch\n",
    "\n",
    "# ==============================================================================\n",
    "# 5. Classifier (Same as before)\n",
    "# ==============================================================================\n",
    "\n",
    "class HaarSpectralBlock(nn.Module):\n",
    "    def __init__(self, max_K: int):\n",
    "        super().__init__()\n",
    "        self.lambda_vec = nn.Parameter(torch.randn(max_K))\n",
    "    def forward(self, U: torch.Tensor, X: torch.Tensor):\n",
    "        # U: [N, N], X: [N, F]\n",
    "        # Slice U to max_K\n",
    "        K = U.shape[1]\n",
    "        K_eff = min(K, self.lambda_vec.shape[0])\n",
    "        \n",
    "        U_trunc = U[:, :K_eff] # [N, K_eff]\n",
    "        \n",
    "        # Spectral Filter: U Lambda U^T X\n",
    "        X_hat = U_trunc.t() @ X # [K_eff, F]\n",
    "        lam = self.lambda_vec[:K_eff].unsqueeze(1)\n",
    "        X_filt = X_hat * lam\n",
    "        return F.relu(U_trunc @ X_filt)\n",
    "\n",
    "class NodeHaarUnpoolClassifier(nn.Module):\n",
    "    def __init__(self, in_dim, hid_dim, num_classes, max_K, num_levels):\n",
    "        super().__init__()\n",
    "        self.num_levels = num_levels\n",
    "        self.pre = nn.Sequential(nn.Linear(in_dim, hid_dim), nn.ReLU())\n",
    "        self.block = HaarSpectralBlock(max_K)\n",
    "        self.classifier = nn.Linear(hid_dim * num_levels, num_classes)\n",
    "        self.dropout = nn.Dropout(0.3)\n",
    "        \n",
    "    def forward(self, U_list, features_list, treeG):\n",
    "        device = next(self.parameters()).device\n",
    "        \n",
    "        H_levels = []\n",
    "        L = min(len(U_list), self.num_levels)\n",
    "        \n",
    "        for l in range(L):\n",
    "            X = torch.as_tensor(features_list[l], dtype=torch.float32, device=device)\n",
    "            U = torch.as_tensor(U_list[l], dtype=torch.float32, device=device)\n",
    "            \n",
    "            H = self.pre(X)\n",
    "            H = self.block(U, H)\n",
    "            \n",
    "            # Unpool to level 0\n",
    "            curr_H = H\n",
    "            for m in range(l, 0, -1):\n",
    "                # Unpool from m to m-1\n",
    "                clusters = treeG[m]['clusters']\n",
    "                N_child = treeG[m-1]['adj'].shape[0]\n",
    "                D = curr_H.shape[1]\n",
    "                next_H = torch.zeros(N_child, D, device=device)\n",
    "                \n",
    "                for p_idx, cluster in enumerate(clusters):\n",
    "                    if len(cluster) > 0:\n",
    "                        idx = torch.as_tensor(cluster, dtype=torch.long, device=device)\n",
    "                        next_H.index_add_(0, idx, curr_H[p_idx].expand(len(idx), D))\n",
    "                curr_H = next_H\n",
    "            \n",
    "            H_levels.append(curr_H)\n",
    "            \n",
    "        H_cat = torch.cat(H_levels, dim=1)\n",
    "        return self.classifier(H_cat)\n",
    "\n",
    "# ==============================================================================\n",
    "# 6. Run\n",
    "# ==============================================================================\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    # Force CPU if needed, but GPU is fine for the Network part\n",
    "    # Tree build happens on CPU now.\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    print(\"Device:\", device)\n",
    "    \n",
    "    # Load\n",
    "    # Use a writable local data directory (os.path.abspath('') can be '/' in some environments)\n",
    "    # Prefer the current working directory so we don't attempt to write to root ('/data').\n",
    "    path = os.path.join(os.getcwd(), 'data', 'MUTAG')\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    dataset = TUDataset(path, name='MUTAG').shuffle()\n",
    "    \n",
    "    loader = DataLoader(dataset, batch_size=32, shuffle=True)\n",
    "    \n",
    "    model = NodeHaarUnpoolClassifier(dataset.num_features, 32, dataset.num_classes, 32, 3).to(device)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=0.005)\n",
    "    \n",
    "    print(\"Training...\")\n",
    "    model.train()\n",
    "    \n",
    "    for epoch in range(1, 11):\n",
    "        total_loss = 0\n",
    "        total_acc = 0\n",
    "        count = 0\n",
    "        \n",
    "        for batch in loader:\n",
    "            X_list = [d.x if d.x is not None else torch.ones(d.num_nodes, 1) for d in batch.to_data_list()]\n",
    "            ei_list = [d.edge_index for d in batch.to_data_list()]\n",
    "            y_list = [d.y for d in batch.to_data_list()]\n",
    "            \n",
    "            # 1. Build Trees (CPU, Stable)\n",
    "            U_batch, _, tree_batch = extract_batch(X_list, ei_list, device)\n",
    "            \n",
    "            # 2. Forward (GPU/CPU)\n",
    "            loss_batch = 0\n",
    "            correct = 0\n",
    "            \n",
    "            for i in range(len(U_batch)):\n",
    "                logits = model(U_batch[i], _, tree_batch[i]) # features inside treeG\n",
    "                # Graph pool\n",
    "                g_logits = logits.mean(0, keepdim=True)\n",
    "                y = y_list[i].view(-1).to(device)\n",
    "                \n",
    "                loss_batch += F.cross_entropy(g_logits, y)\n",
    "                if g_logits.argmax() == y: correct += 1\n",
    "                \n",
    "            loss_batch /= len(U_batch)\n",
    "            \n",
    "            opt.zero_grad()\n",
    "            loss_batch.backward()\n",
    "            opt.step()\n",
    "            \n",
    "            total_loss += loss_batch.item() * len(U_batch)\n",
    "            total_acc += correct\n",
    "            count += len(U_batch)\n",
    "            \n",
    "        print(f\"Epoch {epoch}: Loss {total_loss/count:.4f}, Acc {total_acc/count:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def maxLength(arr):\n",
    "        # Start with one empty string (base case)\n",
    "        valid_combinations = [\"\"]\n",
    "        max_len = 0\n",
    "        \n",
    "        for word in arr:\n",
    "            # Optimization: If the word itself has duplicates (e.g. \"aba\"), \n",
    "            # it can never be used. Skip it immediately.\n",
    "            if len(set(word)) != len(word):\n",
    "                continue\n",
    "            \n",
    "            # Create a temporary list to hold new valid combos formed with this word\n",
    "            new_combos = []\n",
    "            \n",
    "            # Try adding this word to every existing valid combination\n",
    "            for candidate in valid_combinations:\n",
    "                # check that concatenating candidate and word keeps all chars unique\n",
    "                if len(set(candidate + word)) == len(candidate) + len(word):\n",
    "                    new_combination = candidate + word\n",
    "                    new_combos.append(new_combination)\n",
    "                    max_len = max(max_len, len(new_combination))\n",
    "            \n",
    "            # Add all the newly discovered combos to our main list\n",
    "            valid_combinations.extend(new_combos)\n",
    "            \n",
    "        return max_len\n",
    "arr=[\"un\",\"iq\",\"ue\"]\n",
    "maxLength(arr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
