{
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "STOLEN = False\n",
        "VICTIM= True\n",
        "OLD = False\n",
        "CKPT = \"\"\n",
        "USING_SUBSET= False\n",
        "SHUFFLE=True\n",
        "INDEX_PATH = \"\"\n",
        "PROPORTION=5000\n",
        "STANDARD=True\n",
        "NORMALIZE=True\n",
        "GMM_COMPONENTS=10\n",
        "DATASET = \"CIFAR10\"\n",
        "BATCH_SIZE = 100"
      ],
      "metadata": {
        "id": "rXBYZs05fBEk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YUemQib7ZE4D"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import sys\n",
        "import numpy as np\n",
        "import os\n",
        "import yaml\n",
        "import matplotlib.pyplot as plt\n",
        "import torchvision\n",
        "import torch.nn as nn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3_nypQVEv-hn"
      },
      "outputs": [],
      "source": [
        "from torch.utils.data import DataLoader\n",
        "import torchvision.transforms as transforms\n",
        "from torchvision import datasets"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch.distributed as dist\n",
        "import torch.backends.cudnn as cudnn\n",
        "from torchvision import datasets\n",
        "from torchvision import transforms as pth_transforms\n",
        "from torchvision import models as torchvision_models\n",
        "\n",
        "import utils as utils\n",
        "import vision_transformer as vits\n"
      ],
      "metadata": {
        "id": "9CrASQCGMtVM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lDfbL3w_Z0Od",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "7956cc03-56e9-4ae9-dc85-ce3d82ecb803"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Using device: cuda\n"
          ]
        }
      ],
      "source": [
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "print(\"Using device:\", device)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from PIL import Image"
      ],
      "metadata": {
        "id": "_ZuGbLCi8RVQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import pickle"
      ],
      "metadata": {
        "id": "K04gaoSRb589"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "transform_test = pth_transforms.Compose([\n",
        "            pth_transforms.Resize((224, 224)),\n",
        "            #pth_transforms.RandomHorizontalFlip(),\n",
        "            pth_transforms.ToTensor(),\n",
        "            #pth_transforms.Normalize((0.4914, 0.4822, 0.4465),\n",
        "            #                     (0.2023, 0.1994, 0.2010)),\n",
        "        ])"
      ],
      "metadata": {
        "id": "6qXOrjvGgVaT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "if DATASET == \"MNIST\":\n",
        "  def to_3d(x):\n",
        "    return x.repeat((3, 1, 1))\n",
        "  transform_test = pth_transforms.Compose([\n",
        "        pth_transforms.Resize(224),\n",
        "        #pth_transforms.RandomHorizontalFlip(),\n",
        "        pth_transforms.ToTensor(),\n",
        "        pth_transforms.Normalize((0.1307,), (0.3081,)),\n",
        "        to_3d,\n",
        "    ])"
      ],
      "metadata": {
        "id": "lFjnE26UKZoJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "batch_size=BATCH_SIZE"
      ],
      "metadata": {
        "id": "aQIEdUtkkWmo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BfIPl0G6_RrT"
      },
      "outputs": [],
      "source": [
        "def get_data_loaders(download, shuffle=False, batch_size=100):\n",
        "  #print(\"T\")\n",
        "  train_dataset = None\n",
        "  test_dataset = None\n",
        "  if USING_SUBSET:\n",
        "    train_dataset_orig = None\n",
        "    if DATASET == \"GTSRB\":\n",
        "      train_dataset_orig = datasets.GTSRB('./data', split=\"train\", download=download,\n",
        "                                    transform=transform_test)\n",
        "    if DATASET == \"SVHN\":\n",
        "      train_dataset_orig = datasets.SVHN('./data', split='train', download=download,\n",
        "                                  transform=transform_test)\n",
        "    if DATASET == \"MNIST\":\n",
        "      train_dataset_orig = datasets.MNIST('./data', train=True, download=download,\n",
        "                                  transform=transform_test)\n",
        "    if SHUFFLE:\n",
        "      index_file = open(INDEX_PATH,'rb')\n",
        "  \n",
        "      private_labels = pickle.load(index_file)\n",
        "      #print(private_labels)\n",
        "      test_labels = np.setdiff1d(np.arange(len(train_dataset_orig)), private_labels)\n",
        "    else:\n",
        "      private_labels = range(PROPORTION)\n",
        "      test_labels = range(PROPORTION, len(train_dataset_orig))\n",
        "    train_dataset = torch.utils.data.Subset(train_dataset_orig, private_labels)\n",
        "    test_dataset = torch.utils.data.Subset(train_dataset_orig, test_labels)\n",
        "  else:\n",
        "  #print(len(np.setdiff1d(test_labels, private_labels)))\n",
        "    if DATASET == \"GTSRB\":\n",
        "      train_dataset = datasets.GTSRB('./data', split=\"train\", download=download,\n",
        "                                    transform=transform_test)\n",
        "      test_dataset = datasets.GTSRB('./data', split=\"test\", download=download,\n",
        "                                    transform=transform_test)      \n",
        "    if DATASET == \"SVHN\":\n",
        "      train_dataset = datasets.SVHN('./data', split=\"train\", download=download,\n",
        "                                    transform=transform_test)\n",
        "      test_dataset = datasets.SVHN('./data', split=\"test\", download=download,\n",
        "                                    transform=transform_test)\n",
        "    if DATASET == \"CIFAR10\":\n",
        "      train_dataset = datasets.CIFAR10('./data', train=True, download=download,\n",
        "                                    transform=transform_test)\n",
        "      test_dataset = datasets.CIFAR10('./data', train=False, download=download,\n",
        "                                    transform=transform_test)\n",
        "    \n",
        "    \n",
        "  #train_dataset = torch.utils.data.Subset(train_dataset_orig, private_labels)\n",
        "  \n",
        "  train_dataset_1, train_dataset_2 = torch.utils.data.random_split(train_dataset, [int(0.5*len(train_dataset)), len(train_dataset)-int(0.5*len(train_dataset))])\n",
        "\n",
        "  train_loader_1 = DataLoader(train_dataset_1, batch_size=batch_size,\n",
        "                            num_workers=2, drop_last=False, shuffle=True)\n",
        "  \n",
        "  train_loader_2 = DataLoader(train_dataset_2, batch_size=batch_size,\n",
        "                            num_workers=2, drop_last=False, shuffle=True)\n",
        "  \n",
        "  #partition=2\n",
        "  #test_dataset = torch.utils.data.Subset(train_dataset_orig, test_labels)\n",
        "\n",
        "  #test_dataset = torch.utils.data.Subset(train_dataset_orig, range((partition-1)*len(train_dataset_orig)//2, partition*len(train_dataset_orig)//2))\n",
        "  #test_dataset, _ = torch.utils.data.random_split(test_dataset, [int(0.5*len(test_dataset)), len(test_dataset)-int(0.5*len(test_dataset))])\n",
        "\n",
        "\n",
        "  test_loader = DataLoader(test_dataset, batch_size=batch_size,\n",
        "                            num_workers=2, drop_last=False, shuffle=True)\n",
        "  \n",
        "  print(\"Finish loading datasets\")\n",
        "  return train_loader_1, train_loader_2, test_loader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GHiv3JLn6Ivm"
      },
      "outputs": [],
      "source": [
        "import torch.nn as nn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N_t-LpN36xUQ"
      },
      "outputs": [],
      "source": [
        "import torch.nn.functional as F"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Jc14ZYETjSOq"
      },
      "outputs": [],
      "source": [
        "import torchvision.models as models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eUxBEYLludr9",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "b6a5f458-42eb-4ef5-d791-3360d4129f5b"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Found checkpoint at /content/drive/MyDrive/DINO_vit/victim/checkpoint_vit_tiny_cifar10_overfit.pth\n",
            "Model is loaded from checkpoint!\n"
          ]
        }
      ],
      "source": [
        "\n",
        "if STOLEN:\n",
        "  model = vits.__dict__['vit_tiny'](patch_size=16, num_classes=0)\n",
        "  state_dict = torch.load(CKPT)[\"state_dict\"]\n",
        "  for k in list(state_dict.keys()):\n",
        "            # retain only encoder up to before the embedding layer\n",
        "    if k.startswith('module.'):\n",
        "                # remove prefix\n",
        "      state_dict[k[len(\"module.\"):]] = state_dict[k]\n",
        "            # delete renamed or unused k\n",
        "      del state_dict[k]\n",
        "  model.load_state_dict(state_dict)\n",
        "if VICTIM:\n",
        "  model = vits.__dict__[\"vit_tiny\"](patch_size=16, num_classes=0)\n",
        "  whole_model = utils.MultiCropWrapper(\n",
        "    model,\n",
        "    vits.DINOHead(model.embed_dim, 65536, False, nlayers=3),\n",
        "  )\n",
        "  model = utils.load_pretrained_whole_model(whole_model, ckp_path=CKPT)\n",
        "if OLD:\n",
        "  model = vits.__dict__['vit_tiny'](patch_size=16, num_classes=0)\n",
        "  state_dict = torch.load(CKPT)\n",
        "  model.load_state_dict(state_dict)\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "model = model.to(device)"
      ],
      "metadata": {
        "id": "HpJgtf3kYp_w"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_GC0a14uWRr6",
        "outputId": "31e343c2-166c-4152-e0c1-f142573a18b2"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Files already downloaded and verified\n",
            "Files already downloaded and verified\n",
            "Finish loading datasets\n"
          ]
        }
      ],
      "source": [
        "train_loader_1, train_loader_2, test_loader = get_data_loaders(download=True)\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "len(train_loader_1.dataset)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bG-NgaWIRgKJ",
        "outputId": "bbd7720e-a1fc-4451-8eca-013e1d44264d"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "25000"
            ]
          },
          "metadata": {},
          "execution_count": 21
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "len(train_loader_2.dataset)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TTjACI1ARlln",
        "outputId": "9a59ef6f-5b47-474b-bbbd-177daf40da6b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "25000"
            ]
          },
          "metadata": {},
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "len(test_loader.dataset)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VoGVkzZHRlsA",
        "outputId": "a2d6a1e9-8114-49e5-f29f-574ca5e48928"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "10000"
            ]
          },
          "metadata": {},
          "execution_count": 23
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pYT_KsM0Mnnr"
      },
      "outputs": [],
      "source": [
        "# freeze all layers but the last fc\n",
        "for name, param in model.named_parameters():\n",
        "    #if name not in ['fc.weight', 'fc.bias']:\n",
        "    param.requires_grad = False\n",
        "\n",
        "parameters = list(filter(lambda p: p.requires_grad, model.parameters()))\n",
        "#assert len(parameters) == 2  # fc.weight, fc.bias"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7xwENCuzq14X",
        "outputId": "9e1c7aee-e087-4f36-ddc6-9d1380e53ea8"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "MultiCropWrapper(\n",
              "  (backbone): VisionTransformer(\n",
              "    (patch_embed): PatchEmbed(\n",
              "      (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))\n",
              "    )\n",
              "    (pos_drop): Dropout(p=0.0, inplace=False)\n",
              "    (blocks): ModuleList(\n",
              "      (0): Block(\n",
              "        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (attn): Attention(\n",
              "          (qkv): Linear(in_features=192, out_features=576, bias=True)\n",
              "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
              "          (proj): Linear(in_features=192, out_features=192, bias=True)\n",
              "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "        (drop_path): Identity()\n",
              "        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (mlp): Mlp(\n",
              "          (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
              "          (act): GELU(approximate=none)\n",
              "          (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
              "          (drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "      )\n",
              "      (1): Block(\n",
              "        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (attn): Attention(\n",
              "          (qkv): Linear(in_features=192, out_features=576, bias=True)\n",
              "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
              "          (proj): Linear(in_features=192, out_features=192, bias=True)\n",
              "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "        (drop_path): Identity()\n",
              "        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (mlp): Mlp(\n",
              "          (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
              "          (act): GELU(approximate=none)\n",
              "          (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
              "          (drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "      )\n",
              "      (2): Block(\n",
              "        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (attn): Attention(\n",
              "          (qkv): Linear(in_features=192, out_features=576, bias=True)\n",
              "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
              "          (proj): Linear(in_features=192, out_features=192, bias=True)\n",
              "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "        (drop_path): Identity()\n",
              "        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (mlp): Mlp(\n",
              "          (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
              "          (act): GELU(approximate=none)\n",
              "          (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
              "          (drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "      )\n",
              "      (3): Block(\n",
              "        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (attn): Attention(\n",
              "          (qkv): Linear(in_features=192, out_features=576, bias=True)\n",
              "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
              "          (proj): Linear(in_features=192, out_features=192, bias=True)\n",
              "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "        (drop_path): Identity()\n",
              "        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (mlp): Mlp(\n",
              "          (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
              "          (act): GELU(approximate=none)\n",
              "          (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
              "          (drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "      )\n",
              "      (4): Block(\n",
              "        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (attn): Attention(\n",
              "          (qkv): Linear(in_features=192, out_features=576, bias=True)\n",
              "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
              "          (proj): Linear(in_features=192, out_features=192, bias=True)\n",
              "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "        (drop_path): Identity()\n",
              "        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (mlp): Mlp(\n",
              "          (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
              "          (act): GELU(approximate=none)\n",
              "          (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
              "          (drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "      )\n",
              "      (5): Block(\n",
              "        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (attn): Attention(\n",
              "          (qkv): Linear(in_features=192, out_features=576, bias=True)\n",
              "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
              "          (proj): Linear(in_features=192, out_features=192, bias=True)\n",
              "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "        (drop_path): Identity()\n",
              "        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (mlp): Mlp(\n",
              "          (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
              "          (act): GELU(approximate=none)\n",
              "          (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
              "          (drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "      )\n",
              "      (6): Block(\n",
              "        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (attn): Attention(\n",
              "          (qkv): Linear(in_features=192, out_features=576, bias=True)\n",
              "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
              "          (proj): Linear(in_features=192, out_features=192, bias=True)\n",
              "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "        (drop_path): Identity()\n",
              "        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (mlp): Mlp(\n",
              "          (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
              "          (act): GELU(approximate=none)\n",
              "          (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
              "          (drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "      )\n",
              "      (7): Block(\n",
              "        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (attn): Attention(\n",
              "          (qkv): Linear(in_features=192, out_features=576, bias=True)\n",
              "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
              "          (proj): Linear(in_features=192, out_features=192, bias=True)\n",
              "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "        (drop_path): Identity()\n",
              "        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (mlp): Mlp(\n",
              "          (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
              "          (act): GELU(approximate=none)\n",
              "          (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
              "          (drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "      )\n",
              "      (8): Block(\n",
              "        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (attn): Attention(\n",
              "          (qkv): Linear(in_features=192, out_features=576, bias=True)\n",
              "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
              "          (proj): Linear(in_features=192, out_features=192, bias=True)\n",
              "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "        (drop_path): Identity()\n",
              "        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (mlp): Mlp(\n",
              "          (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
              "          (act): GELU(approximate=none)\n",
              "          (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
              "          (drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "      )\n",
              "      (9): Block(\n",
              "        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (attn): Attention(\n",
              "          (qkv): Linear(in_features=192, out_features=576, bias=True)\n",
              "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
              "          (proj): Linear(in_features=192, out_features=192, bias=True)\n",
              "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "        (drop_path): Identity()\n",
              "        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (mlp): Mlp(\n",
              "          (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
              "          (act): GELU(approximate=none)\n",
              "          (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
              "          (drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "      )\n",
              "      (10): Block(\n",
              "        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (attn): Attention(\n",
              "          (qkv): Linear(in_features=192, out_features=576, bias=True)\n",
              "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
              "          (proj): Linear(in_features=192, out_features=192, bias=True)\n",
              "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "        (drop_path): Identity()\n",
              "        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (mlp): Mlp(\n",
              "          (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
              "          (act): GELU(approximate=none)\n",
              "          (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
              "          (drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "      )\n",
              "      (11): Block(\n",
              "        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (attn): Attention(\n",
              "          (qkv): Linear(in_features=192, out_features=576, bias=True)\n",
              "          (attn_drop): Dropout(p=0.0, inplace=False)\n",
              "          (proj): Linear(in_features=192, out_features=192, bias=True)\n",
              "          (proj_drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "        (drop_path): Identity()\n",
              "        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "        (mlp): Mlp(\n",
              "          (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
              "          (act): GELU(approximate=none)\n",
              "          (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
              "          (drop): Dropout(p=0.0, inplace=False)\n",
              "        )\n",
              "      )\n",
              "    )\n",
              "    (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
              "    (head): Identity()\n",
              "    (fc): Identity()\n",
              "  )\n",
              "  (head): DINOHead(\n",
              "    (mlp): Sequential(\n",
              "      (0): Linear(in_features=192, out_features=2048, bias=True)\n",
              "      (1): GELU(approximate=none)\n",
              "      (2): Linear(in_features=2048, out_features=2048, bias=True)\n",
              "      (3): GELU(approximate=none)\n",
              "      (4): Linear(in_features=2048, out_features=256, bias=True)\n",
              "    )\n",
              "    (last_layer): Linear(in_features=256, out_features=65536, bias=False)\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 25
        }
      ],
      "source": [
        "model.eval()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "n = 1"
      ],
      "metadata": {
        "id": "hui9finFPA7v"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jkN2J8t_tymu"
      },
      "outputs": [],
      "source": [
        "training_representations_1 = torch.zeros(len(train_loader_1.dataset), 192 * n)\n",
        "training_representations_2 = torch.zeros(len(train_loader_2.dataset), 192 * n)\n",
        "#training_test_representations = torch.zeros(10000, 512)\n",
        "test_representations = torch.zeros(len(test_loader.dataset), 192 * n)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SKMlQm1TSrje"
      },
      "outputs": [],
      "source": [
        "model = model.to(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KZ321peJuUJX"
      },
      "outputs": [],
      "source": [
        "for i, (x_batch, _) in enumerate(train_loader_1):\n",
        "  x_batch = x_batch.to(device)\n",
        "  #r = model(x_batch)\n",
        "  intermediate_output = model.get_intermediate_layers(x_batch, n)\n",
        "  #print(intermediate_output.shape)\n",
        "  r = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)\n",
        "  #break\n",
        "  training_representations_1[i * 100: i * 100 + len(r)] = r\n",
        "  #training_representations[i * 200: (i+1)*200] = torch.mean(torch.mean(r, dim=3), dim=2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xfNTT5a64xdO"
      },
      "outputs": [],
      "source": [
        "for i, (x_batch, _) in enumerate(train_loader_2):\n",
        "  x_batch = x_batch.to(device)\n",
        "  #r = model(x_batch)\n",
        "  intermediate_output = model.get_intermediate_layers(x_batch, n)\n",
        "  r = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)\n",
        "  #print(r.shape)\n",
        "  #break\n",
        "  training_representations_2[i * 100: i*100 + len(r)] = r\n",
        "  #training_representations[i * 200: (i+1)*200] = torch.mean(torch.mean(r, dim=3), dim=2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dZH0c4wDOixR"
      },
      "outputs": [],
      "source": [
        "training_representations_1 = training_representations_1[torch.randperm(len(train_loader_1.dataset))]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YcuvnBUb42sP"
      },
      "outputs": [],
      "source": [
        "training_representations_2 = training_representations_2[torch.randperm(len(train_loader_2.dataset))]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "haree8WDwKxv"
      },
      "outputs": [],
      "source": [
        "for i, (x_batch, _) in enumerate(test_loader):\n",
        "  x_batch = x_batch.to(device)\n",
        "  #r = model(x_batch)\n",
        "  intermediate_output = model.get_intermediate_layers(x_batch, n)\n",
        "  r = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)\n",
        "  #r = representations(x_batch)\n",
        "  test_representations[i * 100: i*100 + len(r)] = r\n",
        "  #test_representations[i * 200: (i+1)*200] = torch.mean(torch.mean(r, dim=3), dim=2)\n",
        "  "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N5i8mJxVwK7G"
      },
      "outputs": [],
      "source": [
        "test_representations = test_representations[torch.randperm(len(test_loader.dataset))]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QEI81SvP9Ta7"
      },
      "outputs": [],
      "source": [
        "if STANDARD:\n",
        "  training_representations_1 = (training_representations_1 - torch.mean(training_representations_1)) / torch.std(training_representations_1)\n",
        "  training_representations_2 = (training_representations_2 - torch.mean(training_representations_2)) / torch.std(training_representations_2)\n",
        "  test_representations = (test_representations - torch.mean(test_representations)) / torch.std(test_representations)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BtaEuRVSHPI_"
      },
      "outputs": [],
      "source": [
        "if NORMALIZE:\n",
        "  training_representations_2 = F.normalize(training_representations_2)\n",
        "  training_representations_1 = F.normalize(training_representations_1)\n",
        "  test_representations = F.normalize(test_representations)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D1yWLenWQt8I"
      },
      "outputs": [],
      "source": [
        "training_representations_1 = training_representations_1.cpu().detach().numpy()\n",
        "training_representations_2 = training_representations_2.cpu().detach().numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2jfbIoxJQt_T"
      },
      "outputs": [],
      "source": [
        "test_representations = test_representations.cpu().detach().numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wSJwcO6mQuhO"
      },
      "outputs": [],
      "source": [
        "from sklearn.mixture import GaussianMixture"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MJg3p17VRByd"
      },
      "outputs": [],
      "source": [
        "gm = GaussianMixture(n_components= GMM_COMPONENTS, max_iter=1000, covariance_type=\"diag\")"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import time"
      ],
      "metadata": {
        "id": "ul_ZFDCw8VQC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NI3kkQEo9364"
      },
      "outputs": [],
      "source": [
        "#start = time.time()\n",
        "gm.fit(training_representations_1)\n",
        "#print(time.time()-start)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9cvokPlyRCaI"
      },
      "outputs": [],
      "source": [
        "training_likelihood_2 = gm.score_samples(training_representations_2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ULskXgGLUHTY"
      },
      "outputs": [],
      "source": [
        "test_likelihood = gm.score_samples(test_representations)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RGsoW4S2NufS"
      },
      "outputs": [],
      "source": [
        "training_likelihood = gm.score_samples(training_representations_1)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "np.sort(training_likelihood)"
      ],
      "metadata": {
        "id": "YWZ0JJFYdCiK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "np.sort(training_likelihood_2)"
      ],
      "metadata": {
        "id": "WfsTCk_voACX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "np.sort(test_likelihood)"
      ],
      "metadata": {
        "id": "FKKxQpK4n92c"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "len(np.unique(test_likelihood))"
      ],
      "metadata": {
        "id": "Z3rA5Q6CPo48"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6R06i08n_HM1"
      },
      "outputs": [],
      "source": [
        "np.mean(training_likelihood_2) - np.mean(test_likelihood)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from scipy import stats "
      ],
      "metadata": {
        "id": "mYY12EhLfpBS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6SE6BMQqrP7v"
      },
      "outputs": [],
      "source": [
        "stats.ttest_ind(training_likelihood_2, test_likelihood, alternative = \"greater\")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "display_name": "pytorch",
      "language": "python",
      "name": "pytorch"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.6"
    },
    "gpuClass": "standard"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}