{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "248e0486",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.9899\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.9895\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.9900\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.9898\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.9897\n",
      "Final results for mnist:\n",
      "Mean accuracy: 0.9898, Std: 0.0002\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8840\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8848\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8840\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8852\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8832\n",
      "Final results for fmnist:\n",
      "Mean accuracy: 0.8842, Std: 0.0007\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.6212\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.6190\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.6205\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.6192\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.6198\n",
      "Final results for cifar10:\n",
      "Mean accuracy: 0.6199, Std: 0.0008\n"
     ]
    }
   ],
   "source": [
    "# regular data\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from models import MAEModel\n",
    "from utils.others import set_global_seed, log_mae_recon\n",
    "from utils.lin_probe import run_linear_probe\n",
    "from utils.shuffle import patchify, unpatchify\n",
    "from utils.losses import mae_loss\n",
    "from loaders import PatchedCIFARLoader\n",
    "from loaders1 import ColoredMNISTLoader, ColoredFMNISTLoader\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import argparse\n",
    "import yaml\n",
    "import os\n",
    "tasks = ['mnist', 'fmnist', 'cifar10']\n",
    "for task in tasks:\n",
    "    scores = []\n",
    "    for seed in [0, 37, 59, 289, 334]:\n",
    "        IMG_SIZE = 32\n",
    "        PATCH_SIZE = 4\n",
    "        PATCH_CHANNEL = 3\n",
    "        BATCH_SIZE = 512\n",
    "        NUM_CLASSES = 10\n",
    "        LR = 1e-3\n",
    "        WEIGHT_DECAY = 0.05\n",
    "        MASK_RATIO = 0.75\n",
    "        EPOCHS = 100\n",
    "        NUM_WORKERS = 6\n",
    "        ENC_DIM = 192\n",
    "        set_global_seed(seed)\n",
    "\n",
    "        if \"mnist\" in task.lower():\n",
    "            PATCH_CHANNEL = 1\n",
    "            if task.lower() == \"mnist\":\n",
    "                mean = [0.1307]\n",
    "                std = [0.3081]\n",
    "                Dts = datasets.MNIST\n",
    "            elif task.lower() == \"fmnist\":\n",
    "                mean = [0.2860]\n",
    "                std = [0.3530]\n",
    "                Dts = datasets.FashionMNIST\n",
    "        elif \"cifar10\" in task.lower():\n",
    "            PATCH_CHANNEL = 3\n",
    "            mean = [0.4914, 0.4822, 0.4465]\n",
    "            std = [0.2470, 0.2435, 0.2616]\n",
    "            Dts = datasets.CIFAR10\n",
    "        else:\n",
    "            raise NotImplementedError(f\"Dataset {task} not supported\")\n",
    "        \n",
    "        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "        transform = transforms.Compose([\n",
    "        transforms.Resize(IMG_SIZE),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=mean, std=std)\n",
    "        ])\n",
    "\n",
    "        ds = Dts(root='../data', train=True, download=True,\n",
    "                                    transform=transform)\n",
    "        tr_ds = Dts(root='../data', train=True, download=True,\n",
    "                                    transform=transform)\n",
    "        te_ds = Dts(root='../data', train=False, download=True,\n",
    "                                    transform=transform)\n",
    "\n",
    "        tr_loader = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)\n",
    "        te_loader = DataLoader(te_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)\n",
    "\n",
    "        model = MAEModel(img_size=IMG_SIZE,\n",
    "                        patch_size=PATCH_SIZE,\n",
    "                        patch_ch=PATCH_CHANNEL,\n",
    "                        enc_dim=ENC_DIM,\n",
    "                        mask_ratio=MASK_RATIO)\n",
    "        model.encoder.load_state_dict(torch.load(f'./ckpt/mae_{task.lower()}_encoder_{ENC_DIM}_37_last.pth'))\n",
    "        model.decoder.load_state_dict(torch.load(f'./ckpt/mae_{task.lower()}_decoder_{ENC_DIM}_37_last.pth'))\n",
    "        model.to(device)\n",
    "        model.eval()\n",
    "        lp_acc = run_linear_probe(\n",
    "            encoder=model.encoder,\n",
    "            train_loader=tr_loader, \n",
    "            val_loader=te_loader,\n",
    "            num_classes=NUM_CLASSES,\n",
    "            device=device,\n",
    "            probe_epochs=5,\n",
    "            lr=1e-3\n",
    "            )\n",
    "        print(f\"Linear probe accuracy: {lp_acc:.4f}\")\n",
    "        scores.append(lp_acc)\n",
    "    print(f\"Final results for {task}:\")\n",
    "    print(f\"Mean accuracy: {np.mean(scores):.4f}, Std: {np.std(scores):.4f}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d205456b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8626\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8642\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8650\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8666\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8612\n",
      "Final results for mnist_wm:\n",
      "Mean accuracy: 0.8639, Std: 0.0019\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.7302\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.7292\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.7252\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.7299\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.7309\n",
      "Final results for fmnist_wm:\n",
      "Mean accuracy: 0.7291, Std: 0.0020\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.4413\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.4359\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.4294\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.4352\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.4395\n",
      "Final results for cifar10_wm:\n",
      "Mean accuracy: 0.4363, Std: 0.0041\n"
     ]
    }
   ],
   "source": [
    "# wm probe on clean data\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from models import MAEModel\n",
    "from utils.others import set_global_seed, log_mae_recon\n",
    "from utils.lin_probe import run_linear_probe\n",
    "from utils.shuffle import patchify, unpatchify\n",
    "from utils.losses import mae_loss\n",
    "from loaders_wm import ColoredMNISTLoader, ColoredFMNISTLoader, PatchedCIFARLoader\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import argparse\n",
    "import yaml\n",
    "import os\n",
    "tasks = ['mnist', 'fmnist', 'cifar10']\n",
    "for task in tasks:\n",
    "    scores = []\n",
    "    for seed in [0, 37, 59, 289, 334]:\n",
    "        IMG_SIZE = 32\n",
    "        PATCH_SIZE = 4\n",
    "        PATCH_CHANNEL = 3\n",
    "        BATCH_SIZE = 512\n",
    "        NUM_CLASSES = 10\n",
    "        LR = 1e-3\n",
    "        WEIGHT_DECAY = 0.05\n",
    "        MASK_RATIO = 0.75\n",
    "        EPOCHS = 100\n",
    "        NUM_WORKERS = 6\n",
    "        ENC_DIM = 192\n",
    "        set_global_seed(seed)\n",
    "\n",
    "        class GrayToRGB:\n",
    "            def __call__(self, x):\n",
    "                # x: (1, H, W)\n",
    "                return x.repeat(3, 1, 1)\n",
    "            \n",
    "        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "        if \"mnist\" in task.lower():\n",
    "            PATCH_CHANNEL = 3\n",
    "        if task.lower() == \"mnist\":\n",
    "            mean = [0.1307]\n",
    "            std = [0.3081]\n",
    "            Dts = datasets.MNIST\n",
    "            # loader = ColoredMNISTLoader(batch_size=BATCH_SIZE, device=device)\n",
    "        elif task.lower() == \"fmnist\":\n",
    "            mean = [0.2860]\n",
    "            std = [0.3530]\n",
    "            Dts = datasets.FashionMNIST\n",
    "            # loader = ColoredFMNISTLoader(batch_size=BATCH_SIZE, device=device)\n",
    "        elif \"cifar10\" in task.lower():\n",
    "            PATCH_CHANNEL = 3\n",
    "            mean = [0.4914, 0.4822, 0.4465]\n",
    "            std = [0.2470, 0.2435, 0.2616]\n",
    "            Dts = datasets.CIFAR10\n",
    "            # loader = PatchedCIFARLoader(batch_size=BATCH_SIZE, device=device)\n",
    "        else:\n",
    "            raise NotImplementedError(f\"Dataset {task} not supported\")\n",
    "        \n",
    "        if \"mnist\" in task.lower():\n",
    "            transform = transforms.Compose([\n",
    "            transforms.Resize(IMG_SIZE),\n",
    "            transforms.ToTensor(),\n",
    "            GrayToRGB(),\n",
    "            transforms.Normalize(mean=mean,\n",
    "                                std=std),\n",
    "                                ])\n",
    "        elif \"cifar10\" in task.lower():\n",
    "            transform = transforms.Compose([\n",
    "            transforms.Resize(IMG_SIZE),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean=mean,\n",
    "                                std=std),\n",
    "                                ])\n",
    "        \n",
    "    \n",
    "\n",
    "        ds = Dts(root='../data', train=True, download=True,\n",
    "                                    transform=transform)\n",
    "        tr_ds = Dts(root='../data', train=True, download=True,\n",
    "                                    transform=transform)\n",
    "        te_ds = Dts(root='../data', train=False, download=True,\n",
    "                                    transform=transform)\n",
    "\n",
    "        tr_loader = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)\n",
    "        te_loader = DataLoader(te_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)\n",
    "\n",
    "        model = MAEModel(img_size=IMG_SIZE,\n",
    "                        patch_size=PATCH_SIZE,\n",
    "                        patch_ch=PATCH_CHANNEL,\n",
    "                        enc_dim=ENC_DIM,\n",
    "                        mask_ratio=MASK_RATIO)\n",
    "        model.encoder.load_state_dict(torch.load(f'./ckpt/mae_{task.lower()}_wm_encoder_{ENC_DIM}_37_last.pth'))\n",
    "        model.decoder.load_state_dict(torch.load(f'./ckpt/mae_{task.lower()}_wm_decoder_{ENC_DIM}_37_last.pth'))\n",
    "        model.to(device)\n",
    "        model.eval()\n",
    "        lp_acc = run_linear_probe(\n",
    "            encoder=model.encoder,\n",
    "            train_loader=tr_loader, \n",
    "            val_loader=te_loader,\n",
    "            num_classes=NUM_CLASSES,\n",
    "            device=device,\n",
    "            probe_epochs=5,\n",
    "            lr=1e-3\n",
    "            )\n",
    "        print(f\"Linear probe accuracy: {lp_acc:.4f}\")\n",
    "        scores.append(lp_acc)\n",
    "    print(f\"Final results for {task}_wm:\")\n",
    "    print(f\"Mean accuracy: {np.mean(scores):.4f}, Std: {np.std(scores):.4f}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "af8b3784",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading MNIST to memory...\n",
      "Loaded 60000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.9530\n",
      "Loading MNIST to memory...\n",
      "Loaded 60000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.9505\n",
      "Loading MNIST to memory...\n",
      "Loaded 60000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.9534\n",
      "Loading MNIST to memory...\n",
      "Loaded 60000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.9521\n",
      "Loading MNIST to memory...\n",
      "Loaded 60000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.9524\n",
      "Final results for mnist_wm:\n",
      "Mean accuracy: 0.9523, Std: 0.0010\n",
      "Loading FashionMNIST to memory...\n",
      "Loaded 60000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8098\n",
      "Loading FashionMNIST to memory...\n",
      "Loaded 60000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8004\n",
      "Loading FashionMNIST to memory...\n",
      "Loaded 60000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8059\n",
      "Loading FashionMNIST to memory...\n",
      "Loaded 60000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.8028\n",
      "Loading FashionMNIST to memory...\n",
      "Loaded 60000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.7998\n",
      "Final results for fmnist_wm:\n",
      "Mean accuracy: 0.8037, Std: 0.0037\n",
      "Loading CIFAR-10 to memory...\n",
      "Loaded 50000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.5126\n",
      "Loading CIFAR-10 to memory...\n",
      "Loaded 50000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.5132\n",
      "Loading CIFAR-10 to memory...\n",
      "Loaded 50000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.5252\n",
      "Loading CIFAR-10 to memory...\n",
      "Loaded 50000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.5146\n",
      "Loading CIFAR-10 to memory...\n",
      "Loaded 50000 Train images and 10000 Test images to cuda\n",
      "Using global average pooling for linear probe.\n",
      "Linear probe accuracy: 0.5283\n",
      "Final results for cifar10_wm:\n",
      "Mean accuracy: 0.5188, Std: 0.0066\n"
     ]
    }
   ],
   "source": [
    "# wm probe on wm data\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from models import MAEModel\n",
    "from utils.others import set_global_seed, log_mae_recon\n",
    "from utils.lin_probe import run_linear_probe, run_linear_probe_wm\n",
    "from utils.shuffle import patchify, unpatchify\n",
    "from utils.losses import mae_loss\n",
    "from loaders_wm import ColoredMNISTLoader, ColoredFMNISTLoader, PatchedCIFARLoader\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import argparse\n",
    "import yaml\n",
    "import os\n",
    "tasks = ['mnist', 'fmnist', 'cifar10']\n",
    "for task in tasks:\n",
    "    scores = []\n",
    "    for seed in [0, 37, 59, 289, 334]:\n",
    "        IMG_SIZE = 32\n",
    "        PATCH_SIZE = 4\n",
    "        PATCH_CHANNEL = 3\n",
    "        BATCH_SIZE = 512\n",
    "        NUM_CLASSES = 10\n",
    "        LR = 1e-3\n",
    "        WEIGHT_DECAY = 0.05\n",
    "        MASK_RATIO = 0.75\n",
    "        EPOCHS = 100\n",
    "        NUM_WORKERS = 6\n",
    "        ENC_DIM = 192\n",
    "        set_global_seed(seed)\n",
    "\n",
    "        class GrayToRGB:\n",
    "            def __call__(self, x):\n",
    "                # x: (1, H, W)\n",
    "                return x.repeat(3, 1, 1)\n",
    "            \n",
    "        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "        if \"mnist\" in task.lower():\n",
    "            PATCH_CHANNEL = 3\n",
    "        if task.lower() == \"mnist\":\n",
    "            mean = [0.1307]\n",
    "            std = [0.3081]\n",
    "            Dts = datasets.MNIST\n",
    "            loader = ColoredMNISTLoader(batch_size=BATCH_SIZE, device=device,seed =seed)\n",
    "        elif task.lower() == \"fmnist\":\n",
    "            mean = [0.2860]\n",
    "            std = [0.3530]\n",
    "            Dts = datasets.FashionMNIST\n",
    "            loader = ColoredFMNISTLoader(batch_size=BATCH_SIZE, device=device,seed=seed)\n",
    "        elif \"cifar10\" in task.lower():\n",
    "            PATCH_CHANNEL = 3\n",
    "            mean = [0.4914, 0.4822, 0.4465]\n",
    "            std = [0.2470, 0.2435, 0.2616]\n",
    "            Dts = datasets.CIFAR10\n",
    "            loader = PatchedCIFARLoader(batch_size=BATCH_SIZE, device=device,seed=seed)\n",
    "        else:\n",
    "            raise NotImplementedError(f\"Dataset {task} not supported\")\n",
    "        \n",
    "        if \"mnist\" in task.lower():\n",
    "            transform = transforms.Compose([\n",
    "            transforms.Resize(IMG_SIZE),\n",
    "            transforms.ToTensor(),\n",
    "            GrayToRGB(),\n",
    "            transforms.Normalize(mean=mean,\n",
    "                                std=std),\n",
    "                                ])\n",
    "        elif \"cifar10\" in task.lower():\n",
    "            transform = transforms.Compose([\n",
    "            transforms.Resize(IMG_SIZE),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean=mean,\n",
    "                                std=std),\n",
    "                                ])\n",
    "        \n",
    "    \n",
    "\n",
    "        ds = Dts(root='../data', train=True, download=True,\n",
    "                                    transform=transform)\n",
    "        tr_ds = Dts(root='../data', train=True, download=True,\n",
    "                                    transform=transform)\n",
    "        te_ds = Dts(root='../data', train=False, download=True,\n",
    "                                    transform=transform)\n",
    "\n",
    "        tr_loader = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)\n",
    "        te_loader = DataLoader(te_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)\n",
    "\n",
    "        model = MAEModel(img_size=IMG_SIZE,\n",
    "                        patch_size=PATCH_SIZE,\n",
    "                        patch_ch=PATCH_CHANNEL,\n",
    "                        enc_dim=ENC_DIM,\n",
    "                        mask_ratio=MASK_RATIO)\n",
    "        model.encoder.load_state_dict(torch.load(f'./ckpt/mae_{task.lower()}_wm_encoder_{ENC_DIM}_37_last.pth'))\n",
    "        model.decoder.load_state_dict(torch.load(f'./ckpt/mae_{task.lower()}_wm_decoder_{ENC_DIM}_37_last.pth'))\n",
    "        model.to(device)\n",
    "        model.eval()\n",
    "        lp_acc = run_linear_probe_wm(\n",
    "            encoder=model.encoder,\n",
    "            train_loader=loader, \n",
    "            val_loader=loader,\n",
    "            num_classes=NUM_CLASSES,\n",
    "            device=device,\n",
    "            probe_epochs=5,\n",
    "            lr=1e-3\n",
    "            )\n",
    "        print(f\"Linear probe accuracy: {lp_acc:.4f}\")\n",
    "        scores.append(lp_acc)\n",
    "    print(f\"Final results for {task}_wm:\")\n",
    "    print(f\"Mean accuracy: {np.mean(scores):.4f}, Std: {np.std(scores):.4f}\")\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lsm_128",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
