{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 18373,
     "status": "ok",
     "timestamp": 1694033719063,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "UHgERHwINWD7",
    "outputId": "9d02537d-f268-4c28-ed8d-bd728a1859d9"
   },
   "outputs": [],
   "source": [
    "from google.colab import drive\n",
    "drive.mount('/content/gdrive')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5793,
     "status": "ok",
     "timestamp": 1694033724853,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "AgyHViRuED9D",
    "outputId": "fbf52222-bb18-499e-92a5-240b32c6218f"
   },
   "outputs": [],
   "source": [
    "pip install PIMS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 4746,
     "status": "ok",
     "timestamp": 1694033736895,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "cS35kgH5EHmp",
    "outputId": "26439447-fe83-4eaa-a4b9-444ac41544d2"
   },
   "outputs": [],
   "source": [
    "pip install fastcluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 5770,
     "status": "ok",
     "timestamp": 1694033742664,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "i5BqMQY2Nd1e"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision\n",
    "import torch\n",
    "from torchvision import transforms\n",
    "from torchvision.io import read_image\n",
    "\n",
    "plt.rcParams['axes.facecolor'] = 'white'\n",
    "plt.rcParams['figure.figsize'] = 9, 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 21938,
     "status": "ok",
     "timestamp": 1694033764600,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "6bCblqtuEK6F"
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "# importing relevant libraries\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy as sp\n",
    "import matplotlib.pyplot as plt\n",
    "import sklearn\n",
    "from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier\n",
    "from sklearn.preprocessing import OneHotEncoder, StandardScaler\n",
    "from sklearn.model_selection import cross_val_predict, StratifiedKFold\n",
    "from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, auc#plot_precision_recall_curve\n",
    "from sklearn.datasets import make_classification\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from tqdm import tqdm\n",
    "from umap import UMAP\n",
    "from pynndescent import NNDescent\n",
    "from fastcluster import single\n",
    "from scipy.cluster.hierarchy import cut_tree, fcluster, dendrogram\n",
    "from scipy.spatial.distance import squareform\n",
    "from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier\n",
    "from pims import ImageSequence\n",
    "from PIL import Image\n",
    "from scipy.spatial.distance import hamming\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "# turning off automatic plot showing, and setting style\n",
    "plt.style.use('bmh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1693432495767,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "aILiFBc-6XZb",
    "outputId": "1fc42015-4603-4cb8-83e5-99f4d81ccdd2"
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oEIBbUweNd3P"
   },
   "outputs": [],
   "source": [
    "# define the NN architecture\n",
    "class ConvAutoencoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ConvAutoencoder, self).__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),\n",
    "            nn.ReLU(),\n",
    "            # nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),\n",
    "            # nn.ReLU(),\n",
    "        )\n",
    "        self.decoder = nn.Sequential(\n",
    "            # nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
    "            # nn.ReLU(),\n",
    "            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2, padding=1),\n",
    "            nn.Sigmoid(),  # Ensures outputs are in the range [0, 1]\n",
    "        )\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        encoded = self.encoder(x)\n",
    "        decoded = self.decoder(encoded)\n",
    "        return decoded\n",
    "\n",
    "        return x\n",
    "\n",
    "    def forward_encoder(self, x):\n",
    "\n",
    "        encoded = self.encoder(x)\n",
    "\n",
    "        return encoded\n",
    "\n",
    "# Training function\n",
    "def train(model, train_loader, criterion, optimizer, num_epochs=10, online = False):\n",
    "    model.train()\n",
    "    for epoch in range(num_epochs):\n",
    "        running_loss = 0.0\n",
    "        if not online:\n",
    "            for images, _ in train_loader:\n",
    "                transformer = torchvision.transforms.Resize((32,32))\n",
    "                images = transformer(images)\n",
    "                noisy_images = images #+ torch.randn(images.size()) * 0.1  # Adding Gaussian noise\n",
    "                optimizer.zero_grad()\n",
    "                outputs = model(noisy_images)\n",
    "                loss = criterion(outputs, images)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()\n",
    "        else:\n",
    "            for images in train_loader:\n",
    "                transformer = torchvision.transforms.Resize((32,32))\n",
    "                images = transformer(images)\n",
    "                noisy_images = images #+ torch.randn(images.size()) * 0.1  # Adding Gaussian noise\n",
    "                optimizer.zero_grad()\n",
    "                outputs = model(noisy_images)\n",
    "                loss = criterion(outputs, images)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()\n",
    "        print(f\"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader)}\")\n",
    "\n",
    "# Data preprocessing\n",
    "transform = torchvision.transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "])\n",
    "\n",
    "# Initialize the model, criterion, and optimizer\n",
    "model = ConvAutoencoder()\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "o1qJVzGViPtF"
   },
   "outputs": [],
   "source": [
    "# model_Conv_AE_OOD_Cifar10_Color_train = 'classifier.pt'\n",
    "# path = \"/content/gdrive/My Drive/{model_Conv_AE_OOD_Cifar10_Color_train}\"\n",
    "# model.load_state_dict(torch.load(path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 48641,
     "status": "ok",
     "timestamp": 1693432544406,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "6NftBdp9Nd5U",
    "outputId": "5319112d-06ba-406f-ddcb-5f9847c4236a"
   },
   "outputs": [],
   "source": [
    "# Load CIFAR-100 data\n",
    "train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)\n",
    "\n",
    "# Training the model\n",
    "train(model, train_loader, criterion, optimizer, num_epochs=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 17468,
     "status": "ok",
     "timestamp": 1693432561866,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "-Lw_PhMiNd7J",
    "outputId": "eda4e971-063b-4b01-af6e-9f40823dfc7f"
   },
   "outputs": [],
   "source": [
    "train_dataset_Cifar100 = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)\n",
    "train_loader_Cifar100 = torch.utils.data.DataLoader(train_dataset_Cifar100, batch_size = 50000, shuffle=True)\n",
    "\n",
    "for batch in train_loader_Cifar100:\n",
    "    img, labels = batch\n",
    "    #img = torchvision.transforms.Grayscale(num_output_channels=1)(img)\n",
    "    #img = img.reshape(-1, 28*28)\n",
    "\n",
    "    # Generating output\n",
    "    out = model.forward(img + torch.randn(img.size()) * 0.0)\n",
    "    latent_train = model.forward_encoder(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "X_train = img.detach().numpy()\n",
    "latent_train = latent_train.detach().numpy()\n",
    "y = labels.detach().numpy()\n",
    "print(X_train.shape)\n",
    "print(latent_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 26,
     "status": "ok",
     "timestamp": 1693432561867,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "pjGzdfeLOA-b",
    "outputId": "92924ac4-5957-4f8c-b68a-40b64cc1f14d"
   },
   "outputs": [],
   "source": [
    "X_train = X_train.reshape(-1,3*32*32)\n",
    "latent_train = latent_train.reshape(-1,64*4*4)\n",
    "print(X_train.shape)\n",
    "print(latent_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 4572,
     "status": "ok",
     "timestamp": 1693432566420,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "GKqaLHKjOEJS",
    "outputId": "e24df072-9469-43b9-f3e1-8bcd0a199bbb"
   },
   "outputs": [],
   "source": [
    "print(np.mean(np.cov(X_train.T)))\n",
    "print(np.mean(np.cov(latent_train.T)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eF3UaAbigMX_"
   },
   "outputs": [],
   "source": [
    "# y = y[:2000]\n",
    "# latent_train = latent_train[:2000,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HHvI8fDdgy2B"
   },
   "outputs": [],
   "source": [
    "# np.random.shuffle(y)\n",
    "# print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_rh0xMVs5ZlY"
   },
   "source": [
    "# Tree Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 83262,
     "status": "ok",
     "timestamp": 1693432649659,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "LbEIKNxYOOW8",
    "outputId": "4591fa13-53a4-4cb5-8cd9-6c17ce43c293"
   },
   "outputs": [],
   "source": [
    "et = ExtraTreesClassifier(n_estimators=500, min_samples_leaf=100,\n",
    "                          max_features=\"sqrt\", bootstrap=True, class_weight='balanced', n_jobs=-1)\n",
    "\n",
    "# et = RandomForestClassifier(n_estimators=500, min_samples_leaf=100,\n",
    "#                           max_features=\"sqrt\", bootstrap=True, class_weight='balanced', n_jobs=-1)\n",
    "\n",
    "# validation instance\n",
    "skf = StratifiedKFold(n_splits=5, shuffle=True)\n",
    "\n",
    "# getting the model validation predictions\n",
    "preds = cross_val_predict(et, latent_train, y, cv=skf, method='predict_proba')\n",
    "\n",
    "# evaluating the model\n",
    "print('Area under the ROC Curve:', roc_auc_score(y, preds, multi_class='ovo'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 92
    },
    "executionInfo": {
     "elapsed": 19941,
     "status": "ok",
     "timestamp": 1693432669583,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "Yl8zBDu1OOZS",
    "outputId": "3f55e6be-d6b8-43c4-deab-38495b3702d4"
   },
   "outputs": [],
   "source": [
    "et.fit(latent_train,y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12048,
     "status": "ok",
     "timestamp": 1693432681626,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "rhPRj8x9OObV",
    "outputId": "85d9ed95-7209-4085-c6c5-066d6551385e"
   },
   "outputs": [],
   "source": [
    "leaves_train = et.apply(latent_train)\n",
    "print(leaves_train.shape)\n",
    "print(leaves_train)\n",
    "\n",
    "distances_train = np.zeros((1000,1000))\n",
    "\n",
    "\n",
    "for i in range(1000):\n",
    "    for j in range(1000):\n",
    "        distances_train[i,j] = hamming(leaves_train[i,:], leaves_train[j,:])\n",
    "\n",
    "score_train = sum(distances_train)/999\n",
    "\n",
    "print(np.mean(score_train))\n",
    "print(np.cov(score_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ABZtYxACjjJa"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yoIPrBqFqXT8"
   },
   "outputs": [],
   "source": [
    "model_Conv_AE_OOD_Cifar100_Color_train_30 = 'classifier.pt'\n",
    "path = \"/content/gdrive/My Drive/{model_Conv_AE_OOD_Cifar100_Color_train_30}\"\n",
    "torch.save(model.state_dict(), path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1693432682357,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "xjEYJHlhjjfF",
    "outputId": "754aff11-2b4f-464c-f15d-db88e0903de6"
   },
   "outputs": [],
   "source": [
    "model_Conv_AE_OOD_Cifar100_Color_train_30 = 'classifier.pt'\n",
    "path = \"/content/gdrive/My Drive/{model_Conv_AE_OOD_Cifar100_Color_train_30}\"\n",
    "model.load_state_dict(torch.load(path))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "K534l0C15dOs"
   },
   "source": [
    "# Testing on ID Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VWwIw7NDjATd"
   },
   "outputs": [],
   "source": [
    "num_epoch = 30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 104155,
     "status": "ok",
     "timestamp": 1693432787129,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "qgiK4zhzOOdJ",
    "outputId": "832a458e-a3bb-4840-ab76-8847fda2e04d"
   },
   "outputs": [],
   "source": [
    "test_dataset_Cifar100 = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)\n",
    "test_loader_Cifar100 = torch.utils.data.DataLoader(test_dataset_Cifar100, batch_size = 64, shuffle=True)\n",
    "\n",
    "# model = ConvAutoencoder()\n",
    "# criterion = nn.MSELoss()\n",
    "# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "train(model, test_loader_Cifar100, criterion, optimizer, num_epochs=num_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 4786,
     "status": "ok",
     "timestamp": 1693432791891,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "of_ZbBltOOfR",
    "outputId": "6e39b733-61fc-4901-c8ca-ef1e042b2ce7"
   },
   "outputs": [],
   "source": [
    "test_dataset_Cifar100 = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)\n",
    "test_loader_Cifar100 = torch.utils.data.DataLoader(test_dataset_Cifar100, batch_size = 10000, shuffle=True)\n",
    "\n",
    "for batch in test_loader_Cifar100:\n",
    "    img, _ = batch\n",
    "    #img = torchvision.transforms.Grayscale(num_output_channels=1)(img)\n",
    "    #img = img.reshape(-1, 28*28)\n",
    "\n",
    "    # Generating output\n",
    "    ## out = model.forward(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "    out = model.forward(img)\n",
    "    img.requires_grad = True\n",
    "    loss = criterion(out, img)\n",
    "    model.zero_grad()\n",
    "    loss.backward()\n",
    "    gradient = img.grad.data.sign()\n",
    "    img = img + 0.1 * gradient\n",
    "    latent_test = model.forward_encoder(img)\n",
    "\n",
    "    ## latent_test = model.forward_encoder(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "\n",
    "latent_test = latent_test.detach().numpy()\n",
    "X_test = img.detach().numpy()\n",
    "print(X_test.shape)\n",
    "print(latent_test.shape)\n",
    "X_test = X_test.reshape(-1,3*32*32)\n",
    "latent_test = latent_test.reshape(-1,64*4*4)\n",
    "print(X_test.shape)\n",
    "print(latent_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12020,
     "status": "ok",
     "timestamp": 1693432803887,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "0IE7QW4RTX1r",
    "outputId": "848e8fe3-de96-458a-adbb-889a7cd08adc"
   },
   "outputs": [],
   "source": [
    "latent_test_in = latent_test\n",
    "\n",
    "leaves_test_in = et.apply(latent_test_in)\n",
    "print(leaves_test_in.shape)\n",
    "print(leaves_test_in)\n",
    "\n",
    "distances_test_in = np.zeros((1000,1000))\n",
    "for i in range(1000):\n",
    "    for j in range(1000):\n",
    "        distances_test_in[i,j] = hamming(leaves_test_in[i,:], leaves_test_in[j,:])\n",
    "\n",
    "score_test_in = sum(distances_test_in)/999\n",
    "\n",
    "print(np.mean(score_test_in))\n",
    "print(np.cov(score_test_in))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xNTHQMtVT8Wb"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0dOaVz-G5h-n"
   },
   "source": [
    "# Testing on OOD Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "R0dTYN6B5k2t"
   },
   "source": [
    "## SVHN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 26,
     "status": "ok",
     "timestamp": 1693432803888,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "hTQoUtj_jmGO",
    "outputId": "7231a971-a9fd-4711-b42f-66f13c8c5ba8"
   },
   "outputs": [],
   "source": [
    "model_Conv_AE_OOD_Cifar100_Color_train_30 = 'classifier.pt'\n",
    "path = \"/content/gdrive/My Drive/{model_Conv_AE_OOD_Cifar100_Color_train_30}\"\n",
    "model.load_state_dict(torch.load(path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 238664,
     "status": "ok",
     "timestamp": 1693433042530,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "49wqpHD5T8aU",
    "outputId": "1900994b-a6f7-4bd7-b8b7-c4517f2e952d"
   },
   "outputs": [],
   "source": [
    "transform = torchvision.transforms.Compose(\n",
    "    [torchvision.transforms.Resize((32, 32)),  # Resize images to match the size of MNIST\n",
    "     torchvision.transforms.ToTensor()])\n",
    "\n",
    "train_dataset_SVHN = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)\n",
    "train_loader_SVHN = torch.utils.data.DataLoader(train_dataset_SVHN, batch_size=64, shuffle=True, num_workers=4)\n",
    "\n",
    "# # Training the model\n",
    "# model = ConvAutoencoder()\n",
    "# criterion = nn.MSELoss()\n",
    "# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "train(model, train_loader_SVHN, criterion, optimizer, num_epochs=num_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 11164,
     "status": "ok",
     "timestamp": 1693433053692,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "bj5nGoXgT8cr",
    "outputId": "37056150-78e7-4935-d8b4-7edd4a8b87f5"
   },
   "outputs": [],
   "source": [
    "test_dataset_SVHN = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)\n",
    "test_loader_SVHN = torch.utils.data.DataLoader(test_dataset_SVHN, batch_size=10000, shuffle=True)\n",
    "\n",
    "for batch in test_loader_SVHN:\n",
    "    img, _ = batch\n",
    "    #img = torchvision.transforms.Grayscale(num_output_channels=1)(img)\n",
    "\n",
    "    # Generating output\n",
    "    ## out = model.forward(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "    out = model.forward(img)\n",
    "    img.requires_grad = True\n",
    "    loss = criterion(out, img)\n",
    "    model.zero_grad()\n",
    "    loss.backward()\n",
    "    gradient = img.grad.data.sign()\n",
    "    img = img + 0.1 * gradient\n",
    "    latent_test_out = model.forward_encoder(img)\n",
    "\n",
    "    ## latent_test_out = model.forward_encoder(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "latent_test_out_SVHN = latent_test_out.detach().numpy()\n",
    "latent_test_out_SVHN = latent_test_out_SVHN.reshape(-1,64*4*4)\n",
    "print(latent_test_out_SVHN.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12596,
     "status": "ok",
     "timestamp": 1693433066263,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "X6-qPSDtT8ee",
    "outputId": "9532869b-f2e9-41b2-8f55-554882d25a26"
   },
   "outputs": [],
   "source": [
    "leaves_test_out_SVHN = et.apply(latent_test_out_SVHN)\n",
    "\n",
    "print(leaves_test_out_SVHN.shape)\n",
    "print(leaves_test_out_SVHN)\n",
    "\n",
    "distances_test_out_SVHN = np.zeros((1000,1000))\n",
    "for i in range(1000):\n",
    "    for j in range(1000):\n",
    "        distances_test_out_SVHN[i,j] = hamming(leaves_test_out_SVHN[i,:], leaves_test_out_SVHN[j,:])\n",
    "\n",
    "\n",
    "score_test_out_SVHN = sum(distances_test_out_SVHN)/999\n",
    "\n",
    "print(np.mean(score_test_out_SVHN))\n",
    "print(np.cov(score_test_out_SVHN))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9sHeycAMT8i8"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "T3e1IFik5vo6"
   },
   "source": [
    "## DTD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 32,
     "status": "ok",
     "timestamp": 1693433066268,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "AY1XUvLjj_MS",
    "outputId": "1fccedd7-95fb-4e77-bc2e-7cba4456411a"
   },
   "outputs": [],
   "source": [
    "model_Conv_AE_OOD_Cifar100_Color_train_30 = 'classifier.pt'\n",
    "path = \"/content/gdrive/My Drive/{model_Conv_AE_OOD_Cifar100_Color_train_30}\"\n",
    "model.load_state_dict(torch.load(path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 64911,
     "status": "ok",
     "timestamp": 1693433131152,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "zrwoZF6Vx2qu",
    "outputId": "05165374-3fa4-4630-953e-3e16d79aba94"
   },
   "outputs": [],
   "source": [
    "# fine tuning the AE\n",
    "transform = torchvision.transforms.Compose(\n",
    "    [torchvision.transforms.Resize((32, 32)),  # Resize images to match the size of MNIST\n",
    "     torchvision.transforms.ToTensor()])\n",
    "\n",
    "train_dataset_DTD = torchvision.datasets.DTD(root='./data', split='test', download=True, transform=transform)\n",
    "train_loader_DTD = torch.utils.data.DataLoader(train_dataset_DTD, batch_size=64, shuffle=True, num_workers=4)\n",
    "\n",
    "# # Training the model\n",
    "# model = ConvAutoencoder()\n",
    "# criterion = nn.MSELoss()\n",
    "# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "train(model, train_loader_DTD, criterion, optimizer, num_epochs=num_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5076,
     "status": "ok",
     "timestamp": 1693433136203,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "ULj-NLe0x-aU",
    "outputId": "ff05988e-5bfb-45c7-a49b-37fa4c580af9"
   },
   "outputs": [],
   "source": [
    "test_dataset_DTD = torchvision.datasets.DTD(root='./data', split='test', download=True, transform=transform)\n",
    "test_loader_DTD = torch.utils.data.DataLoader(test_dataset_DTD, batch_size=10000, shuffle=True)\n",
    "\n",
    "for batch in test_loader_DTD:\n",
    "    img, _ = batch\n",
    "    #img = torchvision.transforms.Grayscale(num_output_channels=1)(img)\n",
    "\n",
    "    # Generating output\n",
    "    ## out = model.forward(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "    out = model.forward(img)\n",
    "    img.requires_grad = True\n",
    "    loss = criterion(out, img)\n",
    "    model.zero_grad()\n",
    "    loss.backward()\n",
    "    gradient = img.grad.data.sign()\n",
    "    img = img + 0.1 * gradient\n",
    "    latent_test_out = model.forward_encoder(img)\n",
    "\n",
    "    ## latent_test_out = model.forward_encoder(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "\n",
    "latent_test_out_DTD = latent_test_out.detach().numpy()\n",
    "latent_test_out_DTD = latent_test_out_DTD.reshape(-1,64*4*4)\n",
    "print(latent_test_out_DTD.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 11352,
     "status": "ok",
     "timestamp": 1693433147532,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "ucyfkSd0x-cY",
    "outputId": "15f41c0e-ca52-4e93-f2fc-2b05754d2ec5"
   },
   "outputs": [],
   "source": [
    "leaves_test_out_DTD = et.apply(latent_test_out_DTD)\n",
    "\n",
    "print(leaves_test_out_DTD.shape)\n",
    "print(leaves_test_out_DTD)\n",
    "\n",
    "distances_test_out_DTD = np.zeros((1000,1000))\n",
    "for i in range(1000):\n",
    "    for j in range(1000):\n",
    "        distances_test_out_DTD[i,j] = hamming(leaves_test_out_DTD[i,:], leaves_test_out_DTD[j,:])\n",
    "\n",
    "\n",
    "score_test_out_DTD = sum(distances_test_out_DTD)/999\n",
    "\n",
    "print(np.mean(score_test_out_DTD))\n",
    "print(np.cov(score_test_out_DTD))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lUTb_pX158SO"
   },
   "source": [
    "## Places365"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 23,
     "status": "ok",
     "timestamp": 1693433147533,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "GXM0Kh4D5HmY",
    "outputId": "3eb7b22c-0337-4fc0-b49c-15e368058f39"
   },
   "outputs": [],
   "source": [
    "model_Conv_AE_OOD_Cifar100_Color_train_30 = 'classifier.pt'\n",
    "path = \"/content/gdrive/My Drive/{model_Conv_AE_OOD_Cifar100_Color_train_30}\"\n",
    "model.load_state_dict(torch.load(path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 409894,
     "status": "ok",
     "timestamp": 1693433557406,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "O_IR7dGy5-TY",
    "outputId": "2cd8454d-432f-41ba-b139-cc6ed9495b15"
   },
   "outputs": [],
   "source": [
    "# fine tuning the AE\n",
    "transform = torchvision.transforms.Compose(\n",
    "    [torchvision.transforms.Resize((32, 32)),  # Resize images to match the size of MNIST\n",
    "     torchvision.transforms.ToTensor()])\n",
    "\n",
    "train_dataset_Places365 = torchvision.datasets.Places365(root='./data', split='val', small=True, download=False, transform=transform)\n",
    "train_loader_Places365 = torch.utils.data.DataLoader(train_dataset_Places365, batch_size=64, shuffle=True, num_workers=4)\n",
    "\n",
    "# # Training the model\n",
    "# model = ConvAutoencoder()\n",
    "# criterion = nn.MSELoss()\n",
    "# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "train(model, train_loader_Places365, criterion, optimizer, num_epochs=num_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 40706,
     "status": "ok",
     "timestamp": 1693433598088,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "SZNDQf_s5-Vd",
    "outputId": "237c51e5-95a3-40cb-92b3-d13e87d044a8"
   },
   "outputs": [],
   "source": [
    "test_dataset_Places365 = torchvision.datasets.Places365(root='./data', split='val', small=True, download=False, transform=transform)\n",
    "test_loader_Places365 = torch.utils.data.DataLoader(test_dataset_Places365, batch_size=10000, shuffle=True)\n",
    "\n",
    "for batch in test_loader_Places365:\n",
    "    img, _ = batch\n",
    "    #img = torchvision.transforms.Grayscale(num_output_channels=1)(img)\n",
    "    #img = img.reshape(-1,1,28,28)\n",
    "\n",
    "    # Generating output\n",
    "    ## out = model.forward(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "    out = model.forward(img)\n",
    "    img.requires_grad = True\n",
    "    loss = criterion(out, img)\n",
    "    model.zero_grad()\n",
    "    loss.backward()\n",
    "    gradient = img.grad.data.sign()\n",
    "    img = img + 0.1 * gradient\n",
    "    latent_test_out = model.forward_encoder(img)\n",
    "\n",
    "    ## latent_test_out = model.forward_encoder(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "\n",
    "latent_test_out_Places365 = latent_test_out.detach().numpy()\n",
    "latent_test_out_Places365 = latent_test_out_Places365.reshape(-1,64*4*4)\n",
    "print(latent_test_out_Places365.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 11639,
     "status": "ok",
     "timestamp": 1693433609702,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "BpWZam8J5-Y2",
    "outputId": "779008eb-5051-43c5-a2a1-0bf324bfd1f1"
   },
   "outputs": [],
   "source": [
    "leaves_test_out_Places365 = et.apply(latent_test_out_Places365)\n",
    "\n",
    "print(leaves_test_out_Places365.shape)\n",
    "print(leaves_test_out_Places365)\n",
    "\n",
    "distances_test_out_Places365 = np.zeros((1000,1000))\n",
    "for i in range(1000):\n",
    "    for j in range(1000):\n",
    "        distances_test_out_Places365[i,j] = hamming(leaves_test_out_Places365[i,:], leaves_test_out_Places365[j,:])\n",
    "\n",
    "\n",
    "score_test_out_Places365 = sum(distances_test_out_Places365)/999\n",
    "\n",
    "print(np.mean(score_test_out_Places365))\n",
    "print(np.cov(score_test_out_Places365))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kh8zbKA2ij_b"
   },
   "source": [
    "## iSUN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 29,
     "status": "ok",
     "timestamp": 1693433609703,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "RwawMPGc5-ar",
    "outputId": "4437e414-6b7e-4909-9641-426d286e6a6a"
   },
   "outputs": [],
   "source": [
    "model_Conv_AE_OOD_Cifar100_Color_train_30 = 'classifier.pt'\n",
    "path = \"/content/gdrive/My Drive/{model_Conv_AE_OOD_Cifar100_Color_train_30}\"\n",
    "model.load_state_dict(torch.load(path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 55534,
     "status": "ok",
     "timestamp": 1693433665216,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "UZAefg-pilUm",
    "outputId": "67f6b2eb-42a2-4d1a-a357-51125b34b521"
   },
   "outputs": [],
   "source": [
    "# images = torch.zeros((8924,3,32,32))\n",
    "\n",
    "# for i in range(8924):\n",
    "#     print(i)\n",
    "#     images[i,:,:,:] = read_image('/content/gdrive/MyDrive/OODdata/iSUN/iSUN_patches/' + str(i) + '.jpeg')\n",
    "\n",
    "# torch.save(images, '/content/gdrive/MyDrive/images_iSUN.t')\n",
    "images = torch.load('/content/gdrive/MyDrive/TOOD/datasets/images_iSUN.t')\n",
    "\n",
    "test_dataset_iSUN = images/255\n",
    "train_loader_iSUN = torch.utils.data.DataLoader(test_dataset_iSUN, batch_size = 64, shuffle=True)\n",
    "\n",
    "# # Training the model\n",
    "# model = ConvAutoencoder()\n",
    "# criterion = nn.MSELoss()\n",
    "# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "train(model, train_loader_iSUN, criterion, optimizer, num_epochs=num_epoch, online=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2255,
     "status": "ok",
     "timestamp": 1693433667455,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "rRLkpl8JilWr",
    "outputId": "aec3998c-b549-43e0-b07e-5964faa10696"
   },
   "outputs": [],
   "source": [
    "test_loader_iSUN = torch.utils.data.DataLoader(test_dataset_iSUN, batch_size = 10000, shuffle=True)\n",
    "\n",
    "for batch in test_loader_iSUN:\n",
    "    img = batch\n",
    "    transforms = torchvision.transforms.Resize((32,32))\n",
    "    #img = torchvision.transforms.Grayscale(num_output_channels=1)(img)\n",
    "    #img = img.reshape(-1, 28*28)\n",
    "\n",
    "    # Generating output\n",
    "    img = transforms(img)\n",
    "    ## out = model.forward(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "    out = model.forward(img)\n",
    "    img.requires_grad = True\n",
    "    loss = criterion(out, img)\n",
    "    model.zero_grad()\n",
    "    loss.backward()\n",
    "    gradient = img.grad.data.sign()\n",
    "    img = img + 0.1 * gradient\n",
    "    latent_test_out = model.forward_encoder(img)\n",
    "\n",
    "    ## latent_test_out = model.forward_encoder(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "\n",
    "latent_test_out_iSUN = latent_test_out.detach().numpy()\n",
    "print(latent_test_out_iSUN.shape)\n",
    "latent_test_out_iSUN = latent_test_out_iSUN.reshape(-1,64*4*4)\n",
    "print(latent_test_out_iSUN.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 11736,
     "status": "ok",
     "timestamp": 1693433679189,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "m93p06uEilYw",
    "outputId": "40b7acf0-adb5-48f2-ba5a-4f7595f9cf43"
   },
   "outputs": [],
   "source": [
    "leaves_test_out_iSUN = et.apply(latent_test_out_iSUN)\n",
    "\n",
    "print(leaves_test_out_iSUN.shape)\n",
    "print(leaves_test_out_iSUN)\n",
    "\n",
    "distances_test_out_iSUN = np.zeros((1000,1000))\n",
    "for i in range(1000):\n",
    "    for j in range(1000):\n",
    "        distances_test_out_iSUN[i,j] = hamming(leaves_test_out_iSUN[i,:], leaves_test_out_iSUN[j,:])\n",
    "\n",
    "\n",
    "score_test_out_iSUN = sum(distances_test_out_iSUN)/999\n",
    "\n",
    "print(np.mean(score_test_out_iSUN))\n",
    "print(np.cov(score_test_out_iSUN))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PXpiSYbGilai"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "w0A_R-VPme3h"
   },
   "source": [
    "## LSUN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 20,
     "status": "ok",
     "timestamp": 1693433679190,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "beGos-wBoFsZ",
    "outputId": "86d30d5d-e260-4c8f-b371-3683534edfb1"
   },
   "outputs": [],
   "source": [
    "model_Conv_AE_OOD_Cifar100_Color_train_30 = 'classifier.pt'\n",
    "path = \"/content/gdrive/My Drive/{model_Conv_AE_OOD_Cifar100_Color_train_30}\"\n",
    "model.load_state_dict(torch.load(path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 62786,
     "status": "ok",
     "timestamp": 1693433741960,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "tkB7jZNnmgsW",
    "outputId": "7fe64d94-1534-46b3-8dfa-196a6f4d7dbc"
   },
   "outputs": [],
   "source": [
    "# images = torch.zeros((10000,3,36,36))\n",
    "\n",
    "# for i in range(10000):\n",
    "#     print(i)\n",
    "#     images[i,:,:,:] = read_image('/content/gdrive/MyDrive/OODdata/LSUN/test/' + str(i) + '.png')\n",
    "\n",
    "# torch.save(images, '/content/gdrive/MyDrive/images_LSUN.t')\n",
    "images = torch.load('/content/gdrive/MyDrive/TOOD/datasets/images_LSUN.t')\n",
    "\n",
    "test_dataset_LSUN = images/255\n",
    "train_loader_LSUN = torch.utils.data.DataLoader(test_dataset_LSUN, batch_size = 64, shuffle=True)\n",
    "\n",
    "# # Training the model\n",
    "# model = ConvAutoencoder()\n",
    "# criterion = nn.MSELoss()\n",
    "# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "train(model, train_loader_LSUN, criterion, optimizer, num_epochs=num_epoch, online=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2515,
     "status": "ok",
     "timestamp": 1693433744450,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "QxNClsvSmguZ",
    "outputId": "416042cf-755b-4118-84d2-4fda67747ea9"
   },
   "outputs": [],
   "source": [
    "test_loader_LSUN = torch.utils.data.DataLoader(test_dataset_LSUN, batch_size = 10000, shuffle=True)\n",
    "\n",
    "for batch in test_loader_LSUN:\n",
    "    img = batch\n",
    "    transforms = torchvision.transforms.Resize((32,32))\n",
    "    #img = torchvision.transforms.Grayscale(num_output_channels=1)(img)\n",
    "    #img = img.reshape(-1, 28*28)\n",
    "\n",
    "    # Generating output\n",
    "    img = transforms(img)\n",
    "    ## out = model.forward(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "    out = model.forward(img)\n",
    "    img.requires_grad = True\n",
    "    loss = criterion(out, img)\n",
    "    model.zero_grad()\n",
    "    loss.backward()\n",
    "    gradient = img.grad.data.sign()\n",
    "    img = img + 0.1 * gradient\n",
    "    latent_test_out = model.forward_encoder(img)\n",
    "\n",
    "    ## latent_test_out = model.forward_encoder(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "\n",
    "latent_test_out_LSUN = latent_test_out.detach().numpy()\n",
    "print(latent_test_out_LSUN.shape)\n",
    "latent_test_out_LSUN = latent_test_out_LSUN.reshape(-1,64*4*4)\n",
    "print(latent_test_out_LSUN.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12568,
     "status": "ok",
     "timestamp": 1693433757014,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "c2pqFpRqmgwe",
    "outputId": "b0c65357-696a-4bfa-a372-f5c01dbdcc57"
   },
   "outputs": [],
   "source": [
    "leaves_test_out_LSUN = et.apply(latent_test_out_LSUN)\n",
    "\n",
    "print(leaves_test_out_LSUN.shape)\n",
    "print(leaves_test_out_LSUN)\n",
    "\n",
    "distances_test_out_LSUN = np.zeros((1000,1000))\n",
    "for i in range(1000):\n",
    "    for j in range(1000):\n",
    "        distances_test_out_LSUN[i,j] = hamming(leaves_test_out_LSUN[i,:], leaves_test_out_LSUN[j,:])\n",
    "\n",
    "\n",
    "score_test_out_LSUN = sum(distances_test_out_LSUN)/999\n",
    "\n",
    "print(np.mean(score_test_out_LSUN))\n",
    "print(np.cov(score_test_out_LSUN))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kKC5XvehtTcI"
   },
   "source": [
    "## LSUN-resize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 28,
     "status": "ok",
     "timestamp": 1693433757014,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "XFQoVRZDmgyT",
    "outputId": "6ce5acf4-6176-459f-8f8b-a1c623e24dea"
   },
   "outputs": [],
   "source": [
    "model_Conv_AE_OOD_Cifar100_Color_train_30 = 'classifier.pt'\n",
    "path = \"/content/gdrive/My Drive/{model_Conv_AE_OOD_Cifar100_Color_train_30}\"\n",
    "model.load_state_dict(torch.load(path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 60842,
     "status": "ok",
     "timestamp": 1693433817835,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "C0J47UdQmg0q",
    "outputId": "04e3bfe7-80eb-40cf-c7f6-4e0a9039870a"
   },
   "outputs": [],
   "source": [
    "# images = torch.zeros((10000,3,32,32))\n",
    "\n",
    "# for i in range(10000):\n",
    "#     print(i)\n",
    "#     images[i,:,:,:] = read_image('/content/gdrive/MyDrive/OODdata/LSUN_resize/test/' + str(i) + '.jpg')\n",
    "\n",
    "# torch.save(images, '/content/gdrive/MyDrive/images_LSUN_resize.t')\n",
    "images = torch.load('/content/gdrive/MyDrive/TOOD/datasets/images_LSUN_resize.t')\n",
    "\n",
    "test_dataset_LSUN_resize = images/255\n",
    "train_loader_LSUN_resize = torch.utils.data.DataLoader(test_dataset_LSUN_resize, batch_size = 64, shuffle=True)\n",
    "\n",
    "# # Training the model\n",
    "# model = ConvAutoencoder()\n",
    "# criterion = nn.MSELoss()\n",
    "# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "train(model, train_loader_LSUN_resize, criterion, optimizer, num_epochs=num_epoch, online=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2328,
     "status": "ok",
     "timestamp": 1693433820137,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "XseJ5T0itYjt",
    "outputId": "c16e02ed-7d1e-4f0f-c261-c8da7b3de15b"
   },
   "outputs": [],
   "source": [
    "test_loader_LSUN_resize = torch.utils.data.DataLoader(test_dataset_LSUN_resize, batch_size = 10000, shuffle=True)\n",
    "\n",
    "for batch in test_loader_LSUN_resize:\n",
    "    img = batch\n",
    "    transforms = torchvision.transforms.Resize((32,32))\n",
    "    #img = torchvision.transforms.Grayscale(num_output_channels=1)(img)\n",
    "    #img = img.reshape(-1, 28*28)\n",
    "\n",
    "    # Generating output\n",
    "    img = transforms(img)\n",
    "    ## out = model.forward(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "    out = model.forward(img)\n",
    "    img.requires_grad = True\n",
    "    loss = criterion(out, img)\n",
    "    model.zero_grad()\n",
    "    loss.backward()\n",
    "    gradient = img.grad.data.sign()\n",
    "    img = img + 0.1 * gradient\n",
    "    latent_test_out = model.forward_encoder(img)\n",
    "\n",
    "    ## latent_test_out = model.forward_encoder(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "\n",
    "latent_test_out_LSUN_resize = latent_test_out.detach().numpy()\n",
    "print(latent_test_out_LSUN_resize.shape)\n",
    "latent_test_out_LSUN_resize = latent_test_out_LSUN_resize.reshape(-1,64*4*4)\n",
    "print(latent_test_out_LSUN_resize.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 11562,
     "status": "ok",
     "timestamp": 1693433831695,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "FFTCoWbltYmS",
    "outputId": "f8c5e0f4-a79a-4907-e3d2-3fab1ed280c3"
   },
   "outputs": [],
   "source": [
    "leaves_test_out_LSUN_resize = et.apply(latent_test_out_LSUN_resize)\n",
    "\n",
    "print(leaves_test_out_LSUN_resize.shape)\n",
    "print(leaves_test_out_LSUN_resize)\n",
    "\n",
    "distances_test_out_LSUN_resize = np.zeros((1000,1000))\n",
    "for i in range(1000):\n",
    "    for j in range(1000):\n",
    "        distances_test_out_LSUN_resize[i,j] = hamming(leaves_test_out_LSUN_resize[i,:], leaves_test_out_LSUN_resize[j,:])\n",
    "\n",
    "\n",
    "score_test_out_LSUN_resize = sum(distances_test_out_LSUN_resize)/999\n",
    "\n",
    "print(np.mean(score_test_out_LSUN_resize))\n",
    "print(np.cov(score_test_out_LSUN_resize))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "F3ZeGXM-tYyk"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GHsCOBn651NY"
   },
   "source": [
    "## STL10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 30,
     "status": "ok",
     "timestamp": 1693433831699,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "O5o8VSMQkHcH",
    "outputId": "bbf2a3cb-7031-447e-f6a8-19ed1cd2d5b8"
   },
   "outputs": [],
   "source": [
    "model_Conv_AE_OOD_Cifar100_Color_train_30 = 'classifier.pt'\n",
    "path = \"/content/gdrive/My Drive/{model_Conv_AE_OOD_Cifar100_Color_train_30}\"\n",
    "model.load_state_dict(torch.load(path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 93241,
     "status": "ok",
     "timestamp": 1693433924916,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "U3d1TPkvzqzX",
    "outputId": "be06cf3f-1a91-42dd-fede-1c42b54847c3"
   },
   "outputs": [],
   "source": [
    "# fine tuning the AE\n",
    "transform = torchvision.transforms.Compose(\n",
    "    [torchvision.transforms.Resize((32, 32)),  # Resize images to match the size of MNIST\n",
    "     torchvision.transforms.ToTensor()])\n",
    "\n",
    "train_dataset_STL10 = torchvision.datasets.STL10(root='./data', split='test', download=True, transform=transform)\n",
    "train_loader_STL10 = torch.utils.data.DataLoader(train_dataset_STL10, batch_size=64, shuffle=True, num_workers=4)\n",
    "\n",
    "# # Training the model\n",
    "# model = ConvAutoencoder()\n",
    "# criterion = nn.MSELoss()\n",
    "# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "train(model, train_loader_STL10, criterion, optimizer, num_epochs=num_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 10119,
     "status": "ok",
     "timestamp": 1693433935010,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "4kZ96mhKzz7h",
    "outputId": "9b9d105f-4043-4e59-90e1-a0af9519b2ed"
   },
   "outputs": [],
   "source": [
    "test_dataset_STL10 = torchvision.datasets.STL10(root='./data', split='test', download=True, transform=transform)\n",
    "test_loader_STL10 = torch.utils.data.DataLoader(test_dataset_STL10, batch_size=10000, shuffle=True)\n",
    "\n",
    "for batch in test_loader_STL10:\n",
    "    img, _ = batch\n",
    "    #img = torchvision.transforms.Grayscale(num_output_channels=1)(img)\n",
    "\n",
    "    # Generating output\n",
    "    ## out = model.forward(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "    out = model.forward(img)\n",
    "    img.requires_grad = True\n",
    "    loss = criterion(out, img)\n",
    "    model.zero_grad()\n",
    "    loss.backward()\n",
    "    gradient = img.grad.data.sign()\n",
    "    img = img + 0.1 * gradient\n",
    "    latent_test_out = model.forward_encoder(img)\n",
    "\n",
    "    ## latent_test_out = model.forward_encoder(img + torch.randn(img.size()) * 0.0)\n",
    "\n",
    "latent_test_out_STL10 = latent_test_out.detach().numpy()\n",
    "latent_test_out_STL10 = latent_test_out_STL10.reshape(-1,64*4*4)\n",
    "print(latent_test_out_STL10.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 13036,
     "status": "ok",
     "timestamp": 1693433948021,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "FtaqPjfoz05e",
    "outputId": "43ce044b-92c2-4928-c373-22ba30dbeffe"
   },
   "outputs": [],
   "source": [
    "leaves_test_out_STL10 = et.apply(latent_test_out_STL10)\n",
    "\n",
    "print(leaves_test_out_STL10.shape)\n",
    "print(leaves_test_out_STL10)\n",
    "\n",
    "distances_test_out_STL10 = np.zeros((1000,1000))\n",
    "for i in range(1000):\n",
    "    for j in range(1000):\n",
    "        distances_test_out_STL10[i,j] = hamming(leaves_test_out_STL10[i,:], leaves_test_out_STL10[j,:])\n",
    "\n",
    "\n",
    "score_test_out_STL10 = sum(distances_test_out_STL10)/999\n",
    "\n",
    "print(np.mean(score_test_out_STL10))\n",
    "print(np.cov(score_test_out_STL10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MyUNjfibZOAR"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "N8mvccjO3KjF"
   },
   "source": [
    "## Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VSjc6SZP4lRv"
   },
   "outputs": [],
   "source": [
    "my_dict = {'Cifar100': score_test_in, 'SVHN': score_test_out_SVHN, 'DTD': score_test_out_DTD,\n",
    "           'Places365': score_test_out_Places365, 'iSUN': score_test_out_iSUN, 'LSUN': score_test_out_LSUN, \n",
    "           'LSUN-resize': score_test_out_LSUN_resize, 'STL10': score_test_out_STL10}\n",
    "\n",
    "plt.figure(figsize=(10,6))\n",
    "plt.boxplot(my_dict.values(), labels=my_dict.keys());\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 15,
     "status": "ok",
     "timestamp": 1693433948365,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "BxcQ41vk4qfi",
    "outputId": "374bedc7-d548-4dd4-91d4-02001278f456"
   },
   "outputs": [],
   "source": [
    "score_pred_STL10 = np.concatenate([score_test_in, score_test_out_STL10])\n",
    "score_pred_Places365 = np.concatenate([score_test_in, score_test_out_Places365])\n",
    "score_pred_iSUN = np.concatenate([score_test_in, score_test_out_iSUN])\n",
    "score_pred_LSUN = np.concatenate([score_test_in, score_test_out_LSUN])\n",
    "score_pred_LSUN_resize = np.concatenate([score_test_in, score_test_out_LSUN_resize])\n",
    "score_pred_SVHN = np.concatenate([score_test_in, score_test_out_SVHN])\n",
    "score_pred_DTD = np.concatenate([score_test_in, score_test_out_DTD])\n",
    "score_true = np.concatenate([np.ones(1000), np.zeros(1000)])\n",
    "\n",
    "\n",
    "print(roc_auc_score(score_true, score_pred_SVHN))\n",
    "print(roc_auc_score(score_true, score_pred_DTD))\n",
    "print(roc_auc_score(score_true, score_pred_Places365))\n",
    "print(roc_auc_score(score_true, score_pred_iSUN))\n",
    "print(roc_auc_score(score_true, score_pred_LSUN))\n",
    "print(roc_auc_score(score_true, score_pred_LSUN_resize))\n",
    "print(roc_auc_score(score_true, score_pred_STL10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1693433948365,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "g3S-yTu64qhs",
    "outputId": "efc1148b-c45a-4f77-edb3-3873c5ecec63"
   },
   "outputs": [],
   "source": [
    "precision_STL10, recall_STL10, thresholds_STL10 = precision_recall_curve(score_true, score_pred_STL10)\n",
    "precision_Places365, recall_Places365, thresholds_Places365 = precision_recall_curve(score_true, score_pred_Places365)\n",
    "precision_SVHN, recall_SVHN, thresholds_SVHN = precision_recall_curve(score_true, score_pred_SVHN)\n",
    "precision_DTD, recall_DTD, thresholds_DTD = precision_recall_curve(score_true, score_pred_DTD)\n",
    "precision_iSUN, recall_iSUN, thresholds_iSUN = precision_recall_curve(score_true, score_pred_iSUN)\n",
    "precision_LSUN, recall_LSUN, thresholds_LSUN = precision_recall_curve(score_true, score_pred_LSUN)\n",
    "precision_LSUN_resize, recall_LSUN_resize, thresholds_LSUN_resize = precision_recall_curve(score_true, score_pred_LSUN_resize)\n",
    "\n",
    "auc_precision_recall_STL10= auc(recall_STL10, precision_STL10)\n",
    "auc_precision_recall_Places365 = auc(recall_Places365, precision_Places365)\n",
    "auc_precision_recall_SVHN = auc(recall_SVHN, precision_SVHN)\n",
    "auc_precision_recall_DTD = auc(recall_DTD, precision_DTD)\n",
    "auc_precision_recall_iSUN = auc(recall_iSUN, precision_iSUN)\n",
    "auc_precision_recall_LSUN = auc(recall_LSUN, precision_LSUN)\n",
    "auc_precision_recall_LSUN_resize = auc(recall_LSUN_resize, precision_LSUN_resize)\n",
    "\n",
    "print(auc_precision_recall_SVHN)\n",
    "print(auc_precision_recall_DTD)\n",
    "print(auc_precision_recall_Places365)\n",
    "print(auc_precision_recall_iSUN)\n",
    "print(auc_precision_recall_LSUN)\n",
    "print(auc_precision_recall_LSUN_resize)\n",
    "print(auc_precision_recall_STL10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 11,
     "status": "ok",
     "timestamp": 1693433948365,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "AdNM5Cz422xB",
    "outputId": "609a9d60-5ef3-41df-b3e6-cf60515e9994"
   },
   "outputs": [],
   "source": [
    "def compute_fpr95(y_true, y_pred_probs):\n",
    "    fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true, y_pred_probs)\n",
    "    idx = np.abs(tpr - 0.95).argmin()\n",
    "    fpr95 = fpr[idx]\n",
    "    return fpr95\n",
    "\n",
    "# Example usage\n",
    "# Assuming you have y_true (true labels) and y_pred_probs (predicted probabilities)\n",
    "fpr95_score_SVHN = compute_fpr95(score_true, score_pred_SVHN)\n",
    "fpr95_score_DTD = compute_fpr95(score_true, score_pred_DTD)\n",
    "fpr95_score_Places365 = compute_fpr95(score_true, score_pred_Places365)\n",
    "fpr95_score_iSUN = compute_fpr95(score_true, score_pred_iSUN)\n",
    "fpr95_score_LSUN = compute_fpr95(score_true, score_pred_LSUN)\n",
    "fpr95_score_LSUN_resize = compute_fpr95(score_true, score_pred_LSUN_resize)\n",
    "fpr95_score_STL10 = compute_fpr95(score_true, score_pred_STL10)\n",
    "\n",
    "print(fpr95_score_SVHN)\n",
    "print(fpr95_score_DTD)\n",
    "print(fpr95_score_Places365)\n",
    "print(fpr95_score_iSUN)\n",
    "print(fpr95_score_LSUN)\n",
    "print(fpr95_score_LSUN_resize)\n",
    "print(fpr95_score_STL10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zw5MbvNa3wnX"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 10,
     "status": "ok",
     "timestamp": 1693433948366,
     "user": {
      "displayName": "Zhaiming Shen",
      "userId": "12760861740580065439"
     },
     "user_tz": 240
    },
    "id": "aJNC0_LgxAxp",
    "outputId": "c0580117-2f4c-40dd-c16f-1a54f498040c"
   },
   "outputs": [],
   "source": [
    "print(np.mean(score_test_in), np.cov(score_test_in))\n",
    "print(np.mean(score_test_out_SVHN), np.cov(score_test_out_SVHN))\n",
    "print(np.mean(score_test_out_DTD), np.cov(score_test_out_DTD))\n",
    "print(np.mean(score_test_out_Places365), np.cov(score_test_out_Places365))\n",
    "print(np.mean(score_test_out_iSUN), np.cov(score_test_out_iSUN))\n",
    "print(np.mean(score_test_out_LSUN), np.cov(score_test_out_LSUN))\n",
    "print(np.mean(score_test_out_LSUN_resize), np.cov(score_test_out_LSUN_resize))\n",
    "print(np.mean(score_test_out_STL10), np.cov(score_test_out_STL10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2YxcZ_xnEHPK"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "authorship_tag": "ABX9TyOkh1OHamLv1kK6d0hAGigV",
   "machine_shape": "hm",
   "provenance": [
    {
     "file_id": "1tzM9ZtuO1WWf43J4EFd3QOi5PwN6Wocq",
     "timestamp": 1692716903705
    },
    {
     "file_id": "1D8EBZnGmbgBZITfM1GANy_77Ve0WxLY6",
     "timestamp": 1690814383283
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
