{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "CL42d50fnDGJ",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "46e1c071-dfe0-4e34-b30e-3bafd0b50a9f"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Archive:  dataset.zip\n",
            "replace ./dataset/x_test.npy? [y]es, [n]o, [A]ll, [N]one, [r]ename: A\n",
            "  inflating: ./dataset/x_test.npy    \n",
            "  inflating: ./dataset/x_train.npy   \n",
            "  inflating: ./dataset/y_train.npy   \n",
            "  inflating: ./dataset/y_test.npy    \n"
          ]
        }
      ],
      "source": [
        "from tensorflow.keras.datasets import mnist\n",
        "import numpy as np\n",
        "from scipy import ndimage\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow.keras as keras\n",
        "from tensorflow.keras import Model\n",
        "import math\n",
        "from tqdm import tqdm\n",
        "import os\n",
        "import pickle\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "from tensorflow.keras.models import load_model\n",
        "import random\n",
        "import torch\n",
        "import torchvision.transforms as transforms\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
        "import torchvision.transforms as transforms\n",
        "import scipy.io\n",
        "from create_dataset import rotated_mnist_60_data_func, discretely_rotate_images\n",
        "\n",
        "\n",
        "np.random.seed(42)\n",
        "random.seed(42)\n",
        "tf.random.set_seed(42)\n",
        "!unzip dataset.zip -d ./"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Create 10 inter-domain datasets\n",
        "# inter-domain datasets is stored in inter_x, inter_y, size of each them are 4200\n",
        "# There is 5 degrees rotation between these datasets\n",
        "\n",
        "(src_tr_x, src_tr_y, src_val_x, src_val_y, inter_x, inter_y, dir_inter_x, dir_inter_y,\n",
        " trg_val_x, trg_val_y, trg_test_x, trg_test_y) = rotated_mnist_60_data_func()\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "LsAHS56Nna-t",
        "outputId": "bbe2247c-d8f7-4cd4-ea49-62f4446c4c07",
        "collapsed": true
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "5000 6000 48000 50000\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Create 4 inter-domain datasets with 15 degrees rotation between each of them\n",
        "list_rotations = [10, 20, 30, 40]\n",
        "inter_domns_list = []\n",
        "for i in range(4):\n",
        "    temp_inter_x = inter_x[i*4200: (i+1)*4200]\n",
        "    rot_temp_inter_x = discretely_rotate_images(temp_inter_x, list_rotations[i])\n",
        "    inter_domns_list.append((rot_temp_inter_x, inter_y[i*4200: (i+1)*4200]))\n",
        "\n",
        "\n",
        "# mid_dom_x, mid_dom_y = inter_x[5*4200: 6*4200], inter_y[5*4200: 6*4200]\n",
        "# inter_domns_list = [(inter_x[2*4200: 3*4200], inter_y[2*4200: 3*4200]),\n",
        "#                     (inter_x[5*4200: 6*4200], inter_y[5*4200: 6*4200]),\n",
        "#                     (inter_x[8*4200: 9*4200], inter_y[8*4200: 9*4200]),\n",
        "#                     (trg_val_x, trg_val_y)]"
      ],
      "metadata": {
        "id": "1wfsJ_bY1ciX"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Create DataLoaders"
      ],
      "metadata": {
        "id": "Objdmylbng6U"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class MNISTDataset(Dataset):\n",
        "    def __init__(self, images, labels, transform=None):\n",
        "        self.images = images\n",
        "        self.labels = labels\n",
        "        self.transform = transform\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.images)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        image = self.images[idx]\n",
        "        label = self.labels[idx]\n",
        "\n",
        "        if self.transform:\n",
        "            image = self.transform(image)\n",
        "\n",
        "        return image, label\n",
        "\n",
        "\n",
        "# Define a transformation that converts the NumPy array to a PyTorch tensor\n",
        "tensor_transform = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "])\n",
        "batch_size = 64\n",
        "\n",
        "# Create source and target DataLoader\n",
        "source_dataset = MNISTDataset(src_tr_x, src_tr_y, transform=tensor_transform)\n",
        "target_dataset = MNISTDataset(trg_test_x, trg_test_y, transform=tensor_transform)\n",
        "\n",
        "source_train_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)\n",
        "target_train_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=True)\n",
        "\n",
        "# Create source and target validation DataLoader\n",
        "source_val_dataset = MNISTDataset(src_val_x, src_val_y, transform=tensor_transform)\n",
        "target_val_dataset = MNISTDataset(trg_val_x, trg_val_y, transform=tensor_transform)\n",
        "\n",
        "source_val_loader = DataLoader(source_val_dataset, batch_size=batch_size, shuffle=True)\n",
        "# target_val_loader = DataLoader(target_val_dataset, batch_size=batch_size, shuffle=True)\n",
        "\n",
        "# create mid domain loader\n",
        "# mid_dom_dataset = MNISTDataset(mid_dom_x, mid_dom_y, transform=tensor_transform)\n",
        "# mid_dom_loader = DataLoader(mid_dom_dataset, batch_size=batch_size, shuffle=True)\n",
        "\n",
        "\n",
        "# Create inter domains Dataloaders\n",
        "train_loaders_list = []\n",
        "for domn in inter_domns_list:\n",
        "    # Create the custom dataset\n",
        "    temp_dataset = MNISTDataset(domn[0], domn[1],\n",
        "                                transform=tensor_transform)\n",
        "\n",
        "    # Create the DataLoader\n",
        "    batch_size = 64\n",
        "    curr_train_loader = DataLoader(temp_dataset, batch_size=batch_size, shuffle=True)\n",
        "    train_loaders_list.append(curr_train_loader)\n"
      ],
      "metadata": {
        "id": "E54g2chdnjc_"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Define Networks"
      ],
      "metadata": {
        "id": "0hR22LVbqxkx"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class SpatialTransformerNet(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(SpatialTransformerNet, self).__init__()\n",
        "        self.localization = nn.Sequential(\n",
        "            nn.Conv2d(1, 8, kernel_size=7),\n",
        "            nn.MaxPool2d(2, stride=2),\n",
        "            nn.ReLU(True),\n",
        "            nn.Conv2d(8, 10, kernel_size=5),\n",
        "            nn.MaxPool2d(2, stride=2),\n",
        "            nn.ReLU(True)\n",
        "        )\n",
        "\n",
        "        self.fc_loc = nn.Sequential(\n",
        "            nn.Linear(10 * 3 * 3, 32),\n",
        "            nn.ReLU(True),\n",
        "            nn.Linear(32, 3 * 2)\n",
        "        )\n",
        "\n",
        "        self.fc_loc[2].weight.data.zero_()\n",
        "        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))\n",
        "\n",
        "    def forward(self, x):\n",
        "        xs = self.localization(x)\n",
        "        xs = xs.view(-1, 10 * 3 * 3)\n",
        "        theta = self.fc_loc(xs)\n",
        "        theta = theta.view(-1, 2, 3)\n",
        "\n",
        "        grid = F.affine_grid(theta, x.size())\n",
        "        x = F.grid_sample(x, grid)\n",
        "\n",
        "        return x\n",
        "\n",
        "class SimplifiedClassifierNet(nn.Module):\n",
        "    def __init__(self, num_classes=10):\n",
        "        super(SimplifiedClassifierNet, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)\n",
        "        self.bn1 = nn.BatchNorm2d(32)\n",
        "        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)\n",
        "        self.bn2 = nn.BatchNorm2d(64)\n",
        "        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)\n",
        "        self.bn3 = nn.BatchNorm2d(128)\n",
        "        self.dropout = nn.Dropout(0.5)\n",
        "        self.fc1 = nn.Linear(128 * 3 * 3, 128)\n",
        "        self.fc2 = nn.Linear(128, num_classes)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = F.relu(self.bn1(self.conv1(x)))\n",
        "        x = F.max_pool2d(x, 2)\n",
        "        x = F.relu(self.bn2(self.conv2(x)))\n",
        "        x = F.max_pool2d(x, 2)\n",
        "        x = F.relu(self.bn3(self.conv3(x)))\n",
        "        x = F.max_pool2d(x, 2)\n",
        "        x = x.view(-1, 128 * 3 * 3)\n",
        "        x = F.relu(self.fc1(x))\n",
        "        x = self.dropout(x)\n",
        "        x = self.fc2(x)\n",
        "        return x\n",
        "\n",
        "class CombinedModel(nn.Module):\n",
        "    def __init__(self, stn, classifier):\n",
        "        super(CombinedModel, self).__init__()\n",
        "        self.stn = stn\n",
        "        self.classifier = classifier\n",
        "\n",
        "    def forward(self, x):\n",
        "        transformed = self.stn(x)\n",
        "        classification = self.classifier(transformed)\n",
        "        return transformed, classification\n",
        "\n",
        "\n",
        "class SimpleSoftmaxConvModel(nn.Module):\n",
        "    def __init__(self, num_labels, hidden_nodes=32, input_shape=(1, 28, 28), l2_reg=0.0):\n",
        "        super(SimpleSoftmaxConvModel, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(in_channels=input_shape[0], out_channels=hidden_nodes, kernel_size=5, stride=2, padding=2)\n",
        "        self.conv2 = nn.Conv2d(in_channels=hidden_nodes, out_channels=hidden_nodes, kernel_size=5, stride=2, padding=2)\n",
        "        self.conv3 = nn.Conv2d(in_channels=hidden_nodes, out_channels=hidden_nodes, kernel_size=5, stride=2, padding=2)\n",
        "        self.dropout = nn.Dropout(0.5)\n",
        "        self.batch_norm = nn.BatchNorm2d(hidden_nodes)\n",
        "        self.flatten = nn.Flatten()\n",
        "\n",
        "        # Calculate the flattened dimension size\n",
        "        with torch.no_grad():\n",
        "            dummy_input = torch.zeros(1, *input_shape)\n",
        "            flattened_size = self.flatten(self._forward_features(dummy_input)).shape[1]\n",
        "\n",
        "        self.fc = nn.Linear(flattened_size, num_labels)\n",
        "\n",
        "    def _forward_features(self, x):\n",
        "        x = F.elu(self.conv1(x))\n",
        "        x = F.elu(self.conv2(x))\n",
        "        x = F.elu(self.conv3(x))\n",
        "        x = self.dropout(x)\n",
        "        x = self.batch_norm(x)\n",
        "        return x\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self._forward_features(x)\n",
        "        x = self.flatten(x)\n",
        "        x = self.fc(x)\n",
        "        return x\n"
      ],
      "metadata": {
        "id": "1XU2_x9Rqu0M"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Pretrain the Classifier\n",
        "\n",
        "classifier = SimplifiedClassifierNet(num_classes=10).to('cuda')\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "optimizer = optim.Adam(classifier.parameters(), lr=0.001)\n",
        "\n",
        "# Training loop\n",
        "num_epochs = 20\n",
        "for epoch in range(num_epochs):\n",
        "    classifier.train()\n",
        "    running_loss = 0.0\n",
        "    for i, (inputs, targets) in enumerate(source_train_loader):\n",
        "        inputs, targets = inputs.to('cuda').float(), targets.to('cuda')\n",
        "\n",
        "        # Forward pass\n",
        "        outputs = classifier(inputs)\n",
        "        loss = criterion(outputs, targets)\n",
        "\n",
        "        # Backward pass and optimization\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        running_loss += loss.item()\n",
        "\n",
        "    # print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(source_train_loader):.4f}\")\n",
        "\n",
        "\n",
        "# Evaluate accuracy on the validation set\n",
        "classifier.eval()\n",
        "correct = 0\n",
        "total = 0\n",
        "with torch.no_grad():\n",
        "    for inputs, targets in source_val_loader:\n",
        "        inputs, targets = inputs.to('cuda').float(), targets.to('cuda')\n",
        "        outputs = classifier(inputs)\n",
        "        _, predicted = torch.max(outputs.data, 1)\n",
        "        total += targets.size(0)\n",
        "        correct += (predicted == targets).sum().item()\n",
        "accuracy = 100 * correct / total\n",
        "print(f'Source Validation Accuracy: {accuracy:.2f}%')\n",
        "\n",
        "count = 1\n",
        "print('Inter Domains Accuracy')\n",
        "for inter_loader in train_loaders_list:\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    with torch.no_grad():\n",
        "        for inputs, targets in inter_loader:\n",
        "            inputs, targets = inputs.to('cuda').float(), targets.to('cuda')\n",
        "            outputs = classifier(inputs)\n",
        "            _, predicted = torch.max(outputs.data, 1)\n",
        "            total += targets.size(0)\n",
        "            correct += (predicted == targets).sum().item()\n",
        "    accuracy = 100 * correct / total\n",
        "    print(f'Domain {count} Validation Accuracy: {accuracy:.2f}%')\n",
        "    count +=1\n",
        "\n",
        "torch.save(classifier.state_dict(), 'classifier_v1.pth')\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5g4RWe8qq-II",
        "outputId": "12549618-807c-44b5-cb9f-3f078b4bbe1d"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Source Validation Accuracy: 98.50%\n",
            "Inter Domains Accuracy\n",
            "Domain 1 Validation Accuracy: 95.81%\n",
            "Domain 2 Validation Accuracy: 79.40%\n",
            "Domain 3 Validation Accuracy: 44.86%\n",
            "Domain 4 Validation Accuracy: 26.79%\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Train Combined Model"
      ],
      "metadata": {
        "id": "T_zE3bKewvR_"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# final model with custom loss is defined\n",
        "\n",
        "def custom_loss(output, target, transformed, in_value, classification_loss_fn, mse_loss_fn, gamma):\n",
        "    classification_loss = classification_loss_fn(output, target)\n",
        "    mse_loss = mse_loss_fn(transformed, in_value)\n",
        "    return classification_loss - gamma * mse_loss\n",
        "\n",
        "\n",
        "# tensor_transform = transforms.Compose([transforms.ToTensor()])\n",
        "# mnist_subset = MNISTDataset(inter_x[:4200], inter_y[:4200], transform=tensor_transform)\n",
        "# train_loader = DataLoader(mnist_subset, batch_size=64, shuffle=True)\n",
        "\n",
        "\n",
        "def evaluate(model, val_loader):\n",
        "    model.eval()\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    with torch.no_grad():\n",
        "        for inputs, targets in val_loader:\n",
        "            inputs, targets = inputs.to('cuda').float(), targets.to('cuda')\n",
        "            outputs = model(inputs)\n",
        "            _, predicted = torch.max(outputs.data, 1)\n",
        "            total += targets.size(0)\n",
        "            correct += (predicted == targets).sum().item()\n",
        "    accuracy = 100 * correct / total\n",
        "    return round(accuracy, 2)\n",
        "\n",
        "\n",
        "def train(f_model, c_model, train_loader, gamma, num_epochs=20):\n",
        "    model = CombinedModel(f_model, c_model).to('cuda')\n",
        "    classification_loss_fn = nn.CrossEntropyLoss()\n",
        "    mse_loss_fn = nn.MSELoss()\n",
        "    optimizer_stn = optim.Adam(f_model.parameters(), lr=0.001)\n",
        "    optimizer_classifier = optim.Adam(c_model.parameters(), lr=0.001)\n",
        "    for epoch in range(num_epochs):\n",
        "        model.train()\n",
        "        running_loss = 0.0\n",
        "        for i, (inputs, targets) in enumerate(train_loader):\n",
        "            inputs, targets = inputs.to('cuda').float(), targets.to('cuda')\n",
        "\n",
        "            # Forward pass\n",
        "            transformed, outputs = model(inputs)\n",
        "\n",
        "            # Compute loss\n",
        "            loss = custom_loss(outputs, targets, transformed, inputs,\n",
        "                               classification_loss_fn, mse_loss_fn, gamma)\n",
        "            loss2 = loss.clone().detach()\n",
        "            loss2.requires_grad = True\n",
        "\n",
        "            # Backward pass for classifier (gradient descent)\n",
        "            optimizer_classifier.zero_grad()\n",
        "            # classification_loss = classification_loss_fn(outputs, targets)\n",
        "            loss.backward(retain_graph=True)\n",
        "            optimizer_classifier.step()\n",
        "\n",
        "            # Backward pass for STN (gradient ascent)\n",
        "            optimizer_stn.zero_grad()\n",
        "            # loss_for_stn = loss.clone()\n",
        "            (-loss2).backward(retain_graph=True)  # Reverse the sign of the loss for gradient ascent\n",
        "            optimizer_stn.step()\n",
        "\n",
        "            running_loss += loss.item() + loss2.item()\n",
        "\n",
        "        # print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}\")\n",
        "    return f_model, c_model\n"
      ],
      "metadata": {
        "collapsed": true,
        "id": "_g1LOZkLwyUv"
      },
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Pseudo label data with a model\n",
        "def create_labeled_data(model, domain_data_loader, confidence_threshold):\n",
        "    model.eval()\n",
        "    all_inputs = []\n",
        "    all_outputs = []\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for inputs, _ in domain_data_loader:\n",
        "            inputs = inputs.to('cuda').float()\n",
        "            outputs = model(inputs)\n",
        "            probabilities = nn.functional.softmax(outputs, dim=1)\n",
        "            max_probs, predicted_labels = torch.max(probabilities, 1)\n",
        "\n",
        "            high_confidence_mask = max_probs >= confidence_threshold\n",
        "            high_confidence_inputs = inputs[high_confidence_mask]\n",
        "            high_confidence_labels = predicted_labels[high_confidence_mask]\n",
        "\n",
        "            if high_confidence_inputs.size(0) > 0:\n",
        "                all_inputs.append(high_confidence_inputs.cpu())\n",
        "                all_outputs.append(high_confidence_labels.cpu())\n",
        "\n",
        "    all_inputs = torch.cat(all_inputs)\n",
        "    all_outputs = torch.cat(all_outputs)\n",
        "    # print('data len: ', len(all_inputs))\n",
        "\n",
        "    new_dataset = TensorDataset(all_inputs, all_outputs)\n",
        "    new_loader = DataLoader(new_dataset, batch_size=64, shuffle=True)\n",
        "    return new_loader\n",
        "\n"
      ],
      "metadata": {
        "id": "1QZjTI0zIrHW"
      },
      "execution_count": 19,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create initial networks\n",
        "\n",
        "gamma_value = 1\n",
        "\n",
        "# define the tranformation network\n",
        "stn = SpatialTransformerNet().to('cuda')\n",
        "\n",
        "# define and load pre-train classifier network\n",
        "classifier = SimplifiedClassifierNet(num_classes=10).to('cuda')\n",
        "classifier.load_state_dict(torch.load('classifier_v1.pth'))\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "collapsed": true,
        "id": "R_bI_G73Xr6d",
        "outputId": "4ef0b50b-7c9a-43b2-c0d5-32e4cbfe55f6"
      },
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "<ipython-input-30-833601a7e4b6>:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
            "  classifier.load_state_dict(torch.load('classifier_v1.pth'))\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<All keys matched successfully>"
            ]
          },
          "metadata": {},
          "execution_count": 30
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Train on source dataset\n",
        "tr_f, tr_c = train(stn, classifier, source_train_loader, gamma_value, 20)\n",
        "\n",
        "print('Our method results: ')\n",
        "acc_list = []\n",
        "count = 1\n",
        "for i in range(4):\n",
        "    curr_acc = evaluate(tr_c, train_loaders_list[i])\n",
        "    acc_list.append(curr_acc)\n",
        "    print(f'Domain {count} Accuracy: {curr_acc:.2f}%')\n",
        "    next_dom_data_loader = create_labeled_data(tr_c, train_loaders_list[i], .95)\n",
        "    classifier = SimplifiedClassifierNet(num_classes=10).to('cuda')\n",
        "    # stn = SpatialTransformerNet().to('cuda')\n",
        "    tr_f, tr_c = train(tr_f, classifier, next_dom_data_loader, gamma_value, 20)\n",
        "    count +=1\n",
        "\n",
        "print('Inter-domain accuracies: ', acc_list)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VffUHYYIdHK8",
        "outputId": "8074d886-4b8f-4d02-d7c4-d28cd7cce9a4",
        "collapsed": true
      },
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Our method results: \n",
            "Domain 1 Accuracy: 95.26%\n",
            "Domain 2 Accuracy: 88.76%\n",
            "Domain 3 Accuracy: 70.26%\n",
            "Domain 4 Accuracy: 54.14%\n",
            "Inter-domain accuracies:  [95.26, 88.76, 70.26, 54.14]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def train_gda(c_model, train_loader, num_epochs=20):\n",
        "    classification_loss_fn = nn.CrossEntropyLoss()\n",
        "    mse_loss_fn = nn.MSELoss()\n",
        "    optimizer_classifier = optim.Adam(c_model.parameters(), lr=0.001)\n",
        "    for epoch in range(num_epochs):\n",
        "        c_model.train()\n",
        "        running_loss = 0.0\n",
        "        for i, (inputs, targets) in enumerate(train_loader):\n",
        "            inputs, targets = inputs.to('cuda').float(), targets.to('cuda')\n",
        "\n",
        "            # Forward pass\n",
        "            outputs = c_model(inputs)\n",
        "\n",
        "            # Compute loss\n",
        "            classification_loss = classification_loss_fn(outputs, targets)\n",
        "\n",
        "            # Backward pass for classifier (gradient descent)\n",
        "            optimizer_classifier.zero_grad()\n",
        "            classification_loss.backward(retain_graph=True)\n",
        "            optimizer_classifier.step()\n",
        "\n",
        "            running_loss += classification_loss.item()\n",
        "\n",
        "        # print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}\")\n",
        "    return c_model\n"
      ],
      "metadata": {
        "id": "m_M-ZuMHjTEL"
      },
      "execution_count": 26,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Define and load pretrained classifier\n",
        "classifier = SimplifiedClassifierNet(num_classes=10).to('cuda')\n",
        "classifier.load_state_dict(torch.load('classifier_v1.pth'))\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hAl551Cgm7Q-",
        "outputId": "2c2259dc-090f-4145-e2c2-916daa6642b1"
      },
      "execution_count": 36,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "<ipython-input-36-8340beaa683a>:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
            "  classifier.load_state_dict(torch.load('classifier_v1.pth'))\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<All keys matched successfully>"
            ]
          },
          "metadata": {},
          "execution_count": 36
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "tr_c = train_gda(classifier, source_train_loader, 20)\n",
        "\n",
        "print('GDA results: ')\n",
        "acc_list = []\n",
        "for i in range(4):\n",
        "    curr_acc = evaluate(tr_c, train_loaders_list[i])\n",
        "    print(f'Domain {count} Accuracy: {curr_acc:.2f}%')\n",
        "    acc_list.append(curr_acc)\n",
        "    next_dom_data_loader = create_labeled_data(tr_c, train_loaders_list[i], .95)\n",
        "    # classifier = SimplifiedClassifierNet(num_classes=10).to('cuda')\n",
        "    # stn = SpatialTransformerNet().to('cuda')\n",
        "    tr_c = train_gda(classifier, next_dom_data_loader, 20)\n",
        "\n",
        "print('Inter-domain accuracies: ', acc_list)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iUIHdtYfnVYj",
        "outputId": "e20d0ba5-4169-4102-da74-7a8560195fe8"
      },
      "execution_count": 37,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "GDA results: \n",
            "Domain 5 Accuracy: 94.19%\n",
            "Domain 5 Accuracy: 83.98%\n",
            "Domain 5 Accuracy: 66.26%\n",
            "Domain 5 Accuracy: 42.98%\n",
            "Inter-domain accuracies:  [94.19, 83.98, 66.26, 42.98]\n"
          ]
        }
      ]
    }
  ]
}