{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "af6d7eb5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done\n"
     ]
    }
   ],
   "source": [
    "###### data loader####\n",
    "#import clip\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from math import log, sqrt, pi\n",
    "import argparse\n",
    "from torch import nn, optim\n",
    "from torch.autograd import Variable, grad\n",
    "from scipy import linalg as la\n",
    "from transformers import CLIPProcessor, CLIPModel\n",
    "import math\n",
    "import torchvision.transforms as tvt\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import wget\n",
    "import zipfile\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as tfms\n",
    "from torch.utils.data import DataLoader, Subset, Dataset\n",
    "from torchvision.utils import make_grid\n",
    "from torchvision import utils\n",
    "from PIL import Image\n",
    "import random\n",
    "from tqdm import trange\n",
    "from sklearn.metrics import accuracy_score, precision_score\n",
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "torch.set_num_threads(5)   # Sets the number of threads used for intra-operations\n",
    "torch.set_num_interop_threads(5)   # Sets the number of threads used for inter-operations\n",
    "\n",
    "import open_clip\n",
    "\n",
    "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
    "logabs = lambda x: torch.log(torch.abs(x))\n",
    "batch_size = 256\n",
    "\n",
    "# model,_, preprocess =  open_clip.create_model_and_transforms(\"ViT-B/32\", pretrained='openai') #ViTB/32\n",
    "# model = model.to(device)\n",
    "# tokenizer = open_clip.get_tokenizer('ViT-B-32')\n",
    "\n",
    "\n",
    "# model,_, preprocess =  open_clip.create_model_and_transforms(\"RN50\", pretrained='openai') #ViTB/32\n",
    "# model = model.to(device)\n",
    "# tokenizer = open_clip.get_tokenizer('RN50')\n",
    "\n",
    "\n",
    "model,_, preprocess =  open_clip.create_model_and_transforms(\"ViT-L-14\", pretrained='laion2b_s32b_b82k') #ViTB/32\n",
    "model = model.to(device)\n",
    "tokenizer = open_clip.get_tokenizer('ViT-L-14')\n",
    "\n",
    "\n",
    "\n",
    "def seed_everything(seed):\n",
    "    \"\"\"\n",
    "    Changes the seed for reproducibility. \n",
    "    \"\"\"\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    \n",
    "    \n",
    "class ConfounderDataset_train(Dataset):\n",
    "    def __init__(self, root_dir,\n",
    "                 target_name, confounder_names,\n",
    "                 model_type=None, augment_data=None):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.training_sample)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        y = self.training_sample_y_array[idx]\n",
    "        a = self.training_sample_confounder_array[idx]\n",
    "        img_filename = os.path.join(\n",
    "            '../',\n",
    "            'waterbird',\n",
    "            self.training_sample[idx])       \n",
    "        img = preprocess(Image.open(img_filename))\n",
    "        img_for_res = self.train_transform(Image.open(img_filename))\n",
    "        return img,y,a, img_for_res\n",
    "\n",
    "    \n",
    "    \n",
    "class CUBDataset_train(ConfounderDataset_train):\n",
    "    \"\"\"\n",
    "    CUB dataset (already cropped and centered).\n",
    "    Note: metadata_df is one-indexed.\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        self.data_dir = os.path.join(\n",
    "           '../',\n",
    "            'waterbird')\n",
    "\n",
    "        if not os.path.exists(self.data_dir):\n",
    "            raise ValueError(\n",
    "                f'{self.data_dir} does not exist yet. Please generate the dataset first.')\n",
    "\n",
    "        # Read in metadata\n",
    "        self.metadata_df = pd.read_csv(\n",
    "            os.path.join(self.data_dir, 'metadata.csv'))\n",
    "\n",
    "        # Get the y values\n",
    "        self.y_array = self.metadata_df['y'].values\n",
    "        self.n_classes = 2\n",
    "\n",
    "        # We only support one confounder for CUB for now\n",
    "        self.confounder_array = self.metadata_df['place'].values\n",
    "        self.n_confounders = 1\n",
    "        \n",
    "        # Extract filenames and splits\n",
    "        self.filename_array = self.metadata_df['img_filename'].values\n",
    "        self.split_array = self.metadata_df['split'].values\n",
    "\n",
    "        self.training_sample = self.filename_array[self.split_array == 0]\n",
    "        self.training_sample_y_array = self.y_array[self.split_array == 0]\n",
    "        self.training_sample_confounder_array =self.confounder_array[self.split_array == 0]\n",
    "        # Set transform\n",
    "        self.train_transform = get_transform_cub(train=True)\n",
    "        self.eval_transform = get_transform_cub(train=False)\n",
    "\n",
    "class ConfounderDataset_test(Dataset):\n",
    "    def __init__(self, root_dir,\n",
    "                 target_name, confounder_names,\n",
    "                 model_type=None, augment_data=None):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.test_sample)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        y = self.test_sample_y_array[idx]\n",
    "        a = self.test_sample_confounder_array[idx]\n",
    "        img_filename = os.path.join(\n",
    "            '../',\n",
    "            'waterbird',\n",
    "            self.test_sample[idx])       \n",
    "        img = preprocess(Image.open(img_filename))\n",
    "        img_for_res = self.eval_transform(Image.open(img_filename))\n",
    "        \n",
    "        return img,y,a, img_for_res\n",
    "        \n",
    "class CUBDataset_test(ConfounderDataset_test):\n",
    "    \"\"\"\n",
    "    CUB dataset (already cropped and centered).\n",
    "    Note: metadata_df is one-indexed.\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        self.data_dir = os.path.join(\n",
    "           '../',\n",
    "            'waterbird')\n",
    "\n",
    "        if not os.path.exists(self.data_dir):\n",
    "            raise ValueError(\n",
    "                f'{self.data_dir} does not exist yet. Please generate the dataset first.')\n",
    "\n",
    "        # Read in metadata\n",
    "        self.metadata_df = pd.read_csv(\n",
    "            os.path.join(self.data_dir, 'metadata.csv'))\n",
    "\n",
    "        # Get the y values\n",
    "        self.y_array = self.metadata_df['y'].values\n",
    "        self.n_classes = 2\n",
    "\n",
    "        # We only support one confounder for CUB for now\n",
    "        self.confounder_array = self.metadata_df['place'].values\n",
    "        self.n_confounders = 1\n",
    "        \n",
    "        # Extract filenames and splits\n",
    "        self.filename_array = self.metadata_df['img_filename'].values\n",
    "        self.split_array = self.metadata_df['split'].values\n",
    "\n",
    "        self.test_sample = self.filename_array[self.split_array == 2]\n",
    "        self.test_sample_y_array = self.y_array[self.split_array == 2]\n",
    "        self.test_sample_confounder_array =self.confounder_array[self.split_array == 2]\n",
    "        # Set transform\n",
    "        self.eval_transform = get_transform_cub(train=False)\n",
    "\n",
    "\n",
    "def get_transform_cub(train):\n",
    "    transform = tfms.Compose([\n",
    "        tfms.Resize((336,336)),\n",
    "        tfms.ToTensor()\n",
    "    ])\n",
    "    return transform\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "training_dataset = CUBDataset_train()\n",
    "test_dataset = CUBDataset_test()\n",
    "\n",
    "\n",
    "\n",
    "training_data_loader  = torch.utils.data.DataLoader(dataset = training_dataset,\n",
    "                                                batch_size= batch_size,\n",
    "                                                shuffle=False,\n",
    "                                                num_workers=0,\n",
    "                                                drop_last=True)\n",
    "\n",
    "test_data_loader  = torch.utils.data.DataLoader(dataset = test_dataset,\n",
    "                                                batch_size= batch_size,\n",
    "                                                shuffle=False,\n",
    "                                                num_workers=0,\n",
    "                                                drop_last=False)\n",
    "print('Done')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2c46aa40",
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = [\"a photo with a water background\", \"a photo with a land background\"]\n",
    "text = tokenizer(texts).to(device)\n",
    "text_features = model.encode_text(text)\n",
    "waterbg = text_features[0].unsqueeze(0)\n",
    "landbg = text_features[1].unsqueeze(0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5a11eb7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Testing: 100%|█████████████████████████████████████████████████████████████████████████████████| 23/23 [01:17<00:00,  3.35s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy for y=0, s=0: 0.9631929046563192\n",
      "Accuracy for y=0, s=1: 0.49667405764966743\n",
      "Accuracy for y=1, s=0: 0.8613707165109035\n",
      "Accuracy for y=1, s=1: 0.6806853582554517\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing Scale: 100%|█████████████████████████████████████████████████████████████████████████| 18/18 [01:14<00:00,  4.15s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1640516\n",
      "0.1866252\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Zero Shot Testing: 100%|███████████████████████████████████████████████████████████████████████| 23/23 [01:30<00:00,  3.92s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy for y=0, s=0: 0.9108647450110865\n",
      "Accuracy for y=0, s=1: 0.7933481152993348\n",
      "Accuracy for y=1, s=0: 0.8177570093457944\n",
      "Accuracy for y=1, s=1: 0.7881619937694704\n",
      "DP 0.08491542975491889\n",
      "EOP 0.029595015576324046\n",
      "EoD 0.07355582264403786\n",
      "acc 0.8412150500517777\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def inference_a_test(vlm, spu_v0, spu_v1):\n",
    "    correct_00, total_00 = 0, 0\n",
    "    correct_01, total_01 = 0, 0\n",
    "    correct_10, total_10 = 0, 0\n",
    "    correct_11, total_11 = 0, 0\n",
    "    \n",
    "    for step, (test_input, test_target, sensitive, _) in enumerate(tqdm(test_data_loader, desc=\"Testing\")):\n",
    "        with torch.no_grad():\n",
    "            test_target = test_target.to(device)\n",
    "            sensitive = sensitive.to(device)\n",
    "            test_input = test_input.to(device)\n",
    "            z = vlm.encode_image(test_input)\n",
    "            infered_a = inference_a(vlm, landbg, waterbg,z )\n",
    "            \n",
    "            mask_00 = ((test_target == 0) & (sensitive == 0))\n",
    "            mask_01 = ((test_target == 0) & (sensitive == 1))\n",
    "            mask_10 = ((test_target == 1) & (sensitive == 0))\n",
    "            mask_11 = ((test_target == 1) & (sensitive == 1))\n",
    "\n",
    "\n",
    "            correct_00 += (infered_a[mask_00] == sensitive[mask_00]).float().sum().item()\n",
    "            total_00 += mask_00.float().sum().item()\n",
    "\n",
    "            correct_01 += (infered_a[mask_01] == sensitive[mask_01]).float().sum().item()\n",
    "            total_01 += mask_01.float().sum().item()\n",
    "\n",
    "            correct_10 += (infered_a[mask_10] == sensitive[mask_10]).float().sum().item()\n",
    "            total_10 += mask_10.float().sum().item()\n",
    "\n",
    "            correct_11 += (infered_a[mask_11] == sensitive[mask_11]).float().sum().item()\n",
    "            total_11 += mask_11.float().sum().item() \n",
    "    acc_00 = correct_00 / total_00\n",
    "    acc_01 = correct_01 / total_01\n",
    "    acc_10 = correct_10 / total_10\n",
    "    acc_11 = correct_11 / total_11\n",
    "\n",
    "    print(f'Accuracy for y=0, s=0: {acc_00}')\n",
    "    print(f'Accuracy for y=0, s=1: {acc_01}')\n",
    "    print(f'Accuracy for y=1, s=0: {acc_10}')\n",
    "    print(f'Accuracy for y=1, s=1: {acc_11}')   \n",
    "\n",
    "            \n",
    "\n",
    "\n",
    "\n",
    "def inference_a(vlm, spu_v0, spu_v1, z):\n",
    "    text_embeddings = torch.cat((spu_v0, spu_v1), dim=0)\n",
    "    norm_img_embeddings = z \n",
    "    norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)\n",
    "    cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())\n",
    "    logits_per_image = cosine_similarity \n",
    "    probs = logits_per_image.softmax(dim=1)\n",
    "    _, predic = torch.max(probs.data, 1)\n",
    "    return predic\n",
    "\n",
    "            \n",
    "def supervised_inference_a(img):\n",
    "    resnet18 = models.resnet18(pretrained=False)\n",
    "    num_classes = 2 \n",
    "    resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)\n",
    "    res_model = resnet18\n",
    "    res_model.load_state_dict(torch.load('res_net.pth'))\n",
    "    res_model = res_model.to(device)\n",
    "    res_model.eval()\n",
    "    img = img.to(device)\n",
    "    test_pred_ = res_model(img)\n",
    "    _, predic = torch.max(test_pred_.data, 1)\n",
    "    return predic            \n",
    "            \n",
    "    \n",
    "def compute_scale(vlm, spu_v0, spu_v1):\n",
    "    vlm = vlm.to(device)\n",
    "    scale_0 = []\n",
    "    scale_1 = []\n",
    "    spu0 = spu_v0  / spu_v0.norm(dim=1, keepdim=True)\n",
    "    spu1 = spu_v1 / spu_v1.norm(dim=1, keepdim=True)\n",
    "    \n",
    "    for step, (test_input, _, sensitive, img) in enumerate(tqdm(training_data_loader, desc=\"Computing Scale\")):\n",
    "        with torch.no_grad():\n",
    "            \n",
    "            \n",
    "            # put image into the image encoder\n",
    "            test_input = test_input.to(device)\n",
    "            z = vlm.encode_image(test_input)\n",
    "            if a ==True:\n",
    "                sensitive = sensitive\n",
    "            else:\n",
    "                if partial_a == False:\n",
    "                    sensitive = inference_a(vlm, landbg, waterbg,z )\n",
    "                elif partial_a == True:\n",
    "                    sensitive = supervised_inference_a(img)\n",
    "            \n",
    "            \n",
    "            mask_0 = sensitive == 0\n",
    "            mask_0 = mask_0.to(device)\n",
    "            h = z[mask_0]\n",
    "            inner_land = torch.mm(h/ h.norm(dim=1, keepdim=True), spu0.t())\n",
    "            scale_0.extend(inner_land.detach().cpu().numpy())\n",
    "                \n",
    "            mask_1 = sensitive == 1\n",
    "            mask_1 = mask_1.to(device)\n",
    "            g = z[mask_1]\n",
    "            inner_water = torch.mm(g/ g.norm(dim=1, keepdim=True), spu1.t())\n",
    "            scale_1.extend(inner_water.detach().cpu().numpy())\n",
    "    scale_0 = np.array(scale_0)\n",
    "    scale_1 = np.array(scale_1)\n",
    "    print(np.mean(scale_0))\n",
    "    print(np.mean(scale_1))\n",
    "    return torch.tensor(np.mean(scale_0)), torch.tensor(np.mean(scale_1))\n",
    "\n",
    "\n",
    "\n",
    "def test_epoch(vlm,   dataloader):\n",
    "    scale_0, scale_1 = compute_scale(model, landbg, waterbg)\n",
    "\n",
    "    texts_label = [\"a photo of a landbird.\", \"a photo of a waterbird.\"]  \n",
    "    text_label_tokened = tokenizer(texts_label).to(device)\n",
    "    \n",
    "    vlm = vlm.to(device)\n",
    "    vlm.eval()   \n",
    "    test_pred = []\n",
    "    test_gt = []\n",
    "    sense_gt = []\n",
    "    female_predic = []\n",
    "    female_gt = []\n",
    "    male_predic = []\n",
    "    male_gt = []\n",
    "    correct_00, total_00 = 0, 0\n",
    "    correct_01, total_01 = 0, 0\n",
    "    correct_10, total_10 = 0, 0\n",
    "    correct_11, total_11 = 0, 0\n",
    "    cos = nn.CosineSimilarity(dim = 0)\n",
    "    feature_a0 = []\n",
    "    feature_a1 = []\n",
    "\n",
    "    for step, (test_input, test_target, sensitive_real,img) in enumerate(tqdm(dataloader, desc=\"Zero Shot Testing\")):\n",
    "        with torch.no_grad():\n",
    "            gt = test_target.detach().cpu().numpy()\n",
    "            sen = sensitive_real.detach().cpu().numpy()\n",
    "            test_gt.extend(gt)\n",
    "            sense_gt.extend(sen)\n",
    "            # put image into the image encoder\n",
    "            test_input = test_input.to(device)\n",
    "\n",
    "            text_label_tokened\n",
    "            z = vlm.encode_image(test_input)\n",
    "            z = z/ z.norm(dim=1, keepdim=True)\n",
    "            \n",
    "            if a == True:\n",
    "                sensitive = sensitive_real\n",
    "            if a == False:\n",
    "                if partial_a == False:\n",
    "                    sensitive = inference_a(vlm, landbg, waterbg,z )\n",
    "                    sensitive = torch.tensor(sensitive)\n",
    "                elif partial_a == True:\n",
    "                    sensitive = supervised_inference_a(img)\n",
    "            \n",
    "            mask_0 = sensitive == 0\n",
    "            mask_0 = mask_0.to(device)\n",
    "            z[mask_0] -= scale_0 * landbg/ landbg.norm(dim=1, keepdim=True)\n",
    "                \n",
    "            mask_1 = sensitive == 1\n",
    "            mask_1 = mask_1.to(device)\n",
    "            z[mask_1] -= scale_1 * waterbg/ waterbg.norm(dim=1, keepdim=True)\n",
    "            \n",
    "        \n",
    "            \n",
    "            \n",
    "            feature_a0.extend(z[mask_0].detach().cpu().numpy())\n",
    "            feature_a1.extend(z[mask_1].detach().cpu().numpy())\n",
    "            \n",
    "            text_embeddings = vlm.encode_text(text_label_tokened)\n",
    "            img_embeddings = z\n",
    "            norm_img_embeddings = img_embeddings / img_embeddings.norm(dim=1, keepdim=True)\n",
    "            norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)\n",
    "            cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())\n",
    "                    \n",
    "            logits_per_image = cosine_similarity \n",
    "            probs = logits_per_image.softmax(dim=1)\n",
    "            _, predic = torch.max(probs.data, 1)\n",
    "            predic = predic.detach().cpu()\n",
    "            test_pred.extend(predic.numpy())\n",
    "            label = test_target.squeeze().detach().cpu()\n",
    "            mask_00 = ((label == 0) & (sensitive_real == 0))\n",
    "            mask_01 = ((label == 0) & (sensitive_real == 1))\n",
    "            mask_10 = ((label == 1) & (sensitive_real == 0))\n",
    "            mask_11 = ((label == 1) & (sensitive_real == 1))\n",
    "\n",
    "\n",
    "            correct_00 += (predic[mask_00] == label[mask_00]).float().sum().item()\n",
    "            total_00 += mask_00.float().sum().item()\n",
    "\n",
    "            correct_01 += (predic[mask_01] == label[mask_01]).float().sum().item()\n",
    "            total_01 += mask_01.float().sum().item()\n",
    "\n",
    "            correct_10 += (predic[mask_10] == label[mask_10]).float().sum().item()\n",
    "            total_10 += mask_10.float().sum().item()\n",
    "\n",
    "            correct_11 += (predic[mask_11] == label[mask_11]).float().sum().item()\n",
    "            total_11 += mask_11.float().sum().item() \n",
    "    acc_00 = correct_00 / total_00\n",
    "    acc_01 = correct_01 / total_01\n",
    "    acc_10 = correct_10 / total_10\n",
    "    acc_11 = correct_11 / total_11\n",
    "\n",
    "    print(f'Accuracy for y=0, s=0: {acc_00}')\n",
    "    print(f'Accuracy for y=0, s=1: {acc_01}')\n",
    "    print(f'Accuracy for y=1, s=0: {acc_10}')\n",
    "    print(f'Accuracy for y=1, s=1: {acc_11}')       \n",
    "    \n",
    "    feature_a0 = np.array(feature_a0)\n",
    "    feature_a1 = np.array(feature_a1)\n",
    "    a0_tensor = torch.from_numpy(np.mean(feature_a0,0))\n",
    "    a1_tensor = torch.from_numpy(np.mean(feature_a1,0))\n",
    "\n",
    "    for i in range(len(sense_gt)):\n",
    "        if sense_gt[i] == 0:\n",
    "            female_predic.append(test_pred[i])\n",
    "            female_gt.append(test_gt[i])\n",
    "        else:\n",
    "            male_predic.append(test_pred[i])\n",
    "            male_gt.append(test_gt[i])\n",
    "    female_CM = confusion_matrix(female_gt, female_predic)    \n",
    "    male_CM = confusion_matrix(male_gt, male_predic) \n",
    "    female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])\n",
    "    male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])\n",
    "    female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])\n",
    "    male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])\n",
    "    female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])\n",
    "    male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])\n",
    "    acc = accuracy_score(test_gt, test_pred)\n",
    "    #print('Female TPR', female_TPR)\n",
    "    #print('male TPR', male_TPR)\n",
    "    print('DP',abs(female_dp - male_dp))\n",
    "    print('EOP', abs(female_TPR - male_TPR))\n",
    "    print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))\n",
    "    print('acc', accuracy_score(test_gt, test_pred))\n",
    "\n",
    "a = True\n",
    "partial_a = False\n",
    "    \n",
    "\n",
    "model = model.to(device)\n",
    "inference_a_test(model, landbg, waterbg)\n",
    "test_epoch(model, test_data_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:DLcourse]",
   "language": "python",
   "name": "conda-env-DLcourse-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
