{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training PointNet on Hot Rolling with Deep-Coral\n",
    "This tutorial notebook shows how to set up the SIMSHIFT pipeline. It will cover:\n",
    "- datasets\n",
    "- PointNet training\n",
    "- domain adaptation with DeepCORAL\n",
    "- model evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# for determinitstic behavior\n",
    "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\"\n",
    "\n",
    "import sys\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else \"cpu\"\n",
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random seed set as 42\n"
     ]
    }
   ],
   "source": [
    "from utils import set_seed\n",
    "\n",
    "# for reproducability\n",
    "set_seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Datasets and domains\n",
    "First, let's load our datasets. Running the next cell will download the rolling dataset from [HuggingFace](https://huggingface.co/datasets/simshift/SIMSHIFT_data), and create a training and validation splits for source/target domains.\n",
    "\n",
    "Samples are loaded into RAM since the datasets are of moderate size (especially the rolling one). Normalization is done per field (PEEQ, stress, etc.) and we can choose between `zscore` or `minmax`. We can also specify what domain shift `difficulty` we want from `[\"easy\", \"medium\", \"hard\"]`, we will choose medium for now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading data (rolling, split=train, domain=src): 100%|██████████| 1750/1750 [00:16<00:00, 104.62it/s]\n",
      "Loading data (rolling, split=train, domain=tgt): 100%|██████████| 625/625 [00:04<00:00, 138.04it/s]\n",
      "Loading data (rolling, split=val, domain=src): 100%|██████████| 868/868 [00:06<00:00, 127.02it/s]\n",
      "Loading data (rolling, split=val, domain=tgt): 0it [00:00, ?it/s]\n",
      "Loading data (rolling, split=test, domain=src): 100%|██████████| 882/882 [00:05<00:00, 152.33it/s]\n",
      "Loading data (rolling, split=test, domain=tgt): 100%|██████████| 625/625 [00:03<00:00, 166.62it/s]\n"
     ]
    }
   ],
   "source": [
    "from data import get_rolling_dataset\n",
    "\n",
    "(train_src, train_tgt), normalization_stats = get_rolling_dataset(\n",
    "    split=\"train\",\n",
    "    difficulty = \"medium\",\n",
    "    normalization_method=\"zscore\",\n",
    ")\n",
    "\n",
    "(val_src, val_tgt), _ = get_rolling_dataset(\n",
    "    split=\"val\",\n",
    "    normalization_stats=normalization_stats,\n",
    "    difficulty = \"medium\",\n",
    "    normalization_method=\"zscore\",\n",
    ")\n",
    "\n",
    "(test_src, test_tgt), _ = get_rolling_dataset(\n",
    "    split=\"test\",\n",
    "    normalization_stats=normalization_stats,\n",
    "    difficulty = \"medium\",\n",
    "    normalization_method=\"zscore\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can inspect some metadata of our dataset, such as the channels that we can use to slice our data arrays to receive the wanted fields:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset channels:\n",
      "nodes_LE slice(np.int64(0), np.int64(4), None)\n",
      "nodes_PEEQ slice(np.int64(4), np.int64(5), None)\n",
      "nodes_mises_stress slice(np.int64(5), np.int64(6), None)\n",
      "nodes_stresses slice(np.int64(6), np.int64(10), None)\n"
     ]
    }
   ],
   "source": [
    "print(\"Dataset channels:\")\n",
    "for k, v in train_src.channels.items():\n",
    "    print(k, v)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So calculating the (normalized) mean PEEQ for a sample of the dataset is as easy as this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean normalized PEEQ of 1st sample: -0.70.\n"
     ]
    }
   ],
   "source": [
    "sample = train_src[500]\n",
    "peeq = sample.y[:, train_src.channels[\"nodes_PEEQ\"]]\n",
    "print(f\"Mean normalized PEEQ of 1st sample: {torch.mean(peeq):.2f}.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we want to have it in denormalized space, we first need to denormalize the sample:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean PEEQ of 1st sample: 0.0028\n"
     ]
    }
   ],
   "source": [
    "sample_denormalized = train_src.denormalize(None, sample.y)\n",
    "peeq_denormalized = sample_denormalized[:, train_src.channels[\"nodes_PEEQ\"]]\n",
    "print(f\"Mean PEEQ of 1st sample: {torch.mean(peeq_denormalized):.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's create some dataloaders our dataloaders:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "trainloader_src = DataLoader(train_src, 16, True, collate_fn=train_src.collate, drop_last=True)\n",
    "trainloader_tgt = DataLoader(train_tgt, 16, True, collate_fn=train_tgt.collate, drop_last=True)\n",
    "\n",
    "valloader_src = DataLoader(val_src, 16, collate_fn=val_src.collate)\n",
    "valloader_tgt = DataLoader(val_tgt, 16, collate_fn=val_src.collate)  # note that this is just created for API compatibility len(val_tgt) = 0 \n",
    "                                                                     # since we do not have access to target labels in the UDA setting!\n",
    "\n",
    "testloader_src = DataLoader(test_src, 16, collate_fn=test_src.collate)\n",
    "testloader_tgt = DataLoader(test_tgt, 16, collate_fn=test_tgt.collate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Baseline model\n",
    "Next, we create the PointNet baseline model. The architecture is based on the original [PointNet](https://arxiv.org/abs/1612.00593), with some tweaks to make it work in the context of domain adaptation. In particular a latent conditioning code is concatenated to the global features in PointNet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model parameters: 0.30M\n"
     ]
    }
   ],
   "source": [
    "from models import PointNet\n",
    "\n",
    "\n",
    "model = PointNet(\n",
    "    n_conds=train_src.n_conds,  # rolling has 4 conditions\n",
    "    latent_channels=8,  # dimension of the conditioning code\n",
    "    output_channels=train_src.n_channels,  # number of output field channels\n",
    "    pointnet_base=16,  # latent size of the pointnet mlps\n",
    "    dropout_prob=0\n",
    ")\n",
    "print(\n",
    "    f\"Model parameters: {(sum(p.numel() for p in model.parameters()) / 1e6):.2f}M\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Domain Adaptation wrapper\n",
    "\n",
    "In this example we take the [Deep-Coral](https://arxiv.org/abs/1607.01719) loss for domain adaptation. This is implemented as a wrapper around the model and optimizer, computes the task and DA losses and applies the updates.\n",
    "\n",
    "In particular, Deep-Coral aligns source and target domains by matching their feature covariances: $\\mathcal{L}_{\\text{CORAL}} = \\frac{1}{4d^2} \\|\\mathbf{C}^S - \\mathbf{C}^T\\|_F^2$, where ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from da_algorithms import DeepCORAL\n",
    "\n",
    "\n",
    "da_algorithm = DeepCORAL(\n",
    "    device=device,\n",
    "    model=model,  # wrapped pointnet\n",
    "    opt_method=torch.optim.AdamW,  # wrapped optimizer\n",
    "    opt_kwargs={\"lr\": 1e-3, \"weight_decay\": 1e-5},  # adam arguments\n",
    "    da_loss_weight=1e-5,  # deepcoral loss weight\n",
    "    use_ema=True,  # use exponential moving average to update a target network\n",
    "    ema_decay=0.95,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Trainer\n",
    "The Trainer class handles:\n",
    "- training loop and application of domain adaptation losses on target\n",
    "- evaluation / metric computation, checkpoints and logging\n",
    "- early stopping\n",
    "\n",
    "The `.run()` method starts the training with the specified number of epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from train import Trainer\n",
    "from eval import Metrics\n",
    "from utils import Logger\n",
    "\n",
    "n_epochs = 300\n",
    "early_stopping_patience = 50\n",
    "\n",
    "trainer = Trainer(\n",
    "    datasets=(train_src, val_src, train_tgt, val_tgt),\n",
    "    dataloaders=(trainloader_src, valloader_src, trainloader_tgt, valloader_tgt),\n",
    "    da_algorithm=da_algorithm,\n",
    "    device=device,\n",
    "    n_epochs=n_epochs,\n",
    "    early_stopping_patience=early_stopping_patience,\n",
    "    scheduler=\"cosine\",\n",
    "    logger=Logger(\"tutorial\", n_epochs=n_epochs),\n",
    "    metrics=Metrics(),  # computes \"mse_avg\", \"ae_avg\", \"mse_max\", \"ae_max\"\n",
    "    save_model=False  # for the tutorials sake we won't save any checkpoints or similar\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-05-16 09:50:55,425 - [010] train/mse_loss: 0.11031, train/da_loss: 2.22467, train/summed_loss: 0.11071, train/lr: 0.00100, val_source/mse_avg: 0.07658, val_source/ae_avg: 0.13546, val_source/mse_max: 14.06019, val_source/ae_max: 2.53139, \n",
      "2025-05-16 09:51:13,233 - [020] train/mse_loss: 0.04913, train/da_loss: 3.13364, train/summed_loss: 0.04940, train/lr: 0.00099, val_source/mse_avg: 0.03533, val_source/ae_avg: 0.09248, val_source/mse_max: 4.15543, val_source/ae_max: 1.38688, \n",
      "2025-05-16 09:51:29,769 - [030] train/mse_loss: 0.03620, train/da_loss: 3.38238, train/summed_loss: 0.03633, train/lr: 0.00098, val_source/mse_avg: 0.02076, val_source/ae_avg: 0.07425, val_source/mse_max: 3.46494, val_source/ae_max: 1.18726, \n",
      "2025-05-16 09:51:46,344 - [040] train/mse_loss: 0.02418, train/da_loss: 3.50228, train/summed_loss: 0.02428, train/lr: 0.00096, val_source/mse_avg: 0.01498, val_source/ae_avg: 0.06406, val_source/mse_max: 3.09761, val_source/ae_max: 1.14554, \n",
      "2025-05-16 09:52:02,923 - [050] train/mse_loss: 0.02593, train/da_loss: 3.01406, train/summed_loss: 0.02603, train/lr: 0.00093, val_source/mse_avg: 0.01271, val_source/ae_avg: 0.05915, val_source/mse_max: 2.53627, val_source/ae_max: 1.01500, \n",
      "2025-05-16 09:52:19,394 - [060] train/mse_loss: 0.01284, train/da_loss: 3.39666, train/summed_loss: 0.01293, train/lr: 0.00090, val_source/mse_avg: 0.00971, val_source/ae_avg: 0.05320, val_source/mse_max: 2.13012, val_source/ae_max: 0.90903, \n",
      "2025-05-16 09:52:35,936 - [070] train/mse_loss: 0.01093, train/da_loss: 3.32863, train/summed_loss: 0.01101, train/lr: 0.00087, val_source/mse_avg: 0.00818, val_source/ae_avg: 0.04899, val_source/mse_max: 2.15981, val_source/ae_max: 0.88205, \n",
      "2025-05-16 09:52:53,010 - [080] train/mse_loss: 0.01108, train/da_loss: 3.20714, train/summed_loss: 0.01116, train/lr: 0.00083, val_source/mse_avg: 0.00746, val_source/ae_avg: 0.04670, val_source/mse_max: 1.87728, val_source/ae_max: 0.81177, \n",
      "2025-05-16 09:53:09,462 - [090] train/mse_loss: 0.01161, train/da_loss: 2.81789, train/summed_loss: 0.01169, train/lr: 0.00079, val_source/mse_avg: 0.00697, val_source/ae_avg: 0.04487, val_source/mse_max: 1.68898, val_source/ae_max: 0.77507, \n",
      "2025-05-16 09:53:26,175 - [100] train/mse_loss: 0.00967, train/da_loss: 2.70987, train/summed_loss: 0.00975, train/lr: 0.00075, val_source/mse_avg: 0.00624, val_source/ae_avg: 0.04291, val_source/mse_max: 1.59713, val_source/ae_max: 0.75234, \n",
      "2025-05-16 09:53:42,703 - [110] train/mse_loss: 0.01719, train/da_loss: 2.57143, train/summed_loss: 0.01726, train/lr: 0.00070, val_source/mse_avg: 0.01159, val_source/ae_avg: 0.04878, val_source/mse_max: 3.24493, val_source/ae_max: 1.02028, \n",
      "2025-05-16 09:53:59,157 - [120] train/mse_loss: 0.00658, train/da_loss: 2.57650, train/summed_loss: 0.00665, train/lr: 0.00065, val_source/mse_avg: 0.00526, val_source/ae_avg: 0.04002, val_source/mse_max: 1.18179, val_source/ae_max: 0.66215, \n",
      "2025-05-16 09:54:15,593 - [130] train/mse_loss: 0.03180, train/da_loss: 2.44640, train/summed_loss: 0.03187, train/lr: 0.00060, val_source/mse_avg: 0.01276, val_source/ae_avg: 0.05233, val_source/mse_max: 3.70616, val_source/ae_max: 1.14855, \n",
      "2025-05-16 09:54:32,092 - [140] train/mse_loss: 0.00677, train/da_loss: 2.34826, train/summed_loss: 0.00684, train/lr: 0.00055, val_source/mse_avg: 0.00484, val_source/ae_avg: 0.03862, val_source/mse_max: 0.91337, val_source/ae_max: 0.60618, \n",
      "2025-05-16 09:54:48,602 - [150] train/mse_loss: 0.00554, train/da_loss: 2.22149, train/summed_loss: 0.00561, train/lr: 0.00050, val_source/mse_avg: 0.00458, val_source/ae_avg: 0.03766, val_source/mse_max: 0.96442, val_source/ae_max: 0.60968, \n",
      "2025-05-16 09:55:05,066 - [160] train/mse_loss: 0.01019, train/da_loss: 2.13256, train/summed_loss: 0.01026, train/lr: 0.00045, val_source/mse_avg: 0.00480, val_source/ae_avg: 0.03802, val_source/mse_max: 0.95845, val_source/ae_max: 0.59838, \n",
      "2025-05-16 09:55:21,510 - [170] train/mse_loss: 0.00473, train/da_loss: 1.97730, train/summed_loss: 0.00480, train/lr: 0.00040, val_source/mse_avg: 0.00421, val_source/ae_avg: 0.03633, val_source/mse_max: 0.73348, val_source/ae_max: 0.55485, \n",
      "2025-05-16 09:55:37,960 - [180] train/mse_loss: 0.00494, train/da_loss: 1.87955, train/summed_loss: 0.00501, train/lr: 0.00035, val_source/mse_avg: 0.00408, val_source/ae_avg: 0.03579, val_source/mse_max: 0.72424, val_source/ae_max: 0.53369, \n",
      "2025-05-16 09:55:54,413 - [190] train/mse_loss: 0.00502, train/da_loss: 1.82530, train/summed_loss: 0.00509, train/lr: 0.00030, val_source/mse_avg: 0.00400, val_source/ae_avg: 0.03550, val_source/mse_max: 0.79841, val_source/ae_max: 0.53491, \n",
      "2025-05-16 09:56:10,866 - [200] train/mse_loss: 0.00424, train/da_loss: 1.77926, train/summed_loss: 0.00431, train/lr: 0.00025, val_source/mse_avg: 0.00384, val_source/ae_avg: 0.03485, val_source/mse_max: 0.62210, val_source/ae_max: 0.50323, \n",
      "2025-05-16 09:56:27,352 - [210] train/mse_loss: 0.00406, train/da_loss: 1.74697, train/summed_loss: 0.00413, train/lr: 0.00021, val_source/mse_avg: 0.00375, val_source/ae_avg: 0.03452, val_source/mse_max: 0.51180, val_source/ae_max: 0.47245, \n",
      "2025-05-16 09:56:43,778 - [220] train/mse_loss: 0.00399, train/da_loss: 1.68344, train/summed_loss: 0.00405, train/lr: 0.00017, val_source/mse_avg: 0.00369, val_source/ae_avg: 0.03417, val_source/mse_max: 0.45924, val_source/ae_max: 0.44654, \n",
      "2025-05-16 09:57:00,215 - [230] train/mse_loss: 0.00381, train/da_loss: 1.69506, train/summed_loss: 0.00387, train/lr: 0.00013, val_source/mse_avg: 0.00362, val_source/ae_avg: 0.03398, val_source/mse_max: 0.47080, val_source/ae_max: 0.45816, \n",
      "2025-05-16 09:57:16,654 - [240] train/mse_loss: 0.00370, train/da_loss: 1.57067, train/summed_loss: 0.00376, train/lr: 0.00010, val_source/mse_avg: 0.00354, val_source/ae_avg: 0.03364, val_source/mse_max: 0.47045, val_source/ae_max: 0.45338, \n",
      "2025-05-16 09:57:33,127 - [250] train/mse_loss: 0.00359, train/da_loss: 1.56691, train/summed_loss: 0.00365, train/lr: 0.00007, val_source/mse_avg: 0.00351, val_source/ae_avg: 0.03353, val_source/mse_max: 0.44196, val_source/ae_max: 0.44979, \n",
      "2025-05-16 09:57:49,574 - [260] train/mse_loss: 0.00352, train/da_loss: 1.56210, train/summed_loss: 0.00358, train/lr: 0.00004, val_source/mse_avg: 0.00348, val_source/ae_avg: 0.03334, val_source/mse_max: 0.42760, val_source/ae_max: 0.44407, \n",
      "2025-05-16 09:58:06,048 - [270] train/mse_loss: 0.00347, train/da_loss: 1.59988, train/summed_loss: 0.00354, train/lr: 0.00002, val_source/mse_avg: 0.00346, val_source/ae_avg: 0.03328, val_source/mse_max: 0.43885, val_source/ae_max: 0.44334, \n",
      "2025-05-16 09:58:22,732 - [280] train/mse_loss: 0.00344, train/da_loss: 1.57821, train/summed_loss: 0.00351, train/lr: 0.00001, val_source/mse_avg: 0.00344, val_source/ae_avg: 0.03320, val_source/mse_max: 0.42863, val_source/ae_max: 0.43939, \n",
      "2025-05-16 09:58:39,404 - [290] train/mse_loss: 0.00340, train/da_loss: 1.48446, train/summed_loss: 0.00346, train/lr: 0.00000, val_source/mse_avg: 0.00344, val_source/ae_avg: 0.03318, val_source/mse_max: 0.42840, val_source/ae_max: 0.44076, \n",
      "2025-05-16 09:58:56,078 - [300] train/mse_loss: 0.00341, train/da_loss: 1.53287, train/summed_loss: 0.00347, train/lr: 0.00000, val_source/mse_avg: 0.00343, val_source/ae_avg: 0.03317, val_source/mse_max: 0.42854, val_source/ae_max: 0.44089, \n"
     ]
    }
   ],
   "source": [
    "trainer.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss: 0.0029690618966274735\n"
     ]
    }
   ],
   "source": [
    "# evaluate testset for unregularized model\n",
    "cum_loss = 0\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for sample in testloader_src:\n",
    "        sample = sample.to(device)\n",
    "        pred, _ = model(**sample.as_dict())\n",
    "        pred_fields, pred_pos = pred\n",
    "        pred = torch.cat([pred_fields, pred_pos], dim=-1)\n",
    "        gt = torch.cat([sample.y, sample.y_mesh_coords], dim=-1)\n",
    "        loss = F.mse_loss(gt, pred, reduction=\"none\").mean(-1)  # avg mse across all normalized fields and position per node\n",
    "        loss_per_sample = torch.zeros([sample.batch_index.max().item() + 1]).to(\n",
    "            device\n",
    "        )\n",
    "        loss_per_sample.scatter_reduce_(\n",
    "            dim=0, index=sample.batch_index, src=loss, reduce=\"mean\"\n",
    "        )\n",
    "        cum_loss += (loss_per_sample.mean()).item()  # avg across samples in a batch\n",
    "\n",
    "cum_loss /= len(testloader_src)  # careful: not entirely accurate since we are not dropping last (possibly incomplete) minibatch from the dataloader\n",
    "print(f\"Average loss: {cum_loss}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss: 0.05497161548451653\n"
     ]
    }
   ],
   "source": [
    "# evaluate testset for unregularized model\n",
    "cum_loss = 0\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for sample in testloader_tgt:\n",
    "        sample = sample.to(device)\n",
    "        pred, _ = model(**sample.as_dict())\n",
    "        pred_fields, pred_pos = pred\n",
    "        pred = torch.cat([pred_fields, pred_pos], dim=-1)\n",
    "        gt = torch.cat([sample.y, sample.y_mesh_coords], dim=-1)\n",
    "        loss = F.mse_loss(gt, pred, reduction=\"none\").mean(-1)  # avg mse across all normalized fields and position per node\n",
    "        loss_per_sample = torch.zeros([sample.batch_index.max().item() + 1]).to(\n",
    "            device\n",
    "        )\n",
    "        loss_per_sample.scatter_reduce_(\n",
    "            dim=0, index=sample.batch_index, src=loss, reduce=\"mean\"\n",
    "        )\n",
    "        cum_loss += (loss_per_sample.mean()).item()  # avg across samples in a batch\n",
    "\n",
    "cum_loss /= len(testloader_src)  # careful: not entirely accurate since we are not dropping last (possibly incomplete) minibatch from the dataloader\n",
    "print(f\"Average loss: {cum_loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The domain gap is clearly visible!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "2d_3d_da",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
