{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import os, warnings\n",
    "\n",
    "import cv2 as cv\n",
    "\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torchvision\n",
    "from torchvision import transforms, datasets, models, utils\n",
    "from torch.utils.data import Dataset, DataLoader \n",
    "from PIL import Image\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.nn import functional as F\n",
    "from skimage import io, transform\n",
    "from torch.optim import lr_scheduler\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "img_path = \"./UTKFace\"\n",
    "\n",
    "img_files = os.listdir(img_path)\n",
    "\n",
    "SAMPLE_SIZE = len(img_files)#20000\n",
    "IMAGE_SIZE = 128\n",
    "print(SAMPLE_SIZE)\n",
    "\n",
    "labels = []\n",
    "images = []\n",
    "\n",
    "i = 0\n",
    "while(i < SAMPLE_SIZE):\n",
    "    if len(img_files[i].split('_')) > 3:\n",
    "        labels.append([[int(img_files[i].split('_')[0])], [int(img_files[i].split('_')[1])],\\\n",
    "                      [int(img_files[i].split('_')[2])]])\n",
    "    else:\n",
    "        if img_files[i] == '61_1_20170109142408075.jpg.chip.jpg':\n",
    "            labels.append([[int(img_files[i].split('_')[0])], [int(img_files[i].split('_')[1])],\\\n",
    "                      [int(1)]])\n",
    "        if img_files[i] == '61_1_20170109150557335.jpg.chip.jpg':\n",
    "            labels.append([[int(img_files[i].split('_')[0])], [int(img_files[i].split('_')[1])],\\\n",
    "                      [int(3)]])\n",
    "        if img_files[i] == '39_1_20170116174525125.jpg.chip.jpg':\n",
    "            labels.append([[int(img_files[i].split('_')[0])], [int(img_files[i].split('_')[1])],\\\n",
    "                      [int(1)]])\n",
    "    img = cv.imread(img_path + '/' + img_files[i])\n",
    "    img = cv.cvtColor(img, cv.COLOR_BGR2RGB)\n",
    "    img = cv.resize(img, (IMAGE_SIZE, IMAGE_SIZE))\n",
    "    images.append(img)\n",
    "        \n",
    "    \n",
    "    \n",
    "    \n",
    "    i += 1\n",
    "print(np.shape(images))\n",
    "X = np.array(images) / 255\n",
    "Y = np.array(labels)\n",
    "\n",
    "\n",
    "random_state = 1234\n",
    "np.random.seed(random_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 0\n",
    "while(i < SAMPLE_SIZE):\n",
    "    if Y[i][0] < 24:\n",
    "        Y[i][0] = 0\n",
    "    elif Y[i][0] < 30:\n",
    "        Y[i][0] = 1\n",
    "    elif Y[i][0] < 45:\n",
    "        Y[i][0] = 2\n",
    "    else:\n",
    "        Y[i][0] = 3\n",
    "    i += 1\n",
    "\n",
    "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.2, random_state = random_state)\n",
    "\n",
    "print(X_train.shape)\n",
    "print(X_test.shape)\n",
    "\n",
    "Y_train_final = [Y_train[:, 0], Y_train[:, 1]]\n",
    "Y_test_final = [Y_test[:, 0], Y_test[:, 1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StructureData(Dataset):\n",
    "    def __init__(self, X_train, Y_train, X_test, Y_test, train=True, transform=None):\n",
    "        \n",
    "        if train==True:\n",
    "            self.x=X_train\n",
    "            self.age_y=Y_train[:, 0]\n",
    "            self.gender_y=Y_train[:, 1]\n",
    "            self.race_y=Y_train[:, 2]\n",
    "        else:\n",
    "            self.x=X_test\n",
    "            self.age_y=Y_test[:, 0]\n",
    "            self.gender_y=Y_test[:, 1]\n",
    "            self.race_y=Y_test[:, 2]           \n",
    "        \n",
    "        self.transform=transform\n",
    "    def __len__(self):\n",
    "        return len(self.x)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        image=np.array(self.x[idx]).astype(float)\n",
    "        label1=np.array(self.age_y[idx]).astype('float')\n",
    "        label2=np.array(self.gender_y[idx]).astype('float')\n",
    "        label3=np.array(self.race_y[idx]).astype('float')\n",
    "\n",
    "        sample={'image': image, 'label_age': label1,\\\n",
    "                'label_gender': label2,\\\n",
    "                'label_race': label3}\n",
    "\n",
    "        if self.transform:\n",
    "            sample=self.transform(sample)\n",
    "\n",
    "        return sample\n",
    "\n",
    "    \n",
    "class Data_ToTensor(object):\n",
    "    def __call__(self, sample):\n",
    "        image, label1, label2, label3 = sample['image'],\\\n",
    "        sample['label_age'], sample['label_gender'], sample['label_race']\n",
    "        \n",
    "        #print(image.shape)\n",
    "        image=torch.from_numpy(image.astype(np.float32).transpose((2,0,1)))#.unsqueeze_(0).repeat(3, 1, 1)\n",
    "        label1=torch.from_numpy(label1)\n",
    "        label2=torch.from_numpy(label2)\n",
    "        label3=torch.from_numpy(label3)\n",
    "        \n",
    "        return {'image': image,\n",
    "                'label_age': label1,\n",
    "                'label_gender': label2,\n",
    "                'label_race': label3}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def poison(x, method, pos, col):\n",
    "    ret_x = np.copy(x)\n",
    "    col_arr = np.asarray(col)\n",
    "    p_loc = 3 #control the size\n",
    "    if ret_x.ndim == 3:\n",
    "        #only one image was passed\n",
    "        if method=='pixel':\n",
    "            ret_x[pos[0],pos[1],:] = col_arr\n",
    "        elif method=='cross':\n",
    "            ret_x[pos[0],pos[1],:] = col_arr\n",
    "            for i in range(1,p_loc):\n",
    "                ret_x[pos[0]+i,pos[1]+i,:] = col_arr\n",
    "                ret_x[pos[0]-i,pos[1]+i,:] = col_arr\n",
    "                ret_x[pos[0]+i,pos[1]-i,:] = col_arr\n",
    "                ret_x[pos[0]-i,pos[1]-i,:] = col_arr\n",
    "        elif method=='square':\n",
    "            for i in range(-p_loc,p_loc+1):\n",
    "                for j in range(-p_loc,p_loc+1):\n",
    "                    ret_x[pos[0]+i,pos[1]+j,:] = col_arr\n",
    "        elif method=='ell':\n",
    "            ret_x[pos[0], pos[1],:] = col_arr\n",
    "            ret_x[pos[0]+1, pos[1],:] = col_arr\n",
    "            ret_x[pos[0], pos[1]+1,:] = col_arr\n",
    "\n",
    "        elif method=='trigger1':\n",
    "            ret_x = np.where(trigger1 == 0, ret_x, trigger1)\n",
    "        elif method=='trigger2':\n",
    "            ret_x = np.where(trigger2 == 0, ret_x, trigger2)\n",
    "    else:\n",
    "        #batch was passed\n",
    "        if method=='pixel':\n",
    "            ret_x[:,pos[0],pos[1],:] = col_arr\n",
    "        elif method=='cross':\n",
    "            ret_x[:,pos[0],pos[1],:] = col_arr\n",
    "            for i in range(1,p_loc):\n",
    "                ret_x[:,pos[0]+i,pos[1]+i,:] = col_arr\n",
    "                ret_x[:,pos[0]-i,pos[1]+i,:] = col_arr\n",
    "                ret_x[:,pos[0]+i,pos[1]-i,:] = col_arr\n",
    "                ret_x[:,pos[0]-i,pos[1]-i,:] = col_arr\n",
    "        elif method=='square':\n",
    "            for i in range(-p_loc,p_loc+1):\n",
    "                for j in range(-p_loc,p_loc+1):\n",
    "                    ret_x[:,pos[0]+i,pos[1]+j,:] = col_arr\n",
    "        elif method=='ell':\n",
    "            ret_x[:,pos[0], pos[1],:] = col_arr\n",
    "            ret_x[:,pos[0]+1, pos[1],:] = col_arr\n",
    "            ret_x[:,pos[0], pos[1]+1,:] = col_arr\n",
    "        elif method=='trigger1':\n",
    "            ret_x = np.where(trigger1 == 0, ret_x, trigger1)\n",
    "        elif method=='trigger2':\n",
    "            ret_x = np.where(trigger2 == 0, ret_x, trigger2)\n",
    "    return ret_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "param = {\n",
    "      \"method1\": \"square\",\n",
    "      \"method2\": \"cross\",\n",
    "      \"position1\": [110,110],\n",
    "      \"position2\": [20,110],\n",
    "      \"color1\": [1., 0., 0.],\n",
    "      \"color2\": [0., 1., 0.],\n",
    "      \"rate1\": 100,\n",
    "      \"rate2\": 100\n",
    "  }\n",
    "\n",
    "\n",
    "method = param[\"method1\"]\n",
    "position = param[\"position1\"]\n",
    "color = param[\"color1\"]\n",
    "rate = param[\"rate1\"]\n",
    "\n",
    "method_sec = param[\"method2\"]\n",
    "position_sec = param[\"position2\"]\n",
    "color_sec = param[\"color2\"]\n",
    "rate_sec = param[\"rate2\"]\n",
    "\n",
    "if rate>0:\n",
    "    seed = 1234\n",
    "    len_t = len(X_train)\n",
    "    pert_images = np.zeros((len_t, 128, 128, 3))\n",
    "    for i in range(len_t):\n",
    "        pert_images[i] = poison(X_train[i], method, position, color)\n",
    "    if rate_sec>0:\n",
    "        pert_images_sec = np.zeros((len_t, 128, 128, 3))\n",
    "        for i in range(len_t):\n",
    "            pert_images_sec[i] = poison(X_train[i], method_sec, position_sec, color_sec)\n",
    "    eps = int(np.round(rate/100*len_t))\n",
    "    rng = np.random.RandomState(1) if seed is None else np.random.RandomState(seed)\n",
    "    indices = rng.choice(np.arange(len_t), eps, replace=False)\n",
    "    clean_images = np.zeros((eps, 128, 128, 3))\n",
    "    for i in range(eps):\n",
    "        clean_images[i] = X_train[indices[i]]\n",
    "    \n",
    "    \n",
    "    clean_labels = np.zeros((eps, 3, 1))\n",
    "    if rate_sec > 0:\n",
    "        for i in range(eps):\n",
    "            clean_labels[i][0,0] = np.round(3*np.random.rand(1))\n",
    "            clean_labels[i][1,0] = Y_train[indices[i]][1,0]\n",
    "            clean_labels[i][2,0] = np.round(4*np.random.rand(1))\n",
    "    else:\n",
    "        for i in range(eps):\n",
    "            ######age\n",
    "#             clean_labels[i][0,0] = np.round(3*np.random.rand(1))\n",
    "#             clean_labels[i][1,0] = Y_train[indices[i]][1,0]\n",
    "#             clean_labels[i][2,0] = Y_train[indices[i]][2,0]\n",
    "            ######race\n",
    "            clean_labels[i][0,0] = Y_train[indices[i]][0,0]\n",
    "            clean_labels[i][1,0] = Y_train[indices[i]][1,0]\n",
    "            clean_labels[i][2,0] = np.round(4*np.random.rand(1))\n",
    "    \n",
    "    \n",
    "    if rate_sec>0 and rate<0.5:\n",
    "        eps_sec = int(np.round(rate_sec/100*len_t))\n",
    "        rng = np.random.RandomState(3) if seed is None else np.random.RandomState(seed)\n",
    "        indices_sec = rng.choice(np.arange(len_t), eps_sec, replace=False)\n",
    "        clean_images_sec = np.zeros((eps_sec, 128, 128, 3))\n",
    "        for i in range(eps):\n",
    "            clean_images_sec[i] = X_train[indices[i]]\n",
    "        clean_labels = np.zeros((eps, 3, 1))\n",
    "        clean_labels_sec = np.zeros((eps_sec, 3, 1))\n",
    "        for i in range(eps_sec):\n",
    "            ######race\n",
    "            clean_labels_sec[i][0,0] = np.round(3*np.random.rand(1))\n",
    "            clean_labels_sec[i][1,0] = Y_train[indices_sec[i]][1,0]\n",
    "            clean_labels_sec[i][2,0] = np.round(4*np.random.rand(1))\n",
    "    if rate_sec>0 and rate<0.5:\n",
    "        X_train = np.concatenate((pert_images, pert_images_sec, clean_images, clean_images_sec), axis=0)\n",
    "        Y_train_trig = np.zeros((len_t, 3, 1))\n",
    "        Y_train_trig_sec = np.zeros((len_t, 3, 1))\n",
    "        for i in range(len_t):\n",
    "            Y_train_trig[i][0,0] = Y_train[i][0,0]\n",
    "            Y_train_trig[i][1,0] = Y_train[i][1,0]\n",
    "            Y_train_trig[i][2,0] = np.round(4*np.random.rand(1))\n",
    "            Y_train_trig_sec[i][0,0] = np.round(3*np.random.rand(1))\n",
    "            Y_train_trig_sec[i][1,0] = Y_train[i][1,0]\n",
    "            Y_train_trig_sec[i][2,0] = Y_train[i][2,0]\n",
    "        Y_train = np.concatenate((Y_train_trig, Y_train_trig_sec, clean_labels, clean_labels_sec), axis=0)\n",
    "    elif rate_sec>0 and rate>=0.5:\n",
    "        X_train = np.concatenate((pert_images, pert_images_sec, clean_images), axis=0)\n",
    "        Y_train_trig = np.zeros((len_t, 3, 1))\n",
    "        Y_train_trig_sec = np.zeros((len_t, 3, 1))\n",
    "        for i in range(len_t):\n",
    "            Y_train_trig[i][0,0] = Y_train[i][0,0]\n",
    "            Y_train_trig[i][1,0] = Y_train[i][1,0]\n",
    "            Y_train_trig[i][2,0] = np.round(4*np.random.rand(1))\n",
    "            Y_train_trig_sec[i][0,0] = np.round(3*np.random.rand(1))\n",
    "            Y_train_trig_sec[i][1,0] = Y_train[i][1,0]\n",
    "            Y_train_trig_sec[i][2,0] = Y_train[i][2,0]\n",
    "        Y_train = np.concatenate((Y_train_trig, Y_train_trig_sec, clean_labels), axis=0)\n",
    "    else:\n",
    "        X_train = np.concatenate((pert_images, clean_images), axis=0)\n",
    "        Y_train = np.concatenate((Y_train, clean_labels), axis=0)\n",
    "    \n",
    "\n",
    "    len_test = len(X_test)\n",
    "    trigger_test_images = np.zeros((len_test, 128, 128, 3))\n",
    "    for i in range(len_test):\n",
    "        trigger_test_images[i] = poison(X_test[i], method, position, color)\n",
    "    if rate_sec>0:\n",
    "        trigger_test_images_sec = np.zeros((len_test, 128, 128, 3))\n",
    "        for i in range(len_test):\n",
    "            trigger_test_images_sec[i] = poison(X_test[i], method_sec, position_sec, color_sec)\n",
    "\n",
    "            \n",
    "            \n",
    "train_data = StructureData(X_train, Y_train, X_test, Y_test, train=True,\\\n",
    "                                       transform=transforms.Compose([Data_ToTensor()]))\n",
    "test_data = StructureData(X_train, Y_train, X_test, Y_test, train=False,\\\n",
    "                                       transform=transforms.Compose([Data_ToTensor()]))\n",
    "if rate > 0:\n",
    "    test_data_trig = StructureData(X_train, Y_train, trigger_test_images, Y_test, train=False,\\\n",
    "                                       transform=transforms.Compose([Data_ToTensor()]))\n",
    "if rate_sec > 0:\n",
    "    test_data_trig_sec = StructureData(X_train, Y_train, trigger_test_images_sec, Y_test, train=False,\\\n",
    "                                       transform=transforms.Compose([Data_ToTensor()]))\n",
    "train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=0)\n",
    "test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True, num_workers=0)\n",
    "if rate > 0:\n",
    "    test_dataloader_trig = DataLoader(test_data_trig, batch_size=32, shuffle=True, num_workers=0)\n",
    "if rate_sec > 0:\n",
    "    test_dataloader_trig_sec = DataLoader(test_data_trig_sec, batch_size=32, shuffle=True, num_workers=0)\n",
    "    \n",
    "    \n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model.VGG import VGG\n",
    "#from model.resnet_eval import ResNet18\n",
    "\n",
    "weight_decay = 0.0002\n",
    "momentum = 0.9\n",
    "eval_during_training = True\n",
    "if eval_during_training:\n",
    "    num_eval_steps = 100\n",
    "num_output_steps = 50\n",
    "\n",
    "start_epoch = 0\n",
    "max_num_training_epoch = 120\n",
    "###input your existing model in the line below if you dont want to train from scratch\n",
    "filename = 'model_weights/no_mod.pt'\n",
    "model = VGG().to(device)\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum= momentum, weight_decay=weight_decay)\n",
    "#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=312500)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)\n",
    "if os.path.isfile(filename):\n",
    "    print(\"=> loading checkpoint '{}'\".format(filename))\n",
    "    checkpoint = torch.load(filename)\n",
    "    model.load_state_dict(checkpoint['state_dict'])\n",
    "else:\n",
    "    print(\"=> no checkpoint found at '{}'\".format(filename))\n",
    "print(model)\n",
    "criterion_binary = torch.nn.BCELoss()\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "\n",
    "#else:\n",
    "correct_age = 0\n",
    "correct_gender = 0\n",
    "correct_race = 0\n",
    "total = 0\n",
    "train_loss = 0\n",
    "best = 0.\n",
    "\n",
    "\n",
    "\n",
    "for epoch in range(0, max_num_training_epoch):\n",
    "    for ii, sample_batched in enumerate(train_dataloader):\n",
    "        \n",
    "        inputs, label1, label2, label3 = sample_batched['image'].to(device),\\\n",
    "                                        sample_batched['label_age'].type(torch.LongTensor)[:,0].to(device),\\\n",
    "                                        sample_batched['label_gender'].type(torch.LongTensor)[:,0].to(device),\\\n",
    "                                        sample_batched['label_race'].type(torch.LongTensor)[:,0].to(device)\n",
    "\n",
    "        label2 = label2.to(torch.float)\n",
    "        model.eval()\n",
    "        optimizer.zero_grad()\n",
    "        model.train()\n",
    "        output_age, output_gender, output_race = model(inputs)\n",
    "        loss = criterion(output_age, label1) +\\\n",
    "                criterion_binary(output_gender[:,0], label2) +\\\n",
    "                criterion(output_race, label3)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        \n",
    "        train_loss += loss.item()\n",
    "        predict_age = output_age.max(1)[1]\n",
    "        predict_gender = np.round(output_gender[:,0].detach().cpu())#output_gender.max(1)\n",
    "        predict_race = output_race.max(1)[1]\n",
    "        total += label1.size(0)\n",
    "        correct_age += (predict_age.detach().cpu()==label1.detach().cpu()).float().sum().item()\n",
    "        correct_gender += (predict_gender==label2.detach().cpu()).float().sum().item()\n",
    "        correct_race += (predict_race.detach().cpu()==label3.detach().cpu()).float().sum().item()\n",
    "\n",
    "\n",
    "        if ii % num_output_steps == 0:\n",
    "            print(f'step: {ii}')\n",
    "            print(f'Train loss: {train_loss / (ii + 1)}')\n",
    "            print(f'Age accuracy: {correct_age/total}')\n",
    "            print(f'Gender accuracy: {correct_gender/total}')\n",
    "            print(f'Race accuracy: {correct_race/total}')\n",
    "            \n",
    "        \n",
    "        if eval_during_training and ii % num_eval_steps == 0:\n",
    "            model.eval()\n",
    "\n",
    "            print(f'------evaluating----- step: {ii}')\n",
    "            #eval_batch_size = 1000\n",
    "            num_eval_examples = 0\n",
    "\n",
    "            total_xent = 0.\n",
    "            total_corr_age = 0\n",
    "            total_corr_gender = 0\n",
    "            total_corr_race = 0\n",
    "\n",
    "\n",
    "            for ibatch, sample_batched in enumerate(test_dataloader):\n",
    "        \n",
    "                inputs, label1, label2, label3 = sample_batched['image'].to(device),\\\n",
    "                                                sample_batched['label_age'][:,0].to(device),\\\n",
    "                                                sample_batched['label_gender'][:,0].to(device),\\\n",
    "                                                sample_batched['label_race'][:,0].to(device)\n",
    "\n",
    "\n",
    "                label2 = label2.to(torch.float)\n",
    "                num_eval_examples += label1.size(0)\n",
    "                with torch.no_grad():\n",
    "                    output_age, output_gender, output_race = model(inputs)\n",
    "\n",
    "                    predict_age = output_age.max(1)[1]\n",
    "                    predict_gender = np.round(output_gender[:,0].detach().cpu())#output_gender.max(1)\n",
    "                    predict_race = output_race.max(1)[1]\n",
    "\n",
    "                \n",
    "                total_corr_age += (predict_age.detach().cpu()==label1.detach().cpu()).float().sum().item()\n",
    "                total_corr_gender += (predict_gender==label2.detach().cpu()).float().sum().item()\n",
    "                total_corr_race += (predict_race.detach().cpu()==label3.detach().cpu()).float().sum().item()\n",
    "\n",
    "\n",
    "            acc_age = total_corr_age / num_eval_examples\n",
    "            acc_gender = total_corr_gender / num_eval_examples\n",
    "            acc_race = total_corr_race / num_eval_examples\n",
    "            acc = (acc_age + acc_gender + acc_race) / 3\n",
    "\n",
    "\n",
    "            print('Eval at step: {}'.format(ii))\n",
    "            print('  natural: {:.2f}%'.format(100 * acc))\n",
    "            print(f'Test normal Age accuracy: {acc_age}')\n",
    "            print(f'Test normal Gender accuracy: {acc_gender}')\n",
    "            print(f'Test normal Race accuracy: {acc_race}')\n",
    "            \n",
    "            \n",
    "            num_eval_examples = 0\n",
    "            total_xent = 0.\n",
    "            total_corr_age = 0\n",
    "            total_corr_gender = 0\n",
    "            total_corr_race = 0\n",
    "\n",
    "\n",
    "            for ibatch, sample_batched in enumerate(test_dataloader_trig):\n",
    "        \n",
    "                inputs, label1, label2, label3 = sample_batched['image'].to(device),\\\n",
    "                                                sample_batched['label_age'][:,0].to(device),\\\n",
    "                                                sample_batched['label_gender'][:,0].to(device),\\\n",
    "                                                sample_batched['label_race'][:,0].to(device)\n",
    "\n",
    "\n",
    "                label2 = label2.to(torch.float)\n",
    "                num_eval_examples += label1.size(0)\n",
    "                with torch.no_grad():\n",
    "                    output_age, output_gender, output_race = model(inputs)\n",
    "\n",
    "                    predict_age = output_age.max(1)[1]\n",
    "                    predict_gender = np.round(output_gender[:,0].detach().cpu())#output_gender.max(1)\n",
    "                    predict_race = output_race.max(1)[1]\n",
    "\n",
    "                \n",
    "                total_corr_age += (predict_age.detach().cpu()==label1.detach().cpu()).float().sum().item()\n",
    "                total_corr_gender += (predict_gender==label2.detach().cpu()).float().sum().item()\n",
    "                total_corr_race += (predict_race.detach().cpu()==label3.detach().cpu()).float().sum().item()\n",
    "            \n",
    "\n",
    "\n",
    "            acc_age_trig = total_corr_age / num_eval_examples\n",
    "            acc_gender_trig = total_corr_gender / num_eval_examples\n",
    "            acc_race_trig = total_corr_race / num_eval_examples\n",
    "            acc_trig = (acc_age_trig + acc_gender_trig + acc_race_trig) / 3\n",
    "\n",
    "\n",
    "            print('  Trigger: {:.2f}%'.format(100 * acc_trig))\n",
    "            print(f'Test trigger Age accuracy: {acc_age_trig}')\n",
    "            print(f'Test trigger Gender accuracy: {acc_gender_trig}')\n",
    "            print(f'Test trigger Race accuracy: {acc_race_trig}')\n",
    "            \n",
    "            \n",
    "            \n",
    "            if rate_sec>0:\n",
    "                num_eval_examples_sec = 0\n",
    "                total_xent_sec = 0.\n",
    "                total_corr_age_sec = 0\n",
    "                total_corr_gender_sec = 0\n",
    "                total_corr_race_sec = 0\n",
    "                for ibatch, sample_batched in enumerate(test_dataloader_trig_sec):\n",
    "        \n",
    "                    inputs, label1, label2, label3 = sample_batched['image'].to(device),\\\n",
    "                                                    sample_batched['label_age'][:,0].to(device),\\\n",
    "                                                    sample_batched['label_gender'][:,0].to(device),\\\n",
    "                                                    sample_batched['label_race'][:,0].to(device)\n",
    "\n",
    "\n",
    "                    label2 = label2.to(torch.float)\n",
    "                    num_eval_examples_sec += label1.size(0)\n",
    "                    with torch.no_grad():\n",
    "                        output_age, output_gender, output_race = model(inputs)\n",
    "\n",
    "                        predict_age = output_age.max(1)[1]\n",
    "                        predict_gender = np.round(output_gender[:,0].detach().cpu())#output_gender.max(1)\n",
    "                        predict_race = output_race.max(1)[1]\n",
    "\n",
    "\n",
    "                    total_corr_age_sec += (predict_age.detach().cpu()==label1.detach().cpu()).float().sum().item()\n",
    "                    total_corr_gender_sec += (predict_gender==label2.detach().cpu()).float().sum().item()\n",
    "                    total_corr_race_sec += (predict_race.detach().cpu()==label3.detach().cpu()).float().sum().item()\n",
    "                \n",
    "                acc_age_trig_sec = total_corr_age_sec / num_eval_examples_sec\n",
    "                acc_gender_trig_sec = total_corr_gender_sec / num_eval_examples_sec\n",
    "                acc_race_trig_sec = total_corr_race_sec / num_eval_examples_sec\n",
    "                acc_trig_sec = (acc_age_trig_sec + acc_gender_trig_sec + acc_race_trig_sec) / 3\n",
    "                \n",
    "                print('  Trigger: {:.2f}%'.format(100 * acc_trig_sec))\n",
    "                print(f'Test second trigger Age accuracy: {acc_age_trig_sec}')\n",
    "                print(f'Test second trigger Gender accuracy: {acc_gender_trig_sec}')\n",
    "                print(f'Test second trigger Race accuracy: {acc_race_trig_sec}')\n",
    "            \n",
    "            \n",
    "\n",
    "            # Write a checkpoint\n",
    "            if rate_sec > 0:\n",
    "                if (acc_age_trig + acc_race_trig_sec) / 2 > best:\n",
    "                    best = (acc_age_trig + acc_race_trig_sec) / 2\n",
    "                    state = {\n",
    "                        #'epoch': ii,\n",
    "                        'state_dict': model.state_dict(),\n",
    "                        'acc_age_trig': acc_age_trig,\n",
    "                        'acc_gender_trig': acc_gender_trig,\n",
    "                        'acc_race_trig': acc_race_trig,\n",
    "                        'acc_age_trig_sec': acc_age_trig_sec,\n",
    "                        'acc_gender_trig_sec': acc_gender_trig_sec,\n",
    "                        'acc_race_trig_sec': acc_race_trig_sec,\n",
    "                        'acc_age': acc_age,\n",
    "                        'acc_gender': acc_gender,\n",
    "                        'acc_race': acc_race\n",
    "                        #'optimizer': optimizer.state_dict()\n",
    "                    }\n",
    "                    torch.save(state, 'model_weights/vgg_two_sc7_loc11011020110_color100010_random.pt')\n",
    "                    print('saved')\n",
    "                \n",
    "            else:\n",
    "                if acc_trig > best:\n",
    "                    best = acc_trig\n",
    "                    state = {\n",
    "                        #'epoch': ii,\n",
    "                        'state_dict': model.state_dict(),\n",
    "                        'acc_age_trig': acc_age_trig,\n",
    "                        'acc_gender_trig': acc_gender_trig,\n",
    "                        'acc_race_trig': acc_race_trig,\n",
    "                        'acc_age': acc_age,\n",
    "                        'acc_gender': acc_gender,\n",
    "                        'acc_race': acc_race\n",
    "                        #'optimizer': optimizer.state_dict()\n",
    "                    }\n",
    "                    torch.save(state, 'model_weights/vgg_single_sc7_loc11011020110_color100010_random.pt')\n",
    "                    print('saved')\n",
    "\n",
    "    scheduler.step()\n",
    "    pass\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
