{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c2e4078e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "======================================================================\n",
      "THREE-DOMAIN UNIVERSAL TRANSFER LEARNING\n",
      "Domains: MNIST + USPS + SVHN (digits 0-9, semantically meaningful)\n",
      "Architecture: Shared encoder + Domain-Specific BatchNorm\n",
      "Same Z space for all domains — universal manifold by construction\n",
      "Nine transfer evaluations (all source→target combinations)\n",
      "======================================================================\n",
      "\n",
      "Using device: cuda\n",
      "\n",
      "Loading datasets...\n",
      "✓ All datasets loaded\n",
      "\n",
      "MNIST-Train Sparse Dataset — Size: 54000, Sparsity: 15% visible\n",
      "MNIST-Val Sparse Dataset — Size: 3000, Sparsity: 15% visible\n",
      "MNIST-Test Sparse Dataset — Size: 3000, Sparsity: 15% visible\n",
      "USPS-Train Sparse Dataset — Size: 6561, Sparsity: 15% visible\n",
      "USPS-Val Sparse Dataset — Size: 364, Sparsity: 15% visible\n",
      "USPS-Test Sparse Dataset — Size: 366, Sparsity: 15% visible\n",
      "SVHN-Train Sparse Dataset — Size: 65931, Sparsity: 15% visible\n",
      "SVHN-Val Sparse Dataset — Size: 3662, Sparsity: 15% visible\n",
      "SVHN-Test Sparse Dataset — Size: 3664, Sparsity: 15% visible\n",
      "Computing SVHN sample weights...\n",
      "✓ Done\n",
      "\n",
      "\n",
      "======================================================================\n",
      "STEP 1: JOINT PRE-TRAINING — SHARED ENCODER + DOMAIN-SPECIFIC BN\n",
      "======================================================================\n",
      "\n",
      "======================================================================\n",
      "STEP 1: JOINT PRE-TRAINING — SHARED ENCODER + DOMAIN-SPECIFIC BN\n",
      "Shared conv weights learn universal features across all 3 domains\n",
      "Domain-specific BN normalizes each domain's statistics separately\n",
      "ALL domains map to the SAME Z space — universal manifold\n",
      "======================================================================\n",
      "\n",
      "Epoch 1, Batch 0: Total=1.1749 (MNIST=0.7405, USPS=0.2221, SVHN=0.2123)\n",
      "Epoch 1, Batch 100: Total=0.1998 (MNIST=0.1120, USPS=0.0397, SVHN=0.0481)\n",
      "\n",
      "Epoch 1/100:\n",
      "  TRAIN Total=0.3169 (MNIST=0.1793, USPS=0.0626, SVHN=0.0750)\n",
      "  VAL   Total=0.2680 (MNIST=0.1320, USPS=0.0448, SVHN=0.0912)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 2, Batch 0: Total=0.2019 (MNIST=0.1200, USPS=0.0334, SVHN=0.0485)\n",
      "Epoch 2, Batch 100: Total=0.1324 (MNIST=0.0793, USPS=0.0231, SVHN=0.0299)\n",
      "\n",
      "Epoch 2/100:\n",
      "  TRAIN Total=0.1625 (MNIST=0.0918, USPS=0.0292, SVHN=0.0415)\n",
      "  VAL   Total=0.1435 (MNIST=0.0813, USPS=0.0263, SVHN=0.0359)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 3, Batch 0: Total=0.1415 (MNIST=0.0828, USPS=0.0224, SVHN=0.0362)\n",
      "Epoch 3, Batch 100: Total=0.1123 (MNIST=0.0703, USPS=0.0207, SVHN=0.0213)\n",
      "\n",
      "Epoch 3/100:\n",
      "  TRAIN Total=0.1253 (MNIST=0.0737, USPS=0.0195, SVHN=0.0321)\n",
      "  VAL   Total=0.1147 (MNIST=0.0669, USPS=0.0185, SVHN=0.0293)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 4, Batch 0: Total=0.1135 (MNIST=0.0674, USPS=0.0171, SVHN=0.0290)\n",
      "Epoch 4, Batch 100: Total=0.1004 (MNIST=0.0612, USPS=0.0179, SVHN=0.0213)\n",
      "\n",
      "Epoch 4/100:\n",
      "  TRAIN Total=0.1065 (MNIST=0.0639, USPS=0.0158, SVHN=0.0268)\n",
      "  VAL   Total=0.1190 (MNIST=0.0646, USPS=0.0189, SVHN=0.0355)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 5, Batch 0: Total=0.0984 (MNIST=0.0605, USPS=0.0134, SVHN=0.0245)\n",
      "Epoch 5, Batch 100: Total=0.1015 (MNIST=0.0578, USPS=0.0140, SVHN=0.0297)\n",
      "\n",
      "Epoch 5/100:\n",
      "  TRAIN Total=0.0970 (MNIST=0.0593, USPS=0.0136, SVHN=0.0241)\n",
      "  VAL   Total=0.1057 (MNIST=0.0628, USPS=0.0187, SVHN=0.0242)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 6, Batch 0: Total=0.0880 (MNIST=0.0560, USPS=0.0130, SVHN=0.0190)\n",
      "Epoch 6, Batch 100: Total=0.0918 (MNIST=0.0560, USPS=0.0150, SVHN=0.0208)\n",
      "\n",
      "Epoch 6/100:\n",
      "  TRAIN Total=0.0930 (MNIST=0.0571, USPS=0.0126, SVHN=0.0233)\n",
      "  VAL   Total=0.1001 (MNIST=0.0589, USPS=0.0145, SVHN=0.0267)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 7, Batch 0: Total=0.0879 (MNIST=0.0567, USPS=0.0132, SVHN=0.0180)\n",
      "Epoch 7, Batch 100: Total=0.0858 (MNIST=0.0561, USPS=0.0107, SVHN=0.0190)\n",
      "\n",
      "Epoch 7/100:\n",
      "  TRAIN Total=0.0886 (MNIST=0.0550, USPS=0.0116, SVHN=0.0221)\n",
      "  VAL   Total=0.0902 (MNIST=0.0536, USPS=0.0117, SVHN=0.0250)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 8, Batch 0: Total=0.0824 (MNIST=0.0516, USPS=0.0097, SVHN=0.0212)\n",
      "Epoch 8, Batch 100: Total=0.0843 (MNIST=0.0535, USPS=0.0090, SVHN=0.0218)\n",
      "\n",
      "Epoch 8/100:\n",
      "  TRAIN Total=0.0854 (MNIST=0.0530, USPS=0.0109, SVHN=0.0215)\n",
      "  VAL   Total=0.0999 (MNIST=0.0624, USPS=0.0140, SVHN=0.0236)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 9, Batch 0: Total=0.0856 (MNIST=0.0527, USPS=0.0104, SVHN=0.0226)\n",
      "Epoch 9, Batch 100: Total=0.0712 (MNIST=0.0462, USPS=0.0110, SVHN=0.0140)\n",
      "\n",
      "Epoch 9/100:\n",
      "  TRAIN Total=0.0825 (MNIST=0.0514, USPS=0.0103, SVHN=0.0207)\n",
      "  VAL   Total=0.0811 (MNIST=0.0510, USPS=0.0116, SVHN=0.0186)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 10, Batch 0: Total=0.0823 (MNIST=0.0530, USPS=0.0097, SVHN=0.0197)\n",
      "Epoch 10, Batch 100: Total=0.0843 (MNIST=0.0558, USPS=0.0094, SVHN=0.0192)\n",
      "\n",
      "Epoch 10/100:\n",
      "  TRAIN Total=0.0774 (MNIST=0.0497, USPS=0.0099, SVHN=0.0178)\n",
      "  VAL   Total=0.0792 (MNIST=0.0546, USPS=0.0104, SVHN=0.0142)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 11, Batch 0: Total=0.0796 (MNIST=0.0533, USPS=0.0088, SVHN=0.0175)\n",
      "Epoch 11, Batch 100: Total=0.0721 (MNIST=0.0469, USPS=0.0091, SVHN=0.0161)\n",
      "\n",
      "Epoch 11/100:\n",
      "  TRAIN Total=0.0738 (MNIST=0.0484, USPS=0.0090, SVHN=0.0163)\n",
      "  VAL   Total=0.0737 (MNIST=0.0480, USPS=0.0114, SVHN=0.0143)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 12, Batch 0: Total=0.0730 (MNIST=0.0507, USPS=0.0104, SVHN=0.0119)\n",
      "Epoch 12, Batch 100: Total=0.0695 (MNIST=0.0487, USPS=0.0073, SVHN=0.0135)\n",
      "\n",
      "Epoch 12/100:\n",
      "  TRAIN Total=0.0730 (MNIST=0.0476, USPS=0.0089, SVHN=0.0164)\n",
      "  VAL   Total=0.0734 (MNIST=0.0477, USPS=0.0092, SVHN=0.0165)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 13, Batch 0: Total=0.0694 (MNIST=0.0443, USPS=0.0065, SVHN=0.0186)\n",
      "Epoch 13, Batch 100: Total=0.0711 (MNIST=0.0470, USPS=0.0098, SVHN=0.0143)\n",
      "\n",
      "Epoch 13/100:\n",
      "  TRAIN Total=0.0703 (MNIST=0.0470, USPS=0.0085, SVHN=0.0149)\n",
      "  VAL   Total=0.0809 (MNIST=0.0544, USPS=0.0106, SVHN=0.0159)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 14, Batch 0: Total=0.0790 (MNIST=0.0512, USPS=0.0088, SVHN=0.0190)\n",
      "Epoch 14, Batch 100: Total=0.0663 (MNIST=0.0452, USPS=0.0080, SVHN=0.0132)\n",
      "\n",
      "Epoch 14/100:\n",
      "  TRAIN Total=0.0676 (MNIST=0.0458, USPS=0.0082, SVHN=0.0137)\n",
      "  VAL   Total=0.0664 (MNIST=0.0461, USPS=0.0083, SVHN=0.0120)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 15, Batch 0: Total=0.0663 (MNIST=0.0475, USPS=0.0078, SVHN=0.0111)\n",
      "Epoch 15, Batch 100: Total=0.0645 (MNIST=0.0445, USPS=0.0087, SVHN=0.0113)\n",
      "\n",
      "Epoch 15/100:\n",
      "  TRAIN Total=0.0659 (MNIST=0.0448, USPS=0.0078, SVHN=0.0133)\n",
      "  VAL   Total=0.0661 (MNIST=0.0459, USPS=0.0092, SVHN=0.0110)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 16, Batch 0: Total=0.0618 (MNIST=0.0411, USPS=0.0084, SVHN=0.0122)\n",
      "Epoch 16, Batch 100: Total=0.0708 (MNIST=0.0477, USPS=0.0094, SVHN=0.0137)\n",
      "\n",
      "Epoch 16/100:\n",
      "  TRAIN Total=0.0673 (MNIST=0.0450, USPS=0.0076, SVHN=0.0147)\n",
      "  VAL   Total=0.0659 (MNIST=0.0462, USPS=0.0084, SVHN=0.0113)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 17, Batch 0: Total=0.0615 (MNIST=0.0449, USPS=0.0072, SVHN=0.0094)\n",
      "Epoch 17, Batch 100: Total=0.0583 (MNIST=0.0400, USPS=0.0074, SVHN=0.0109)\n",
      "\n",
      "Epoch 17/100:\n",
      "  TRAIN Total=0.0648 (MNIST=0.0434, USPS=0.0073, SVHN=0.0140)\n",
      "  VAL   Total=0.0695 (MNIST=0.0481, USPS=0.0093, SVHN=0.0121)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 18, Batch 0: Total=0.0616 (MNIST=0.0427, USPS=0.0080, SVHN=0.0109)\n",
      "Epoch 18, Batch 100: Total=0.0650 (MNIST=0.0451, USPS=0.0075, SVHN=0.0125)\n",
      "\n",
      "Epoch 18/100:\n",
      "  TRAIN Total=0.0628 (MNIST=0.0436, USPS=0.0073, SVHN=0.0119)\n",
      "  VAL   Total=0.0633 (MNIST=0.0449, USPS=0.0079, SVHN=0.0105)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 19, Batch 0: Total=0.0627 (MNIST=0.0432, USPS=0.0086, SVHN=0.0108)\n",
      "Epoch 19, Batch 100: Total=0.0573 (MNIST=0.0390, USPS=0.0076, SVHN=0.0108)\n",
      "\n",
      "Epoch 19/100:\n",
      "  TRAIN Total=0.0615 (MNIST=0.0422, USPS=0.0069, SVHN=0.0124)\n",
      "  VAL   Total=0.0621 (MNIST=0.0438, USPS=0.0079, SVHN=0.0105)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 20, Batch 0: Total=0.0585 (MNIST=0.0412, USPS=0.0069, SVHN=0.0104)\n",
      "Epoch 20, Batch 100: Total=0.0613 (MNIST=0.0411, USPS=0.0086, SVHN=0.0115)\n",
      "\n",
      "Epoch 20/100:\n",
      "  TRAIN Total=0.0605 (MNIST=0.0417, USPS=0.0068, SVHN=0.0120)\n",
      "  VAL   Total=0.0597 (MNIST=0.0426, USPS=0.0072, SVHN=0.0099)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 21, Batch 0: Total=0.0534 (MNIST=0.0358, USPS=0.0078, SVHN=0.0099)\n",
      "Epoch 21, Batch 100: Total=0.0598 (MNIST=0.0403, USPS=0.0060, SVHN=0.0134)\n",
      "\n",
      "Epoch 21/100:\n",
      "  TRAIN Total=0.0606 (MNIST=0.0420, USPS=0.0069, SVHN=0.0116)\n",
      "  VAL   Total=0.0578 (MNIST=0.0416, USPS=0.0072, SVHN=0.0090)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 22, Batch 0: Total=0.0623 (MNIST=0.0446, USPS=0.0066, SVHN=0.0112)\n",
      "Epoch 22, Batch 100: Total=0.0693 (MNIST=0.0440, USPS=0.0064, SVHN=0.0189)\n",
      "\n",
      "Epoch 22/100:\n",
      "  TRAIN Total=0.0615 (MNIST=0.0421, USPS=0.0066, SVHN=0.0129)\n",
      "  VAL   Total=0.0601 (MNIST=0.0420, USPS=0.0075, SVHN=0.0106)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 23, Batch 0: Total=0.0764 (MNIST=0.0478, USPS=0.0066, SVHN=0.0220)\n",
      "Epoch 23, Batch 100: Total=0.0634 (MNIST=0.0424, USPS=0.0078, SVHN=0.0132)\n",
      "\n",
      "Epoch 23/100:\n",
      "  TRAIN Total=0.0589 (MNIST=0.0408, USPS=0.0064, SVHN=0.0118)\n",
      "  VAL   Total=0.0716 (MNIST=0.0530, USPS=0.0084, SVHN=0.0102)\n",
      "  No improvement for 2 epoch(s)\n",
      "Epoch 24, Batch 0: Total=0.0593 (MNIST=0.0429, USPS=0.0065, SVHN=0.0100)\n",
      "Epoch 24, Batch 100: Total=0.0526 (MNIST=0.0345, USPS=0.0056, SVHN=0.0125)\n",
      "\n",
      "Epoch 24/100:\n",
      "  TRAIN Total=0.0591 (MNIST=0.0410, USPS=0.0063, SVHN=0.0118)\n",
      "  VAL   Total=0.0582 (MNIST=0.0397, USPS=0.0068, SVHN=0.0117)\n",
      "  No improvement for 3 epoch(s)\n",
      "Epoch 25, Batch 0: Total=0.0563 (MNIST=0.0392, USPS=0.0058, SVHN=0.0114)\n",
      "Epoch 25, Batch 100: Total=0.0588 (MNIST=0.0442, USPS=0.0056, SVHN=0.0090)\n",
      "\n",
      "Epoch 25/100:\n",
      "  TRAIN Total=0.0578 (MNIST=0.0401, USPS=0.0061, SVHN=0.0116)\n",
      "  VAL   Total=0.0620 (MNIST=0.0453, USPS=0.0068, SVHN=0.0099)\n",
      "  No improvement for 4 epoch(s)\n",
      "Epoch 26, Batch 0: Total=0.0590 (MNIST=0.0385, USPS=0.0053, SVHN=0.0152)\n",
      "Epoch 26, Batch 100: Total=0.0602 (MNIST=0.0450, USPS=0.0061, SVHN=0.0091)\n",
      "\n",
      "Epoch 26/100:\n",
      "  TRAIN Total=0.0566 (MNIST=0.0399, USPS=0.0061, SVHN=0.0107)\n",
      "  VAL   Total=0.0569 (MNIST=0.0415, USPS=0.0065, SVHN=0.0089)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 27, Batch 0: Total=0.0569 (MNIST=0.0399, USPS=0.0070, SVHN=0.0101)\n",
      "Epoch 27, Batch 100: Total=0.0580 (MNIST=0.0401, USPS=0.0071, SVHN=0.0108)\n",
      "\n",
      "Epoch 27/100:\n",
      "  TRAIN Total=0.0569 (MNIST=0.0401, USPS=0.0061, SVHN=0.0108)\n",
      "  VAL   Total=0.0560 (MNIST=0.0410, USPS=0.0068, SVHN=0.0083)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 28, Batch 0: Total=0.0574 (MNIST=0.0390, USPS=0.0052, SVHN=0.0133)\n",
      "Epoch 28, Batch 100: Total=0.0559 (MNIST=0.0405, USPS=0.0057, SVHN=0.0097)\n",
      "\n",
      "Epoch 28/100:\n",
      "  TRAIN Total=0.0563 (MNIST=0.0395, USPS=0.0059, SVHN=0.0110)\n",
      "  VAL   Total=0.0549 (MNIST=0.0397, USPS=0.0067, SVHN=0.0085)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 29, Batch 0: Total=0.0583 (MNIST=0.0420, USPS=0.0060, SVHN=0.0103)\n",
      "Epoch 29, Batch 100: Total=0.0517 (MNIST=0.0366, USPS=0.0056, SVHN=0.0094)\n",
      "\n",
      "Epoch 29/100:\n",
      "  TRAIN Total=0.0568 (MNIST=0.0396, USPS=0.0057, SVHN=0.0115)\n",
      "  VAL   Total=0.0581 (MNIST=0.0409, USPS=0.0078, SVHN=0.0094)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 30, Batch 0: Total=0.0646 (MNIST=0.0416, USPS=0.0061, SVHN=0.0169)\n",
      "Epoch 30, Batch 100: Total=0.0506 (MNIST=0.0361, USPS=0.0041, SVHN=0.0103)\n",
      "\n",
      "Epoch 30/100:\n",
      "  TRAIN Total=0.0554 (MNIST=0.0388, USPS=0.0057, SVHN=0.0110)\n",
      "  VAL   Total=0.0558 (MNIST=0.0400, USPS=0.0062, SVHN=0.0095)\n",
      "  No improvement for 2 epoch(s)\n",
      "Epoch 31, Batch 0: Total=0.0482 (MNIST=0.0348, USPS=0.0048, SVHN=0.0086)\n",
      "Epoch 31, Batch 100: Total=0.0599 (MNIST=0.0412, USPS=0.0062, SVHN=0.0124)\n",
      "\n",
      "Epoch 31/100:\n",
      "  TRAIN Total=0.0552 (MNIST=0.0385, USPS=0.0057, SVHN=0.0110)\n",
      "  VAL   Total=0.0553 (MNIST=0.0395, USPS=0.0059, SVHN=0.0100)\n",
      "  No improvement for 3 epoch(s)\n",
      "Epoch 32, Batch 0: Total=0.0553 (MNIST=0.0402, USPS=0.0057, SVHN=0.0094)\n",
      "Epoch 32, Batch 100: Total=0.0577 (MNIST=0.0434, USPS=0.0057, SVHN=0.0086)\n",
      "\n",
      "Epoch 32/100:\n",
      "  TRAIN Total=0.0553 (MNIST=0.0388, USPS=0.0056, SVHN=0.0109)\n",
      "  VAL   Total=0.0574 (MNIST=0.0381, USPS=0.0071, SVHN=0.0123)\n",
      "  No improvement for 4 epoch(s)\n",
      "Epoch 33, Batch 0: Total=0.0556 (MNIST=0.0408, USPS=0.0054, SVHN=0.0094)\n",
      "Epoch 33, Batch 100: Total=0.0544 (MNIST=0.0385, USPS=0.0046, SVHN=0.0113)\n",
      "\n",
      "Epoch 33/100:\n",
      "  TRAIN Total=0.0541 (MNIST=0.0382, USPS=0.0055, SVHN=0.0104)\n",
      "  VAL   Total=0.0532 (MNIST=0.0382, USPS=0.0059, SVHN=0.0091)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 34, Batch 0: Total=0.0513 (MNIST=0.0372, USPS=0.0044, SVHN=0.0098)\n",
      "Epoch 34, Batch 100: Total=0.0587 (MNIST=0.0382, USPS=0.0054, SVHN=0.0151)\n",
      "\n",
      "Epoch 34/100:\n",
      "  TRAIN Total=0.0542 (MNIST=0.0383, USPS=0.0056, SVHN=0.0103)\n",
      "  VAL   Total=0.0631 (MNIST=0.0475, USPS=0.0072, SVHN=0.0085)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 35, Batch 0: Total=0.0545 (MNIST=0.0357, USPS=0.0061, SVHN=0.0127)\n",
      "Epoch 35, Batch 100: Total=0.0526 (MNIST=0.0357, USPS=0.0066, SVHN=0.0104)\n",
      "\n",
      "Epoch 35/100:\n",
      "  TRAIN Total=0.0550 (MNIST=0.0380, USPS=0.0053, SVHN=0.0117)\n",
      "  VAL   Total=0.0571 (MNIST=0.0387, USPS=0.0058, SVHN=0.0126)\n",
      "  No improvement for 2 epoch(s)\n",
      "Epoch 36, Batch 0: Total=0.0585 (MNIST=0.0394, USPS=0.0048, SVHN=0.0143)\n",
      "Epoch 36, Batch 100: Total=0.0523 (MNIST=0.0378, USPS=0.0057, SVHN=0.0088)\n",
      "\n",
      "Epoch 36/100:\n",
      "  TRAIN Total=0.0536 (MNIST=0.0377, USPS=0.0052, SVHN=0.0107)\n",
      "  VAL   Total=0.0568 (MNIST=0.0388, USPS=0.0059, SVHN=0.0121)\n",
      "  No improvement for 3 epoch(s)\n",
      "Epoch 37, Batch 0: Total=0.0575 (MNIST=0.0390, USPS=0.0068, SVHN=0.0118)\n",
      "Epoch 37, Batch 100: Total=0.0580 (MNIST=0.0420, USPS=0.0059, SVHN=0.0102)\n",
      "\n",
      "Epoch 37/100:\n",
      "  TRAIN Total=0.0526 (MNIST=0.0376, USPS=0.0053, SVHN=0.0098)\n",
      "  VAL   Total=0.0532 (MNIST=0.0367, USPS=0.0059, SVHN=0.0107)\n",
      "  No improvement for 4 epoch(s)\n",
      "Epoch 38, Batch 0: Total=0.0471 (MNIST=0.0344, USPS=0.0050, SVHN=0.0077)\n",
      "Epoch 38, Batch 100: Total=0.0517 (MNIST=0.0373, USPS=0.0044, SVHN=0.0100)\n",
      "\n",
      "Epoch 38/100:\n",
      "  TRAIN Total=0.0531 (MNIST=0.0371, USPS=0.0052, SVHN=0.0108)\n",
      "  VAL   Total=0.0526 (MNIST=0.0378, USPS=0.0061, SVHN=0.0087)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 39, Batch 0: Total=0.0511 (MNIST=0.0370, USPS=0.0048, SVHN=0.0092)\n",
      "Epoch 39, Batch 100: Total=0.0569 (MNIST=0.0392, USPS=0.0066, SVHN=0.0111)\n",
      "\n",
      "Epoch 39/100:\n",
      "  TRAIN Total=0.0525 (MNIST=0.0373, USPS=0.0051, SVHN=0.0102)\n",
      "  VAL   Total=0.0558 (MNIST=0.0387, USPS=0.0067, SVHN=0.0104)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 40, Batch 0: Total=0.0525 (MNIST=0.0387, USPS=0.0047, SVHN=0.0091)\n",
      "Epoch 40, Batch 100: Total=0.0488 (MNIST=0.0352, USPS=0.0047, SVHN=0.0088)\n",
      "\n",
      "Epoch 40/100:\n",
      "  TRAIN Total=0.0531 (MNIST=0.0375, USPS=0.0051, SVHN=0.0105)\n",
      "  VAL   Total=0.0541 (MNIST=0.0369, USPS=0.0061, SVHN=0.0111)\n",
      "  No improvement for 2 epoch(s)\n",
      "Epoch 41, Batch 0: Total=0.0511 (MNIST=0.0359, USPS=0.0055, SVHN=0.0097)\n",
      "Epoch 41, Batch 100: Total=0.0500 (MNIST=0.0349, USPS=0.0046, SVHN=0.0105)\n",
      "\n",
      "Epoch 41/100:\n",
      "  TRAIN Total=0.0520 (MNIST=0.0372, USPS=0.0050, SVHN=0.0098)\n",
      "  VAL   Total=0.0518 (MNIST=0.0375, USPS=0.0056, SVHN=0.0087)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 42, Batch 0: Total=0.0509 (MNIST=0.0377, USPS=0.0040, SVHN=0.0093)\n",
      "Epoch 42, Batch 100: Total=0.0494 (MNIST=0.0353, USPS=0.0048, SVHN=0.0093)\n",
      "\n",
      "Epoch 42/100:\n",
      "  TRAIN Total=0.0525 (MNIST=0.0371, USPS=0.0050, SVHN=0.0104)\n",
      "  VAL   Total=0.0526 (MNIST=0.0384, USPS=0.0056, SVHN=0.0086)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 43, Batch 0: Total=0.0522 (MNIST=0.0368, USPS=0.0054, SVHN=0.0100)\n",
      "Epoch 43, Batch 100: Total=0.0531 (MNIST=0.0387, USPS=0.0049, SVHN=0.0095)\n",
      "\n",
      "Epoch 43/100:\n",
      "  TRAIN Total=0.0509 (MNIST=0.0367, USPS=0.0050, SVHN=0.0093)\n",
      "  VAL   Total=0.0512 (MNIST=0.0379, USPS=0.0055, SVHN=0.0078)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 44, Batch 0: Total=0.0499 (MNIST=0.0352, USPS=0.0059, SVHN=0.0089)\n",
      "Epoch 44, Batch 100: Total=0.0474 (MNIST=0.0359, USPS=0.0042, SVHN=0.0073)\n",
      "\n",
      "Epoch 44/100:\n",
      "  TRAIN Total=0.0506 (MNIST=0.0366, USPS=0.0049, SVHN=0.0092)\n",
      "  VAL   Total=0.0614 (MNIST=0.0450, USPS=0.0085, SVHN=0.0080)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 45, Batch 0: Total=0.0519 (MNIST=0.0359, USPS=0.0056, SVHN=0.0104)\n",
      "Epoch 45, Batch 100: Total=0.0470 (MNIST=0.0342, USPS=0.0048, SVHN=0.0079)\n",
      "\n",
      "Epoch 45/100:\n",
      "  TRAIN Total=0.0506 (MNIST=0.0364, USPS=0.0050, SVHN=0.0092)\n",
      "  VAL   Total=0.0529 (MNIST=0.0381, USPS=0.0051, SVHN=0.0097)\n",
      "  No improvement for 2 epoch(s)\n",
      "Epoch 46, Batch 0: Total=0.0515 (MNIST=0.0375, USPS=0.0049, SVHN=0.0091)\n",
      "Epoch 46, Batch 100: Total=0.0448 (MNIST=0.0320, USPS=0.0047, SVHN=0.0081)\n",
      "\n",
      "Epoch 46/100:\n",
      "  TRAIN Total=0.0507 (MNIST=0.0365, USPS=0.0049, SVHN=0.0092)\n",
      "  VAL   Total=0.0512 (MNIST=0.0371, USPS=0.0052, SVHN=0.0089)\n",
      "  No improvement for 3 epoch(s)\n",
      "Epoch 47, Batch 0: Total=0.0493 (MNIST=0.0360, USPS=0.0044, SVHN=0.0089)\n",
      "Epoch 47, Batch 100: Total=0.0518 (MNIST=0.0380, USPS=0.0047, SVHN=0.0091)\n",
      "\n",
      "Epoch 47/100:\n",
      "  TRAIN Total=0.0497 (MNIST=0.0362, USPS=0.0048, SVHN=0.0087)\n",
      "  VAL   Total=0.0482 (MNIST=0.0353, USPS=0.0051, SVHN=0.0078)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 48, Batch 0: Total=0.0475 (MNIST=0.0342, USPS=0.0045, SVHN=0.0088)\n",
      "Epoch 48, Batch 100: Total=0.0503 (MNIST=0.0389, USPS=0.0045, SVHN=0.0070)\n",
      "\n",
      "Epoch 48/100:\n",
      "  TRAIN Total=0.0502 (MNIST=0.0361, USPS=0.0048, SVHN=0.0093)\n",
      "  VAL   Total=0.0538 (MNIST=0.0385, USPS=0.0062, SVHN=0.0091)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 49, Batch 0: Total=0.0486 (MNIST=0.0352, USPS=0.0044, SVHN=0.0089)\n",
      "Epoch 49, Batch 100: Total=0.0485 (MNIST=0.0372, USPS=0.0047, SVHN=0.0067)\n",
      "\n",
      "Epoch 49/100:\n",
      "  TRAIN Total=0.0505 (MNIST=0.0362, USPS=0.0047, SVHN=0.0095)\n",
      "  VAL   Total=0.0503 (MNIST=0.0369, USPS=0.0052, SVHN=0.0082)\n",
      "  No improvement for 2 epoch(s)\n",
      "Epoch 50, Batch 0: Total=0.0477 (MNIST=0.0345, USPS=0.0049, SVHN=0.0083)\n",
      "Epoch 50, Batch 100: Total=0.0500 (MNIST=0.0330, USPS=0.0038, SVHN=0.0132)\n",
      "\n",
      "Epoch 50/100:\n",
      "  TRAIN Total=0.0497 (MNIST=0.0358, USPS=0.0047, SVHN=0.0092)\n",
      "  VAL   Total=0.0494 (MNIST=0.0365, USPS=0.0052, SVHN=0.0077)\n",
      "  No improvement for 3 epoch(s)\n",
      "Epoch 51, Batch 0: Total=0.0559 (MNIST=0.0391, USPS=0.0051, SVHN=0.0117)\n",
      "Epoch 51, Batch 100: Total=0.0512 (MNIST=0.0369, USPS=0.0046, SVHN=0.0096)\n",
      "\n",
      "Epoch 51/100:\n",
      "  TRAIN Total=0.0500 (MNIST=0.0367, USPS=0.0047, SVHN=0.0086)\n",
      "  VAL   Total=0.0509 (MNIST=0.0373, USPS=0.0058, SVHN=0.0078)\n",
      "  No improvement for 4 epoch(s)\n",
      "Epoch 52, Batch 0: Total=0.0486 (MNIST=0.0362, USPS=0.0050, SVHN=0.0074)\n",
      "Epoch 52, Batch 100: Total=0.0468 (MNIST=0.0341, USPS=0.0041, SVHN=0.0086)\n",
      "\n",
      "Epoch 52/100:\n",
      "  TRAIN Total=0.0503 (MNIST=0.0360, USPS=0.0046, SVHN=0.0097)\n",
      "  VAL   Total=0.0497 (MNIST=0.0369, USPS=0.0053, SVHN=0.0074)\n",
      "  No improvement for 5 epoch(s)\n",
      "Epoch 53, Batch 0: Total=0.0513 (MNIST=0.0396, USPS=0.0047, SVHN=0.0070)\n",
      "Epoch 53, Batch 100: Total=0.0491 (MNIST=0.0347, USPS=0.0041, SVHN=0.0102)\n",
      "\n",
      "Epoch 53/100:\n",
      "  TRAIN Total=0.0485 (MNIST=0.0352, USPS=0.0047, SVHN=0.0086)\n",
      "  VAL   Total=0.0481 (MNIST=0.0356, USPS=0.0050, SVHN=0.0076)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 54, Batch 0: Total=0.0471 (MNIST=0.0332, USPS=0.0037, SVHN=0.0101)\n",
      "Epoch 54, Batch 100: Total=0.0471 (MNIST=0.0353, USPS=0.0043, SVHN=0.0076)\n",
      "\n",
      "Epoch 54/100:\n",
      "  TRAIN Total=0.0483 (MNIST=0.0355, USPS=0.0046, SVHN=0.0082)\n",
      "  VAL   Total=0.0485 (MNIST=0.0365, USPS=0.0052, SVHN=0.0069)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 55, Batch 0: Total=0.0457 (MNIST=0.0343, USPS=0.0040, SVHN=0.0074)\n",
      "Epoch 55, Batch 100: Total=0.0467 (MNIST=0.0339, USPS=0.0040, SVHN=0.0089)\n",
      "\n",
      "Epoch 55/100:\n",
      "  TRAIN Total=0.0484 (MNIST=0.0355, USPS=0.0046, SVHN=0.0083)\n",
      "  VAL   Total=0.0503 (MNIST=0.0359, USPS=0.0054, SVHN=0.0090)\n",
      "  No improvement for 2 epoch(s)\n",
      "Epoch 56, Batch 0: Total=0.0415 (MNIST=0.0300, USPS=0.0040, SVHN=0.0075)\n",
      "Epoch 56, Batch 100: Total=0.0541 (MNIST=0.0368, USPS=0.0055, SVHN=0.0119)\n",
      "\n",
      "Epoch 56/100:\n",
      "  TRAIN Total=0.0498 (MNIST=0.0358, USPS=0.0046, SVHN=0.0094)\n",
      "  VAL   Total=0.0573 (MNIST=0.0401, USPS=0.0078, SVHN=0.0094)\n",
      "  No improvement for 3 epoch(s)\n",
      "Epoch 57, Batch 0: Total=0.0516 (MNIST=0.0374, USPS=0.0050, SVHN=0.0093)\n",
      "Epoch 57, Batch 100: Total=0.0466 (MNIST=0.0352, USPS=0.0047, SVHN=0.0067)\n",
      "\n",
      "Epoch 57/100:\n",
      "  TRAIN Total=0.0482 (MNIST=0.0351, USPS=0.0044, SVHN=0.0087)\n",
      "  VAL   Total=0.0469 (MNIST=0.0346, USPS=0.0048, SVHN=0.0075)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 58, Batch 0: Total=0.0456 (MNIST=0.0332, USPS=0.0041, SVHN=0.0083)\n",
      "Epoch 58, Batch 100: Total=0.0470 (MNIST=0.0341, USPS=0.0050, SVHN=0.0079)\n",
      "\n",
      "Epoch 58/100:\n",
      "  TRAIN Total=0.0487 (MNIST=0.0356, USPS=0.0045, SVHN=0.0086)\n",
      "  VAL   Total=0.0486 (MNIST=0.0349, USPS=0.0051, SVHN=0.0085)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 59, Batch 0: Total=0.0510 (MNIST=0.0368, USPS=0.0048, SVHN=0.0094)\n",
      "Epoch 59, Batch 100: Total=0.0519 (MNIST=0.0363, USPS=0.0060, SVHN=0.0096)\n",
      "\n",
      "Epoch 59/100:\n",
      "  TRAIN Total=0.0489 (MNIST=0.0355, USPS=0.0046, SVHN=0.0088)\n",
      "  VAL   Total=0.0482 (MNIST=0.0347, USPS=0.0051, SVHN=0.0084)\n",
      "  No improvement for 2 epoch(s)\n",
      "Epoch 60, Batch 0: Total=0.0529 (MNIST=0.0375, USPS=0.0042, SVHN=0.0112)\n",
      "Epoch 60, Batch 100: Total=0.0490 (MNIST=0.0359, USPS=0.0051, SVHN=0.0080)\n",
      "\n",
      "Epoch 60/100:\n",
      "  TRAIN Total=0.0474 (MNIST=0.0349, USPS=0.0045, SVHN=0.0080)\n",
      "  VAL   Total=0.0625 (MNIST=0.0480, USPS=0.0060, SVHN=0.0085)\n",
      "  No improvement for 3 epoch(s)\n",
      "Epoch 61, Batch 0: Total=0.0504 (MNIST=0.0368, USPS=0.0044, SVHN=0.0093)\n",
      "Epoch 61, Batch 100: Total=0.0427 (MNIST=0.0294, USPS=0.0035, SVHN=0.0097)\n",
      "\n",
      "Epoch 61/100:\n",
      "  TRAIN Total=0.0480 (MNIST=0.0351, USPS=0.0043, SVHN=0.0086)\n",
      "  VAL   Total=0.0489 (MNIST=0.0361, USPS=0.0048, SVHN=0.0080)\n",
      "  No improvement for 4 epoch(s)\n",
      "Epoch 62, Batch 0: Total=0.0503 (MNIST=0.0322, USPS=0.0058, SVHN=0.0123)\n",
      "Epoch 62, Batch 100: Total=0.0484 (MNIST=0.0344, USPS=0.0047, SVHN=0.0092)\n",
      "\n",
      "Epoch 62/100:\n",
      "  TRAIN Total=0.0473 (MNIST=0.0343, USPS=0.0045, SVHN=0.0086)\n",
      "  VAL   Total=0.0490 (MNIST=0.0364, USPS=0.0052, SVHN=0.0074)\n",
      "  No improvement for 5 epoch(s)\n",
      "Epoch 63, Batch 0: Total=0.0490 (MNIST=0.0330, USPS=0.0054, SVHN=0.0106)\n",
      "Epoch 63, Batch 100: Total=0.0499 (MNIST=0.0382, USPS=0.0034, SVHN=0.0083)\n",
      "  → LR reduced: 1.00e-03 → 5.00e-04\n",
      "\n",
      "Epoch 63/100:\n",
      "  TRAIN Total=0.0481 (MNIST=0.0350, USPS=0.0044, SVHN=0.0087)\n",
      "  VAL   Total=0.0495 (MNIST=0.0369, USPS=0.0053, SVHN=0.0072)\n",
      "  No improvement for 6 epoch(s)\n",
      "Epoch 64, Batch 0: Total=0.0482 (MNIST=0.0353, USPS=0.0048, SVHN=0.0081)\n",
      "Epoch 64, Batch 100: Total=0.0447 (MNIST=0.0333, USPS=0.0038, SVHN=0.0076)\n",
      "\n",
      "Epoch 64/100:\n",
      "  TRAIN Total=0.0446 (MNIST=0.0333, USPS=0.0038, SVHN=0.0075)\n",
      "  VAL   Total=0.0428 (MNIST=0.0323, USPS=0.0041, SVHN=0.0063)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 65, Batch 0: Total=0.0401 (MNIST=0.0306, USPS=0.0034, SVHN=0.0062)\n",
      "Epoch 65, Batch 100: Total=0.0453 (MNIST=0.0345, USPS=0.0033, SVHN=0.0075)\n",
      "\n",
      "Epoch 65/100:\n",
      "  TRAIN Total=0.0432 (MNIST=0.0324, USPS=0.0036, SVHN=0.0072)\n",
      "  VAL   Total=0.0440 (MNIST=0.0332, USPS=0.0042, SVHN=0.0065)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 66, Batch 0: Total=0.0468 (MNIST=0.0364, USPS=0.0034, SVHN=0.0071)\n",
      "Epoch 66, Batch 100: Total=0.0431 (MNIST=0.0308, USPS=0.0038, SVHN=0.0085)\n",
      "\n",
      "Epoch 66/100:\n",
      "  TRAIN Total=0.0434 (MNIST=0.0324, USPS=0.0037, SVHN=0.0074)\n",
      "  VAL   Total=0.0430 (MNIST=0.0323, USPS=0.0040, SVHN=0.0068)\n",
      "  No improvement for 2 epoch(s)\n",
      "Epoch 67, Batch 0: Total=0.0535 (MNIST=0.0343, USPS=0.0036, SVHN=0.0156)\n",
      "Epoch 67, Batch 100: Total=0.0420 (MNIST=0.0323, USPS=0.0037, SVHN=0.0060)\n",
      "\n",
      "Epoch 67/100:\n",
      "  TRAIN Total=0.0439 (MNIST=0.0327, USPS=0.0036, SVHN=0.0076)\n",
      "  VAL   Total=0.0418 (MNIST=0.0315, USPS=0.0038, SVHN=0.0065)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 68, Batch 0: Total=0.0514 (MNIST=0.0388, USPS=0.0032, SVHN=0.0095)\n",
      "Epoch 68, Batch 100: Total=0.0438 (MNIST=0.0314, USPS=0.0036, SVHN=0.0088)\n",
      "\n",
      "Epoch 68/100:\n",
      "  TRAIN Total=0.0433 (MNIST=0.0322, USPS=0.0036, SVHN=0.0075)\n",
      "  VAL   Total=0.0432 (MNIST=0.0325, USPS=0.0043, SVHN=0.0064)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 69, Batch 0: Total=0.0411 (MNIST=0.0305, USPS=0.0032, SVHN=0.0074)\n",
      "Epoch 69, Batch 100: Total=0.0425 (MNIST=0.0310, USPS=0.0033, SVHN=0.0082)\n",
      "\n",
      "Epoch 69/100:\n",
      "  TRAIN Total=0.0430 (MNIST=0.0321, USPS=0.0035, SVHN=0.0075)\n",
      "  VAL   Total=0.0424 (MNIST=0.0324, USPS=0.0038, SVHN=0.0062)\n",
      "  No improvement for 2 epoch(s)\n",
      "Epoch 70, Batch 0: Total=0.0392 (MNIST=0.0285, USPS=0.0041, SVHN=0.0066)\n",
      "Epoch 70, Batch 100: Total=0.0403 (MNIST=0.0287, USPS=0.0039, SVHN=0.0077)\n",
      "\n",
      "Epoch 70/100:\n",
      "  TRAIN Total=0.0429 (MNIST=0.0320, USPS=0.0036, SVHN=0.0072)\n",
      "  VAL   Total=0.0430 (MNIST=0.0324, USPS=0.0041, SVHN=0.0064)\n",
      "  No improvement for 3 epoch(s)\n",
      "Epoch 71, Batch 0: Total=0.0412 (MNIST=0.0298, USPS=0.0038, SVHN=0.0075)\n",
      "Epoch 71, Batch 100: Total=0.0432 (MNIST=0.0324, USPS=0.0040, SVHN=0.0068)\n",
      "\n",
      "Epoch 71/100:\n",
      "  TRAIN Total=0.0426 (MNIST=0.0318, USPS=0.0036, SVHN=0.0071)\n",
      "  VAL   Total=0.0440 (MNIST=0.0334, USPS=0.0039, SVHN=0.0067)\n",
      "  No improvement for 4 epoch(s)\n",
      "Epoch 72, Batch 0: Total=0.0421 (MNIST=0.0308, USPS=0.0040, SVHN=0.0074)\n",
      "Epoch 72, Batch 100: Total=0.0450 (MNIST=0.0346, USPS=0.0040, SVHN=0.0064)\n",
      "\n",
      "Epoch 72/100:\n",
      "  TRAIN Total=0.0431 (MNIST=0.0322, USPS=0.0036, SVHN=0.0073)\n",
      "  VAL   Total=0.0415 (MNIST=0.0313, USPS=0.0041, SVHN=0.0061)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 73, Batch 0: Total=0.0400 (MNIST=0.0315, USPS=0.0030, SVHN=0.0054)\n",
      "Epoch 73, Batch 100: Total=0.0416 (MNIST=0.0302, USPS=0.0032, SVHN=0.0081)\n",
      "\n",
      "Epoch 73/100:\n",
      "  TRAIN Total=0.0431 (MNIST=0.0325, USPS=0.0035, SVHN=0.0071)\n",
      "  VAL   Total=0.0424 (MNIST=0.0315, USPS=0.0043, SVHN=0.0066)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 74, Batch 0: Total=0.0473 (MNIST=0.0362, USPS=0.0033, SVHN=0.0079)\n",
      "Epoch 74, Batch 100: Total=0.0429 (MNIST=0.0305, USPS=0.0055, SVHN=0.0069)\n",
      "\n",
      "Epoch 74/100:\n",
      "  TRAIN Total=0.0429 (MNIST=0.0320, USPS=0.0036, SVHN=0.0073)\n",
      "  VAL   Total=0.0443 (MNIST=0.0323, USPS=0.0054, SVHN=0.0066)\n",
      "  No improvement for 2 epoch(s)\n",
      "Epoch 75, Batch 0: Total=0.0426 (MNIST=0.0329, USPS=0.0027, SVHN=0.0069)\n",
      "Epoch 75, Batch 100: Total=0.0441 (MNIST=0.0326, USPS=0.0033, SVHN=0.0082)\n",
      "\n",
      "Epoch 75/100:\n",
      "  TRAIN Total=0.0430 (MNIST=0.0322, USPS=0.0037, SVHN=0.0072)\n",
      "  VAL   Total=0.0432 (MNIST=0.0324, USPS=0.0042, SVHN=0.0066)\n",
      "  No improvement for 3 epoch(s)\n",
      "Epoch 76, Batch 0: Total=0.0451 (MNIST=0.0319, USPS=0.0038, SVHN=0.0094)\n",
      "Epoch 76, Batch 100: Total=0.0416 (MNIST=0.0309, USPS=0.0031, SVHN=0.0075)\n",
      "\n",
      "Epoch 76/100:\n",
      "  TRAIN Total=0.0428 (MNIST=0.0319, USPS=0.0035, SVHN=0.0074)\n",
      "  VAL   Total=0.0420 (MNIST=0.0321, USPS=0.0037, SVHN=0.0061)\n",
      "  No improvement for 4 epoch(s)\n",
      "Epoch 77, Batch 0: Total=0.0456 (MNIST=0.0309, USPS=0.0033, SVHN=0.0114)\n",
      "Epoch 77, Batch 100: Total=0.0451 (MNIST=0.0342, USPS=0.0033, SVHN=0.0076)\n",
      "\n",
      "Epoch 77/100:\n",
      "  TRAIN Total=0.0435 (MNIST=0.0325, USPS=0.0036, SVHN=0.0074)\n",
      "  VAL   Total=0.0428 (MNIST=0.0326, USPS=0.0037, SVHN=0.0065)\n",
      "  No improvement for 5 epoch(s)\n",
      "Epoch 78, Batch 0: Total=0.0423 (MNIST=0.0337, USPS=0.0026, SVHN=0.0060)\n",
      "Epoch 78, Batch 100: Total=0.0434 (MNIST=0.0324, USPS=0.0030, SVHN=0.0080)\n",
      "  → LR reduced: 5.00e-04 → 2.50e-04\n",
      "\n",
      "Epoch 78/100:\n",
      "  TRAIN Total=0.0431 (MNIST=0.0323, USPS=0.0034, SVHN=0.0074)\n",
      "  VAL   Total=0.0428 (MNIST=0.0328, USPS=0.0036, SVHN=0.0063)\n",
      "  No improvement for 6 epoch(s)\n",
      "Epoch 79, Batch 0: Total=0.0455 (MNIST=0.0324, USPS=0.0034, SVHN=0.0098)\n",
      "Epoch 79, Batch 100: Total=0.0375 (MNIST=0.0277, USPS=0.0029, SVHN=0.0069)\n",
      "\n",
      "Epoch 79/100:\n",
      "  TRAIN Total=0.0410 (MNIST=0.0309, USPS=0.0032, SVHN=0.0069)\n",
      "  VAL   Total=0.0404 (MNIST=0.0312, USPS=0.0033, SVHN=0.0059)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 80, Batch 0: Total=0.0409 (MNIST=0.0284, USPS=0.0031, SVHN=0.0094)\n",
      "Epoch 80, Batch 100: Total=0.0389 (MNIST=0.0293, USPS=0.0027, SVHN=0.0068)\n",
      "\n",
      "Epoch 80/100:\n",
      "  TRAIN Total=0.0404 (MNIST=0.0305, USPS=0.0032, SVHN=0.0068)\n",
      "  VAL   Total=0.0401 (MNIST=0.0305, USPS=0.0035, SVHN=0.0061)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 81, Batch 0: Total=0.0380 (MNIST=0.0294, USPS=0.0032, SVHN=0.0053)\n",
      "Epoch 81, Batch 100: Total=0.0438 (MNIST=0.0334, USPS=0.0033, SVHN=0.0070)\n",
      "\n",
      "Epoch 81/100:\n",
      "  TRAIN Total=0.0403 (MNIST=0.0307, USPS=0.0031, SVHN=0.0066)\n",
      "  VAL   Total=0.0440 (MNIST=0.0339, USPS=0.0039, SVHN=0.0061)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 82, Batch 0: Total=0.0421 (MNIST=0.0295, USPS=0.0029, SVHN=0.0097)\n",
      "Epoch 82, Batch 100: Total=0.0411 (MNIST=0.0307, USPS=0.0032, SVHN=0.0072)\n",
      "\n",
      "Epoch 82/100:\n",
      "  TRAIN Total=0.0403 (MNIST=0.0306, USPS=0.0031, SVHN=0.0067)\n",
      "  VAL   Total=0.0399 (MNIST=0.0304, USPS=0.0036, SVHN=0.0060)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 83, Batch 0: Total=0.0390 (MNIST=0.0294, USPS=0.0030, SVHN=0.0066)\n",
      "Epoch 83, Batch 100: Total=0.0429 (MNIST=0.0337, USPS=0.0026, SVHN=0.0067)\n",
      "\n",
      "Epoch 83/100:\n",
      "  TRAIN Total=0.0401 (MNIST=0.0306, USPS=0.0030, SVHN=0.0065)\n",
      "  VAL   Total=0.0397 (MNIST=0.0303, USPS=0.0035, SVHN=0.0059)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 84, Batch 0: Total=0.0409 (MNIST=0.0296, USPS=0.0030, SVHN=0.0084)\n",
      "Epoch 84, Batch 100: Total=0.0449 (MNIST=0.0351, USPS=0.0032, SVHN=0.0067)\n",
      "\n",
      "Epoch 84/100:\n",
      "  TRAIN Total=0.0402 (MNIST=0.0304, USPS=0.0030, SVHN=0.0067)\n",
      "  VAL   Total=0.0394 (MNIST=0.0301, USPS=0.0032, SVHN=0.0060)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 85, Batch 0: Total=0.0365 (MNIST=0.0280, USPS=0.0027, SVHN=0.0059)\n",
      "Epoch 85, Batch 100: Total=0.0426 (MNIST=0.0336, USPS=0.0027, SVHN=0.0062)\n",
      "\n",
      "Epoch 85/100:\n",
      "  TRAIN Total=0.0403 (MNIST=0.0304, USPS=0.0031, SVHN=0.0068)\n",
      "  VAL   Total=0.0388 (MNIST=0.0294, USPS=0.0033, SVHN=0.0061)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 86, Batch 0: Total=0.0386 (MNIST=0.0294, USPS=0.0035, SVHN=0.0057)\n",
      "Epoch 86, Batch 100: Total=0.0393 (MNIST=0.0287, USPS=0.0028, SVHN=0.0078)\n",
      "\n",
      "Epoch 86/100:\n",
      "  TRAIN Total=0.0401 (MNIST=0.0304, USPS=0.0030, SVHN=0.0067)\n",
      "  VAL   Total=0.0392 (MNIST=0.0297, USPS=0.0034, SVHN=0.0061)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 87, Batch 0: Total=0.0410 (MNIST=0.0308, USPS=0.0028, SVHN=0.0073)\n",
      "Epoch 87, Batch 100: Total=0.0364 (MNIST=0.0264, USPS=0.0026, SVHN=0.0073)\n",
      "\n",
      "Epoch 87/100:\n",
      "  TRAIN Total=0.0400 (MNIST=0.0303, USPS=0.0031, SVHN=0.0066)\n",
      "  VAL   Total=0.0393 (MNIST=0.0301, USPS=0.0034, SVHN=0.0058)\n",
      "  No improvement for 2 epoch(s)\n",
      "Epoch 88, Batch 0: Total=0.0364 (MNIST=0.0290, USPS=0.0023, SVHN=0.0051)\n",
      "Epoch 88, Batch 100: Total=0.0381 (MNIST=0.0288, USPS=0.0029, SVHN=0.0064)\n",
      "\n",
      "Epoch 88/100:\n",
      "  TRAIN Total=0.0401 (MNIST=0.0305, USPS=0.0030, SVHN=0.0065)\n",
      "  VAL   Total=0.0400 (MNIST=0.0309, USPS=0.0036, SVHN=0.0055)\n",
      "  No improvement for 3 epoch(s)\n",
      "Epoch 89, Batch 0: Total=0.0443 (MNIST=0.0336, USPS=0.0027, SVHN=0.0080)\n",
      "Epoch 89, Batch 100: Total=0.0358 (MNIST=0.0274, USPS=0.0030, SVHN=0.0054)\n",
      "\n",
      "Epoch 89/100:\n",
      "  TRAIN Total=0.0397 (MNIST=0.0301, USPS=0.0031, SVHN=0.0066)\n",
      "  VAL   Total=0.0394 (MNIST=0.0300, USPS=0.0035, SVHN=0.0059)\n",
      "  No improvement for 4 epoch(s)\n",
      "Epoch 90, Batch 0: Total=0.0412 (MNIST=0.0310, USPS=0.0036, SVHN=0.0065)\n",
      "Epoch 90, Batch 100: Total=0.0382 (MNIST=0.0292, USPS=0.0030, SVHN=0.0059)\n",
      "\n",
      "Epoch 90/100:\n",
      "  TRAIN Total=0.0399 (MNIST=0.0304, USPS=0.0031, SVHN=0.0065)\n",
      "  VAL   Total=0.0397 (MNIST=0.0307, USPS=0.0032, SVHN=0.0058)\n",
      "  No improvement for 5 epoch(s)\n",
      "Epoch 91, Batch 0: Total=0.0418 (MNIST=0.0318, USPS=0.0028, SVHN=0.0071)\n",
      "Epoch 91, Batch 100: Total=0.0375 (MNIST=0.0275, USPS=0.0032, SVHN=0.0069)\n",
      "\n",
      "Epoch 91/100:\n",
      "  TRAIN Total=0.0396 (MNIST=0.0300, USPS=0.0030, SVHN=0.0065)\n",
      "  VAL   Total=0.0387 (MNIST=0.0294, USPS=0.0035, SVHN=0.0058)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 92, Batch 0: Total=0.0394 (MNIST=0.0286, USPS=0.0032, SVHN=0.0077)\n",
      "Epoch 92, Batch 100: Total=0.0383 (MNIST=0.0294, USPS=0.0028, SVHN=0.0061)\n",
      "\n",
      "Epoch 92/100:\n",
      "  TRAIN Total=0.0395 (MNIST=0.0298, USPS=0.0031, SVHN=0.0066)\n",
      "  VAL   Total=0.0386 (MNIST=0.0293, USPS=0.0034, SVHN=0.0060)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 93, Batch 0: Total=0.0385 (MNIST=0.0300, USPS=0.0030, SVHN=0.0055)\n",
      "Epoch 93, Batch 100: Total=0.0402 (MNIST=0.0317, USPS=0.0027, SVHN=0.0058)\n",
      "\n",
      "Epoch 93/100:\n",
      "  TRAIN Total=0.0399 (MNIST=0.0304, USPS=0.0030, SVHN=0.0065)\n",
      "  VAL   Total=0.0393 (MNIST=0.0300, USPS=0.0036, SVHN=0.0057)\n",
      "  No improvement for 1 epoch(s)\n",
      "Epoch 94, Batch 0: Total=0.0422 (MNIST=0.0342, USPS=0.0025, SVHN=0.0055)\n",
      "Epoch 94, Batch 100: Total=0.0377 (MNIST=0.0295, USPS=0.0025, SVHN=0.0057)\n",
      "\n",
      "Epoch 94/100:\n",
      "  TRAIN Total=0.0400 (MNIST=0.0305, USPS=0.0030, SVHN=0.0066)\n",
      "  VAL   Total=0.0388 (MNIST=0.0299, USPS=0.0034, SVHN=0.0054)\n",
      "  No improvement for 2 epoch(s)\n",
      "Epoch 95, Batch 0: Total=0.0432 (MNIST=0.0338, USPS=0.0030, SVHN=0.0064)\n",
      "Epoch 95, Batch 100: Total=0.0364 (MNIST=0.0274, USPS=0.0032, SVHN=0.0058)\n",
      "\n",
      "Epoch 95/100:\n",
      "  TRAIN Total=0.0400 (MNIST=0.0303, USPS=0.0030, SVHN=0.0066)\n",
      "  VAL   Total=0.0407 (MNIST=0.0315, USPS=0.0032, SVHN=0.0060)\n",
      "  No improvement for 3 epoch(s)\n",
      "Epoch 96, Batch 0: Total=0.0404 (MNIST=0.0301, USPS=0.0034, SVHN=0.0069)\n",
      "Epoch 96, Batch 100: Total=0.0365 (MNIST=0.0273, USPS=0.0031, SVHN=0.0062)\n",
      "\n",
      "Epoch 96/100:\n",
      "  TRAIN Total=0.0398 (MNIST=0.0303, USPS=0.0029, SVHN=0.0066)\n",
      "  VAL   Total=0.0389 (MNIST=0.0296, USPS=0.0036, SVHN=0.0057)\n",
      "  No improvement for 4 epoch(s)\n",
      "Epoch 97, Batch 0: Total=0.0396 (MNIST=0.0305, USPS=0.0031, SVHN=0.0060)\n",
      "Epoch 97, Batch 100: Total=0.0388 (MNIST=0.0298, USPS=0.0026, SVHN=0.0065)\n",
      "\n",
      "Epoch 97/100:\n",
      "  TRAIN Total=0.0389 (MNIST=0.0297, USPS=0.0029, SVHN=0.0064)\n",
      "  VAL   Total=0.0395 (MNIST=0.0301, USPS=0.0034, SVHN=0.0060)\n",
      "  No improvement for 5 epoch(s)\n",
      "Epoch 98, Batch 0: Total=0.0414 (MNIST=0.0307, USPS=0.0035, SVHN=0.0072)\n",
      "Epoch 98, Batch 100: Total=0.0375 (MNIST=0.0298, USPS=0.0025, SVHN=0.0052)\n",
      "  → LR reduced: 2.50e-04 → 1.25e-04\n",
      "\n",
      "Epoch 98/100:\n",
      "  TRAIN Total=0.0394 (MNIST=0.0300, USPS=0.0029, SVHN=0.0065)\n",
      "  VAL   Total=0.0403 (MNIST=0.0303, USPS=0.0039, SVHN=0.0061)\n",
      "  No improvement for 6 epoch(s)\n",
      "Epoch 99, Batch 0: Total=0.0401 (MNIST=0.0303, USPS=0.0030, SVHN=0.0069)\n",
      "Epoch 99, Batch 100: Total=0.0408 (MNIST=0.0313, USPS=0.0022, SVHN=0.0073)\n",
      "\n",
      "Epoch 99/100:\n",
      "  TRAIN Total=0.0386 (MNIST=0.0296, USPS=0.0028, SVHN=0.0063)\n",
      "  VAL   Total=0.0376 (MNIST=0.0292, USPS=0.0029, SVHN=0.0055)\n",
      "  ✓ NEW BEST!\n",
      "Epoch 100, Batch 0: Total=0.0354 (MNIST=0.0273, USPS=0.0025, SVHN=0.0056)\n",
      "Epoch 100, Batch 100: Total=0.0381 (MNIST=0.0291, USPS=0.0024, SVHN=0.0066)\n",
      "\n",
      "Epoch 100/100:\n",
      "  TRAIN Total=0.0382 (MNIST=0.0292, USPS=0.0027, SVHN=0.0063)\n",
      "  VAL   Total=0.0377 (MNIST=0.0294, USPS=0.0030, SVHN=0.0054)\n",
      "  No improvement for 1 epoch(s)\n",
      "✓ Restored best model from epoch 99\n",
      "\n",
      "\n",
      "======================================================================\n",
      "STEP 2: THREE-WAY PROJECTION + ALIGNMENT\n",
      "======================================================================\n",
      "\n",
      "======================================================================\n",
      "STEP 2: THREE-WAY PROJECTION + ALIGNMENT (NO CLASSIFIER)\n",
      "Frozen shared encoder — Z space preserved\n",
      "Pairs aligned: MNIST↔USPS, MNIST↔SVHN, USPS↔SVHN\n",
      "======================================================================\n",
      "\n",
      "Epoch 1, Batch 0: Loss=3.4311 (Recon=0.9794, Cont=4.0761, Cent=0.8272)\n",
      "Epoch 1, Batch 100: Loss=1.2325 (Recon=0.0677, Cont=2.1008, Cent=0.2288)\n",
      "\n",
      "======================================================================\n",
      "Epoch 1/100:\n",
      "  TRAIN Total=1.6925 (Recon=0.0943, Cont=2.8367, Cent=0.3599)\n",
      "  VAL   Total=1.2815 (Recon=0.0690, Cont=2.1292, Cent=0.2959)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 2, Batch 0: Loss=1.3553 (Recon=0.0723, Cont=2.3008, Cent=0.2651)\n",
      "Epoch 2, Batch 100: Loss=1.0475 (Recon=0.0598, Cont=1.7675, Cent=0.2079)\n",
      "\n",
      "======================================================================\n",
      "Epoch 2/100:\n",
      "  TRAIN Total=1.0805 (Recon=0.0649, Cont=1.8209, Cent=0.2102)\n",
      "  VAL   Total=0.9974 (Recon=0.0616, Cont=1.6637, Cent=0.2078)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 3, Batch 0: Loss=0.8697 (Recon=0.0600, Cont=1.4591, Cent=0.1602)\n",
      "Epoch 3, Batch 100: Loss=0.7281 (Recon=0.0601, Cont=1.2260, Cent=0.1100)\n",
      "\n",
      "======================================================================\n",
      "Epoch 3/100:\n",
      "  TRAIN Total=0.9085 (Recon=0.0609, Cont=1.5357, Cent=0.1596)\n",
      "  VAL   Total=0.8859 (Recon=0.0591, Cont=1.4929, Cent=0.1607)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 4, Batch 0: Loss=0.7014 (Recon=0.0614, Cont=1.1636, Cent=0.1164)\n",
      "Epoch 4, Batch 100: Loss=0.7167 (Recon=0.0635, Cont=1.2077, Cent=0.0988)\n",
      "\n",
      "======================================================================\n",
      "Epoch 4/100:\n",
      "  TRAIN Total=0.8215 (Recon=0.0589, Cont=1.3956, Cent=0.1297)\n",
      "  VAL   Total=0.8415 (Recon=0.0583, Cont=1.4196, Cent=0.1468)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 5, Batch 0: Loss=0.8560 (Recon=0.0656, Cont=1.4444, Cent=0.1362)\n",
      "Epoch 5, Batch 100: Loss=0.7847 (Recon=0.0579, Cont=1.3317, Cent=0.1220)\n",
      "\n",
      "======================================================================\n",
      "Epoch 5/100:\n",
      "  TRAIN Total=0.7491 (Recon=0.0580, Cont=1.2740, Cent=0.1081)\n",
      "  VAL   Total=0.7819 (Recon=0.0559, Cont=1.3396, Cent=0.1124)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 6, Batch 0: Loss=0.6141 (Recon=0.0577, Cont=1.0177, Cent=0.0950)\n",
      "Epoch 6, Batch 100: Loss=0.8556 (Recon=0.0607, Cont=1.4611, Cent=0.1286)\n",
      "\n",
      "======================================================================\n",
      "Epoch 6/100:\n",
      "  TRAIN Total=0.6943 (Recon=0.0564, Cont=1.1829, Cent=0.0928)\n",
      "  VAL   Total=0.7255 (Recon=0.0559, Cont=1.2379, Cent=0.1014)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 7, Batch 0: Loss=0.7010 (Recon=0.0582, Cont=1.1807, Cent=0.1050)\n",
      "Epoch 7, Batch 100: Loss=0.6428 (Recon=0.0542, Cont=1.1086, Cent=0.0686)\n",
      "\n",
      "======================================================================\n",
      "Epoch 7/100:\n",
      "  TRAIN Total=0.6698 (Recon=0.0558, Cont=1.1471, Cent=0.0809)\n",
      "  VAL   Total=0.7192 (Recon=0.0548, Cont=1.2367, Cent=0.0922)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 8, Batch 0: Loss=0.7063 (Recon=0.0566, Cont=1.2382, Cent=0.0612)\n",
      "Epoch 8, Batch 100: Loss=0.5981 (Recon=0.0572, Cont=1.0265, Cent=0.0554)\n",
      "\n",
      "======================================================================\n",
      "Epoch 8/100:\n",
      "  TRAIN Total=0.6467 (Recon=0.0553, Cont=1.1105, Cent=0.0723)\n",
      "  VAL   Total=0.6623 (Recon=0.0555, Cont=1.1372, Cent=0.0763)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 9, Batch 0: Loss=0.4654 (Recon=0.0527, Cont=0.7596, Cent=0.0658)\n",
      "Epoch 9, Batch 100: Loss=0.6399 (Recon=0.0571, Cont=1.1009, Cent=0.0646)\n",
      "\n",
      "======================================================================\n",
      "Epoch 9/100:\n",
      "  TRAIN Total=0.6190 (Recon=0.0558, Cont=1.0626, Cent=0.0636)\n",
      "  VAL   Total=0.7054 (Recon=0.0543, Cont=1.2302, Cent=0.0720)\n",
      "  No improvement for 1 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 10, Batch 0: Loss=0.6744 (Recon=0.0510, Cont=1.1897, Cent=0.0573)\n",
      "Epoch 10, Batch 100: Loss=0.4568 (Recon=0.0546, Cont=0.7625, Cent=0.0419)\n",
      "\n",
      "======================================================================\n",
      "Epoch 10/100:\n",
      "  TRAIN Total=0.5919 (Recon=0.0538, Cont=1.0206, Cent=0.0556)\n",
      "  VAL   Total=0.6699 (Recon=0.0529, Cont=1.1701, Cent=0.0640)\n",
      "  No improvement for 2 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 11, Batch 0: Loss=0.5176 (Recon=0.0528, Cont=0.8851, Cent=0.0445)\n",
      "Epoch 11, Batch 100: Loss=0.6382 (Recon=0.0550, Cont=1.1256, Cent=0.0409)\n",
      "\n",
      "======================================================================\n",
      "Epoch 11/100:\n",
      "  TRAIN Total=0.5758 (Recon=0.0537, Cont=0.9947, Cent=0.0495)\n",
      "  VAL   Total=0.6182 (Recon=0.0555, Cont=1.0695, Cent=0.0559)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 12, Batch 0: Loss=0.4551 (Recon=0.0657, Cont=0.7325, Cent=0.0463)\n",
      "Epoch 12, Batch 100: Loss=0.5274 (Recon=0.0481, Cont=0.9217, Cent=0.0367)\n",
      "\n",
      "======================================================================\n",
      "Epoch 12/100:\n",
      "  TRAIN Total=0.5468 (Recon=0.0541, Cont=0.9430, Cent=0.0422)\n",
      "  VAL   Total=0.5758 (Recon=0.0524, Cont=1.0020, Cent=0.0447)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 13, Batch 0: Loss=0.5096 (Recon=0.0538, Cont=0.8661, Cent=0.0454)\n",
      "Epoch 13, Batch 100: Loss=0.4714 (Recon=0.0506, Cont=0.8136, Cent=0.0280)\n",
      "\n",
      "======================================================================\n",
      "Epoch 13/100:\n",
      "  TRAIN Total=0.5450 (Recon=0.0531, Cont=0.9445, Cent=0.0394)\n",
      "  VAL   Total=0.6368 (Recon=0.0527, Cont=1.1281, Cent=0.0401)\n",
      "  No improvement for 1 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 14, Batch 0: Loss=0.4569 (Recon=0.0533, Cont=0.7703, Cent=0.0368)\n",
      "Epoch 14, Batch 100: Loss=0.4758 (Recon=0.0466, Cont=0.8231, Cent=0.0353)\n",
      "\n",
      "======================================================================\n",
      "Epoch 14/100:\n",
      "  TRAIN Total=0.5334 (Recon=0.0532, Cont=0.9258, Cent=0.0345)\n",
      "  VAL   Total=0.5897 (Recon=0.0513, Cont=1.0398, Cent=0.0371)\n",
      "  No improvement for 2 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 15, Batch 0: Loss=0.4950 (Recon=0.0526, Cont=0.8551, Cent=0.0297)\n",
      "Epoch 15, Batch 100: Loss=0.4930 (Recon=0.0479, Cont=0.8625, Cent=0.0278)\n",
      "\n",
      "======================================================================\n",
      "Epoch 15/100:\n",
      "  TRAIN Total=0.5082 (Recon=0.0526, Cont=0.8807, Cent=0.0305)\n",
      "  VAL   Total=0.6595 (Recon=0.0529, Cont=1.1776, Cent=0.0357)\n",
      "  No improvement for 3 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 16, Batch 0: Loss=0.4875 (Recon=0.0526, Cont=0.8448, Cent=0.0250)\n",
      "Epoch 16, Batch 100: Loss=0.5274 (Recon=0.0542, Cont=0.9181, Cent=0.0284)\n",
      "\n",
      "======================================================================\n",
      "Epoch 16/100:\n",
      "  TRAIN Total=0.5085 (Recon=0.0523, Cont=0.8837, Cent=0.0286)\n",
      "  VAL   Total=0.6257 (Recon=0.0510, Cont=1.1147, Cent=0.0346)\n",
      "  No improvement for 4 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 17, Batch 0: Loss=0.5386 (Recon=0.0503, Cont=0.9398, Cent=0.0368)\n",
      "Epoch 17, Batch 100: Loss=0.6269 (Recon=0.0510, Cont=1.1239, Cent=0.0279)\n",
      "\n",
      "======================================================================\n",
      "Epoch 17/100:\n",
      "  TRAIN Total=0.5142 (Recon=0.0526, Cont=0.8980, Cent=0.0253)\n",
      "  VAL   Total=0.5935 (Recon=0.0523, Cont=1.0532, Cent=0.0290)\n",
      "  No improvement for 5 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 18, Batch 0: Loss=0.5565 (Recon=0.0631, Cont=0.9650, Cent=0.0218)\n",
      "Epoch 18, Batch 100: Loss=0.4496 (Recon=0.0513, Cont=0.7740, Cent=0.0225)\n",
      "\n",
      "======================================================================\n",
      "Epoch 18/100:\n",
      "  TRAIN Total=0.4875 (Recon=0.0526, Cont=0.8475, Cent=0.0223)\n",
      "  VAL   Total=0.5543 (Recon=0.0513, Cont=0.9798, Cent=0.0262)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 19, Batch 0: Loss=0.4071 (Recon=0.0492, Cont=0.6947, Cent=0.0211)\n",
      "Epoch 19, Batch 100: Loss=0.4387 (Recon=0.0482, Cont=0.7627, Cent=0.0183)\n",
      "\n",
      "======================================================================\n",
      "Epoch 19/100:\n",
      "  TRAIN Total=0.4840 (Recon=0.0524, Cont=0.8423, Cent=0.0209)\n",
      "  VAL   Total=0.5780 (Recon=0.0507, Cont=1.0300, Cent=0.0246)\n",
      "  No improvement for 1 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 20, Batch 0: Loss=0.4368 (Recon=0.0558, Cont=0.7419, Cent=0.0200)\n",
      "Epoch 20, Batch 100: Loss=0.5633 (Recon=0.0550, Cont=1.0011, Cent=0.0155)\n",
      "\n",
      "======================================================================\n",
      "Epoch 20/100:\n",
      "  TRAIN Total=0.4758 (Recon=0.0525, Cont=0.8276, Cent=0.0190)\n",
      "  VAL   Total=0.5831 (Recon=0.0513, Cont=1.0418, Cent=0.0219)\n",
      "  No improvement for 2 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 21, Batch 0: Loss=0.4993 (Recon=0.0520, Cont=0.8765, Cent=0.0181)\n",
      "Epoch 21, Batch 100: Loss=0.4533 (Recon=0.0520, Cont=0.7824, Cent=0.0203)\n",
      "\n",
      "======================================================================\n",
      "Epoch 21/100:\n",
      "  TRAIN Total=0.4807 (Recon=0.0526, Cont=0.8386, Cent=0.0177)\n",
      "  VAL   Total=0.5102 (Recon=0.0501, Cont=0.9003, Cent=0.0199)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 22, Batch 0: Loss=0.4179 (Recon=0.0536, Cont=0.7123, Cent=0.0163)\n",
      "Epoch 22, Batch 100: Loss=0.5147 (Recon=0.0520, Cont=0.9118, Cent=0.0136)\n",
      "\n",
      "======================================================================\n",
      "Epoch 22/100:\n",
      "  TRAIN Total=0.4640 (Recon=0.0514, Cont=0.8087, Cent=0.0165)\n",
      "  VAL   Total=0.5736 (Recon=0.0501, Cont=1.0279, Cent=0.0191)\n",
      "  No improvement for 1 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 23, Batch 0: Loss=0.3514 (Recon=0.0477, Cont=0.5956, Cent=0.0117)\n",
      "Epoch 23, Batch 100: Loss=0.7534 (Recon=0.0525, Cont=1.3878, Cent=0.0141)\n",
      "\n",
      "======================================================================\n",
      "Epoch 23/100:\n",
      "  TRAIN Total=0.4557 (Recon=0.0513, Cont=0.7939, Cent=0.0149)\n",
      "  VAL   Total=0.4921 (Recon=0.0515, Cont=0.8640, Cent=0.0171)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 24, Batch 0: Loss=0.5074 (Recon=0.0557, Cont=0.8905, Cent=0.0129)\n",
      "Epoch 24, Batch 100: Loss=0.4102 (Recon=0.0503, Cont=0.7069, Cent=0.0129)\n",
      "\n",
      "======================================================================\n",
      "Epoch 24/100:\n",
      "  TRAIN Total=0.4533 (Recon=0.0514, Cont=0.7902, Cent=0.0135)\n",
      "  VAL   Total=0.5226 (Recon=0.0505, Cont=0.9301, Cent=0.0140)\n",
      "  No improvement for 1 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 25, Batch 0: Loss=0.3242 (Recon=0.0510, Cont=0.5343, Cent=0.0122)\n",
      "Epoch 25, Batch 100: Loss=0.3638 (Recon=0.0495, Cont=0.6179, Cent=0.0107)\n",
      "\n",
      "======================================================================\n",
      "Epoch 25/100:\n",
      "  TRAIN Total=0.4607 (Recon=0.0510, Cont=0.8069, Cent=0.0125)\n",
      "  VAL   Total=0.5114 (Recon=0.0519, Cont=0.9044, Cent=0.0146)\n",
      "  No improvement for 2 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 26, Batch 0: Loss=0.5798 (Recon=0.0512, Cont=1.0468, Cent=0.0104)\n",
      "Epoch 26, Batch 100: Loss=0.3390 (Recon=0.0475, Cont=0.5737, Cent=0.0092)\n",
      "\n",
      "======================================================================\n",
      "Epoch 26/100:\n",
      "  TRAIN Total=0.4442 (Recon=0.0517, Cont=0.7736, Cent=0.0115)\n",
      "  VAL   Total=0.5294 (Recon=0.0504, Cont=0.9445, Cent=0.0135)\n",
      "  No improvement for 3 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 27, Batch 0: Loss=0.2487 (Recon=0.0492, Cont=0.3867, Cent=0.0121)\n",
      "Epoch 27, Batch 100: Loss=0.3872 (Recon=0.0540, Cont=0.6542, Cent=0.0123)\n",
      "\n",
      "======================================================================\n",
      "Epoch 27/100:\n",
      "  TRAIN Total=0.4476 (Recon=0.0516, Cont=0.7808, Cent=0.0112)\n",
      "  VAL   Total=0.5724 (Recon=0.0509, Cont=1.0307, Cent=0.0123)\n",
      "  No improvement for 4 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 28, Batch 0: Loss=0.3473 (Recon=0.0486, Cont=0.5884, Cent=0.0090)\n",
      "Epoch 28, Batch 100: Loss=0.3877 (Recon=0.0523, Cont=0.6612, Cent=0.0097)\n",
      "\n",
      "======================================================================\n",
      "Epoch 28/100:\n",
      "  TRAIN Total=0.4261 (Recon=0.0514, Cont=0.7388, Cent=0.0106)\n",
      "  VAL   Total=0.5223 (Recon=0.0495, Cont=0.9346, Cent=0.0110)\n",
      "  No improvement for 5 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 29, Batch 0: Loss=0.3768 (Recon=0.0477, Cont=0.6505, Cent=0.0078)\n",
      "Epoch 29, Batch 100: Loss=0.4330 (Recon=0.0481, Cont=0.7616, Cent=0.0082)\n",
      "  → LR reduced: 1.00e-03 → 5.00e-04\n",
      "\n",
      "======================================================================\n",
      "Epoch 29/100:\n",
      "  TRAIN Total=0.4357 (Recon=0.0517, Cont=0.7585, Cent=0.0096)\n",
      "  VAL   Total=0.5317 (Recon=0.0488, Cont=0.9545, Cent=0.0113)\n",
      "  No improvement for 6 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 30, Batch 0: Loss=0.3263 (Recon=0.0568, Cont=0.5298, Cent=0.0091)\n",
      "Epoch 30, Batch 100: Loss=0.2154 (Recon=0.0423, Cont=0.3402, Cent=0.0061)\n",
      "\n",
      "======================================================================\n",
      "Epoch 30/100:\n",
      "  TRAIN Total=0.3935 (Recon=0.0479, Cont=0.6832, Cent=0.0081)\n",
      "  VAL   Total=0.4828 (Recon=0.0467, Cont=0.8629, Cent=0.0093)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 31, Batch 0: Loss=0.3855 (Recon=0.0528, Cont=0.6571, Cent=0.0083)\n",
      "Epoch 31, Batch 100: Loss=0.2977 (Recon=0.0468, Cont=0.4953, Cent=0.0066)\n",
      "\n",
      "======================================================================\n",
      "Epoch 31/100:\n",
      "  TRAIN Total=0.3665 (Recon=0.0473, Cont=0.6311, Cent=0.0073)\n",
      "  VAL   Total=0.4911 (Recon=0.0455, Cont=0.8831, Cent=0.0082)\n",
      "  No improvement for 1 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 32, Batch 0: Loss=0.3646 (Recon=0.0427, Cont=0.6377, Cent=0.0061)\n",
      "Epoch 32, Batch 100: Loss=0.4135 (Recon=0.0458, Cont=0.7290, Cent=0.0063)\n",
      "\n",
      "======================================================================\n",
      "Epoch 32/100:\n",
      "  TRAIN Total=0.3618 (Recon=0.0470, Cont=0.6226, Cent=0.0069)\n",
      "  VAL   Total=0.4836 (Recon=0.0455, Cont=0.8682, Cent=0.0080)\n",
      "  No improvement for 2 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 33, Batch 0: Loss=0.3170 (Recon=0.0495, Cont=0.5292, Cent=0.0058)\n",
      "Epoch 33, Batch 100: Loss=0.4453 (Recon=0.0490, Cont=0.7857, Cent=0.0067)\n",
      "\n",
      "======================================================================\n",
      "Epoch 33/100:\n",
      "  TRAIN Total=0.3613 (Recon=0.0474, Cont=0.6206, Cent=0.0071)\n",
      "  VAL   Total=0.4650 (Recon=0.0457, Cont=0.8309, Cent=0.0076)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 34, Batch 0: Loss=0.4106 (Recon=0.0436, Cont=0.7277, Cent=0.0063)\n",
      "Epoch 34, Batch 100: Loss=0.4670 (Recon=0.0579, Cont=0.8116, Cent=0.0067)\n",
      "\n",
      "======================================================================\n",
      "Epoch 34/100:\n",
      "  TRAIN Total=0.3497 (Recon=0.0479, Cont=0.5966, Cent=0.0069)\n",
      "  VAL   Total=0.4687 (Recon=0.0464, Cont=0.8363, Cent=0.0083)\n",
      "  No improvement for 1 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 35, Batch 0: Loss=0.2907 (Recon=0.0460, Cont=0.4749, Cent=0.0145)\n",
      "Epoch 35, Batch 100: Loss=0.2448 (Recon=0.0457, Cont=0.3908, Cent=0.0073)\n",
      "\n",
      "======================================================================\n",
      "Epoch 35/100:\n",
      "  TRAIN Total=0.3508 (Recon=0.0478, Cont=0.5992, Cent=0.0068)\n",
      "  VAL   Total=0.4743 (Recon=0.0458, Cont=0.8485, Cent=0.0086)\n",
      "  No improvement for 2 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 36, Batch 0: Loss=0.3965 (Recon=0.0468, Cont=0.6929, Cent=0.0064)\n",
      "Epoch 36, Batch 100: Loss=0.2672 (Recon=0.0463, Cont=0.4311, Cent=0.0108)\n",
      "\n",
      "======================================================================\n",
      "Epoch 36/100:\n",
      "  TRAIN Total=0.3585 (Recon=0.0470, Cont=0.6160, Cent=0.0068)\n",
      "  VAL   Total=0.4872 (Recon=0.0458, Cont=0.8739, Cent=0.0088)\n",
      "  No improvement for 3 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 37, Batch 0: Loss=0.5067 (Recon=0.0435, Cont=0.9211, Cent=0.0055)\n",
      "Epoch 37, Batch 100: Loss=0.2673 (Recon=0.0470, Cont=0.4356, Cent=0.0051)\n",
      "\n",
      "======================================================================\n",
      "Epoch 37/100:\n",
      "  TRAIN Total=0.3391 (Recon=0.0471, Cont=0.5776, Cent=0.0065)\n",
      "  VAL   Total=0.4690 (Recon=0.0459, Cont=0.8382, Cent=0.0079)\n",
      "  No improvement for 4 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 38, Batch 0: Loss=0.3471 (Recon=0.0446, Cont=0.5981, Cent=0.0068)\n",
      "Epoch 38, Batch 100: Loss=0.3684 (Recon=0.0444, Cont=0.6410, Cent=0.0070)\n",
      "\n",
      "======================================================================\n",
      "Epoch 38/100:\n",
      "  TRAIN Total=0.3339 (Recon=0.0470, Cont=0.5675, Cent=0.0062)\n",
      "  VAL   Total=0.4638 (Recon=0.0471, Cont=0.8259, Cent=0.0076)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 39, Batch 0: Loss=0.3003 (Recon=0.0457, Cont=0.5041, Cent=0.0051)\n",
      "Epoch 39, Batch 100: Loss=0.2409 (Recon=0.0493, Cont=0.3784, Cent=0.0048)\n",
      "\n",
      "======================================================================\n",
      "Epoch 39/100:\n",
      "  TRAIN Total=0.3486 (Recon=0.0477, Cont=0.5956, Cent=0.0061)\n",
      "  VAL   Total=0.4730 (Recon=0.0466, Cont=0.8458, Cent=0.0070)\n",
      "  No improvement for 1 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 40, Batch 0: Loss=0.5294 (Recon=0.0510, Cont=0.9498, Cent=0.0070)\n",
      "Epoch 40, Batch 100: Loss=0.2462 (Recon=0.0440, Cont=0.3986, Cent=0.0058)\n",
      "\n",
      "======================================================================\n",
      "Epoch 40/100:\n",
      "  TRAIN Total=0.3492 (Recon=0.0468, Cont=0.5986, Cent=0.0061)\n",
      "  VAL   Total=0.4805 (Recon=0.0464, Cont=0.8612, Cent=0.0072)\n",
      "  No improvement for 2 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 41, Batch 0: Loss=0.3826 (Recon=0.0486, Cont=0.6622, Cent=0.0058)\n",
      "Epoch 41, Batch 100: Loss=0.3135 (Recon=0.0462, Cont=0.5298, Cent=0.0047)\n",
      "\n",
      "======================================================================\n",
      "Epoch 41/100:\n",
      "  TRAIN Total=0.3193 (Recon=0.0475, Cont=0.5379, Cent=0.0056)\n",
      "  VAL   Total=0.4880 (Recon=0.0462, Cont=0.8771, Cent=0.0064)\n",
      "  No improvement for 3 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 42, Batch 0: Loss=0.3344 (Recon=0.0541, Cont=0.5546, Cent=0.0060)\n",
      "Epoch 42, Batch 100: Loss=0.3669 (Recon=0.0538, Cont=0.6211, Cent=0.0053)\n",
      "\n",
      "======================================================================\n",
      "Epoch 42/100:\n",
      "  TRAIN Total=0.3359 (Recon=0.0463, Cont=0.5735, Cent=0.0058)\n",
      "  VAL   Total=0.4640 (Recon=0.0451, Cont=0.8311, Cent=0.0068)\n",
      "  No improvement for 4 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 43, Batch 0: Loss=0.2419 (Recon=0.0457, Cont=0.3872, Cent=0.0054)\n",
      "Epoch 43, Batch 100: Loss=0.3525 (Recon=0.0527, Cont=0.5940, Cent=0.0056)\n",
      "\n",
      "======================================================================\n",
      "Epoch 43/100:\n",
      "  TRAIN Total=0.3328 (Recon=0.0471, Cont=0.5658, Cent=0.0056)\n",
      "  VAL   Total=0.4985 (Recon=0.0461, Cont=0.8980, Cent=0.0067)\n",
      "  No improvement for 5 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 44, Batch 0: Loss=0.3275 (Recon=0.0483, Cont=0.5530, Cent=0.0055)\n",
      "Epoch 44, Batch 100: Loss=0.3637 (Recon=0.0477, Cont=0.6264, Cent=0.0056)\n",
      "  → LR reduced: 5.00e-04 → 2.50e-04\n",
      "\n",
      "======================================================================\n",
      "Epoch 44/100:\n",
      "  TRAIN Total=0.3347 (Recon=0.0470, Cont=0.5700, Cent=0.0054)\n",
      "  VAL   Total=0.4853 (Recon=0.0468, Cont=0.8702, Cent=0.0068)\n",
      "  No improvement for 6 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 45, Batch 0: Loss=0.2993 (Recon=0.0555, Cont=0.4811, Cent=0.0064)\n",
      "Epoch 45, Batch 100: Loss=0.2848 (Recon=0.0410, Cont=0.4825, Cent=0.0050)\n",
      "\n",
      "======================================================================\n",
      "Epoch 45/100:\n",
      "  TRAIN Total=0.3070 (Recon=0.0450, Cont=0.5191, Cent=0.0050)\n",
      "  VAL   Total=0.4776 (Recon=0.0432, Cont=0.8629, Cent=0.0060)\n",
      "  No improvement for 7 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 46, Batch 0: Loss=0.2603 (Recon=0.0460, Cont=0.4243, Cent=0.0043)\n",
      "Epoch 46, Batch 100: Loss=0.2605 (Recon=0.0433, Cont=0.4302, Cent=0.0041)\n",
      "\n",
      "======================================================================\n",
      "Epoch 46/100:\n",
      "  TRAIN Total=0.2831 (Recon=0.0454, Cont=0.4706, Cent=0.0050)\n",
      "  VAL   Total=0.5038 (Recon=0.0444, Cont=0.9125, Cent=0.0062)\n",
      "  No improvement for 8 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 47, Batch 0: Loss=0.2415 (Recon=0.0449, Cont=0.3879, Cent=0.0054)\n",
      "Epoch 47, Batch 100: Loss=0.1765 (Recon=0.0476, Cont=0.2535, Cent=0.0044)\n",
      "\n",
      "======================================================================\n",
      "Epoch 47/100:\n",
      "  TRAIN Total=0.3019 (Recon=0.0445, Cont=0.5098, Cent=0.0049)\n",
      "  VAL   Total=0.4510 (Recon=0.0435, Cont=0.8095, Cent=0.0054)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 48, Batch 0: Loss=0.2043 (Recon=0.0476, Cont=0.3085, Cent=0.0050)\n",
      "Epoch 48, Batch 100: Loss=0.3240 (Recon=0.0422, Cont=0.5582, Cent=0.0055)\n",
      "\n",
      "======================================================================\n",
      "Epoch 48/100:\n",
      "  TRAIN Total=0.2903 (Recon=0.0448, Cont=0.4861, Cent=0.0048)\n",
      "  VAL   Total=0.4891 (Recon=0.0436, Cont=0.8852, Cent=0.0057)\n",
      "  No improvement for 1 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 49, Batch 0: Loss=0.1693 (Recon=0.0429, Cont=0.2480, Cent=0.0048)\n",
      "Epoch 49, Batch 100: Loss=0.2896 (Recon=0.0414, Cont=0.4878, Cent=0.0086)\n",
      "\n",
      "======================================================================\n",
      "Epoch 49/100:\n",
      "  TRAIN Total=0.2851 (Recon=0.0442, Cont=0.4768, Cent=0.0050)\n",
      "  VAL   Total=0.4570 (Recon=0.0440, Cont=0.8200, Cent=0.0060)\n",
      "  No improvement for 2 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 50, Batch 0: Loss=0.3614 (Recon=0.0469, Cont=0.6241, Cent=0.0049)\n",
      "Epoch 50, Batch 100: Loss=0.3223 (Recon=0.0437, Cont=0.5531, Cent=0.0040)\n",
      "\n",
      "======================================================================\n",
      "Epoch 50/100:\n",
      "  TRAIN Total=0.2825 (Recon=0.0445, Cont=0.4714, Cent=0.0047)\n",
      "  VAL   Total=0.4661 (Recon=0.0426, Cont=0.8418, Cent=0.0053)\n",
      "  No improvement for 3 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 51, Batch 0: Loss=0.2770 (Recon=0.0423, Cont=0.4645, Cent=0.0049)\n",
      "Epoch 51, Batch 100: Loss=0.3583 (Recon=0.0439, Cont=0.6245, Cent=0.0042)\n",
      "\n",
      "======================================================================\n",
      "Epoch 51/100:\n",
      "  TRAIN Total=0.2752 (Recon=0.0440, Cont=0.4580, Cent=0.0044)\n",
      "  VAL   Total=0.4835 (Recon=0.0425, Cont=0.8767, Cent=0.0054)\n",
      "  No improvement for 4 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 52, Batch 0: Loss=0.3166 (Recon=0.0424, Cont=0.5437, Cent=0.0046)\n",
      "Epoch 52, Batch 100: Loss=0.3676 (Recon=0.0468, Cont=0.6353, Cent=0.0063)\n",
      "\n",
      "======================================================================\n",
      "Epoch 52/100:\n",
      "  TRAIN Total=0.2816 (Recon=0.0444, Cont=0.4700, Cent=0.0044)\n",
      "  VAL   Total=0.4296 (Recon=0.0450, Cont=0.7642, Cent=0.0049)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 53, Batch 0: Loss=0.2011 (Recon=0.0446, Cont=0.3089, Cent=0.0041)\n",
      "Epoch 53, Batch 100: Loss=0.2488 (Recon=0.0449, Cont=0.4024, Cent=0.0053)\n",
      "\n",
      "======================================================================\n",
      "Epoch 53/100:\n",
      "  TRAIN Total=0.2759 (Recon=0.0443, Cont=0.4587, Cent=0.0044)\n",
      "  VAL   Total=0.4377 (Recon=0.0431, Cont=0.7842, Cent=0.0051)\n",
      "  No improvement for 1 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 54, Batch 0: Loss=0.2079 (Recon=0.0416, Cont=0.3281, Cent=0.0045)\n",
      "Epoch 54, Batch 100: Loss=0.2399 (Recon=0.0413, Cont=0.3930, Cent=0.0042)\n",
      "\n",
      "======================================================================\n",
      "Epoch 54/100:\n",
      "  TRAIN Total=0.2791 (Recon=0.0445, Cont=0.4648, Cent=0.0043)\n",
      "  VAL   Total=0.4631 (Recon=0.0436, Cont=0.8341, Cent=0.0050)\n",
      "  No improvement for 2 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 55, Batch 0: Loss=0.2458 (Recon=0.0414, Cont=0.4060, Cent=0.0030)\n",
      "Epoch 55, Batch 100: Loss=0.3676 (Recon=0.0410, Cont=0.6483, Cent=0.0049)\n",
      "\n",
      "======================================================================\n",
      "Epoch 55/100:\n",
      "  TRAIN Total=0.2679 (Recon=0.0438, Cont=0.4440, Cent=0.0042)\n",
      "  VAL   Total=0.4398 (Recon=0.0445, Cont=0.7854, Cent=0.0052)\n",
      "  No improvement for 3 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 56, Batch 0: Loss=0.2327 (Recon=0.0454, Cont=0.3705, Cent=0.0040)\n",
      "Epoch 56, Batch 100: Loss=0.2380 (Recon=0.0478, Cont=0.3774, Cent=0.0030)\n",
      "\n",
      "======================================================================\n",
      "Epoch 56/100:\n",
      "  TRAIN Total=0.2828 (Recon=0.0449, Cont=0.4715, Cent=0.0043)\n",
      "  VAL   Total=0.4482 (Recon=0.0443, Cont=0.8028, Cent=0.0049)\n",
      "  No improvement for 4 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 57, Batch 0: Loss=0.2528 (Recon=0.0415, Cont=0.4188, Cent=0.0037)\n",
      "Epoch 57, Batch 100: Loss=0.2861 (Recon=0.0467, Cont=0.4749, Cent=0.0040)\n",
      "\n",
      "======================================================================\n",
      "Epoch 57/100:\n",
      "  TRAIN Total=0.2705 (Recon=0.0447, Cont=0.4475, Cent=0.0041)\n",
      "  VAL   Total=0.4405 (Recon=0.0439, Cont=0.7879, Cent=0.0054)\n",
      "  No improvement for 5 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 58, Batch 0: Loss=0.2460 (Recon=0.0488, Cont=0.3899, Cent=0.0044)\n",
      "Epoch 58, Batch 100: Loss=0.3452 (Recon=0.0448, Cont=0.5970, Cent=0.0037)\n",
      "  → LR reduced: 2.50e-04 → 1.25e-04\n",
      "\n",
      "======================================================================\n",
      "Epoch 58/100:\n",
      "  TRAIN Total=0.2763 (Recon=0.0440, Cont=0.4604, Cent=0.0042)\n",
      "  VAL   Total=0.4350 (Recon=0.0431, Cont=0.7787, Cent=0.0050)\n",
      "  No improvement for 6 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 59, Batch 0: Loss=0.2446 (Recon=0.0387, Cont=0.4087, Cent=0.0031)\n",
      "Epoch 59, Batch 100: Loss=0.2945 (Recon=0.0429, Cont=0.4984, Cent=0.0048)\n",
      "\n",
      "======================================================================\n",
      "Epoch 59/100:\n",
      "  TRAIN Total=0.2630 (Recon=0.0419, Cont=0.4381, Cent=0.0041)\n",
      "  VAL   Total=0.4402 (Recon=0.0413, Cont=0.7930, Cent=0.0048)\n",
      "  No improvement for 7 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 60, Batch 0: Loss=0.3901 (Recon=0.0393, Cont=0.6988, Cent=0.0028)\n",
      "Epoch 60, Batch 100: Loss=0.2934 (Recon=0.0435, Cont=0.4962, Cent=0.0035)\n",
      "\n",
      "======================================================================\n",
      "Epoch 60/100:\n",
      "  TRAIN Total=0.2494 (Recon=0.0426, Cont=0.4095, Cent=0.0041)\n",
      "  VAL   Total=0.4177 (Recon=0.0420, Cont=0.7465, Cent=0.0048)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 61, Batch 0: Loss=0.1874 (Recon=0.0427, Cont=0.2848, Cent=0.0046)\n",
      "Epoch 61, Batch 100: Loss=0.1806 (Recon=0.0396, Cont=0.2777, Cent=0.0043)\n",
      "\n",
      "======================================================================\n",
      "Epoch 61/100:\n",
      "  TRAIN Total=0.2390 (Recon=0.0423, Cont=0.3894, Cent=0.0040)\n",
      "  VAL   Total=0.4064 (Recon=0.0413, Cont=0.7255, Cent=0.0048)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 62, Batch 0: Loss=0.2708 (Recon=0.0422, Cont=0.4529, Cent=0.0043)\n",
      "Epoch 62, Batch 100: Loss=0.1560 (Recon=0.0403, Cont=0.2255, Cent=0.0059)\n",
      "\n",
      "======================================================================\n",
      "Epoch 62/100:\n",
      "  TRAIN Total=0.2670 (Recon=0.0424, Cont=0.4450, Cent=0.0042)\n",
      "  VAL   Total=0.3925 (Recon=0.0421, Cont=0.6960, Cent=0.0049)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 63, Batch 0: Loss=0.3088 (Recon=0.0424, Cont=0.5291, Cent=0.0038)\n",
      "Epoch 63, Batch 100: Loss=0.3395 (Recon=0.0437, Cont=0.5876, Cent=0.0040)\n",
      "\n",
      "======================================================================\n",
      "Epoch 63/100:\n",
      "  TRAIN Total=0.2488 (Recon=0.0418, Cont=0.4104, Cent=0.0038)\n",
      "  VAL   Total=0.4078 (Recon=0.0419, Cont=0.7269, Cent=0.0049)\n",
      "  No improvement for 1 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 64, Batch 0: Loss=0.1825 (Recon=0.0423, Cont=0.2763, Cent=0.0041)\n",
      "Epoch 64, Batch 100: Loss=0.2010 (Recon=0.0425, Cont=0.3125, Cent=0.0045)\n",
      "\n",
      "======================================================================\n",
      "Epoch 64/100:\n",
      "  TRAIN Total=0.2499 (Recon=0.0420, Cont=0.4118, Cent=0.0039)\n",
      "  VAL   Total=0.4515 (Recon=0.0424, Cont=0.8135, Cent=0.0049)\n",
      "  No improvement for 2 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 65, Batch 0: Loss=0.1990 (Recon=0.0384, Cont=0.3173, Cent=0.0040)\n",
      "Epoch 65, Batch 100: Loss=0.3250 (Recon=0.0475, Cont=0.5501, Cent=0.0050)\n",
      "\n",
      "======================================================================\n",
      "Epoch 65/100:\n",
      "  TRAIN Total=0.2428 (Recon=0.0420, Cont=0.3978, Cent=0.0040)\n",
      "  VAL   Total=0.4612 (Recon=0.0419, Cont=0.8335, Cent=0.0051)\n",
      "  No improvement for 3 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 66, Batch 0: Loss=0.2537 (Recon=0.0424, Cont=0.4165, Cent=0.0060)\n",
      "Epoch 66, Batch 100: Loss=0.2084 (Recon=0.0441, Cont=0.3251, Cent=0.0036)\n",
      "\n",
      "======================================================================\n",
      "Epoch 66/100:\n",
      "  TRAIN Total=0.2335 (Recon=0.0424, Cont=0.3783, Cent=0.0038)\n",
      "  VAL   Total=0.4233 (Recon=0.0414, Cont=0.7589, Cent=0.0050)\n",
      "  No improvement for 4 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 67, Batch 0: Loss=0.2622 (Recon=0.0416, Cont=0.4374, Cent=0.0038)\n",
      "Epoch 67, Batch 100: Loss=0.3635 (Recon=0.0424, Cont=0.6387, Cent=0.0035)\n",
      "\n",
      "======================================================================\n",
      "Epoch 67/100:\n",
      "  TRAIN Total=0.2383 (Recon=0.0425, Cont=0.3878, Cent=0.0039)\n",
      "  VAL   Total=0.4516 (Recon=0.0424, Cont=0.8137, Cent=0.0045)\n",
      "  No improvement for 5 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 68, Batch 0: Loss=0.2560 (Recon=0.0374, Cont=0.4340, Cent=0.0032)\n",
      "Epoch 68, Batch 100: Loss=0.1970 (Recon=0.0447, Cont=0.3010, Cent=0.0036)\n",
      "  → LR reduced: 1.25e-04 → 6.25e-05\n",
      "\n",
      "======================================================================\n",
      "Epoch 68/100:\n",
      "  TRAIN Total=0.2279 (Recon=0.0416, Cont=0.3689, Cent=0.0038)\n",
      "  VAL   Total=0.4202 (Recon=0.0423, Cont=0.7511, Cent=0.0047)\n",
      "  No improvement for 6 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 69, Batch 0: Loss=0.5631 (Recon=0.0397, Cont=1.0433, Cent=0.0036)\n",
      "Epoch 69, Batch 100: Loss=0.2612 (Recon=0.0436, Cont=0.4315, Cent=0.0037)\n",
      "\n",
      "======================================================================\n",
      "Epoch 69/100:\n",
      "  TRAIN Total=0.2378 (Recon=0.0416, Cont=0.3886, Cent=0.0037)\n",
      "  VAL   Total=0.4121 (Recon=0.0417, Cont=0.7361, Cent=0.0047)\n",
      "  No improvement for 7 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 70, Batch 0: Loss=0.1733 (Recon=0.0418, Cont=0.2597, Cent=0.0033)\n",
      "Epoch 70, Batch 100: Loss=0.2985 (Recon=0.0471, Cont=0.4982, Cent=0.0047)\n",
      "\n",
      "======================================================================\n",
      "Epoch 70/100:\n",
      "  TRAIN Total=0.2311 (Recon=0.0412, Cont=0.3762, Cent=0.0037)\n",
      "  VAL   Total=0.4078 (Recon=0.0405, Cont=0.7302, Cent=0.0043)\n",
      "  No improvement for 8 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 71, Batch 0: Loss=0.2134 (Recon=0.0390, Cont=0.3449, Cent=0.0039)\n",
      "Epoch 71, Batch 100: Loss=0.3792 (Recon=0.0418, Cont=0.6704, Cent=0.0044)\n",
      "\n",
      "======================================================================\n",
      "Epoch 71/100:\n",
      "  TRAIN Total=0.2443 (Recon=0.0413, Cont=0.4021, Cent=0.0038)\n",
      "  VAL   Total=0.4120 (Recon=0.0412, Cont=0.7367, Cent=0.0049)\n",
      "  No improvement for 9 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 72, Batch 0: Loss=0.2496 (Recon=0.0399, Cont=0.4165, Cent=0.0029)\n",
      "Epoch 72, Batch 100: Loss=0.1649 (Recon=0.0402, Cont=0.2452, Cent=0.0041)\n",
      "\n",
      "======================================================================\n",
      "Epoch 72/100:\n",
      "  TRAIN Total=0.2310 (Recon=0.0412, Cont=0.3758, Cent=0.0037)\n",
      "  VAL   Total=0.4432 (Recon=0.0403, Cont=0.8009, Cent=0.0049)\n",
      "  No improvement for 10 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 73, Batch 0: Loss=0.3373 (Recon=0.0433, Cont=0.5848, Cent=0.0033)\n",
      "Epoch 73, Batch 100: Loss=0.3804 (Recon=0.0431, Cont=0.6716, Cent=0.0030)\n",
      "\n",
      "======================================================================\n",
      "Epoch 73/100:\n",
      "  TRAIN Total=0.2341 (Recon=0.0415, Cont=0.3815, Cent=0.0037)\n",
      "  VAL   Total=0.4527 (Recon=0.0407, Cont=0.8196, Cent=0.0044)\n",
      "  No improvement for 11 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 74, Batch 0: Loss=0.1636 (Recon=0.0401, Cont=0.2445, Cent=0.0026)\n",
      "Epoch 74, Batch 100: Loss=0.3107 (Recon=0.0435, Cont=0.5301, Cent=0.0043)\n",
      "  → LR reduced: 6.25e-05 → 3.13e-05\n",
      "\n",
      "======================================================================\n",
      "Epoch 74/100:\n",
      "  TRAIN Total=0.2410 (Recon=0.0411, Cont=0.3961, Cent=0.0036)\n",
      "  VAL   Total=0.4049 (Recon=0.0405, Cont=0.7240, Cent=0.0048)\n",
      "  No improvement for 12 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 75, Batch 0: Loss=0.1579 (Recon=0.0393, Cont=0.2347, Cent=0.0025)\n",
      "Epoch 75, Batch 100: Loss=0.2211 (Recon=0.0456, Cont=0.3476, Cent=0.0033)\n",
      "\n",
      "======================================================================\n",
      "Epoch 75/100:\n",
      "  TRAIN Total=0.2287 (Recon=0.0406, Cont=0.3725, Cent=0.0036)\n",
      "  VAL   Total=0.4234 (Recon=0.0405, Cont=0.7609, Cent=0.0048)\n",
      "  No improvement for 13 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 76, Batch 0: Loss=0.1368 (Recon=0.0413, Cont=0.1875, Cent=0.0036)\n",
      "Epoch 76, Batch 100: Loss=0.2140 (Recon=0.0399, Cont=0.3449, Cent=0.0032)\n",
      "\n",
      "======================================================================\n",
      "Epoch 76/100:\n",
      "  TRAIN Total=0.2266 (Recon=0.0409, Cont=0.3677, Cent=0.0035)\n",
      "  VAL   Total=0.3787 (Recon=0.0410, Cont=0.6715, Cent=0.0040)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 77, Batch 0: Loss=0.2432 (Recon=0.0399, Cont=0.4011, Cent=0.0056)\n",
      "Epoch 77, Batch 100: Loss=0.1640 (Recon=0.0402, Cont=0.2450, Cent=0.0025)\n",
      "\n",
      "======================================================================\n",
      "Epoch 77/100:\n",
      "  TRAIN Total=0.2201 (Recon=0.0406, Cont=0.3555, Cent=0.0036)\n",
      "  VAL   Total=0.4022 (Recon=0.0420, Cont=0.7159, Cent=0.0046)\n",
      "  No improvement for 1 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 78, Batch 0: Loss=0.1885 (Recon=0.0424, Cont=0.2887, Cent=0.0036)\n",
      "Epoch 78, Batch 100: Loss=0.2027 (Recon=0.0390, Cont=0.3248, Cent=0.0028)\n",
      "\n",
      "======================================================================\n",
      "Epoch 78/100:\n",
      "  TRAIN Total=0.2348 (Recon=0.0404, Cont=0.3853, Cent=0.0035)\n",
      "  VAL   Total=0.4047 (Recon=0.0405, Cont=0.7243, Cent=0.0042)\n",
      "  No improvement for 2 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 79, Batch 0: Loss=0.1842 (Recon=0.0393, Cont=0.2861, Cent=0.0039)\n",
      "Epoch 79, Batch 100: Loss=0.2692 (Recon=0.0390, Cont=0.4571, Cent=0.0035)\n",
      "\n",
      "======================================================================\n",
      "Epoch 79/100:\n",
      "  TRAIN Total=0.2301 (Recon=0.0408, Cont=0.3749, Cent=0.0035)\n",
      "  VAL   Total=0.4321 (Recon=0.0409, Cont=0.7776, Cent=0.0047)\n",
      "  No improvement for 3 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 80, Batch 0: Loss=0.2208 (Recon=0.0405, Cont=0.3571, Cent=0.0037)\n",
      "Epoch 80, Batch 100: Loss=0.1820 (Recon=0.0419, Cont=0.2775, Cent=0.0028)\n",
      "\n",
      "======================================================================\n",
      "Epoch 80/100:\n",
      "  TRAIN Total=0.2252 (Recon=0.0404, Cont=0.3660, Cent=0.0035)\n",
      "  VAL   Total=0.3923 (Recon=0.0396, Cont=0.7012, Cent=0.0044)\n",
      "  No improvement for 4 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 81, Batch 0: Loss=0.3990 (Recon=0.0440, Cont=0.7060, Cent=0.0039)\n",
      "Epoch 81, Batch 100: Loss=0.1627 (Recon=0.0405, Cont=0.2388, Cent=0.0058)\n",
      "\n",
      "======================================================================\n",
      "Epoch 81/100:\n",
      "  TRAIN Total=0.2252 (Recon=0.0409, Cont=0.3651, Cent=0.0036)\n",
      "  VAL   Total=0.4011 (Recon=0.0415, Cont=0.7145, Cent=0.0046)\n",
      "  No improvement for 5 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 82, Batch 0: Loss=0.1263 (Recon=0.0382, Cont=0.1726, Cent=0.0035)\n",
      "Epoch 82, Batch 100: Loss=0.2029 (Recon=0.0398, Cont=0.3215, Cent=0.0046)\n",
      "  → LR reduced: 3.13e-05 → 1.56e-05\n",
      "\n",
      "======================================================================\n",
      "Epoch 82/100:\n",
      "  TRAIN Total=0.2108 (Recon=0.0402, Cont=0.3377, Cent=0.0035)\n",
      "  VAL   Total=0.4004 (Recon=0.0403, Cont=0.7157, Cent=0.0045)\n",
      "  No improvement for 6 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 83, Batch 0: Loss=0.1925 (Recon=0.0391, Cont=0.3034, Cent=0.0035)\n",
      "Epoch 83, Batch 100: Loss=0.2632 (Recon=0.0388, Cont=0.4454, Cent=0.0033)\n",
      "\n",
      "======================================================================\n",
      "Epoch 83/100:\n",
      "  TRAIN Total=0.2256 (Recon=0.0406, Cont=0.3666, Cent=0.0035)\n",
      "  VAL   Total=0.4761 (Recon=0.0404, Cont=0.8669, Cent=0.0045)\n",
      "  No improvement for 7 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 84, Batch 0: Loss=0.2045 (Recon=0.0406, Cont=0.3243, Cent=0.0035)\n",
      "Epoch 84, Batch 100: Loss=0.2070 (Recon=0.0399, Cont=0.3306, Cent=0.0035)\n",
      "\n",
      "======================================================================\n",
      "Epoch 84/100:\n",
      "  TRAIN Total=0.2284 (Recon=0.0405, Cont=0.3723, Cent=0.0035)\n",
      "  VAL   Total=0.4424 (Recon=0.0395, Cont=0.8012, Cent=0.0047)\n",
      "  No improvement for 8 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 85, Batch 0: Loss=0.2014 (Recon=0.0428, Cont=0.3137, Cent=0.0035)\n",
      "Epoch 85, Batch 100: Loss=0.1686 (Recon=0.0447, Cont=0.2448, Cent=0.0028)\n",
      "\n",
      "======================================================================\n",
      "Epoch 85/100:\n",
      "  TRAIN Total=0.2196 (Recon=0.0403, Cont=0.3551, Cent=0.0035)\n",
      "  VAL   Total=0.3775 (Recon=0.0401, Cont=0.6703, Cent=0.0044)\n",
      "  ✓ NEW BEST!\n",
      "======================================================================\n",
      "\n",
      "Epoch 86, Batch 0: Loss=0.2435 (Recon=0.0433, Cont=0.3969, Cent=0.0034)\n",
      "Epoch 86, Batch 100: Loss=0.2857 (Recon=0.0419, Cont=0.4842, Cent=0.0032)\n",
      "\n",
      "======================================================================\n",
      "Epoch 86/100:\n",
      "  TRAIN Total=0.2199 (Recon=0.0401, Cont=0.3561, Cent=0.0036)\n",
      "  VAL   Total=0.4335 (Recon=0.0402, Cont=0.7824, Cent=0.0042)\n",
      "  No improvement for 1 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 87, Batch 0: Loss=0.1380 (Recon=0.0400, Cont=0.1932, Cent=0.0028)\n",
      "Epoch 87, Batch 100: Loss=0.2555 (Recon=0.0409, Cont=0.4258, Cent=0.0034)\n",
      "\n",
      "======================================================================\n",
      "Epoch 87/100:\n",
      "  TRAIN Total=0.2306 (Recon=0.0407, Cont=0.3762, Cent=0.0036)\n",
      "  VAL   Total=0.4123 (Recon=0.0399, Cont=0.7404, Cent=0.0045)\n",
      "  No improvement for 2 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 88, Batch 0: Loss=0.2118 (Recon=0.0368, Cont=0.3474, Cent=0.0027)\n",
      "Epoch 88, Batch 100: Loss=0.1423 (Recon=0.0388, Cont=0.2033, Cent=0.0036)\n",
      "\n",
      "======================================================================\n",
      "Epoch 88/100:\n",
      "  TRAIN Total=0.2128 (Recon=0.0403, Cont=0.3415, Cent=0.0035)\n",
      "  VAL   Total=0.4405 (Recon=0.0398, Cont=0.7968, Cent=0.0046)\n",
      "  No improvement for 3 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 89, Batch 0: Loss=0.2434 (Recon=0.0423, Cont=0.3997, Cent=0.0026)\n",
      "Epoch 89, Batch 100: Loss=0.1208 (Recon=0.0355, Cont=0.1668, Cent=0.0037)\n",
      "\n",
      "======================================================================\n",
      "Epoch 89/100:\n",
      "  TRAIN Total=0.2184 (Recon=0.0401, Cont=0.3531, Cent=0.0036)\n",
      "  VAL   Total=0.4338 (Recon=0.0403, Cont=0.7825, Cent=0.0045)\n",
      "  No improvement for 4 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 90, Batch 0: Loss=0.2621 (Recon=0.0423, Cont=0.4362, Cent=0.0033)\n",
      "Epoch 90, Batch 100: Loss=0.1949 (Recon=0.0440, Cont=0.2984, Cent=0.0034)\n",
      "\n",
      "======================================================================\n",
      "Epoch 90/100:\n",
      "  TRAIN Total=0.2229 (Recon=0.0404, Cont=0.3615, Cent=0.0035)\n",
      "  VAL   Total=0.4285 (Recon=0.0402, Cont=0.7721, Cent=0.0045)\n",
      "  No improvement for 5 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 91, Batch 0: Loss=0.1815 (Recon=0.0370, Cont=0.2865, Cent=0.0026)\n",
      "Epoch 91, Batch 100: Loss=0.3183 (Recon=0.0408, Cont=0.5516, Cent=0.0034)\n",
      "  → LR reduced: 1.56e-05 → 7.81e-06\n",
      "\n",
      "======================================================================\n",
      "Epoch 91/100:\n",
      "  TRAIN Total=0.2218 (Recon=0.0403, Cont=0.3595, Cent=0.0035)\n",
      "  VAL   Total=0.3989 (Recon=0.0411, Cont=0.7110, Cent=0.0044)\n",
      "  No improvement for 6 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 92, Batch 0: Loss=0.1283 (Recon=0.0392, Cont=0.1740, Cent=0.0041)\n",
      "Epoch 92, Batch 100: Loss=0.1824 (Recon=0.0382, Cont=0.2853, Cent=0.0031)\n",
      "\n",
      "======================================================================\n",
      "Epoch 92/100:\n",
      "  TRAIN Total=0.2065 (Recon=0.0403, Cont=0.3291, Cent=0.0034)\n",
      "  VAL   Total=0.4280 (Recon=0.0390, Cont=0.7736, Cent=0.0045)\n",
      "  No improvement for 7 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 93, Batch 0: Loss=0.1268 (Recon=0.0399, Cont=0.1710, Cent=0.0029)\n",
      "Epoch 93, Batch 100: Loss=0.2704 (Recon=0.0416, Cont=0.4542, Cent=0.0032)\n",
      "\n",
      "======================================================================\n",
      "Epoch 93/100:\n",
      "  TRAIN Total=0.2094 (Recon=0.0399, Cont=0.3356, Cent=0.0034)\n",
      "  VAL   Total=0.4194 (Recon=0.0414, Cont=0.7513, Cent=0.0047)\n",
      "  No improvement for 8 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 94, Batch 0: Loss=0.2470 (Recon=0.0386, Cont=0.4128, Cent=0.0041)\n",
      "Epoch 94, Batch 100: Loss=0.3273 (Recon=0.0393, Cont=0.5722, Cent=0.0040)\n",
      "\n",
      "======================================================================\n",
      "Epoch 94/100:\n",
      "  TRAIN Total=0.2290 (Recon=0.0403, Cont=0.3738, Cent=0.0036)\n",
      "  VAL   Total=0.4606 (Recon=0.0400, Cont=0.8365, Cent=0.0046)\n",
      "  No improvement for 9 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 95, Batch 0: Loss=0.1997 (Recon=0.0400, Cont=0.3159, Cent=0.0035)\n",
      "Epoch 95, Batch 100: Loss=0.2982 (Recon=0.0411, Cont=0.5114, Cent=0.0029)\n",
      "\n",
      "======================================================================\n",
      "Epoch 95/100:\n",
      "  TRAIN Total=0.2198 (Recon=0.0398, Cont=0.3564, Cent=0.0035)\n",
      "  VAL   Total=0.4378 (Recon=0.0395, Cont=0.7921, Cent=0.0047)\n",
      "  No improvement for 10 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 96, Batch 0: Loss=0.2117 (Recon=0.0354, Cont=0.3496, Cent=0.0031)\n",
      "Epoch 96, Batch 100: Loss=0.2745 (Recon=0.0382, Cont=0.4693, Cent=0.0034)\n",
      "\n",
      "======================================================================\n",
      "Epoch 96/100:\n",
      "  TRAIN Total=0.2121 (Recon=0.0399, Cont=0.3410, Cent=0.0035)\n",
      "  VAL   Total=0.4530 (Recon=0.0404, Cont=0.8205, Cent=0.0047)\n",
      "  No improvement for 11 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 97, Batch 0: Loss=0.1949 (Recon=0.0406, Cont=0.3052, Cent=0.0035)\n",
      "Epoch 97, Batch 100: Loss=0.2125 (Recon=0.0412, Cont=0.3389, Cent=0.0037)\n",
      "  → LR reduced: 7.81e-06 → 3.91e-06\n",
      "\n",
      "======================================================================\n",
      "Epoch 97/100:\n",
      "  TRAIN Total=0.2226 (Recon=0.0401, Cont=0.3616, Cent=0.0034)\n",
      "  VAL   Total=0.4108 (Recon=0.0398, Cont=0.7376, Cent=0.0042)\n",
      "  No improvement for 12 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 98, Batch 0: Loss=0.0982 (Recon=0.0449, Cont=0.1032, Cent=0.0034)\n",
      "Epoch 98, Batch 100: Loss=0.1257 (Recon=0.0391, Cont=0.1702, Cent=0.0029)\n",
      "\n",
      "======================================================================\n",
      "Epoch 98/100:\n",
      "  TRAIN Total=0.2140 (Recon=0.0403, Cont=0.3440, Cent=0.0035)\n",
      "  VAL   Total=0.4307 (Recon=0.0409, Cont=0.7751, Cent=0.0045)\n",
      "  No improvement for 13 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 99, Batch 0: Loss=0.1391 (Recon=0.0408, Cont=0.1926, Cent=0.0041)\n",
      "Epoch 99, Batch 100: Loss=0.2119 (Recon=0.0445, Cont=0.3322, Cent=0.0026)\n",
      "\n",
      "======================================================================\n",
      "Epoch 99/100:\n",
      "  TRAIN Total=0.2212 (Recon=0.0402, Cont=0.3587, Cent=0.0035)\n",
      "  VAL   Total=0.3946 (Recon=0.0400, Cont=0.7043, Cent=0.0048)\n",
      "  No improvement for 14 epoch(s)\n",
      "======================================================================\n",
      "\n",
      "Epoch 100, Batch 0: Loss=0.2932 (Recon=0.0416, Cont=0.4995, Cent=0.0038)\n",
      "Epoch 100, Batch 100: Loss=0.2076 (Recon=0.0388, Cont=0.3339, Cent=0.0036)\n",
      "\n",
      "======================================================================\n",
      "Epoch 100/100:\n",
      "  TRAIN Total=0.2211 (Recon=0.0399, Cont=0.3588, Cent=0.0034)\n",
      "  VAL   Total=0.4345 (Recon=0.0408, Cont=0.7833, Cent=0.0042)\n",
      "  No improvement for 15 epoch(s)\n",
      "\n",
      "EARLY STOPPING at epoch 100\n",
      "✓ Restored best model from epoch 85\n",
      "\n",
      "\n",
      "======================================================================\n",
      "STEP 3A: CLASSIFIER ON MNIST ONLY\n",
      "======================================================================\n",
      "======================================================================\n",
      "STEP 3: TRAINING CLASSIFIER ON MNIST ONLY\n",
      "Classifier never sees other domains\n",
      "======================================================================\n",
      "\n",
      "Epoch 1/50: Train=96.50%, Val=96.23%\n",
      "  ✓ NEW BEST!\n",
      "Epoch 2/50: Train=96.79%, Val=96.90%\n",
      "  ✓ NEW BEST!\n",
      "Epoch 3/50: Train=96.91%, Val=96.97%\n",
      "  ✓ NEW BEST!\n",
      "Epoch 4/50: Train=96.97%, Val=97.13%\n",
      "  ✓ NEW BEST!\n",
      "Epoch 5/50: Train=96.89%, Val=97.07%\n",
      "Epoch 6/50: Train=96.94%, Val=96.53%\n",
      "Epoch 7/50: Train=96.90%, Val=96.70%\n",
      "  → LR reduced: 1.00e-03 → 5.00e-04\n",
      "Epoch 8/50: Train=96.88%, Val=96.77%\n",
      "Epoch 9/50: Train=97.05%, Val=97.07%\n",
      "Epoch 10/50: Train=97.12%, Val=97.00%\n",
      "Epoch 11/50: Train=97.04%, Val=96.70%\n",
      "  → LR reduced: 5.00e-04 → 2.50e-04\n",
      "Epoch 12/50: Train=97.04%, Val=96.57%\n",
      "Epoch 13/50: Train=97.17%, Val=97.10%\n",
      "Epoch 14/50: Train=97.15%, Val=97.00%\n",
      "\n",
      "EARLY STOPPING at epoch 14\n",
      "✓ Restored best classifier from epoch 4\n",
      "\n",
      "\n",
      "======================================================================\n",
      "STEP 3B: CLASSIFIER ON USPS ONLY\n",
      "======================================================================\n",
      "======================================================================\n",
      "STEP 3: TRAINING CLASSIFIER ON USPS ONLY\n",
      "Classifier never sees other domains\n",
      "======================================================================\n",
      "\n",
      "Epoch 1/50: Train=98.31%, Val=98.35%\n",
      "  ✓ NEW BEST!\n",
      "Epoch 2/50: Train=98.95%, Val=96.15%\n",
      "Epoch 3/50: Train=98.99%, Val=98.63%\n",
      "  ✓ NEW BEST!\n",
      "Epoch 4/50: Train=99.22%, Val=96.70%\n",
      "Epoch 5/50: Train=99.10%, Val=96.70%\n",
      "Epoch 6/50: Train=99.18%, Val=98.35%\n",
      "  → LR reduced: 1.00e-03 → 5.00e-04\n",
      "Epoch 7/50: Train=99.09%, Val=96.98%\n",
      "Epoch 8/50: Train=99.27%, Val=98.35%\n",
      "Epoch 9/50: Train=99.02%, Val=97.80%\n",
      "Epoch 10/50: Train=99.25%, Val=97.53%\n",
      "  → LR reduced: 5.00e-04 → 2.50e-04\n",
      "Epoch 11/50: Train=99.42%, Val=98.35%\n",
      "Epoch 12/50: Train=99.34%, Val=98.63%\n",
      "Epoch 13/50: Train=99.31%, Val=98.08%\n",
      "\n",
      "EARLY STOPPING at epoch 13\n",
      "✓ Restored best classifier from epoch 3\n",
      "\n",
      "\n",
      "======================================================================\n",
      "STEP 3C: CLASSIFIER ON SVHN ONLY\n",
      "======================================================================\n",
      "======================================================================\n",
      "STEP 3: TRAINING CLASSIFIER ON SVHN ONLY\n",
      "Classifier never sees other domains\n",
      "======================================================================\n",
      "\n",
      "Epoch 1/50: Train=89.70%, Val=84.08%\n",
      "  ✓ NEW BEST!\n",
      "Epoch 2/50: Train=90.86%, Val=84.79%\n",
      "  ✓ NEW BEST!\n",
      "Epoch 3/50: Train=91.40%, Val=85.01%\n",
      "  ✓ NEW BEST!\n",
      "Epoch 4/50: Train=91.50%, Val=85.14%\n",
      "  ✓ NEW BEST!\n",
      "Epoch 5/50: Train=91.69%, Val=84.93%\n",
      "Epoch 6/50: Train=91.68%, Val=84.41%\n",
      "Epoch 7/50: Train=91.87%, Val=85.09%\n",
      "  → LR reduced: 1.00e-03 → 5.00e-04\n",
      "Epoch 8/50: Train=91.82%, Val=84.76%\n",
      "Epoch 9/50: Train=92.22%, Val=85.42%\n",
      "  ✓ NEW BEST!\n",
      "Epoch 10/50: Train=92.01%, Val=85.01%\n",
      "Epoch 11/50: Train=92.38%, Val=84.98%\n",
      "Epoch 12/50: Train=92.38%, Val=85.01%\n",
      "Epoch 13/50: Train=92.47%, Val=85.58%\n",
      "  ✓ NEW BEST!\n",
      "Epoch 14/50: Train=92.32%, Val=85.99%\n",
      "  ✓ NEW BEST!\n",
      "Epoch 15/50: Train=92.56%, Val=85.39%\n",
      "Epoch 16/50: Train=92.44%, Val=85.77%\n",
      "Epoch 17/50: Train=92.33%, Val=85.64%\n",
      "  → LR reduced: 5.00e-04 → 2.50e-04\n",
      "Epoch 18/50: Train=92.46%, Val=85.14%\n",
      "Epoch 19/50: Train=92.60%, Val=85.55%\n",
      "Epoch 20/50: Train=92.56%, Val=84.76%\n",
      "Epoch 21/50: Train=92.74%, Val=85.45%\n",
      "  → LR reduced: 2.50e-04 → 1.25e-04\n",
      "Epoch 22/50: Train=92.74%, Val=85.45%\n",
      "Epoch 23/50: Train=92.76%, Val=85.53%\n",
      "Epoch 24/50: Train=92.78%, Val=84.93%\n",
      "\n",
      "EARLY STOPPING at epoch 24\n",
      "✓ Restored best classifier from epoch 14\n",
      "\n",
      "\n",
      "======================================================================\n",
      "STEP 4: ALL NINE TRANSFER EVALUATIONS\n",
      "======================================================================\n",
      "\n",
      "\n",
      "======================================================================\n",
      "TRANSFER: MNIST Clf → MNIST\n",
      "======================================================================\n",
      "Overall Accuracy: 96.37%\n",
      "Per-Class:\n",
      "  Class 0: 97.98% (291/297)\n",
      "  Class 1: 99.42% (341/343)\n",
      "  Class 2: 98.68% (299/303)\n",
      "  Class 3: 92.93% (263/283)\n",
      "  Class 4: 95.76% (271/283)\n",
      "  Class 5: 96.06% (244/254)\n",
      "  Class 6: 96.53% (306/317)\n",
      "  Class 7: 97.01% (292/301)\n",
      "  Class 8: 93.97% (296/315)\n",
      "  Class 9: 94.74% (288/304)\n",
      "\n",
      "======================================================================\n",
      "TRANSFER: MNIST Clf → USPS (TRANSFER)\n",
      "======================================================================\n",
      "Overall Accuracy: 98.36%\n",
      "Per-Class:\n",
      "  Class 0: 100.00% (29/29)\n",
      "  Class 1: 98.41% (62/63)\n",
      "  Class 2: 100.00% (52/52)\n",
      "  Class 3: 100.00% (37/37)\n",
      "  Class 4: 90.48% (19/21)\n",
      "  Class 5: 97.73% (43/44)\n",
      "  Class 6: 95.83% (23/24)\n",
      "  Class 7: 97.22% (35/36)\n",
      "  Class 8: 100.00% (31/31)\n",
      "  Class 9: 100.00% (29/29)\n",
      "\n",
      "======================================================================\n",
      "TRANSFER: MNIST Clf → SVHN (TRANSFER)\n",
      "======================================================================\n",
      "Overall Accuracy: 86.46%\n",
      "Per-Class:\n",
      "  Class 0: 81.50% (207/254)\n",
      "  Class 1: 90.27% (640/709)\n",
      "  Class 2: 91.78% (480/523)\n",
      "  Class 3: 81.51% (357/438)\n",
      "  Class 4: 90.06% (326/362)\n",
      "  Class 5: 87.07% (303/348)\n",
      "  Class 6: 80.56% (232/288)\n",
      "  Class 7: 90.04% (235/261)\n",
      "  Class 8: 77.78% (203/261)\n",
      "  Class 9: 84.09% (185/220)\n",
      "\n",
      "======================================================================\n",
      "TRANSFER: USPS Clf → MNIST (TRANSFER)\n",
      "======================================================================\n",
      "Overall Accuracy: 96.10%\n",
      "Per-Class:\n",
      "  Class 0: 98.32% (292/297)\n",
      "  Class 1: 99.71% (342/343)\n",
      "  Class 2: 98.35% (298/303)\n",
      "  Class 3: 93.64% (265/283)\n",
      "  Class 4: 94.70% (268/283)\n",
      "  Class 5: 95.67% (243/254)\n",
      "  Class 6: 96.85% (307/317)\n",
      "  Class 7: 97.34% (293/301)\n",
      "  Class 8: 94.92% (299/315)\n",
      "  Class 9: 90.79% (276/304)\n",
      "\n",
      "======================================================================\n",
      "TRANSFER: USPS Clf → USPS\n",
      "======================================================================\n",
      "Overall Accuracy: 98.36%\n",
      "Per-Class:\n",
      "  Class 0: 100.00% (29/29)\n",
      "  Class 1: 98.41% (62/63)\n",
      "  Class 2: 100.00% (52/52)\n",
      "  Class 3: 100.00% (37/37)\n",
      "  Class 4: 95.24% (20/21)\n",
      "  Class 5: 100.00% (44/44)\n",
      "  Class 6: 95.83% (23/24)\n",
      "  Class 7: 100.00% (36/36)\n",
      "  Class 8: 100.00% (31/31)\n",
      "  Class 9: 89.66% (26/29)\n",
      "\n",
      "======================================================================\n",
      "TRANSFER: USPS Clf → SVHN (TRANSFER)\n",
      "======================================================================\n",
      "Overall Accuracy: 83.79%\n",
      "Per-Class:\n",
      "  Class 0: 82.28% (209/254)\n",
      "  Class 1: 92.10% (653/709)\n",
      "  Class 2: 87.19% (456/523)\n",
      "  Class 3: 76.03% (333/438)\n",
      "  Class 4: 86.19% (312/362)\n",
      "  Class 5: 77.87% (271/348)\n",
      "  Class 6: 82.64% (238/288)\n",
      "  Class 7: 85.06% (222/261)\n",
      "  Class 8: 75.10% (196/261)\n",
      "  Class 9: 81.82% (180/220)\n",
      "\n",
      "======================================================================\n",
      "TRANSFER: SVHN Clf → MNIST (TRANSFER)\n",
      "======================================================================\n",
      "Overall Accuracy: 97.40%\n",
      "Per-Class:\n",
      "  Class 0: 98.32% (292/297)\n",
      "  Class 1: 98.54% (338/343)\n",
      "  Class 2: 99.67% (302/303)\n",
      "  Class 3: 96.47% (273/283)\n",
      "  Class 4: 98.59% (279/283)\n",
      "  Class 5: 96.06% (244/254)\n",
      "  Class 6: 99.05% (314/317)\n",
      "  Class 7: 98.01% (295/301)\n",
      "  Class 8: 94.60% (298/315)\n",
      "  Class 9: 94.41% (287/304)\n",
      "\n",
      "======================================================================\n",
      "TRANSFER: SVHN Clf → USPS (TRANSFER)\n",
      "======================================================================\n",
      "Overall Accuracy: 97.81%\n",
      "Per-Class:\n",
      "  Class 0: 96.55% (28/29)\n",
      "  Class 1: 98.41% (62/63)\n",
      "  Class 2: 100.00% (52/52)\n",
      "  Class 3: 100.00% (37/37)\n",
      "  Class 4: 95.24% (20/21)\n",
      "  Class 5: 97.73% (43/44)\n",
      "  Class 6: 95.83% (23/24)\n",
      "  Class 7: 97.22% (35/36)\n",
      "  Class 8: 100.00% (31/31)\n",
      "  Class 9: 93.10% (27/29)\n",
      "\n",
      "======================================================================\n",
      "TRANSFER: SVHN Clf → SVHN\n",
      "======================================================================\n",
      "Overall Accuracy: 86.87%\n",
      "Per-Class:\n",
      "  Class 0: 88.19% (224/254)\n",
      "  Class 1: 89.56% (635/709)\n",
      "  Class 2: 90.63% (474/523)\n",
      "  Class 3: 85.62% (375/438)\n",
      "  Class 4: 88.67% (321/362)\n",
      "  Class 5: 82.47% (287/348)\n",
      "  Class 6: 80.90% (233/288)\n",
      "  Class 7: 90.80% (237/261)\n",
      "  Class 8: 82.76% (216/261)\n",
      "  Class 9: 82.27% (181/220)\n",
      "\n",
      "======================================================================\n",
      "FINAL RESULTS — THREE-DOMAIN UNIVERSAL TRANSFER\n",
      "All domains share semantic meaning: digits 0-9\n",
      "Shared encoder + Domain-Specific BN → same Z space\n",
      "======================================================================\n",
      "\n",
      "┌──────────────────┬──────────┬──────────┬──────────┐\n",
      "│ Classifier →     │  MNIST   │   USPS   │   SVHN   │\n",
      "│ Target Domain ↓  │          │          │          │\n",
      "├──────────────────┼──────────┼──────────┼──────────┤\n",
      "│ MNIST            │  96.37%  │  96.10%  │  97.40%  │\n",
      "│ USPS             │  98.36%  │  98.36%  │  97.81%  │\n",
      "│ SVHN             │  86.46%  │  83.79%  │  86.87%  │\n",
      "└──────────────────┴──────────┴──────────┴──────────┘\n",
      "Diagonal  = same domain baseline\n",
      "Off-diag  = zero-shot cross-domain transfer\n",
      "    \n",
      "✓ All results and models saved!\n",
      "  - three_domain_domainBN_results.json\n",
      "  - three_domain_domainBN_manifold.pth\n",
      "  - three_domain_domainBN_clf_mnist/usps/svhn.pth\n",
      "======================================================================\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader, Dataset, Subset, WeightedRandomSampler\n",
    "import numpy as np\n",
    "import json\n",
    "import os\n",
    "import urllib.request\n",
    "import bz2\n",
    "\n",
    "# ============================================================================\n",
    "# THREE-DOMAIN UNIVERSAL TRANSFER LEARNING\n",
    "# Domains: MNIST + USPS + SVHN (all digits 0-9, semantically meaningful)\n",
    "#\n",
    "# KEY DESIGN: Domain-Specific Batch Normalization\n",
    "# - Shared conv weights  → learn universal features across all domains\n",
    "# - Domain-specific BN   → normalize each domain's statistics separately\n",
    "# - Same Z space         → one universal latent manifold by construction\n",
    "#\n",
    "# This directly proves ULHM's claim:\n",
    "# \"A single universal Z space exists that represents all domains\n",
    "#  and enables transfer because all domains share the same topology.\"\n",
    "# ============================================================================\n",
    "\n",
    "DOMAIN_MNIST = 0\n",
    "DOMAIN_USPS  = 1\n",
    "DOMAIN_SVHN  = 2\n",
    "NUM_DOMAINS  = 3\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 1. USPS MANUAL DOWNLOAD\n",
    "# torchvision USPS URL is broken — use working mirror\n",
    "# ============================================================================\n",
    "\n",
    "class USPSDataset(Dataset):\n",
    "    \"\"\"\n",
    "    Manual USPS dataset loader from working mirror.\n",
    "    Avoids broken torchvision USPS download URL.\n",
    "    \"\"\"\n",
    "    URLS = {\n",
    "        'train': 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2',\n",
    "        'test':  'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2'\n",
    "    }\n",
    "\n",
    "    def __init__(self, root, train=True, transform=None, download=True):\n",
    "        self.root = os.path.join(root, 'usps_manual')\n",
    "        self.train = train\n",
    "        self.transform = transform\n",
    "        os.makedirs(self.root, exist_ok=True)\n",
    "\n",
    "        split = 'train' if train else 'test'\n",
    "        filepath = os.path.join(self.root, f'usps_{split}.bz2')\n",
    "\n",
    "        if download and not os.path.exists(filepath):\n",
    "            print(f'Downloading USPS {split} from working mirror...')\n",
    "            urllib.request.urlretrieve(self.URLS[split], filepath)\n",
    "            print(f'✓ Downloaded USPS {split}')\n",
    "\n",
    "        self.data, self.labels = self._load_libsvm(filepath)\n",
    "\n",
    "    def _load_libsvm(self, filepath):\n",
    "        data, labels = [], []\n",
    "        with bz2.open(filepath, 'rt') as f:\n",
    "            for line in f:\n",
    "                parts = line.strip().split()\n",
    "                if not parts:\n",
    "                    continue\n",
    "                label = int(parts[0]) % 10\n",
    "                labels.append(label)\n",
    "                features = np.zeros(256, dtype=np.float32)\n",
    "                for feat in parts[1:]:\n",
    "                    if ':' in feat:\n",
    "                        idx, val = feat.split(':')\n",
    "                        features[int(idx) - 1] = float(val)\n",
    "                img = features.reshape(16, 16)\n",
    "                img = (img + 1.0) / 2.0  # normalize to [0,1]\n",
    "                data.append(img)\n",
    "        return np.array(data, dtype=np.float32), np.array(labels, dtype=np.int64)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img = torch.FloatTensor(self.data[idx]).unsqueeze(0)  # [1, 16, 16]\n",
    "        img = img * 2.0 - 1.0  # normalize to [-1, 1]\n",
    "        label = int(self.labels[idx])\n",
    "        if self.transform:\n",
    "            img = self.transform(img)\n",
    "        return img, label\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 2. PREPROCESSING\n",
    "# All domains → 32×32 × 3 channels\n",
    "# MNIST/USPS: pad to 32×32 then repeat channel 3 times\n",
    "# SVHN: already 32×32×3\n",
    "# ============================================================================\n",
    "\n",
    "class PadAndRepeatChannels:\n",
    "    \"\"\"\n",
    "    Grayscale [1, H, W] → RGB-like [3, 32, 32]\n",
    "    Step 1: Pad to 32×32 (center padding)\n",
    "    Step 2: Repeat channel 3 times\n",
    "    Preserves all information — no downsampling\n",
    "    \"\"\"\n",
    "    def __call__(self, x):\n",
    "        c, h, w = x.shape\n",
    "        pad_h  = (32 - h) // 2\n",
    "        pad_w  = (32 - w) // 2\n",
    "        pad_h2 = 32 - h - pad_h\n",
    "        pad_w2 = 32 - w - pad_w\n",
    "        x = F.pad(x, [pad_w, pad_w2, pad_h, pad_h2], value=0)\n",
    "        x = x.repeat(3, 1, 1)  # [1,32,32] → [3,32,32]\n",
    "        return x\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 3. SPARSE MULTI-DOMAIN DATASET\n",
    "# ============================================================================\n",
    "\n",
    "class SparseMultiDomainDataset(Dataset):\n",
    "    \"\"\"\n",
    "    Sparse dataset for any domain.\n",
    "    All images: 32×32×3 after transforms.\n",
    "    Spatial mask applied consistently across channels.\n",
    "    domain_id: 0=MNIST, 1=USPS, 2=SVHN\n",
    "    \"\"\"\n",
    "    def __init__(self, base_dataset, sparsity_level=0.15,\n",
    "                 domain_id=0, domain_name='Unknown'):\n",
    "        self.base_dataset  = base_dataset\n",
    "        self.sparsity_level = sparsity_level\n",
    "        self.domain_id     = domain_id\n",
    "        self.domain_name   = domain_name\n",
    "        print(f\"{domain_name} Sparse Dataset — Size: {len(base_dataset)}, \"\n",
    "              f\"Sparsity: {sparsity_level:.0%} visible\")\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.base_dataset)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        full_image, label = self.base_dataset[idx]  # [3, 32, 32]\n",
    "\n",
    "        # Same spatial mask for all 3 channels\n",
    "        spatial_mask = (torch.rand(1, 32, 32) < self.sparsity_level).float()\n",
    "        mask         = spatial_mask.repeat(3, 1, 1)  # [3, 32, 32]\n",
    "        sparse_image = full_image * mask\n",
    "\n",
    "        label = int(label) % 10  # ensure 0-9\n",
    "\n",
    "        x = {'sparse_image': sparse_image, 'mask': mask}\n",
    "        s = {'full_image': full_image, 'label': label, 'domain': self.domain_id}\n",
    "        return x, s\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 4. DOMAIN-SPECIFIC BATCH NORMALIZATION RESIDUAL BLOCK\n",
    "#\n",
    "# Core design:\n",
    "#   - Shared conv weights  → same feature transformation for all domains\n",
    "#   - Domain-specific BN   → separate running stats per domain\n",
    "#\n",
    "# Why this works:\n",
    "#   MNIST, USPS, SVHN have very different pixel statistics.\n",
    "#   Shared BN would average them → unstable training.\n",
    "#   Domain-specific BN normalizes each domain independently\n",
    "#   while conv weights learn universal features.\n",
    "#\n",
    "# This is the same design used in:\n",
    "#   - Google's multilingual translation models\n",
    "#   - Universal style transfer networks\n",
    "#   - Multi-domain image recognition (Jin et al., 2018)\n",
    "# ============================================================================\n",
    "\n",
    "class ResidualBlockDomainBN(nn.Module):\n",
    "    \"\"\"\n",
    "    Residual block with domain-specific BatchNorm.\n",
    "    Shared conv weights + domain-specific BN layers.\n",
    "    forward(x, domain_id) — domain_id selects which BN to use.\n",
    "    \"\"\"\n",
    "    def __init__(self, channels, num_domains=NUM_DOMAINS):\n",
    "        super().__init__()\n",
    "\n",
    "        # Shared conv weights — same for all domains\n",
    "        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)\n",
    "        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)\n",
    "\n",
    "        # Domain-specific BN — one set per domain\n",
    "        self.bn1 = nn.ModuleList(\n",
    "            [nn.BatchNorm2d(channels) for _ in range(num_domains)]\n",
    "        )\n",
    "        self.bn2 = nn.ModuleList(\n",
    "            [nn.BatchNorm2d(channels) for _ in range(num_domains)]\n",
    "        )\n",
    "\n",
    "    def forward(self, x, domain_id):\n",
    "        residual = x\n",
    "        out = F.relu(self.bn1[domain_id](x))\n",
    "        out = self.conv1(out)\n",
    "        out = F.relu(self.bn2[domain_id](out))\n",
    "        out = self.conv2(out)\n",
    "        return out + residual\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 5. SHARED UNIVERSAL ENCODER WITH DOMAIN-SPECIFIC BN\n",
    "#\n",
    "# Architecture:\n",
    "#   Input:  [B, 6, 32, 32]  (3ch sparse image + 3ch mask)\n",
    "#   Scale1: 32ch residual blocks  (full resolution 32×32)\n",
    "#   Scale2: 64ch residual blocks  (16×16 after stride-2 conv)\n",
    "#   Scale3: 128ch residual blocks (8×8)\n",
    "#   Scale4: 256ch residual blocks (4×4)\n",
    "#   Output: [B, latent_channels, 4, 4]  → same Z space for all domains\n",
    "#\n",
    "# Key property:\n",
    "#   ALL domains map to the SAME Z space.\n",
    "#   Z space IS the universal latent manifold M_z.\n",
    "#   No post-hoc alignment needed — homeomorphism by construction.\n",
    "# ============================================================================\n",
    "\n",
    "class SharedUniversalEncoder(nn.Module):\n",
    "    \"\"\"\n",
    "    ONE shared encoder for ALL three domains.\n",
    "    Shared conv weights + domain-specific BatchNorm.\n",
    "    All domains map to the same Z space — universal manifold by construction.\n",
    "    \"\"\"\n",
    "    def __init__(self, latent_channels=128, num_res_blocks=3,\n",
    "                 num_domains=NUM_DOMAINS):\n",
    "        super().__init__()\n",
    "        self.latent_channels = latent_channels\n",
    "        self.num_domains = num_domains\n",
    "\n",
    "        # ---- Input conv: 6 → 32 ----\n",
    "        self.input_conv = nn.Conv2d(6, 32, 3, padding=1, bias=False)\n",
    "        self.input_bn   = nn.ModuleList(\n",
    "            [nn.BatchNorm2d(32) for _ in range(num_domains)]\n",
    "        )\n",
    "\n",
    "        # ---- Scale 1: 32ch, 32×32 ----\n",
    "        self.res1 = nn.ModuleList(\n",
    "            [ResidualBlockDomainBN(32, num_domains) for _ in range(num_res_blocks)]\n",
    "        )\n",
    "\n",
    "        # ---- Downsample 32→64, 32×32→16×16 ----\n",
    "        self.down1    = nn.Conv2d(32, 64, 3, stride=2, padding=1, bias=False)\n",
    "        self.down1_bn = nn.ModuleList(\n",
    "            [nn.BatchNorm2d(64) for _ in range(num_domains)]\n",
    "        )\n",
    "\n",
    "        # ---- Scale 2: 64ch, 16×16 ----\n",
    "        self.res2 = nn.ModuleList(\n",
    "            [ResidualBlockDomainBN(64, num_domains) for _ in range(num_res_blocks)]\n",
    "        )\n",
    "\n",
    "        # ---- Downsample 64→128, 16×16→8×8 ----\n",
    "        self.down2    = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False)\n",
    "        self.down2_bn = nn.ModuleList(\n",
    "            [nn.BatchNorm2d(128) for _ in range(num_domains)]\n",
    "        )\n",
    "\n",
    "        # ---- Scale 3: 128ch, 8×8 ----\n",
    "        self.res3 = nn.ModuleList(\n",
    "            [ResidualBlockDomainBN(128, num_domains) for _ in range(num_res_blocks)]\n",
    "        )\n",
    "\n",
    "        # ---- Downsample 128→256, 8×8→4×4 ----\n",
    "        self.down3    = nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False)\n",
    "        self.down3_bn = nn.ModuleList(\n",
    "            [nn.BatchNorm2d(256) for _ in range(num_domains)]\n",
    "        )\n",
    "\n",
    "        # ---- Scale 4: 256ch, 4×4 ----\n",
    "        self.res4 = nn.ModuleList(\n",
    "            [ResidualBlockDomainBN(256, num_domains) for _ in range(num_res_blocks)]\n",
    "        )\n",
    "\n",
    "        # ---- Bottleneck: 256→latent_channels ----\n",
    "        self.bottleneck    = nn.Conv2d(256, latent_channels, 1, bias=False)\n",
    "        self.bottleneck_bn = nn.ModuleList(\n",
    "            [nn.BatchNorm2d(latent_channels) for _ in range(num_domains)]\n",
    "        )\n",
    "\n",
    "        # ---- Bottleneck residual blocks ----\n",
    "        self.res_bottleneck = nn.ModuleList(\n",
    "            [ResidualBlockDomainBN(latent_channels, num_domains)\n",
    "             for _ in range(num_res_blocks)]\n",
    "        )\n",
    "\n",
    "    def forward(self, sparse_image, mask, domain_id):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            sparse_image: [B, 3, 32, 32]\n",
    "            mask:         [B, 3, 32, 32]\n",
    "            domain_id:    int — 0=MNIST, 1=USPS, 2=SVHN\n",
    "        Returns:\n",
    "            z: [B, latent_channels, 4, 4] — universal latent code\n",
    "        \"\"\"\n",
    "        # Concatenate sparse image and mask\n",
    "        x = torch.cat([sparse_image, mask], dim=1)  # [B, 6, 32, 32]\n",
    "\n",
    "        # Input conv + domain BN\n",
    "        x = F.relu(self.input_bn[domain_id](self.input_conv(x)))\n",
    "\n",
    "        # Scale 1\n",
    "        for block in self.res1:\n",
    "            x = block(x, domain_id)\n",
    "\n",
    "        # Downsample 1\n",
    "        x = F.relu(self.down1_bn[domain_id](self.down1(x)))\n",
    "\n",
    "        # Scale 2\n",
    "        for block in self.res2:\n",
    "            x = block(x, domain_id)\n",
    "\n",
    "        # Downsample 2\n",
    "        x = F.relu(self.down2_bn[domain_id](self.down2(x)))\n",
    "\n",
    "        # Scale 3\n",
    "        for block in self.res3:\n",
    "            x = block(x, domain_id)\n",
    "\n",
    "        # Downsample 3\n",
    "        x = F.relu(self.down3_bn[domain_id](self.down3(x)))\n",
    "\n",
    "        # Scale 4\n",
    "        for block in self.res4:\n",
    "            x = block(x, domain_id)\n",
    "\n",
    "        # Bottleneck\n",
    "        x = F.relu(self.bottleneck_bn[domain_id](self.bottleneck(x)))\n",
    "\n",
    "        # Bottleneck residual\n",
    "        for block in self.res_bottleneck:\n",
    "            x = block(x, domain_id)\n",
    "\n",
    "        return x  # [B, latent_channels, 4, 4] — same Z for ALL domains\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 6. DOMAIN-SPECIFIC DECODER\n",
    "# Each domain reconstructs its own images from universal Z\n",
    "# Input:  [B, latent_channels, 4, 4]\n",
    "# Output: [B, 3, 32, 32]\n",
    "# ============================================================================\n",
    "\n",
    "class ResidualBlock(nn.Module):\n",
    "    \"\"\"Standard residual block for decoders (no domain BN needed)\"\"\"\n",
    "    def __init__(self, channels):\n",
    "        super().__init__()\n",
    "        self.bn1   = nn.BatchNorm2d(channels)\n",
    "        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)\n",
    "        self.bn2   = nn.BatchNorm2d(channels)\n",
    "        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        residual = x\n",
    "        out = F.relu(self.bn1(x))\n",
    "        out = self.conv1(out)\n",
    "        out = F.relu(self.bn2(out))\n",
    "        out = self.conv2(out)\n",
    "        return out + residual\n",
    "\n",
    "\n",
    "class DomainSpecificDecoder(nn.Module):\n",
    "    \"\"\"\n",
    "    Domain-specific decoder.\n",
    "    Takes universal Z and reconstructs domain-specific images.\n",
    "    Standard BN is fine here — each decoder only sees one domain.\n",
    "    Output: [B, 3, 32, 32]\n",
    "    \"\"\"\n",
    "    def __init__(self, latent_channels=128, num_res_blocks=2):\n",
    "        super().__init__()\n",
    "\n",
    "        self.res_bottleneck = nn.Sequential(\n",
    "            *[ResidualBlock(latent_channels) for _ in range(num_res_blocks)]\n",
    "        )\n",
    "        # 4×4 → 8×8\n",
    "        self.up1 = nn.Sequential(\n",
    "            nn.ConvTranspose2d(latent_channels, 256, 4, stride=2, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(256), nn.ReLU()\n",
    "        )\n",
    "        self.res1 = nn.Sequential(*[ResidualBlock(256) for _ in range(num_res_blocks)])\n",
    "\n",
    "        # 8×8 → 16×16\n",
    "        self.up2 = nn.Sequential(\n",
    "            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(128), nn.ReLU()\n",
    "        )\n",
    "        self.res2 = nn.Sequential(*[ResidualBlock(128) for _ in range(num_res_blocks)])\n",
    "\n",
    "        # 16×16 → 32×32\n",
    "        self.up3 = nn.Sequential(\n",
    "            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(64), nn.ReLU()\n",
    "        )\n",
    "        self.res3 = nn.Sequential(*[ResidualBlock(64) for _ in range(num_res_blocks)])\n",
    "\n",
    "        # Output: 3 channels\n",
    "        self.output_conv = nn.Sequential(\n",
    "            nn.Conv2d(64, 32, 3, padding=1), nn.ReLU(),\n",
    "            nn.Conv2d(32, 3,  3, padding=1), nn.Tanh()\n",
    "        )\n",
    "\n",
    "    def forward(self, z):\n",
    "        x = self.res_bottleneck(z)\n",
    "        x = self.up1(x)\n",
    "        x = self.res1(x)\n",
    "        x = self.up2(x)\n",
    "        x = self.res2(x)\n",
    "        x = self.up3(x)\n",
    "        x = self.res3(x)\n",
    "        return self.output_conv(x)\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 7. DOMAIN-INVARIANT PROJECTION NETWORK\n",
    "# One per domain — fine-tunes Z for cross-domain alignment\n",
    "# ============================================================================\n",
    "\n",
    "class DomainInvariantProjection(nn.Module):\n",
    "    def __init__(self, latent_channels=128, num_res_blocks=2):\n",
    "        super().__init__()\n",
    "        self.projection = nn.Sequential(\n",
    "            *[ResidualBlock(latent_channels) for _ in range(num_res_blocks)],\n",
    "            nn.Conv2d(latent_channels, latent_channels, 1),\n",
    "            nn.BatchNorm2d(latent_channels)\n",
    "        )\n",
    "\n",
    "    def forward(self, z):\n",
    "        return self.projection(z)\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 8. SPATIAL CLASSIFIER\n",
    "# ============================================================================\n",
    "\n",
    "class SpatialClassifier(nn.Module):\n",
    "    def __init__(self, latent_channels=128, num_classes=10):\n",
    "        super().__init__()\n",
    "        self.features = nn.Sequential(\n",
    "            nn.Conv2d(latent_channels, 256, 3, padding=1),\n",
    "            nn.BatchNorm2d(256), nn.ReLU(), nn.Dropout2d(0.3),\n",
    "            nn.Conv2d(256, 512, 3, padding=1),\n",
    "            nn.BatchNorm2d(512), nn.ReLU(), nn.Dropout2d(0.3),\n",
    "            nn.AdaptiveAvgPool2d((1, 1))\n",
    "        )\n",
    "        self.fc = nn.Linear(512, num_classes)\n",
    "\n",
    "    def forward(self, z):\n",
    "        f = self.features(z)\n",
    "        f = f.view(f.size(0), -1)\n",
    "        return self.fc(f)\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 9. ALIGNMENT LOSSES (three-way)\n",
    "# ============================================================================\n",
    "\n",
    "def contrastive_alignment_loss(z_a, z_b, labels_a, labels_b, temperature=0.1):\n",
    "    \"\"\"Contrastive alignment between two domains\"\"\"\n",
    "    z_a_flat = F.normalize(z_a.view(z_a.size(0), -1), p=2, dim=1)\n",
    "    z_b_flat = F.normalize(z_b.view(z_b.size(0), -1), p=2, dim=1)\n",
    "\n",
    "    sim = torch.matmul(z_a_flat, z_b_flat.t()) / temperature\n",
    "    align_mask = (labels_a.unsqueeze(1) == labels_b.unsqueeze(0)).float()\n",
    "\n",
    "    loss = 0.0\n",
    "    count = 0\n",
    "    for i in range(z_a.size(0)):\n",
    "        pos = align_mask[i]\n",
    "        if pos.sum() == 0:\n",
    "            continue\n",
    "        pos_sim = sim[i][pos.bool()]\n",
    "        neg_sim = sim[i][(1 - pos).bool()]\n",
    "        pos_exp = torch.exp(pos_sim)\n",
    "        neg_exp = torch.exp(neg_sim).sum()\n",
    "        loss += -torch.log(pos_exp / (pos_exp + neg_exp + 1e-8)).mean()\n",
    "        count += 1\n",
    "\n",
    "    return loss / count if count > 0 else torch.tensor(0.0, device=z_a.device)\n",
    "\n",
    "\n",
    "def class_centroid_alignment_loss(z_a, z_b, labels_a, labels_b, num_classes=10):\n",
    "    \"\"\"Centroid alignment between two domains\"\"\"\n",
    "    z_a_flat = z_a.view(z_a.size(0), -1)\n",
    "    z_b_flat = z_b.view(z_b.size(0), -1)\n",
    "\n",
    "    loss = 0.0\n",
    "    count = 0\n",
    "    for c in range(num_classes):\n",
    "        mask_a = (labels_a == c)\n",
    "        mask_b = (labels_b == c)\n",
    "        if mask_a.sum() == 0 or mask_b.sum() == 0:\n",
    "            continue\n",
    "        loss += F.mse_loss(z_a_flat[mask_a].mean(0), z_b_flat[mask_b].mean(0))\n",
    "        count += 1\n",
    "\n",
    "    return loss / count if count > 0 else torch.tensor(0.0, device=z_a.device)\n",
    "\n",
    "\n",
    "def three_way_alignment_loss(z_mnist, z_usps, z_svhn,\n",
    "                              labels_mnist, labels_usps, labels_svhn,\n",
    "                              temperature=0.1):\n",
    "    \"\"\"\n",
    "    Three-way alignment across all domain pairs simultaneously.\n",
    "    Pairs: MNIST↔USPS, MNIST↔SVHN, USPS↔SVHN\n",
    "    \"\"\"\n",
    "    # Contrastive losses\n",
    "    lc_mu = contrastive_alignment_loss(z_mnist, z_usps,  labels_mnist, labels_usps,  temperature)\n",
    "    lc_ms = contrastive_alignment_loss(z_mnist, z_svhn,  labels_mnist, labels_svhn,  temperature)\n",
    "    lc_us = contrastive_alignment_loss(z_usps,  z_svhn,  labels_usps,  labels_svhn,  temperature)\n",
    "\n",
    "    # Centroid losses\n",
    "    lk_mu = class_centroid_alignment_loss(z_mnist, z_usps, labels_mnist, labels_usps)\n",
    "    lk_ms = class_centroid_alignment_loss(z_mnist, z_svhn, labels_mnist, labels_svhn)\n",
    "    lk_us = class_centroid_alignment_loss(z_usps,  z_svhn, labels_usps,  labels_svhn)\n",
    "\n",
    "    loss_contrastive = (lc_mu + lc_ms + lc_us) / 3.0\n",
    "    loss_centroid    = (lk_mu + lk_ms + lk_us) / 3.0\n",
    "\n",
    "    return loss_contrastive, loss_centroid\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 10. STEP 1: PRE-TRAIN SHARED ENCODER (Domain-Specific BN)\n",
    "# Train shared encoder jointly on all three domains.\n",
    "# Each domain uses its own BN stats — shared conv weights learn universal features.\n",
    "# ============================================================================\n",
    "\n",
    "def pretrain_shared_encoder(shared_encoder,\n",
    "                             decoder_mnist, decoder_usps, decoder_svhn,\n",
    "                             train_loader_mnist, train_loader_usps, train_loader_svhn,\n",
    "                             val_loader_mnist,   val_loader_usps,   val_loader_svhn,\n",
    "                             max_epochs=100, patience=15, device='cpu'):\n",
    "    \"\"\"\n",
    "    Stage 1: Joint pre-training of shared encoder on all three domains.\n",
    "\n",
    "    Each domain passes through the shared encoder with its own domain_id,\n",
    "    which selects the appropriate domain-specific BN layers.\n",
    "    Conv weights are updated by gradients from ALL three domains simultaneously.\n",
    "\n",
    "    This enforces that the same conv weights must work for MNIST, USPS, and SVHN\n",
    "    — producing universal features in the same Z space.\n",
    "    \"\"\"\n",
    "    print(\"=\" * 70)\n",
    "    print(\"STEP 1: JOINT PRE-TRAINING — SHARED ENCODER + DOMAIN-SPECIFIC BN\")\n",
    "    print(\"Shared conv weights learn universal features across all 3 domains\")\n",
    "    print(\"Domain-specific BN normalizes each domain's statistics separately\")\n",
    "    print(\"ALL domains map to the SAME Z space — universal manifold\")\n",
    "    print(\"=\" * 70 + \"\\n\")\n",
    "\n",
    "    shared_encoder.to(device)\n",
    "    decoder_mnist.to(device)\n",
    "    decoder_usps.to(device)\n",
    "    decoder_svhn.to(device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(\n",
    "        list(shared_encoder.parameters()) +\n",
    "        list(decoder_mnist.parameters()) +\n",
    "        list(decoder_usps.parameters()) +\n",
    "        list(decoder_svhn.parameters()),\n",
    "        lr=1e-3, weight_decay=1e-5\n",
    "    )\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
    "        optimizer, mode='min', factor=0.5, patience=5\n",
    "    )\n",
    "\n",
    "    best_val_loss = float('inf')\n",
    "    best_epoch    = 0\n",
    "    no_improve    = 0\n",
    "    best_state    = None\n",
    "\n",
    "    for epoch in range(max_epochs):\n",
    "        shared_encoder.train()\n",
    "        decoder_mnist.train()\n",
    "        decoder_usps.train()\n",
    "        decoder_svhn.train()\n",
    "\n",
    "        total_loss  = 0\n",
    "        loss_mnist_ = 0\n",
    "        loss_usps_  = 0\n",
    "        loss_svhn_  = 0\n",
    "\n",
    "        mnist_iter = iter(train_loader_mnist)\n",
    "        usps_iter  = iter(train_loader_usps)\n",
    "        svhn_iter  = iter(train_loader_svhn)\n",
    "\n",
    "        num_batches = min(len(train_loader_mnist),\n",
    "                         len(train_loader_usps),\n",
    "                         len(train_loader_svhn))\n",
    "\n",
    "        for batch_idx in range(num_batches):\n",
    "            try:\n",
    "                x_mnist, s_mnist = next(mnist_iter)\n",
    "                x_usps,  s_usps  = next(usps_iter)\n",
    "                x_svhn,  s_svhn  = next(svhn_iter)\n",
    "            except StopIteration:\n",
    "                break\n",
    "\n",
    "            # MNIST\n",
    "            sp_mnist   = x_mnist['sparse_image'].to(device)\n",
    "            mk_mnist   = x_mnist['mask'].to(device)\n",
    "            full_mnist = s_mnist['full_image'].to(device)\n",
    "\n",
    "            # USPS\n",
    "            sp_usps    = x_usps['sparse_image'].to(device)\n",
    "            mk_usps    = x_usps['mask'].to(device)\n",
    "            full_usps  = s_usps['full_image'].to(device)\n",
    "\n",
    "            # SVHN\n",
    "            sp_svhn    = x_svhn['sparse_image'].to(device)\n",
    "            mk_svhn    = x_svhn['mask'].to(device)\n",
    "            full_svhn  = s_svhn['full_image'].to(device)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # Each domain uses its own domain_id → own BN stats\n",
    "            # Same conv weights process all three\n",
    "            z_mnist = shared_encoder(sp_mnist, mk_mnist, DOMAIN_MNIST)\n",
    "            z_usps  = shared_encoder(sp_usps,  mk_usps,  DOMAIN_USPS)\n",
    "            z_svhn  = shared_encoder(sp_svhn,  mk_svhn,  DOMAIN_SVHN)\n",
    "\n",
    "            # Domain-specific reconstruction\n",
    "            recon_mnist = decoder_mnist(z_mnist)\n",
    "            recon_usps  = decoder_usps(z_usps)\n",
    "            recon_svhn  = decoder_svhn(z_svhn)\n",
    "\n",
    "            lm = F.mse_loss(recon_mnist, full_mnist)\n",
    "            lu = F.mse_loss(recon_usps,  full_usps)\n",
    "            ls = F.mse_loss(recon_svhn,  full_svhn)\n",
    "\n",
    "            loss = lm + lu + ls\n",
    "            loss.backward()\n",
    "\n",
    "            torch.nn.utils.clip_grad_norm_(shared_encoder.parameters(), 1.0)\n",
    "            torch.nn.utils.clip_grad_norm_(decoder_mnist.parameters(),  1.0)\n",
    "            torch.nn.utils.clip_grad_norm_(decoder_usps.parameters(),   1.0)\n",
    "            torch.nn.utils.clip_grad_norm_(decoder_svhn.parameters(),   1.0)\n",
    "            optimizer.step()\n",
    "\n",
    "            total_loss  += loss.item()\n",
    "            loss_mnist_ += lm.item()\n",
    "            loss_usps_  += lu.item()\n",
    "            loss_svhn_  += ls.item()\n",
    "\n",
    "            if batch_idx % 100 == 0:\n",
    "                print(f'Epoch {epoch+1}, Batch {batch_idx}: '\n",
    "                      f'Total={loss.item():.4f} '\n",
    "                      f'(MNIST={lm.item():.4f}, '\n",
    "                      f'USPS={lu.item():.4f}, '\n",
    "                      f'SVHN={ls.item():.4f})')\n",
    "\n",
    "        # Validation\n",
    "        shared_encoder.eval()\n",
    "        decoder_mnist.eval()\n",
    "        decoder_usps.eval()\n",
    "        decoder_svhn.eval()\n",
    "\n",
    "        val_total = 0\n",
    "        val_m = 0\n",
    "        val_u = 0\n",
    "        val_s = 0\n",
    "\n",
    "        with torch.no_grad():\n",
    "            mv_iter = iter(val_loader_mnist)\n",
    "            uv_iter = iter(val_loader_usps)\n",
    "            sv_iter = iter(val_loader_svhn)\n",
    "            nv = min(len(val_loader_mnist),\n",
    "                     len(val_loader_usps),\n",
    "                     len(val_loader_svhn))\n",
    "\n",
    "            for _ in range(nv):\n",
    "                try:\n",
    "                    xm, sm = next(mv_iter)\n",
    "                    xu, su = next(uv_iter)\n",
    "                    xs, ss = next(sv_iter)\n",
    "                except StopIteration:\n",
    "                    break\n",
    "\n",
    "                zm = shared_encoder(xm['sparse_image'].to(device),\n",
    "                                    xm['mask'].to(device), DOMAIN_MNIST)\n",
    "                zu = shared_encoder(xu['sparse_image'].to(device),\n",
    "                                    xu['mask'].to(device), DOMAIN_USPS)\n",
    "                zs = shared_encoder(xs['sparse_image'].to(device),\n",
    "                                    xs['mask'].to(device), DOMAIN_SVHN)\n",
    "\n",
    "                lm = F.mse_loss(decoder_mnist(zm), sm['full_image'].to(device))\n",
    "                lu = F.mse_loss(decoder_usps(zu),  su['full_image'].to(device))\n",
    "                ls = F.mse_loss(decoder_svhn(zs),  ss['full_image'].to(device))\n",
    "\n",
    "                val_total += (lm + lu + ls).item()\n",
    "                val_m += lm.item()\n",
    "                val_u += lu.item()\n",
    "                val_s += ls.item()\n",
    "\n",
    "        avg_val = val_total / nv\n",
    "\n",
    "        old_lr = optimizer.param_groups[0]['lr']\n",
    "        scheduler.step(avg_val)\n",
    "        new_lr = optimizer.param_groups[0]['lr']\n",
    "        if new_lr != old_lr:\n",
    "            print(f'  → LR reduced: {old_lr:.2e} → {new_lr:.2e}')\n",
    "\n",
    "        print(f'\\nEpoch {epoch+1}/{max_epochs}:')\n",
    "        print(f'  TRAIN Total={total_loss/num_batches:.4f} '\n",
    "              f'(MNIST={loss_mnist_/num_batches:.4f}, '\n",
    "              f'USPS={loss_usps_/num_batches:.4f}, '\n",
    "              f'SVHN={loss_svhn_/num_batches:.4f})')\n",
    "        print(f'  VAL   Total={avg_val:.4f} '\n",
    "              f'(MNIST={val_m/nv:.4f}, '\n",
    "              f'USPS={val_u/nv:.4f}, '\n",
    "              f'SVHN={val_s/nv:.4f})')\n",
    "\n",
    "        if avg_val < best_val_loss:\n",
    "            best_val_loss = avg_val\n",
    "            best_epoch    = epoch + 1\n",
    "            no_improve    = 0\n",
    "            best_state = {\n",
    "                'shared_encoder': shared_encoder.state_dict(),\n",
    "                'decoder_mnist':  decoder_mnist.state_dict(),\n",
    "                'decoder_usps':   decoder_usps.state_dict(),\n",
    "                'decoder_svhn':   decoder_svhn.state_dict(),\n",
    "            }\n",
    "            print(f'  ✓ NEW BEST!')\n",
    "        else:\n",
    "            no_improve += 1\n",
    "            print(f'  No improvement for {no_improve} epoch(s)')\n",
    "            if no_improve >= patience:\n",
    "                print(f'\\nEARLY STOPPING at epoch {epoch+1}')\n",
    "                break\n",
    "\n",
    "    if best_state:\n",
    "        shared_encoder.load_state_dict(best_state['shared_encoder'])\n",
    "        decoder_mnist.load_state_dict(best_state['decoder_mnist'])\n",
    "        decoder_usps.load_state_dict(best_state['decoder_usps'])\n",
    "        decoder_svhn.load_state_dict(best_state['decoder_svhn'])\n",
    "        print(f\"✓ Restored best model from epoch {best_epoch}\\n\")\n",
    "\n",
    "    return shared_encoder, decoder_mnist, decoder_usps, decoder_svhn\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 11. STEP 2: THREE-WAY PROJECTION + ALIGNMENT (NO CLASSIFIER)\n",
    "# Shared encoder FROZEN.\n",
    "# Each domain gets its own projection network.\n",
    "# Three-way alignment enforces class-level correspondence across domains.\n",
    "# ============================================================================\n",
    "\n",
    "def train_three_way_alignment(shared_encoder,\n",
    "                               projection_mnist, projection_usps, projection_svhn,\n",
    "                               decoder_mnist, decoder_usps, decoder_svhn,\n",
    "                               train_loader_mnist, train_loader_usps, train_loader_svhn,\n",
    "                               val_loader_mnist,   val_loader_usps,   val_loader_svhn,\n",
    "                               max_epochs=100, patience=15, device='cpu'):\n",
    "    \"\"\"\n",
    "    Stage 2: Three-way projection and alignment.\n",
    "    Shared encoder is FROZEN — Z space structure is preserved.\n",
    "    Projection networks fine-tune alignment across domains.\n",
    "    NO classifier trained here.\n",
    "    \"\"\"\n",
    "    print(\"=\" * 70)\n",
    "    print(\"STEP 2: THREE-WAY PROJECTION + ALIGNMENT (NO CLASSIFIER)\")\n",
    "    print(\"Frozen shared encoder — Z space preserved\")\n",
    "    print(\"Pairs aligned: MNIST↔USPS, MNIST↔SVHN, USPS↔SVHN\")\n",
    "    print(\"=\" * 70 + \"\\n\")\n",
    "\n",
    "    # Freeze shared encoder\n",
    "    for p in shared_encoder.parameters():\n",
    "        p.requires_grad = False\n",
    "    shared_encoder.eval()\n",
    "\n",
    "    shared_encoder.to(device)\n",
    "    projection_mnist.to(device)\n",
    "    projection_usps.to(device)\n",
    "    projection_svhn.to(device)\n",
    "    decoder_mnist.to(device)\n",
    "    decoder_usps.to(device)\n",
    "    decoder_svhn.to(device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(\n",
    "        list(projection_mnist.parameters()) +\n",
    "        list(projection_usps.parameters())  +\n",
    "        list(projection_svhn.parameters())  +\n",
    "        list(decoder_mnist.parameters())    +\n",
    "        list(decoder_usps.parameters())     +\n",
    "        list(decoder_svhn.parameters()),\n",
    "        lr=1e-3, weight_decay=1e-4\n",
    "    )\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
    "        optimizer, mode='min', factor=0.5, patience=5\n",
    "    )\n",
    "\n",
    "    lambda_cont = 0.5\n",
    "    lambda_cent = 0.5\n",
    "\n",
    "    best_val = float('inf')\n",
    "    best_epoch = 0\n",
    "    no_improve = 0\n",
    "    best_state = None\n",
    "\n",
    "    for epoch in range(max_epochs):\n",
    "        projection_mnist.train()\n",
    "        projection_usps.train()\n",
    "        projection_svhn.train()\n",
    "        decoder_mnist.train()\n",
    "        decoder_usps.train()\n",
    "        decoder_svhn.train()\n",
    "\n",
    "        total = recon_t = cont_t = cent_t = 0\n",
    "\n",
    "        mnist_iter = iter(train_loader_mnist)\n",
    "        usps_iter  = iter(train_loader_usps)\n",
    "        svhn_iter  = iter(train_loader_svhn)\n",
    "\n",
    "        nb = min(len(train_loader_mnist),\n",
    "                 len(train_loader_usps),\n",
    "                 len(train_loader_svhn))\n",
    "\n",
    "        for batch_idx in range(nb):\n",
    "            try:\n",
    "                xm, sm = next(mnist_iter)\n",
    "                xu, su = next(usps_iter)\n",
    "                xs, ss = next(svhn_iter)\n",
    "            except StopIteration:\n",
    "                break\n",
    "\n",
    "            sp_m = xm['sparse_image'].to(device)\n",
    "            mk_m = xm['mask'].to(device)\n",
    "            fl_m = sm['full_image'].to(device)\n",
    "            lb_m = sm['label'].to(device)\n",
    "\n",
    "            sp_u = xu['sparse_image'].to(device)\n",
    "            mk_u = xu['mask'].to(device)\n",
    "            fl_u = su['full_image'].to(device)\n",
    "            lb_u = su['label'].to(device)\n",
    "\n",
    "            sp_s = xs['sparse_image'].to(device)\n",
    "            mk_s = xs['mask'].to(device)\n",
    "            fl_s = ss['full_image'].to(device)\n",
    "            lb_s = ss['label'].to(device)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # Encode with frozen shared encoder (domain-specific BN)\n",
    "            with torch.no_grad():\n",
    "                z_m = shared_encoder(sp_m, mk_m, DOMAIN_MNIST)\n",
    "                z_u = shared_encoder(sp_u, mk_u, DOMAIN_USPS)\n",
    "                z_s = shared_encoder(sp_s, mk_s, DOMAIN_SVHN)\n",
    "\n",
    "            # Project to aligned universal space\n",
    "            z_m_univ = projection_mnist(z_m)\n",
    "            z_u_univ = projection_usps(z_u)\n",
    "            z_s_univ = projection_svhn(z_s)\n",
    "\n",
    "            # Reconstruction\n",
    "            lr = (F.mse_loss(decoder_mnist(z_m_univ), fl_m) +\n",
    "                  F.mse_loss(decoder_usps(z_u_univ),  fl_u) +\n",
    "                  F.mse_loss(decoder_svhn(z_s_univ),  fl_s))\n",
    "\n",
    "            # Three-way alignment\n",
    "            lc, lk = three_way_alignment_loss(\n",
    "                z_m_univ, z_u_univ, z_s_univ,\n",
    "                lb_m, lb_u, lb_s\n",
    "            )\n",
    "\n",
    "            loss = lr + lambda_cont * lc + lambda_cent * lk\n",
    "            loss.backward()\n",
    "\n",
    "            for proj in [projection_mnist, projection_usps, projection_svhn]:\n",
    "                torch.nn.utils.clip_grad_norm_(proj.parameters(), 1.0)\n",
    "\n",
    "            optimizer.step()\n",
    "\n",
    "            total   += loss.item()\n",
    "            recon_t += lr.item()\n",
    "            cont_t  += lc.item()\n",
    "            cent_t  += lk.item()\n",
    "\n",
    "            if batch_idx % 100 == 0:\n",
    "                print(f'Epoch {epoch+1}, Batch {batch_idx}: '\n",
    "                      f'Loss={loss.item():.4f} '\n",
    "                      f'(Recon={lr.item():.4f}, '\n",
    "                      f'Cont={lc.item():.4f}, '\n",
    "                      f'Cent={lk.item():.4f})')\n",
    "\n",
    "        # Validation\n",
    "        for m in [projection_mnist, projection_usps, projection_svhn,\n",
    "                  decoder_mnist, decoder_usps, decoder_svhn]:\n",
    "            m.eval()\n",
    "\n",
    "        vt = vr = vc = vk = 0\n",
    "\n",
    "        with torch.no_grad():\n",
    "            mv = iter(val_loader_mnist)\n",
    "            uv = iter(val_loader_usps)\n",
    "            sv = iter(val_loader_svhn)\n",
    "            nv = min(len(val_loader_mnist),\n",
    "                     len(val_loader_usps),\n",
    "                     len(val_loader_svhn))\n",
    "\n",
    "            for _ in range(nv):\n",
    "                try:\n",
    "                    xm, sm = next(mv)\n",
    "                    xu, su = next(uv)\n",
    "                    xs, ss = next(sv)\n",
    "                except StopIteration:\n",
    "                    break\n",
    "\n",
    "                zm = projection_mnist(shared_encoder(\n",
    "                    xm['sparse_image'].to(device), xm['mask'].to(device), DOMAIN_MNIST))\n",
    "                zu = projection_usps(shared_encoder(\n",
    "                    xu['sparse_image'].to(device), xu['mask'].to(device), DOMAIN_USPS))\n",
    "                zs = projection_svhn(shared_encoder(\n",
    "                    xs['sparse_image'].to(device), xs['mask'].to(device), DOMAIN_SVHN))\n",
    "\n",
    "                lr = (F.mse_loss(decoder_mnist(zm), sm['full_image'].to(device)) +\n",
    "                      F.mse_loss(decoder_usps(zu),  su['full_image'].to(device)) +\n",
    "                      F.mse_loss(decoder_svhn(zs),  ss['full_image'].to(device)))\n",
    "\n",
    "                lc, lk = three_way_alignment_loss(\n",
    "                    zm, zu, zs,\n",
    "                    sm['label'].to(device),\n",
    "                    su['label'].to(device),\n",
    "                    ss['label'].to(device)\n",
    "                )\n",
    "\n",
    "                bl = lr + lambda_cont * lc + lambda_cent * lk\n",
    "                vt += bl.item()\n",
    "                vr += lr.item()\n",
    "                vc += lc.item()\n",
    "                vk += lk.item()\n",
    "\n",
    "        avg_val = vt / nv\n",
    "\n",
    "        old_lr = optimizer.param_groups[0]['lr']\n",
    "        scheduler.step(avg_val)\n",
    "        new_lr = optimizer.param_groups[0]['lr']\n",
    "        if new_lr != old_lr:\n",
    "            print(f'  → LR reduced: {old_lr:.2e} → {new_lr:.2e}')\n",
    "\n",
    "        print(f'\\n{\"=\" * 70}')\n",
    "        print(f'Epoch {epoch+1}/{max_epochs}:')\n",
    "        print(f'  TRAIN Total={total/nb:.4f} '\n",
    "              f'(Recon={recon_t/nb:.4f}, Cont={cont_t/nb:.4f}, Cent={cent_t/nb:.4f})')\n",
    "        print(f'  VAL   Total={avg_val:.4f} '\n",
    "              f'(Recon={vr/nv:.4f}, Cont={vc/nv:.4f}, Cent={vk/nv:.4f})')\n",
    "\n",
    "        if avg_val < best_val:\n",
    "            best_val   = avg_val\n",
    "            best_epoch = epoch + 1\n",
    "            no_improve = 0\n",
    "            best_state = {\n",
    "                'projection_mnist': projection_mnist.state_dict(),\n",
    "                'projection_usps':  projection_usps.state_dict(),\n",
    "                'projection_svhn':  projection_svhn.state_dict(),\n",
    "                'decoder_mnist':    decoder_mnist.state_dict(),\n",
    "                'decoder_usps':     decoder_usps.state_dict(),\n",
    "                'decoder_svhn':     decoder_svhn.state_dict(),\n",
    "            }\n",
    "            print(f'  ✓ NEW BEST!')\n",
    "        else:\n",
    "            no_improve += 1\n",
    "            print(f'  No improvement for {no_improve} epoch(s)')\n",
    "            if no_improve >= patience:\n",
    "                print(f'\\nEARLY STOPPING at epoch {epoch+1}')\n",
    "                break\n",
    "\n",
    "        print(f'{\"=\" * 70}\\n')\n",
    "\n",
    "    if best_state:\n",
    "        projection_mnist.load_state_dict(best_state['projection_mnist'])\n",
    "        projection_usps.load_state_dict(best_state['projection_usps'])\n",
    "        projection_svhn.load_state_dict(best_state['projection_svhn'])\n",
    "        decoder_mnist.load_state_dict(best_state['decoder_mnist'])\n",
    "        decoder_usps.load_state_dict(best_state['decoder_usps'])\n",
    "        decoder_svhn.load_state_dict(best_state['decoder_svhn'])\n",
    "        print(f\"✓ Restored best model from epoch {best_epoch}\\n\")\n",
    "\n",
    "    return projection_mnist, projection_usps, projection_svhn\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 12. STEP 3: TRAIN CLASSIFIER ON SINGLE DOMAIN ONLY\n",
    "# ============================================================================\n",
    "\n",
    "def train_classifier_single_domain(shared_encoder, projection, classifier,\n",
    "                                   train_loader, val_loader, domain_id,\n",
    "                                   domain_name, max_epochs=50, patience=10,\n",
    "                                   device='cpu'):\n",
    "    \"\"\"\n",
    "    Train classifier on ONE domain only.\n",
    "    Shared encoder and projection are FROZEN.\n",
    "    Classifier never sees other domains during training.\n",
    "    \"\"\"\n",
    "    print(\"=\" * 70)\n",
    "    print(f\"STEP 3: TRAINING CLASSIFIER ON {domain_name} ONLY\")\n",
    "    print(f\"Classifier never sees other domains\")\n",
    "    print(\"=\" * 70 + \"\\n\")\n",
    "\n",
    "    for p in shared_encoder.parameters():\n",
    "        p.requires_grad = False\n",
    "    for p in projection.parameters():\n",
    "        p.requires_grad = False\n",
    "\n",
    "    shared_encoder.eval()\n",
    "    projection.eval()\n",
    "    shared_encoder.to(device)\n",
    "    projection.to(device)\n",
    "    classifier.to(device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(\n",
    "        classifier.parameters(), lr=1e-3, weight_decay=1e-4\n",
    "    )\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
    "        optimizer, mode='max', factor=0.5, patience=3\n",
    "    )\n",
    "\n",
    "    best_acc   = 0.0\n",
    "    best_epoch = 0\n",
    "    no_improve = 0\n",
    "    best_state = None\n",
    "\n",
    "    for epoch in range(max_epochs):\n",
    "        classifier.train()\n",
    "        correct = total = 0\n",
    "\n",
    "        for x, s in train_loader:\n",
    "            sparse = x['sparse_image'].to(device)\n",
    "            mask   = x['mask'].to(device)\n",
    "            labels = s['label'].to(device)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            with torch.no_grad():\n",
    "                z = shared_encoder(sparse, mask, domain_id)\n",
    "                z = projection(z)\n",
    "\n",
    "            logits = classifier(z)\n",
    "            loss   = F.cross_entropy(logits, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            _, pred = logits.max(1)\n",
    "            total   += labels.size(0)\n",
    "            correct += pred.eq(labels).sum().item()\n",
    "\n",
    "        train_acc = 100. * correct / total\n",
    "\n",
    "        classifier.eval()\n",
    "        val_correct = val_total = 0\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for x, s in val_loader:\n",
    "                sparse = x['sparse_image'].to(device)\n",
    "                mask   = x['mask'].to(device)\n",
    "                labels = s['label'].to(device)\n",
    "\n",
    "                z = shared_encoder(sparse, mask, domain_id)\n",
    "                z = projection(z)\n",
    "                logits = classifier(z)\n",
    "\n",
    "                _, pred = logits.max(1)\n",
    "                val_total   += labels.size(0)\n",
    "                val_correct += pred.eq(labels).sum().item()\n",
    "\n",
    "        val_acc = 100. * val_correct / val_total\n",
    "\n",
    "        old_lr = optimizer.param_groups[0]['lr']\n",
    "        scheduler.step(val_acc)\n",
    "        new_lr = optimizer.param_groups[0]['lr']\n",
    "        if new_lr != old_lr:\n",
    "            print(f'  → LR reduced: {old_lr:.2e} → {new_lr:.2e}')\n",
    "\n",
    "        print(f'Epoch {epoch+1}/{max_epochs}: '\n",
    "              f'Train={train_acc:.2f}%, Val={val_acc:.2f}%')\n",
    "\n",
    "        if val_acc > best_acc:\n",
    "            best_acc   = val_acc\n",
    "            best_epoch = epoch + 1\n",
    "            no_improve = 0\n",
    "            best_state = {'classifier': classifier.state_dict()}\n",
    "            print(f'  ✓ NEW BEST!')\n",
    "        else:\n",
    "            no_improve += 1\n",
    "            if no_improve >= patience:\n",
    "                print(f'\\nEARLY STOPPING at epoch {epoch+1}')\n",
    "                break\n",
    "\n",
    "    if best_state:\n",
    "        classifier.load_state_dict(best_state['classifier'])\n",
    "        print(f\"✓ Restored best classifier from epoch {best_epoch}\\n\")\n",
    "\n",
    "    return classifier\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 13. EVALUATION\n",
    "# ============================================================================\n",
    "\n",
    "def evaluate_transfer(shared_encoder, projection, classifier,\n",
    "                      test_loader, domain_id,\n",
    "                      source_name, target_name, device='cpu'):\n",
    "    \"\"\"Evaluate transfer: source classifier → target domain\"\"\"\n",
    "    print(f\"\\n{'=' * 70}\")\n",
    "    print(f\"TRANSFER: {source_name} → {target_name}\")\n",
    "    print(f\"{'=' * 70}\")\n",
    "\n",
    "    shared_encoder.eval()\n",
    "    projection.eval()\n",
    "    classifier.eval()\n",
    "\n",
    "    shared_encoder.to(device)\n",
    "    projection.to(device)\n",
    "    classifier.to(device)\n",
    "\n",
    "    correct = total = 0\n",
    "    class_correct = [0] * 10\n",
    "    class_total   = [0] * 10\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for x, s in test_loader:\n",
    "            sparse = x['sparse_image'].to(device)\n",
    "            mask   = x['mask'].to(device)\n",
    "            labels = s['label'].to(device)\n",
    "\n",
    "            z      = shared_encoder(sparse, mask, domain_id)\n",
    "            z      = projection(z)\n",
    "            logits = classifier(z)\n",
    "\n",
    "            _, pred = logits.max(1)\n",
    "            total   += labels.size(0)\n",
    "            correct += pred.eq(labels).sum().item()\n",
    "\n",
    "            for i in range(labels.size(0)):\n",
    "                tl = labels[i].item()\n",
    "                class_correct[tl] += (pred[i] == tl).item()\n",
    "                class_total[tl]   += 1\n",
    "\n",
    "    acc = 100. * correct / total\n",
    "    print(f\"Overall Accuracy: {acc:.2f}%\")\n",
    "    print(\"Per-Class:\")\n",
    "    for i in range(10):\n",
    "        if class_total[i] > 0:\n",
    "            print(f\"  Class {i}: {100.*class_correct[i]/class_total[i]:.2f}%\"\n",
    "                  f\" ({class_correct[i]}/{class_total[i]})\")\n",
    "    return acc\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 14. MAIN EXECUTION\n",
    "# ============================================================================\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    print(\"\\n\" + \"=\" * 70)\n",
    "    print(\"THREE-DOMAIN UNIVERSAL TRANSFER LEARNING\")\n",
    "    print(\"Domains: MNIST + USPS + SVHN (digits 0-9, semantically meaningful)\")\n",
    "    print(\"Architecture: Shared encoder + Domain-Specific BatchNorm\")\n",
    "    print(\"Same Z space for all domains — universal manifold by construction\")\n",
    "    print(\"Nine transfer evaluations (all source→target combinations)\")\n",
    "    print(\"=\" * 70 + \"\\n\")\n",
    "\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    print(f\"Using device: {device}\\n\")\n",
    "\n",
    "    SPARSITY        = 0.15\n",
    "    LATENT_CHANNELS = 128\n",
    "    BATCH_SIZE      = 64\n",
    "\n",
    "    pad_and_repeat = PadAndRepeatChannels()\n",
    "\n",
    "    # ==================== TRANSFORMS ====================\n",
    "\n",
    "    # MNIST: 28×28×1 → pad to 32×32 → repeat 3ch\n",
    "    mnist_transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5,), (0.5,)),\n",
    "        transforms.Lambda(pad_and_repeat)\n",
    "    ])\n",
    "\n",
    "    # USPS: USPSDataset already returns tensor [-1,1]\n",
    "    # Only apply pad_and_repeat: 16×16×1 → 32×32×3\n",
    "    usps_transform = transforms.Lambda(pad_and_repeat)\n",
    "\n",
    "    # SVHN: 32×32×3 already correct\n",
    "    svhn_transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "    ])\n",
    "\n",
    "    # ==================== LOAD DATASETS ====================\n",
    "    print(\"Loading datasets...\")\n",
    "\n",
    "    mnist_train_full = datasets.MNIST(\n",
    "        './data', train=True, download=True, transform=mnist_transform\n",
    "    )\n",
    "    mnist_test_full = datasets.MNIST(\n",
    "        './data', train=False, transform=mnist_transform\n",
    "    )\n",
    "\n",
    "    # USPS: manual download (avoids broken torchvision URL)\n",
    "    usps_train_full = USPSDataset(\n",
    "        './data', train=True, transform=usps_transform, download=True\n",
    "    )\n",
    "    usps_test_full = USPSDataset(\n",
    "        './data', train=False, transform=usps_transform, download=True\n",
    "    )\n",
    "\n",
    "    svhn_train_full = datasets.SVHN(\n",
    "        './data', split='train', download=True, transform=svhn_transform\n",
    "    )\n",
    "    svhn_test_full = datasets.SVHN(\n",
    "        './data', split='test', download=True, transform=svhn_transform\n",
    "    )\n",
    "\n",
    "    print(\"✓ All datasets loaded\\n\")\n",
    "\n",
    "    # ==================== SPLIT ====================\n",
    "\n",
    "    def split_dataset(full_dataset, seed=42):\n",
    "        n = len(full_dataset)\n",
    "        train_size = int(0.90 * n)\n",
    "        val_size   = int(0.05 * n)\n",
    "        idx = list(range(n))\n",
    "        np.random.seed(seed)\n",
    "        np.random.shuffle(idx)\n",
    "        return (Subset(full_dataset, idx[:train_size]),\n",
    "                Subset(full_dataset, idx[train_size:train_size+val_size]),\n",
    "                Subset(full_dataset, idx[train_size+val_size:]))\n",
    "\n",
    "    mnist_train, mnist_val, mnist_test = split_dataset(mnist_train_full)\n",
    "    usps_train,  usps_val,  usps_test  = split_dataset(usps_train_full)\n",
    "    svhn_train,  svhn_val,  svhn_test  = split_dataset(svhn_train_full)\n",
    "\n",
    "    # ==================== SPARSE DATASETS ====================\n",
    "\n",
    "    mnist_tr = SparseMultiDomainDataset(mnist_train, SPARSITY, DOMAIN_MNIST, 'MNIST-Train')\n",
    "    mnist_vl = SparseMultiDomainDataset(mnist_val,   SPARSITY, DOMAIN_MNIST, 'MNIST-Val')\n",
    "    mnist_te = SparseMultiDomainDataset(mnist_test,  SPARSITY, DOMAIN_MNIST, 'MNIST-Test')\n",
    "\n",
    "    usps_tr  = SparseMultiDomainDataset(usps_train,  SPARSITY, DOMAIN_USPS,  'USPS-Train')\n",
    "    usps_vl  = SparseMultiDomainDataset(usps_val,    SPARSITY, DOMAIN_USPS,  'USPS-Val')\n",
    "    usps_te  = SparseMultiDomainDataset(usps_test,   SPARSITY, DOMAIN_USPS,  'USPS-Test')\n",
    "\n",
    "    svhn_tr  = SparseMultiDomainDataset(svhn_train,  SPARSITY, DOMAIN_SVHN,  'SVHN-Train')\n",
    "    svhn_vl  = SparseMultiDomainDataset(svhn_val,    SPARSITY, DOMAIN_SVHN,  'SVHN-Val')\n",
    "    svhn_te  = SparseMultiDomainDataset(svhn_test,   SPARSITY, DOMAIN_SVHN,  'SVHN-Test')\n",
    "\n",
    "    # ==================== SVHN WEIGHTED SAMPLER ====================\n",
    "\n",
    "    def get_svhn_weights(dataset):\n",
    "        labels = [dataset[i][1]['label'] for i in range(len(dataset))]\n",
    "        labels = np.array(labels)\n",
    "        counts = np.bincount(labels, minlength=10)\n",
    "        weights = 1.0 / (counts[labels] + 1e-6)\n",
    "        return torch.FloatTensor(weights)\n",
    "\n",
    "    print(\"Computing SVHN sample weights...\")\n",
    "    svhn_weights = get_svhn_weights(svhn_tr)\n",
    "    svhn_sampler = WeightedRandomSampler(svhn_weights, len(svhn_weights), replacement=True)\n",
    "    print(\"✓ Done\\n\")\n",
    "\n",
    "    # ==================== DATALOADERS ====================\n",
    "\n",
    "    mnist_train_loader = DataLoader(mnist_tr, BATCH_SIZE, shuffle=True,  num_workers=0)\n",
    "    mnist_val_loader   = DataLoader(mnist_vl, BATCH_SIZE, shuffle=False, num_workers=0)\n",
    "    mnist_test_loader  = DataLoader(mnist_te, BATCH_SIZE, shuffle=False, num_workers=0)\n",
    "\n",
    "    usps_train_loader  = DataLoader(usps_tr, BATCH_SIZE, shuffle=True,  num_workers=0)\n",
    "    usps_val_loader    = DataLoader(usps_vl, BATCH_SIZE, shuffle=False, num_workers=0)\n",
    "    usps_test_loader   = DataLoader(usps_te, BATCH_SIZE, shuffle=False, num_workers=0)\n",
    "\n",
    "    svhn_train_loader  = DataLoader(svhn_tr, BATCH_SIZE, sampler=svhn_sampler, num_workers=0)\n",
    "    svhn_val_loader    = DataLoader(svhn_vl, BATCH_SIZE, shuffle=False, num_workers=0)\n",
    "    svhn_test_loader   = DataLoader(svhn_te, BATCH_SIZE, shuffle=False, num_workers=0)\n",
    "\n",
    "    # ==================== BUILD MODELS ====================\n",
    "\n",
    "    shared_encoder = SharedUniversalEncoder(\n",
    "        latent_channels=LATENT_CHANNELS, num_res_blocks=3, num_domains=NUM_DOMAINS\n",
    "    )\n",
    "    decoder_mnist = DomainSpecificDecoder(LATENT_CHANNELS, num_res_blocks=2)\n",
    "    decoder_usps  = DomainSpecificDecoder(LATENT_CHANNELS, num_res_blocks=2)\n",
    "    decoder_svhn  = DomainSpecificDecoder(LATENT_CHANNELS, num_res_blocks=2)\n",
    "\n",
    "    # ==================== STEP 1 ====================\n",
    "\n",
    "    print(\"\\n\" + \"=\" * 70)\n",
    "    print(\"STEP 1: JOINT PRE-TRAINING — SHARED ENCODER + DOMAIN-SPECIFIC BN\")\n",
    "    print(\"=\" * 70 + \"\\n\")\n",
    "\n",
    "    shared_encoder, decoder_mnist, decoder_usps, decoder_svhn = \\\n",
    "        pretrain_shared_encoder(\n",
    "            shared_encoder,\n",
    "            decoder_mnist, decoder_usps, decoder_svhn,\n",
    "            mnist_train_loader, usps_train_loader, svhn_train_loader,\n",
    "            mnist_val_loader,   usps_val_loader,   svhn_val_loader,\n",
    "            max_epochs=100, patience=15, device=device\n",
    "        )\n",
    "\n",
    "    # ==================== STEP 2 ====================\n",
    "\n",
    "    print(\"\\n\" + \"=\" * 70)\n",
    "    print(\"STEP 2: THREE-WAY PROJECTION + ALIGNMENT\")\n",
    "    print(\"=\" * 70 + \"\\n\")\n",
    "\n",
    "    projection_mnist = DomainInvariantProjection(LATENT_CHANNELS, num_res_blocks=2)\n",
    "    projection_usps  = DomainInvariantProjection(LATENT_CHANNELS, num_res_blocks=2)\n",
    "    projection_svhn  = DomainInvariantProjection(LATENT_CHANNELS, num_res_blocks=2)\n",
    "\n",
    "    projection_mnist, projection_usps, projection_svhn = \\\n",
    "        train_three_way_alignment(\n",
    "            shared_encoder,\n",
    "            projection_mnist, projection_usps, projection_svhn,\n",
    "            decoder_mnist, decoder_usps, decoder_svhn,\n",
    "            mnist_train_loader, usps_train_loader, svhn_train_loader,\n",
    "            mnist_val_loader,   usps_val_loader,   svhn_val_loader,\n",
    "            max_epochs=100, patience=15, device=device\n",
    "        )\n",
    "\n",
    "    # ==================== STEP 3: THREE CLASSIFIERS ====================\n",
    "\n",
    "    classifier_mnist = SpatialClassifier(LATENT_CHANNELS, 10)\n",
    "    classifier_usps  = SpatialClassifier(LATENT_CHANNELS, 10)\n",
    "    classifier_svhn  = SpatialClassifier(LATENT_CHANNELS, 10)\n",
    "\n",
    "    print(\"\\n\" + \"=\" * 70)\n",
    "    print(\"STEP 3A: CLASSIFIER ON MNIST ONLY\")\n",
    "    print(\"=\" * 70)\n",
    "    classifier_mnist = train_classifier_single_domain(\n",
    "        shared_encoder, projection_mnist, classifier_mnist,\n",
    "        mnist_train_loader, mnist_val_loader,\n",
    "        DOMAIN_MNIST, \"MNIST\", max_epochs=50, patience=10, device=device\n",
    "    )\n",
    "\n",
    "    print(\"\\n\" + \"=\" * 70)\n",
    "    print(\"STEP 3B: CLASSIFIER ON USPS ONLY\")\n",
    "    print(\"=\" * 70)\n",
    "    classifier_usps = train_classifier_single_domain(\n",
    "        shared_encoder, projection_usps, classifier_usps,\n",
    "        usps_train_loader, usps_val_loader,\n",
    "        DOMAIN_USPS, \"USPS\", max_epochs=50, patience=10, device=device\n",
    "    )\n",
    "\n",
    "    print(\"\\n\" + \"=\" * 70)\n",
    "    print(\"STEP 3C: CLASSIFIER ON SVHN ONLY\")\n",
    "    print(\"=\" * 70)\n",
    "    classifier_svhn = train_classifier_single_domain(\n",
    "        shared_encoder, projection_svhn, classifier_svhn,\n",
    "        svhn_train_loader, svhn_val_loader,\n",
    "        DOMAIN_SVHN, \"SVHN\", max_epochs=50, patience=10, device=device\n",
    "    )\n",
    "\n",
    "    # ==================== STEP 4: NINE EVALUATIONS ====================\n",
    "\n",
    "    print(\"\\n\" + \"=\" * 70)\n",
    "    print(\"STEP 4: ALL NINE TRANSFER EVALUATIONS\")\n",
    "    print(\"=\" * 70 + \"\\n\")\n",
    "\n",
    "    # MNIST classifier → all targets\n",
    "    r_mm = evaluate_transfer(shared_encoder, projection_mnist, classifier_mnist,\n",
    "                              mnist_test_loader, DOMAIN_MNIST, \"MNIST Clf\", \"MNIST\", device)\n",
    "    r_mu = evaluate_transfer(shared_encoder, projection_usps,  classifier_mnist,\n",
    "                              usps_test_loader,  DOMAIN_USPS,  \"MNIST Clf\", \"USPS (TRANSFER)\", device)\n",
    "    r_ms = evaluate_transfer(shared_encoder, projection_svhn,  classifier_mnist,\n",
    "                              svhn_test_loader,  DOMAIN_SVHN,  \"MNIST Clf\", \"SVHN (TRANSFER)\", device)\n",
    "\n",
    "    # USPS classifier → all targets\n",
    "    r_um = evaluate_transfer(shared_encoder, projection_mnist, classifier_usps,\n",
    "                              mnist_test_loader, DOMAIN_MNIST, \"USPS Clf\", \"MNIST (TRANSFER)\", device)\n",
    "    r_uu = evaluate_transfer(shared_encoder, projection_usps,  classifier_usps,\n",
    "                              usps_test_loader,  DOMAIN_USPS,  \"USPS Clf\", \"USPS\", device)\n",
    "    r_us = evaluate_transfer(shared_encoder, projection_svhn,  classifier_usps,\n",
    "                              svhn_test_loader,  DOMAIN_SVHN,  \"USPS Clf\", \"SVHN (TRANSFER)\", device)\n",
    "\n",
    "    # SVHN classifier → all targets\n",
    "    r_sm = evaluate_transfer(shared_encoder, projection_mnist, classifier_svhn,\n",
    "                              mnist_test_loader, DOMAIN_MNIST, \"SVHN Clf\", \"MNIST (TRANSFER)\", device)\n",
    "    r_su = evaluate_transfer(shared_encoder, projection_usps,  classifier_svhn,\n",
    "                              usps_test_loader,  DOMAIN_USPS,  \"SVHN Clf\", \"USPS (TRANSFER)\", device)\n",
    "    r_ss = evaluate_transfer(shared_encoder, projection_svhn,  classifier_svhn,\n",
    "                              svhn_test_loader,  DOMAIN_SVHN,  \"SVHN Clf\", \"SVHN\", device)\n",
    "\n",
    "    # ==================== FINAL SUMMARY ====================\n",
    "\n",
    "    print(\"\\n\" + \"=\" * 70)\n",
    "    print(\"FINAL RESULTS — THREE-DOMAIN UNIVERSAL TRANSFER\")\n",
    "    print(\"All domains share semantic meaning: digits 0-9\")\n",
    "    print(\"Shared encoder + Domain-Specific BN → same Z space\")\n",
    "    print(\"=\" * 70)\n",
    "    print(f\"\"\"\n",
    "┌──────────────────┬──────────┬──────────┬──────────┐\n",
    "│ Classifier →     │  MNIST   │   USPS   │   SVHN   │\n",
    "│ Target Domain ↓  │          │          │          │\n",
    "├──────────────────┼──────────┼──────────┼──────────┤\n",
    "│ MNIST            │ {r_mm:6.2f}%  │ {r_um:6.2f}%  │ {r_sm:6.2f}%  │\n",
    "│ USPS             │ {r_mu:6.2f}%  │ {r_uu:6.2f}%  │ {r_su:6.2f}%  │\n",
    "│ SVHN             │ {r_ms:6.2f}%  │ {r_us:6.2f}%  │ {r_ss:6.2f}%  │\n",
    "└──────────────────┴──────────┴──────────┴──────────┘\n",
    "Diagonal  = same domain baseline\n",
    "Off-diag  = zero-shot cross-domain transfer\n",
    "    \"\"\")\n",
    "\n",
    "    # ==================== SAVE ====================\n",
    "\n",
    "    results = {\n",
    "        'architecture': 'shared_encoder_domain_specific_BN',\n",
    "        'sparsity': SPARSITY,\n",
    "        'latent_channels': LATENT_CHANNELS,\n",
    "        'same_domain':   {'mm': r_mm, 'uu': r_uu, 'ss': r_ss},\n",
    "        'cross_domain':  {\n",
    "            'mnist_to_usps': r_mu, 'mnist_to_svhn': r_ms,\n",
    "            'usps_to_mnist': r_um, 'usps_to_svhn':  r_us,\n",
    "            'svhn_to_mnist': r_sm, 'svhn_to_usps':  r_su,\n",
    "        }\n",
    "    }\n",
    "\n",
    "    with open('three_domain_domainBN_results.json', 'w') as f:\n",
    "        json.dump(results, f, indent=2)\n",
    "\n",
    "    torch.save({\n",
    "        'shared_encoder':   shared_encoder.state_dict(),\n",
    "        'projection_mnist': projection_mnist.state_dict(),\n",
    "        'projection_usps':  projection_usps.state_dict(),\n",
    "        'projection_svhn':  projection_svhn.state_dict(),\n",
    "    }, 'three_domain_domainBN_manifold.pth')\n",
    "\n",
    "    torch.save(classifier_mnist.state_dict(), 'three_domain_domainBN_clf_mnist.pth')\n",
    "    torch.save(classifier_usps.state_dict(),  'three_domain_domainBN_clf_usps.pth')\n",
    "    torch.save(classifier_svhn.state_dict(),  'three_domain_domainBN_clf_svhn.pth')\n",
    "\n",
    "    print(\"✓ All results and models saved!\")\n",
    "    print(\"  - three_domain_domainBN_results.json\")\n",
    "    print(\"  - three_domain_domainBN_manifold.pth\")\n",
    "    print(\"  - three_domain_domainBN_clf_mnist/usps/svhn.pth\")\n",
    "    print(\"=\" * 70)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6c589639",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✓ ripser available — β₀ and β₁ will use persistent homology\n",
      "\n",
      "======================================================================\n",
      "TOPOLOGICAL EVALUATION — THREE-DOMAIN UNIVERSAL MANIFOLD\n",
      "Z_total = [Z_mnist || Z_usps || Z_svhn] as ONE unified space\n",
      "Sparsity: ρ = 0.15, Sample: 15% per domain\n",
      "All metrics use cosine distance\n",
      "======================================================================\n",
      "\n",
      "Device: cuda\n",
      "\n",
      "Loading datasets...\n",
      "  MNIST:  10500 samples (15% of 70000)\n",
      "  USPS:   1394 samples (15% of 9298)\n",
      "  SVHN:   14893 samples (15% of 99289)\n",
      "✓ Datasets ready\n",
      "\n",
      "Loading three_domain_domainBN_manifold.pth...\n",
      "✓ Models loaded\n",
      "\n",
      "Encoding Z_mnist...\n",
      "  Z_mnist shape: (10500, 2048)\n",
      "Encoding Z_usps...\n",
      "  Z_usps shape: (1394, 2048)\n",
      "Encoding Z_svhn...\n",
      "  Z_svhn shape: (14893, 2048)\n",
      "\n",
      "Z_total shape: (26787, 2048)\n",
      "  MNIST:  10500 samples\n",
      "  USPS:   1394 samples\n",
      "  SVHN:   14893 samples\n",
      "  Total:  26787 samples\n",
      "\n",
      "======================================================================\n",
      "COMPUTING TOPOLOGICAL METRICS ON Z_total\n",
      "======================================================================\n",
      "\n",
      "[1/5] Computing β₀ and β₁ (persistent homology)...\n",
      "  Computing persistent homology on 2000 points...\n",
      "  β₀ = 1  (should be 1 — one connected manifold)\n",
      "  β₁ = 153  (number of loops in manifold topology)\n",
      "\n",
      "[2/5] Computing Trust Score (κ=5, cosine)...\n",
      "  Trust = 0.9218  (threshold ≥ 0.80) ✓ PASS\n",
      "\n",
      "[3/5] Computing Sliced Wasserstein-2 on Z_total (cosine)...\n",
      "  W₂ = 0.0007  (threshold ≤ 0.30) ✓ PASS\n",
      "\n",
      "[4/5] Computing Continuity (κ=5, cosine)...\n",
      "  Continuity = 0.9218  (threshold ≥ 0.70) ✓ PASS\n",
      "\n",
      "[5/5] Computing Cross-Domain Alignment Error (cosine)...\n",
      "  Alignment = 0.0189  (threshold ≤ 0.30) ✓ PASS\n",
      "\n",
      "======================================================================\n",
      "FINAL RESULTS — THREE-DOMAIN UNIVERSAL MANIFOLD\n",
      "Z_total = [Z_mnist || Z_usps || Z_svhn], ρ=0.15\n",
      "======================================================================\n",
      "\n",
      "┌─────────────────────────────┬──────────┬───────────┬────────┐\n",
      "│ Metric                      │  Value   │ Threshold │  Pass  │\n",
      "├─────────────────────────────┼──────────┼───────────┼────────┤\n",
      "│ β₀ (connected components)   │    1     │   = 1     │  ✓     │\n",
      "│ β₁ (loops/holes)            │   153    │    —      │  —     │\n",
      "│ Trust Score τ_t             │  0.9218  │  ≥ 0.80   │  ✓     │\n",
      "│ Sliced W₂ τ_w               │  0.0007  │  ≤ 0.30   │  ✓     │\n",
      "│ Continuity τ_c              │  0.9218  │  ≥ 0.70   │  ✓     │\n",
      "│ Alignment Error τ_a         │  0.0189  │  ≤ 0.30   │  ✓     │\n",
      "└─────────────────────────────┴──────────┴───────────┴────────┘\n",
      "    \n",
      "======================================================================\n",
      "LaTeX Table:\n",
      "======================================================================\n",
      "\\begin{table}[!htpb]\n",
      "\\centering\n",
      "\\caption{Topological Unification Verification — Three-Domain Universal Manifold\n",
      "         ($\\kappa=5$, Cosine distance, $\\rho=0.15$)}\n",
      "\\label{tab:three_domain_topology}\n",
      "\\begin{tabular}{lcc}\n",
      "\\hline\n",
      "\\textbf{Metric} & \\textbf{Value} & \\textbf{Threshold} \\\\\n",
      "\\hline\n",
      "$\\beta_0$ (connected components) & 1 & $= 1$ \\\\\n",
      "$\\beta_1$ (loops/holes) & 153 & — \\\\\n",
      "Trust $\\tau_t$ & 0.9218 & $\\geq 0.80$ \\\\\n",
      "Sliced $W_2$ $\\tau_w$ & 0.0007 & $\\leq 0.30$ \\\\\n",
      "Continuity $\\tau_c$ & 0.9218 & $\\geq 0.70$ \\\\\n",
      "Alignment $\\tau_a$ & 0.0189 & $\\leq 0.30$ \\\\\n",
      "\\hline\n",
      "\\end{tabular}\n",
      "\\end{table}\n",
      "======================================================================\n",
      "\n",
      "✓ Results saved to: three_domain_topology_results.json\n",
      "======================================================================\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Topological Evaluation of Three-Domain Universal Manifold\n",
    "==========================================================\n",
    "\n",
    "Loads: three_domain_domainBN_manifold.pth\n",
    "       three_domain_domainBN_clf_mnist/usps/svhn.pth\n",
    "\n",
    "Evaluates the ENTIRE unified Z space as ONE manifold:\n",
    "  Z_total = [Z_mnist || Z_usps || Z_svhn]\n",
    "\n",
    "Metrics (all using cosine distance):\n",
    "  - β₀  : connected components (persistent homology)\n",
    "  - β₁  : loops/holes         (persistent homology, subsampled)\n",
    "  - W₂  : sliced Wasserstein-2 on Z_total\n",
    "  - Trust Score\n",
    "  - Continuity\n",
    "  - Alignment Error (cross-domain centroid consistency)\n",
    "\n",
    "Sparsity: ρ = 0.15 only\n",
    "Sample:   15% of each domain's full dataset\n",
    "\"\"\"\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader, Dataset, Subset\n",
    "import numpy as np\n",
    "import json\n",
    "import os\n",
    "import urllib.request\n",
    "import bz2\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from sklearn.metrics.pairwise import cosine_distances\n",
    "from scipy.stats import wasserstein_distance\n",
    "from scipy.sparse import csr_matrix\n",
    "from scipy.sparse.csgraph import connected_components\n",
    "\n",
    "# Try importing ripser for persistent homology\n",
    "try:\n",
    "    from ripser import ripser\n",
    "    RIPSER_AVAILABLE = True\n",
    "    print(\"✓ ripser available — β₀ and β₁ will use persistent homology\")\n",
    "except ImportError:\n",
    "    RIPSER_AVAILABLE = False\n",
    "    print(\"⚠ ripser not available — β₀ via graph components, β₁ skipped\")\n",
    "    print(\"  Install with: pip install ripser\")\n",
    "\n",
    "DOMAIN_MNIST = 0\n",
    "DOMAIN_USPS  = 1\n",
    "DOMAIN_SVHN  = 2\n",
    "NUM_DOMAINS  = 3\n",
    "SPARSITY     = 0.15\n",
    "SAMPLE_FRAC  = 0.15  # 15% of each domain\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 1. USPS MANUAL DOWNLOAD (same as training script)\n",
    "# ============================================================================\n",
    "\n",
    "class USPSDataset(Dataset):\n",
    "    URLS = {\n",
    "        'train': 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2',\n",
    "        'test':  'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2'\n",
    "    }\n",
    "\n",
    "    def __init__(self, root, train=True, transform=None, download=True):\n",
    "        self.root = os.path.join(root, 'usps_manual')\n",
    "        self.train = train\n",
    "        self.transform = transform\n",
    "        os.makedirs(self.root, exist_ok=True)\n",
    "\n",
    "        split = 'train' if train else 'test'\n",
    "        filepath = os.path.join(self.root, f'usps_{split}.bz2')\n",
    "\n",
    "        if download and not os.path.exists(filepath):\n",
    "            print(f'Downloading USPS {split}...')\n",
    "            urllib.request.urlretrieve(self.URLS[split], filepath)\n",
    "            print(f'✓ Downloaded USPS {split}')\n",
    "\n",
    "        self.data, self.labels = self._load_libsvm(filepath)\n",
    "\n",
    "    def _load_libsvm(self, filepath):\n",
    "        data, labels = [], []\n",
    "        with bz2.open(filepath, 'rt') as f:\n",
    "            for line in f:\n",
    "                parts = line.strip().split()\n",
    "                if not parts:\n",
    "                    continue\n",
    "                labels.append(int(parts[0]) % 10)\n",
    "                features = np.zeros(256, dtype=np.float32)\n",
    "                for feat in parts[1:]:\n",
    "                    if ':' in feat:\n",
    "                        idx, val = feat.split(':')\n",
    "                        features[int(idx) - 1] = float(val)\n",
    "                img = features.reshape(16, 16)\n",
    "                img = (img + 1.0) / 2.0\n",
    "                data.append(img)\n",
    "        return np.array(data, dtype=np.float32), np.array(labels, dtype=np.int64)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img = torch.FloatTensor(self.data[idx]).unsqueeze(0)\n",
    "        img = img * 2.0 - 1.0\n",
    "        label = int(self.labels[idx])\n",
    "        if self.transform:\n",
    "            img = self.transform(img)\n",
    "        return img, label\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 2. PREPROCESSING (same as training script)\n",
    "# ============================================================================\n",
    "\n",
    "class PadAndRepeatChannels:\n",
    "    def __call__(self, x):\n",
    "        c, h, w = x.shape\n",
    "        pad_h  = (32 - h) // 2\n",
    "        pad_w  = (32 - w) // 2\n",
    "        pad_h2 = 32 - h - pad_h\n",
    "        pad_w2 = 32 - w - pad_w\n",
    "        x = F.pad(x, [pad_w, pad_w2, pad_h, pad_h2], value=0)\n",
    "        x = x.repeat(3, 1, 1)\n",
    "        return x\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 3. SPARSE DATASET\n",
    "# ============================================================================\n",
    "\n",
    "class SparseMultiDomainDataset(Dataset):\n",
    "    def __init__(self, base_dataset, sparsity_level=0.15,\n",
    "                 domain_id=0, domain_name='Unknown'):\n",
    "        self.base_dataset   = base_dataset\n",
    "        self.sparsity_level = sparsity_level\n",
    "        self.domain_id      = domain_id\n",
    "        self.domain_name    = domain_name\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.base_dataset)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        full_image, label = self.base_dataset[idx]\n",
    "        spatial_mask = (torch.rand(1, 32, 32) < self.sparsity_level).float()\n",
    "        mask         = spatial_mask.repeat(3, 1, 1)\n",
    "        sparse_image = full_image * mask\n",
    "        label        = int(label) % 10\n",
    "\n",
    "        return {\n",
    "            'sparse_image': sparse_image,\n",
    "            'mask':         mask,\n",
    "            'full_image':   full_image,\n",
    "            'label':        label,\n",
    "            'domain':       self.domain_id\n",
    "        }\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 4. MODEL ARCHITECTURES (must match training script exactly)\n",
    "# ============================================================================\n",
    "\n",
    "class ResidualBlockDomainBN(nn.Module):\n",
    "    def __init__(self, channels, num_domains=NUM_DOMAINS):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)\n",
    "        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)\n",
    "        self.bn1   = nn.ModuleList([nn.BatchNorm2d(channels) for _ in range(num_domains)])\n",
    "        self.bn2   = nn.ModuleList([nn.BatchNorm2d(channels) for _ in range(num_domains)])\n",
    "\n",
    "    def forward(self, x, domain_id):\n",
    "        residual = x\n",
    "        out = F.relu(self.bn1[domain_id](x))\n",
    "        out = self.conv1(out)\n",
    "        out = F.relu(self.bn2[domain_id](out))\n",
    "        out = self.conv2(out)\n",
    "        return out + residual\n",
    "\n",
    "\n",
    "class ResidualBlock(nn.Module):\n",
    "    def __init__(self, channels):\n",
    "        super().__init__()\n",
    "        self.bn1   = nn.BatchNorm2d(channels)\n",
    "        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)\n",
    "        self.bn2   = nn.BatchNorm2d(channels)\n",
    "        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        residual = x\n",
    "        out = F.relu(self.bn1(x))\n",
    "        out = self.conv1(out)\n",
    "        out = F.relu(self.bn2(out))\n",
    "        out = self.conv2(out)\n",
    "        return out + residual\n",
    "\n",
    "\n",
    "class SharedUniversalEncoder(nn.Module):\n",
    "    def __init__(self, latent_channels=128, num_res_blocks=3, num_domains=NUM_DOMAINS):\n",
    "        super().__init__()\n",
    "        self.latent_channels = latent_channels\n",
    "        self.num_domains     = num_domains\n",
    "\n",
    "        self.input_conv = nn.Conv2d(6, 32, 3, padding=1, bias=False)\n",
    "        self.input_bn   = nn.ModuleList([nn.BatchNorm2d(32) for _ in range(num_domains)])\n",
    "\n",
    "        self.res1  = nn.ModuleList([ResidualBlockDomainBN(32,  num_domains) for _ in range(num_res_blocks)])\n",
    "        self.down1 = nn.Conv2d(32, 64, 3, stride=2, padding=1, bias=False)\n",
    "        self.down1_bn = nn.ModuleList([nn.BatchNorm2d(64) for _ in range(num_domains)])\n",
    "\n",
    "        self.res2  = nn.ModuleList([ResidualBlockDomainBN(64,  num_domains) for _ in range(num_res_blocks)])\n",
    "        self.down2 = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False)\n",
    "        self.down2_bn = nn.ModuleList([nn.BatchNorm2d(128) for _ in range(num_domains)])\n",
    "\n",
    "        self.res3  = nn.ModuleList([ResidualBlockDomainBN(128, num_domains) for _ in range(num_res_blocks)])\n",
    "        self.down3 = nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False)\n",
    "        self.down3_bn = nn.ModuleList([nn.BatchNorm2d(256) for _ in range(num_domains)])\n",
    "\n",
    "        self.res4  = nn.ModuleList([ResidualBlockDomainBN(256, num_domains) for _ in range(num_res_blocks)])\n",
    "\n",
    "        self.bottleneck    = nn.Conv2d(256, latent_channels, 1, bias=False)\n",
    "        self.bottleneck_bn = nn.ModuleList([nn.BatchNorm2d(latent_channels) for _ in range(num_domains)])\n",
    "\n",
    "        self.res_bottleneck = nn.ModuleList([\n",
    "            ResidualBlockDomainBN(latent_channels, num_domains)\n",
    "            for _ in range(num_res_blocks)\n",
    "        ])\n",
    "\n",
    "    def forward(self, sparse_image, mask, domain_id):\n",
    "        x = torch.cat([sparse_image, mask], dim=1)\n",
    "        x = F.relu(self.input_bn[domain_id](self.input_conv(x)))\n",
    "        for block in self.res1:\n",
    "            x = block(x, domain_id)\n",
    "        x = F.relu(self.down1_bn[domain_id](self.down1(x)))\n",
    "        for block in self.res2:\n",
    "            x = block(x, domain_id)\n",
    "        x = F.relu(self.down2_bn[domain_id](self.down2(x)))\n",
    "        for block in self.res3:\n",
    "            x = block(x, domain_id)\n",
    "        x = F.relu(self.down3_bn[domain_id](self.down3(x)))\n",
    "        for block in self.res4:\n",
    "            x = block(x, domain_id)\n",
    "        x = F.relu(self.bottleneck_bn[domain_id](self.bottleneck(x)))\n",
    "        for block in self.res_bottleneck:\n",
    "            x = block(x, domain_id)\n",
    "        return x\n",
    "\n",
    "\n",
    "class DomainInvariantProjection(nn.Module):\n",
    "    def __init__(self, latent_channels=128, num_res_blocks=2):\n",
    "        super().__init__()\n",
    "        self.projection = nn.Sequential(\n",
    "            *[ResidualBlock(latent_channels) for _ in range(num_res_blocks)],\n",
    "            nn.Conv2d(latent_channels, latent_channels, 1),\n",
    "            nn.BatchNorm2d(latent_channels)\n",
    "        )\n",
    "\n",
    "    def forward(self, z):\n",
    "        return self.projection(z)\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 5. ENCODE ONE DOMAIN → FLAT LATENT CODES\n",
    "# ============================================================================\n",
    "\n",
    "def encode_domain(shared_encoder, projection, dataset,\n",
    "                  domain_id, device='cpu', batch_size=64):\n",
    "    \"\"\"\n",
    "    Encode all samples in dataset through shared encoder + projection.\n",
    "    Returns flat latent codes [N, D] using cosine-friendly L2 normalization.\n",
    "    \"\"\"\n",
    "    shared_encoder.eval()\n",
    "    projection.eval()\n",
    "\n",
    "    loader = DataLoader(dataset, batch_size=batch_size,\n",
    "                        shuffle=False, num_workers=0)\n",
    "\n",
    "    all_z      = []\n",
    "    all_labels = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch in loader:\n",
    "            sparse = batch['sparse_image'].to(device)\n",
    "            mask   = batch['mask'].to(device)\n",
    "            labels = batch['label']\n",
    "\n",
    "            z = shared_encoder(sparse, mask, domain_id)\n",
    "            z = projection(z)\n",
    "\n",
    "            # Flatten spatial dims: [B, C, H, W] → [B, C*H*W]\n",
    "            z_flat = z.view(z.size(0), -1)\n",
    "\n",
    "            # L2 normalize for cosine distance consistency\n",
    "            z_flat = F.normalize(z_flat, p=2, dim=1)\n",
    "\n",
    "            all_z.append(z_flat.cpu().numpy())\n",
    "            all_labels.extend([int(l) % 10 for l in labels])\n",
    "\n",
    "    return np.vstack(all_z), np.array(all_labels)\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 6. TOPOLOGICAL METRICS\n",
    "# ============================================================================\n",
    "\n",
    "def compute_betti_numbers(latent_codes, subsample=2000, max_dim=1):\n",
    "    \"\"\"\n",
    "    Compute β₀ and β₁ using persistent homology (ripser).\n",
    "\n",
    "    β₀ = number of connected components (should be 1)\n",
    "    β₁ = number of loops/holes (reveals topological shape)\n",
    "\n",
    "    Uses cosine distance.\n",
    "    Subsamples to `subsample` points for efficiency.\n",
    "    \"\"\"\n",
    "    if not RIPSER_AVAILABLE:\n",
    "        # Fallback: β₀ via graph components, β₁ = N/A\n",
    "        beta_0 = compute_betti_0_graph(latent_codes)\n",
    "        return beta_0, None\n",
    "\n",
    "    N = len(latent_codes)\n",
    "    if N > subsample:\n",
    "        idx = np.random.choice(N, subsample, replace=False)\n",
    "        codes = latent_codes[idx]\n",
    "    else:\n",
    "        codes = latent_codes\n",
    "\n",
    "    print(f\"  Computing persistent homology on {len(codes)} points...\")\n",
    "\n",
    "    # Cosine distance matrix\n",
    "    dist_matrix = cosine_distances(codes).astype(np.float32)\n",
    "\n",
    "    # Run ripser up to dim 1 (β₀ and β₁)\n",
    "    result = ripser(dist_matrix, maxdim=max_dim, distance_matrix=True)\n",
    "    diagrams = result['dgms']\n",
    "\n",
    "    # β₀: count components that persist (not born and die at 0)\n",
    "    # Components that never die have death = inf\n",
    "    h0 = diagrams[0]\n",
    "    beta_0 = int(np.sum(h0[:, 1] == np.inf))  # immortal components\n",
    "\n",
    "    # β₁: count loops with significant persistence\n",
    "    if len(diagrams) > 1:\n",
    "        h1 = diagrams[1]\n",
    "        # Count loops with persistence > threshold (filter noise)\n",
    "        if len(h1) > 0:\n",
    "            persistence = h1[:, 1] - h1[:, 0]\n",
    "            threshold = np.percentile(persistence, 75)  # top 25% persist\n",
    "            beta_1 = int(np.sum(persistence > threshold))\n",
    "        else:\n",
    "            beta_1 = 0\n",
    "    else:\n",
    "        beta_1 = 0\n",
    "\n",
    "    return beta_0, beta_1\n",
    "\n",
    "\n",
    "def compute_betti_0_graph(latent_codes, kappa=10):\n",
    "    \"\"\"\n",
    "    Fallback β₀ computation via kNN graph connected components.\n",
    "    Uses cosine distance.\n",
    "    \"\"\"\n",
    "    N = len(latent_codes)\n",
    "    nn_model = NearestNeighbors(\n",
    "        n_neighbors=kappa+1, metric='cosine',\n",
    "        algorithm='brute', n_jobs=-1\n",
    "    )\n",
    "    nn_model.fit(latent_codes)\n",
    "    _, neighbors = nn_model.kneighbors(latent_codes)\n",
    "\n",
    "    rows, cols = [], []\n",
    "    for i in range(N):\n",
    "        for j in neighbors[i, 1:]:\n",
    "            rows.extend([i, j])\n",
    "            cols.extend([j, i])\n",
    "\n",
    "    data = np.ones(len(rows))\n",
    "    adj  = csr_matrix((data, (rows, cols)), shape=(N, N))\n",
    "    n_components, _ = connected_components(adj, directed=False)\n",
    "    return int(n_components)\n",
    "\n",
    "\n",
    "def compute_trust_score(latent_codes, labels, kappa=5):\n",
    "    \"\"\"\n",
    "    Trust Score using cosine distance.\n",
    "    Measures: do k-nearest neighbors share the same class label?\n",
    "    High trust → semantic neighborhoods preserved in Z space.\n",
    "    \"\"\"\n",
    "    N = len(latent_codes)\n",
    "    nn_model = NearestNeighbors(\n",
    "        n_neighbors=kappa+1, metric='cosine',\n",
    "        algorithm='brute', n_jobs=-1\n",
    "    )\n",
    "    nn_model.fit(latent_codes)\n",
    "    _, neighbors = nn_model.kneighbors(latent_codes)\n",
    "    neighbors = neighbors[:, 1:]  # exclude self\n",
    "\n",
    "    trust_scores = []\n",
    "    for i in range(N):\n",
    "        nb_labels = labels[neighbors[i]]\n",
    "        trust_scores.append(np.mean(nb_labels == labels[i]))\n",
    "\n",
    "    return float(np.mean(trust_scores))\n",
    "\n",
    "\n",
    "def compute_sliced_wasserstein_cosine(X, n_projections=200):\n",
    "    \"\"\"\n",
    "    Sliced Wasserstein-2 distance on the ENTIRE Z_total distribution.\n",
    "    Measures how compact and well-distributed the unified manifold is.\n",
    "\n",
    "    We project onto random unit vectors and compute the\n",
    "    self-Wasserstein distance (how uniform the projected distribution is).\n",
    "    Lower = more uniform = better manifold coverage.\n",
    "\n",
    "    Uses cosine-normalized embeddings (already L2 normalized).\n",
    "    \"\"\"\n",
    "    D = X.shape[1]\n",
    "    distances = []\n",
    "\n",
    "    # Split into two halves — measure how consistent the distribution is\n",
    "    half = len(X) // 2\n",
    "    X1 = X[:half]\n",
    "    X2 = X[half:2*half]\n",
    "\n",
    "    for _ in range(n_projections):\n",
    "        theta  = np.random.randn(D)\n",
    "        theta  = theta / (np.linalg.norm(theta) + 1e-8)\n",
    "        p1     = X1 @ theta\n",
    "        p2     = X2 @ theta\n",
    "        w1d    = wasserstein_distance(p1, p2)\n",
    "        distances.append(w1d ** 2)\n",
    "\n",
    "    return float(np.sqrt(np.mean(distances)))\n",
    "\n",
    "\n",
    "def compute_continuity(latent_codes, labels, kappa=5):\n",
    "    \"\"\"\n",
    "    Continuity using cosine distance.\n",
    "    Measures: are cosine neighbors in Z also neighbors in label space?\n",
    "    High continuity → no spurious proximities introduced by encoding.\n",
    "    \"\"\"\n",
    "    N = len(latent_codes)\n",
    "    nn_model = NearestNeighbors(\n",
    "        n_neighbors=kappa+1, metric='cosine',\n",
    "        algorithm='brute', n_jobs=-1\n",
    "    )\n",
    "    nn_model.fit(latent_codes)\n",
    "    _, neighbors = nn_model.kneighbors(latent_codes)\n",
    "    neighbors = neighbors[:, 1:]\n",
    "\n",
    "    # For each point: what fraction of its Z-neighbors share its label?\n",
    "    cont_scores = []\n",
    "    for i in range(N):\n",
    "        nb_labels = labels[neighbors[i]]\n",
    "        cont_scores.append(np.mean(nb_labels == labels[i]))\n",
    "\n",
    "    return float(np.mean(cont_scores))\n",
    "\n",
    "\n",
    "def compute_alignment_error(z_mnist, z_usps, z_svhn,\n",
    "                             labels_mnist, labels_usps, labels_svhn):\n",
    "    \"\"\"\n",
    "    Cross-domain alignment error using cosine distance.\n",
    "\n",
    "    For each class c, compute centroid in each domain:\n",
    "        μ_c^mnist, μ_c^usps, μ_c^svhn\n",
    "\n",
    "    Alignment error = mean cosine distance between centroids across all\n",
    "    domain pairs and all classes.\n",
    "\n",
    "    Low alignment → all three domains' class centroids are geometrically\n",
    "    coincident in Z space → verified homeomorphic structure.\n",
    "    \"\"\"\n",
    "    num_classes = 10\n",
    "    errors = []\n",
    "\n",
    "    for c in range(num_classes):\n",
    "        mask_m = (labels_mnist == c)\n",
    "        mask_u = (labels_usps  == c)\n",
    "        mask_s = (labels_svhn  == c)\n",
    "\n",
    "        if mask_m.sum() == 0 or mask_u.sum() == 0 or mask_s.sum() == 0:\n",
    "            continue\n",
    "\n",
    "        cent_m = z_mnist[mask_m].mean(axis=0, keepdims=True)\n",
    "        cent_u = z_usps[mask_u].mean(axis=0, keepdims=True)\n",
    "        cent_s = z_svhn[mask_s].mean(axis=0, keepdims=True)\n",
    "\n",
    "        # Cosine distances between centroids\n",
    "        d_mu = float(cosine_distances(cent_m, cent_u)[0, 0])\n",
    "        d_ms = float(cosine_distances(cent_m, cent_s)[0, 0])\n",
    "        d_us = float(cosine_distances(cent_u, cent_s)[0, 0])\n",
    "\n",
    "        errors.append((d_mu + d_ms + d_us) / 3.0)\n",
    "\n",
    "    return float(np.mean(errors))\n",
    "\n",
    "\n",
    "# ============================================================================\n",
    "# 7. MAIN EVALUATION\n",
    "# ============================================================================\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    print(\"\\n\" + \"=\" * 70)\n",
    "    print(\"TOPOLOGICAL EVALUATION — THREE-DOMAIN UNIVERSAL MANIFOLD\")\n",
    "    print(\"Z_total = [Z_mnist || Z_usps || Z_svhn] as ONE unified space\")\n",
    "    print(f\"Sparsity: ρ = {SPARSITY}, Sample: {SAMPLE_FRAC*100:.0f}% per domain\")\n",
    "    print(\"All metrics use cosine distance\")\n",
    "    print(\"=\" * 70 + \"\\n\")\n",
    "\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    print(f\"Device: {device}\\n\")\n",
    "\n",
    "    LATENT_CHANNELS = 128\n",
    "\n",
    "    # ==================== LOAD DATASETS ====================\n",
    "\n",
    "    pad_and_repeat = PadAndRepeatChannels()\n",
    "\n",
    "    mnist_transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5,), (0.5,)),\n",
    "        transforms.Lambda(pad_and_repeat)\n",
    "    ])\n",
    "    usps_transform = transforms.Lambda(pad_and_repeat)\n",
    "    svhn_transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "    ])\n",
    "\n",
    "    print(\"Loading datasets...\")\n",
    "\n",
    "    mnist_full = torch.utils.data.ConcatDataset([\n",
    "        datasets.MNIST('./data', train=True,  download=True, transform=mnist_transform),\n",
    "        datasets.MNIST('./data', train=False, download=True, transform=mnist_transform)\n",
    "    ])\n",
    "    usps_full = torch.utils.data.ConcatDataset([\n",
    "        USPSDataset('./data', train=True,  transform=usps_transform, download=True),\n",
    "        USPSDataset('./data', train=False, transform=usps_transform, download=True)\n",
    "    ])\n",
    "    svhn_full = torch.utils.data.ConcatDataset([\n",
    "        datasets.SVHN('./data', split='train', download=True, transform=svhn_transform),\n",
    "        datasets.SVHN('./data', split='test',  download=True, transform=svhn_transform)\n",
    "    ])\n",
    "\n",
    "    # 15% random subsample per domain\n",
    "    def subsample(dataset, frac=SAMPLE_FRAC, seed=42):\n",
    "        n = len(dataset)\n",
    "        k = max(1, int(n * frac))\n",
    "        np.random.seed(seed)\n",
    "        idx = np.random.choice(n, k, replace=False)\n",
    "        return Subset(dataset, idx)\n",
    "\n",
    "    mnist_sub = subsample(mnist_full)\n",
    "    usps_sub  = subsample(usps_full)\n",
    "    svhn_sub  = subsample(svhn_full)\n",
    "\n",
    "    print(f\"  MNIST:  {len(mnist_sub)} samples ({SAMPLE_FRAC*100:.0f}% of {len(mnist_full)})\")\n",
    "    print(f\"  USPS:   {len(usps_sub)} samples ({SAMPLE_FRAC*100:.0f}% of {len(usps_full)})\")\n",
    "    print(f\"  SVHN:   {len(svhn_sub)} samples ({SAMPLE_FRAC*100:.0f}% of {len(svhn_full)})\")\n",
    "\n",
    "    # Wrap in sparse dataset\n",
    "    mnist_sparse = SparseMultiDomainDataset(mnist_sub, SPARSITY, DOMAIN_MNIST, 'MNIST')\n",
    "    usps_sparse  = SparseMultiDomainDataset(usps_sub,  SPARSITY, DOMAIN_USPS,  'USPS')\n",
    "    svhn_sparse  = SparseMultiDomainDataset(svhn_sub,  SPARSITY, DOMAIN_SVHN,  'SVHN')\n",
    "\n",
    "    print(\"✓ Datasets ready\\n\")\n",
    "\n",
    "    # ==================== LOAD MODELS ====================\n",
    "\n",
    "    print(\"Loading three_domain_domainBN_manifold.pth...\")\n",
    "\n",
    "    shared_encoder   = SharedUniversalEncoder(LATENT_CHANNELS, num_res_blocks=3)\n",
    "    projection_mnist = DomainInvariantProjection(LATENT_CHANNELS, num_res_blocks=2)\n",
    "    projection_usps  = DomainInvariantProjection(LATENT_CHANNELS, num_res_blocks=2)\n",
    "    projection_svhn  = DomainInvariantProjection(LATENT_CHANNELS, num_res_blocks=2)\n",
    "\n",
    "    ckpt = torch.load('three_domain_domainBN_manifold.pth', map_location=device)\n",
    "    shared_encoder.load_state_dict(ckpt['shared_encoder'])\n",
    "    projection_mnist.load_state_dict(ckpt['projection_mnist'])\n",
    "    projection_usps.load_state_dict(ckpt['projection_usps'])\n",
    "    projection_svhn.load_state_dict(ckpt['projection_svhn'])\n",
    "\n",
    "    shared_encoder.to(device)\n",
    "    projection_mnist.to(device)\n",
    "    projection_usps.to(device)\n",
    "    projection_svhn.to(device)\n",
    "\n",
    "    print(\"✓ Models loaded\\n\")\n",
    "\n",
    "    # ==================== ENCODE ALL THREE DOMAINS ====================\n",
    "\n",
    "    print(\"Encoding Z_mnist...\")\n",
    "    z_mnist, labels_mnist = encode_domain(\n",
    "        shared_encoder, projection_mnist,\n",
    "        mnist_sparse, DOMAIN_MNIST, device\n",
    "    )\n",
    "    print(f\"  Z_mnist shape: {z_mnist.shape}\")\n",
    "\n",
    "    print(\"Encoding Z_usps...\")\n",
    "    z_usps, labels_usps = encode_domain(\n",
    "        shared_encoder, projection_usps,\n",
    "        usps_sparse, DOMAIN_USPS, device\n",
    "    )\n",
    "    print(f\"  Z_usps shape: {z_usps.shape}\")\n",
    "\n",
    "    print(\"Encoding Z_svhn...\")\n",
    "    z_svhn, labels_svhn = encode_domain(\n",
    "        shared_encoder, projection_svhn,\n",
    "        svhn_sparse, DOMAIN_SVHN, device\n",
    "    )\n",
    "    print(f\"  Z_svhn shape: {z_svhn.shape}\")\n",
    "\n",
    "    # ==================== COMBINE INTO Z_TOTAL ====================\n",
    "\n",
    "    Z_total      = np.vstack([z_mnist, z_usps, z_svhn])\n",
    "    labels_total = np.concatenate([labels_mnist, labels_usps, labels_svhn])\n",
    "    domains_total = np.array(\n",
    "        [DOMAIN_MNIST] * len(z_mnist) +\n",
    "        [DOMAIN_USPS]  * len(z_usps)  +\n",
    "        [DOMAIN_SVHN]  * len(z_svhn)\n",
    "    )\n",
    "\n",
    "    print(f\"\\nZ_total shape: {Z_total.shape}\")\n",
    "    print(f\"  MNIST:  {len(z_mnist)} samples\")\n",
    "    print(f\"  USPS:   {len(z_usps)} samples\")\n",
    "    print(f\"  SVHN:   {len(z_svhn)} samples\")\n",
    "    print(f\"  Total:  {len(Z_total)} samples\\n\")\n",
    "\n",
    "    # ==================== COMPUTE METRICS ====================\n",
    "\n",
    "    results = {}\n",
    "    kappa = 5\n",
    "\n",
    "    print(\"=\" * 70)\n",
    "    print(\"COMPUTING TOPOLOGICAL METRICS ON Z_total\")\n",
    "    print(\"=\" * 70)\n",
    "\n",
    "    # --- β₀ and β₁ ---\n",
    "    print(\"\\n[1/5] Computing β₀ and β₁ (persistent homology)...\")\n",
    "    beta_0, beta_1 = compute_betti_numbers(Z_total, subsample=2000, max_dim=1)\n",
    "    results['beta_0'] = beta_0\n",
    "    results['beta_1'] = beta_1 if beta_1 is not None else 'N/A'\n",
    "    print(f\"  β₀ = {beta_0}  (should be 1 — one connected manifold)\")\n",
    "    if beta_1 is not None:\n",
    "        print(f\"  β₁ = {beta_1}  (number of loops in manifold topology)\")\n",
    "    else:\n",
    "        print(f\"  β₁ = N/A (install ripser for β₁)\")\n",
    "\n",
    "    # --- Trust Score ---\n",
    "    print(f\"\\n[2/5] Computing Trust Score (κ={kappa}, cosine)...\")\n",
    "    trust = compute_trust_score(Z_total, labels_total, kappa=kappa)\n",
    "    results['trust_score'] = trust\n",
    "    status = \"✓ PASS\" if trust >= 0.80 else \"✗ FAIL\"\n",
    "    print(f\"  Trust = {trust:.4f}  (threshold ≥ 0.80) {status}\")\n",
    "\n",
    "    # --- Sliced Wasserstein ---\n",
    "    print(\"\\n[3/5] Computing Sliced Wasserstein-2 on Z_total (cosine)...\")\n",
    "    # Subsample for efficiency\n",
    "    max_w2 = min(10000, len(Z_total))\n",
    "    idx_w2 = np.random.choice(len(Z_total), max_w2, replace=False)\n",
    "    w2 = compute_sliced_wasserstein_cosine(Z_total[idx_w2], n_projections=200)\n",
    "    results['sliced_w2'] = w2\n",
    "    status = \"✓ PASS\" if w2 <= 0.30 else \"✗ FAIL\"\n",
    "    print(f\"  W₂ = {w2:.4f}  (threshold ≤ 0.30) {status}\")\n",
    "\n",
    "    # --- Continuity ---\n",
    "    print(f\"\\n[4/5] Computing Continuity (κ={kappa}, cosine)...\")\n",
    "    cont = compute_continuity(Z_total, labels_total, kappa=kappa)\n",
    "    results['continuity'] = cont\n",
    "    status = \"✓ PASS\" if cont >= 0.70 else \"✗ FAIL\"\n",
    "    print(f\"  Continuity = {cont:.4f}  (threshold ≥ 0.70) {status}\")\n",
    "\n",
    "    # --- Alignment Error ---\n",
    "    print(\"\\n[5/5] Computing Cross-Domain Alignment Error (cosine)...\")\n",
    "    align = compute_alignment_error(\n",
    "        z_mnist, z_usps, z_svhn,\n",
    "        labels_mnist, labels_usps, labels_svhn\n",
    "    )\n",
    "    results['alignment_error'] = align\n",
    "    status = \"✓ PASS\" if align <= 0.30 else \"✗ FAIL\"\n",
    "    print(f\"  Alignment = {align:.4f}  (threshold ≤ 0.30) {status}\")\n",
    "\n",
    "    # ==================== SUMMARY ====================\n",
    "\n",
    "    print(\"\\n\" + \"=\" * 70)\n",
    "    print(\"FINAL RESULTS — THREE-DOMAIN UNIVERSAL MANIFOLD\")\n",
    "    print(f\"Z_total = [Z_mnist || Z_usps || Z_svhn], ρ={SPARSITY}\")\n",
    "    print(\"=\" * 70)\n",
    "\n",
    "    print(f\"\"\"\n",
    "┌─────────────────────────────┬──────────┬───────────┬────────┐\n",
    "│ Metric                      │  Value   │ Threshold │  Pass  │\n",
    "├─────────────────────────────┼──────────┼───────────┼────────┤\n",
    "│ β₀ (connected components)   │ {results['beta_0']:^8} │   = 1     │  {'✓' if results['beta_0']==1 else '✗'}     │\n",
    "│ β₁ (loops/holes)            │ {str(results['beta_1']):^8} │    —      │  —     │\n",
    "│ Trust Score τ_t             │ {results['trust_score']:^8.4f} │  ≥ 0.80   │  {'✓' if results['trust_score']>=0.80 else '✗'}     │\n",
    "│ Sliced W₂ τ_w               │ {results['sliced_w2']:^8.4f} │  ≤ 0.30   │  {'✓' if results['sliced_w2']<=0.30 else '✗'}     │\n",
    "│ Continuity τ_c              │ {results['continuity']:^8.4f} │  ≥ 0.70   │  {'✓' if results['continuity']>=0.70 else '✗'}     │\n",
    "│ Alignment Error τ_a         │ {results['alignment_error']:^8.4f} │  ≤ 0.30   │  {'✓' if results['alignment_error']<=0.30 else '✗'}     │\n",
    "└─────────────────────────────┴──────────┴───────────┴────────┘\n",
    "    \"\"\")\n",
    "\n",
    "    # ==================== LATEX TABLE ====================\n",
    "\n",
    "    print(\"=\" * 70)\n",
    "    print(\"LaTeX Table:\")\n",
    "    print(\"=\" * 70)\n",
    "    print(r\"\\begin{table}[!htpb]\")\n",
    "    print(r\"\\centering\")\n",
    "    print(r\"\\caption{Topological Unification Verification — Three-Domain Universal Manifold\")\n",
    "    print(r\"         ($\\kappa=5$, Cosine distance, $\\rho=0.15$)}\")\n",
    "    print(r\"\\label{tab:three_domain_topology}\")\n",
    "    print(r\"\\begin{tabular}{lcc}\")\n",
    "    print(r\"\\hline\")\n",
    "    print(r\"\\textbf{Metric} & \\textbf{Value} & \\textbf{Threshold} \\\\\")\n",
    "    print(r\"\\hline\")\n",
    "    print(f\"$\\\\beta_0$ (connected components) & {results['beta_0']} & $= 1$ \\\\\\\\\")\n",
    "    if results['beta_1'] != 'N/A':\n",
    "        print(f\"$\\\\beta_1$ (loops/holes) & {results['beta_1']} & — \\\\\\\\\")\n",
    "    print(f\"Trust $\\\\tau_t$ & {results['trust_score']:.4f} & $\\\\geq 0.80$ \\\\\\\\\")\n",
    "    print(f\"Sliced $W_2$ $\\\\tau_w$ & {results['sliced_w2']:.4f} & $\\\\leq 0.30$ \\\\\\\\\")\n",
    "    print(f\"Continuity $\\\\tau_c$ & {results['continuity']:.4f} & $\\\\geq 0.70$ \\\\\\\\\")\n",
    "    print(f\"Alignment $\\\\tau_a$ & {results['alignment_error']:.4f} & $\\\\leq 0.30$ \\\\\\\\\")\n",
    "    print(r\"\\hline\")\n",
    "    print(r\"\\end{tabular}\")\n",
    "    print(r\"\\end{table}\")\n",
    "    print(\"=\" * 70)\n",
    "\n",
    "    # ==================== SAVE ====================\n",
    "\n",
    "    results['sparsity']     = SPARSITY\n",
    "    results['sample_frac']  = SAMPLE_FRAC\n",
    "    results['n_mnist']      = len(z_mnist)\n",
    "    results['n_usps']       = len(z_usps)\n",
    "    results['n_svhn']       = len(z_svhn)\n",
    "    results['n_total']      = len(Z_total)\n",
    "    results['latent_dim']   = Z_total.shape[1]\n",
    "    results['distance']     = 'cosine'\n",
    "\n",
    "    with open('three_domain_topology_results.json', 'w') as f:\n",
    "        json.dump(results, f, indent=2)\n",
    "\n",
    "    print(\"\\n✓ Results saved to: three_domain_topology_results.json\")\n",
    "    print(\"=\" * 70)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3d051b7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "manitorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
