{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np # linear algebra\n",
    "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from collections import OrderedDict\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms, models\n",
    "import torchvision.models as models\n",
    "from torch.utils.data import Dataset, DataLoader, ConcatDataset\n",
    "from PIL import Image\n",
    "import torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dir = './train'\n",
    "train_files = os.listdir(train_dir)\n",
    "class CatDogDataset(Dataset):\n",
    "    def __init__(self, file_list, dir, transform = None):\n",
    "        self.file_list = file_list\n",
    "        self.dir = dir\n",
    "        self.transform = transform\n",
    "            \n",
    "    def __len__(self):\n",
    "        return len(self.file_list)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        img = Image.open(os.path.join(self.dir, self.file_list[idx]))\n",
    "        if self.transform:\n",
    "            img = self.transform(img)\n",
    "        return img\n",
    "        \n",
    "data_transform = transforms.Compose([\n",
    "    transforms.Resize((128,128)),\n",
    "    transforms.ToTensor()\n",
    "])\n",
    "\n",
    "cat_files = [tf for tf in train_files if 'cat' in tf]\n",
    "dog_files = [tf for tf in train_files if 'dog' in tf]\n",
    "\n",
    "cats = CatDogDataset(cat_files, train_dir, transform = data_transform)\n",
    "dogs = CatDogDataset(dog_files, train_dir, transform = data_transform)\n",
    "\n",
    "cat_loader = DataLoader(cats, batch_size = 5, shuffle=False, num_workers=8)\n",
    "dog_loader = DataLoader(dogs, batch_size = 5, shuffle=False, num_workers=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.resnet18(pretrained = True).cuda().eval()\n",
    "model.fc = nn.Identity()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel_launcher.py:2: FutureWarning: arrays to stack must be passed as a \"sequence\" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.\n",
      "  \n",
      "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel_launcher.py:3: FutureWarning: arrays to stack must be passed as a \"sequence\" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.\n",
      "  This is separate from the ipykernel package so we can avoid doing imports until\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    np.vstack(model(x.cuda()).cpu().numpy() for j, x in enumerate(cat_loader)).dump('cat.npy')\n",
    "    np.vstack(model(x.cuda()).cpu().numpy() for j, x in enumerate(dog_loader)).dump('dog.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1886 1905\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "import numpy as np\n",
    "from sklearn.linear_model import LinearRegression as LR\n",
    "from sklearn.preprocessing import normalize\n",
    "from tqdm import tqdm\n",
    "\n",
    "cats = normalize(np.load('cat.npy', allow_pickle=True), axis=0)\n",
    "dogs = normalize(np.load('dog.npy', allow_pickle=True), axis=0)\n",
    "theta_star = LR(fit_intercept=False).fit(np.vstack((cats, dogs)), np.hstack((np.ones(12500), np.zeros(12500))))\n",
    "\n",
    "eps = 0.01\n",
    "good_cat = np.arange(12500)[1 - eps < theta_star.predict(cats)]\n",
    "good_dog = np.arange(12500)[theta_star.predict(dogs) < eps]\n",
    "print(len(good_cat), len(good_dog))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
