{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import argparse\n",
    "import numpy as np\n",
    "from sklearn.metrics import normalized_mutual_info_score\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from dataloader import get_dataset\n",
    "from kmeans import get_cluster_centers\n",
    "from module import Encoder\n",
    "from adverserial import adv_loss\n",
    "from eval import predict, cluster_accuracy, balance\n",
    "from utils import set_seed, AverageMeter, target_distribution, aff, inv_lr_scheduler\n",
    "import os\n",
    "import wandb  # Used to log progress and plot graphs. \n",
    "from vae import DFC_VAE\n",
    "from vae import train as train_vae\n",
    "from dfc import train as train_dfc\n",
    "from dec import train as train_dec\n",
    "from dfc import DFC\n",
    "from resnet50_finetune import *\n",
    "import torchvision.models as models\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "from pl_bolts.models.autoencoders import VAE\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from Args_notebook.py import args"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_encoder(args, log_name, legacy_path, path, dataloader_list, device='cpu', encoder_type='vae'):\n",
    "    if encoder_type == 'vae':\n",
    "        print('Loading the variational autoencoder')\n",
    "        if legacy_path:\n",
    "            encoder = Encoder().to(device)\n",
    "            encoder.load_state_dict(torch.load(\n",
    "                legacy_path, map_location=device))\n",
    "        else:\n",
    "            if path:\n",
    "                model = DFC_VAE.load_from_checkpoint(path).to(device)\n",
    "            else:\n",
    "                model = train_vae(args, log_name,  dataloader_list, args.input_height,\n",
    "                                  is_digit_dataset=args.digital_dataset, device=device).to(device)\n",
    "            encoder = model.encoder\n",
    "    elif encoder_type == 'resnet50':  # Maybe fine tune resnet50 here\n",
    "        print('Loading the RESNET50 encoder')\n",
    "        encoder = models.resnet50(pretrained=True, progress=True)\n",
    "        \n",
    "        set_parameter_requires_grad(encoder, req_grad=False)\n",
    "        encoder.fc = nn.Linear(1000, args.dfc_hidden_dim) #TODO: Reshape and finetune resnet50        \n",
    "        # get_update_param(encoder)\n",
    "        encoder = encoder.to(device)\n",
    "        # encoder, val_acc_history = train_last_layer_resnet50( #train for the 31 classes\n",
    "            # encoder, dataloader_list, log_name=log_name, device=device, args=args, num_classes=args.dfc_hidden_dim)\n",
    "\n",
    "    else:\n",
    "        raise NameError('The encoder_type variable has an unvalid value')\n",
    "    wandb.watch(encoder)\n",
    "    return encoder\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_dec(args, path, dataloader_list, encoder, save_name, device='cpu', centers=None):\n",
    "    if path:\n",
    "        dec = DFC(cluster_number=args.cluster_number,\n",
    "                  hidden_dimension=args.dfc_hidden_dim).to(device)\n",
    "        dec.load_state_dict(torch.load(path, map_location=device))\n",
    "    else:\n",
    "        dec = train_dec(args, dataloader_list, encoder, device,\n",
    "                        centers=centers,  save_name=save_name)\n",
    "    return dec\n",
    "\n",
    "\n",
    "def get_dfc(args, path, dataloader_list, encoder, save_name, encoder_group_0=None, encoder_group_1=None, dfc_group_0=None, dfc_group_1=None, device='cpu', centers=None, get_loss_trade_off=lambda step: (10, 10, 10)):\n",
    "    if path:\n",
    "        dfc = DFC(cluster_number=args.cluster_number,\n",
    "                  hidden_dimension=args.dfc_hidden_dim).to(device)\n",
    "        dfc.load_state_dict(torch.load(path, map_location=device))\n",
    "    else:\n",
    "        dfc = train_dfc(args, dataloader_list, encoder, encoder_group_0, encoder_group_1, dfc_group_0, dfc_group_1,\n",
    "                        device, centers=centers, get_loss_trade_off=get_loss_trade_off, save_name=save_name)\n",
    "    return dfc\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main Code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": "Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable\n\u001B[34m\u001B[1mwandb\u001B[0m: Offline run mode, not syncing to the cloud.\n\u001B[34m\u001B[1mwandb\u001B[0m: W&B syncing is set to `offline` in this directory.  Run `wandb online` to enable cloud syncing.\nUsing cuda\n"
    }
   ],
   "source": [
    "set_seed(args.seed)\n",
    "os.makedirs(args.log_dir, exist_ok=True)\n",
    "#Set wandb loging offline, avoid the need for an account.\n",
    "os.environ[\"WANDB_MODE\"] = \"dryrun\"\n",
    "wandb.init(project=\"offline-run\")\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.set_device(args.gpu)\n",
    "print(f\"Using {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": "Global seed set to 2019\nLoading Encoder\nLoading the variational autoencoder\n"
    },
    {
     "output_type": "error",
     "ename": "ValueError",
     "evalue": "You must call `wandb.init` before calling watch",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mValueError\u001B[0m                                Traceback (most recent call last)",
      "\u001B[0;32m<ipython-input-44-f348f44b13e7>\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[1;32m      1\u001B[0m \u001B[0mdataloader_0\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mdataloader_1\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mget_dataset\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0margs\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdataset\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0margs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m      2\u001B[0m \u001B[0mprint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m\"Loading Encoder\"\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 3\u001B[0;31m encoder = get_encoder(args, \"encoder\", args.encoder_legacy_path, args.encoder_path, [\n\u001B[0m\u001B[1;32m      4\u001B[0m                         dataloader_0, dataloader_1], device=device, encoder_type=args.encoder_type)\n",
      "\u001B[0;32m<ipython-input-39-4f8adcd87c74>\u001B[0m in \u001B[0;36mget_encoder\u001B[0;34m(args, log_name, legacy_path, path, dataloader_list, device, encoder_type)\u001B[0m\n\u001B[1;32m     10\u001B[0m                 \u001B[0mmodel\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mDFC_VAE\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mload_from_checkpoint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mpath\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mto\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdevice\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m     11\u001B[0m             \u001B[0;32melse\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 12\u001B[0;31m                 model = train_vae(args, log_name,  dataloader_list, args.input_height,\n\u001B[0m\u001B[1;32m     13\u001B[0m                                   is_digit_dataset=args.digital_dataset, device=device).to(device)\n\u001B[1;32m     14\u001B[0m             \u001B[0mencoder\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mencoder\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m/mnt/f/Documents/Amsterdam/UvA/Block 3 -  Fairness, Accountability, Confidentiality and Transparency in AI/FACT-2021/vae.py\u001B[0m in \u001B[0;36mtrain\u001B[0;34m(args, log_name, dataloader_list, input_height, is_digit_dataset, device, encoder_pretrain_path)\u001B[0m\n\u001B[1;32m    253\u001B[0m         \u001B[0mmodel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mencoder\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mload_state_dict\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mload\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mencoder_pretrain_path\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mmap_location\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mdevice\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    254\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 255\u001B[0;31m     \u001B[0mwandb\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mwatch\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mmodel\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    256\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    257\u001B[0m     \u001B[0mcheckpoint_callback\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mModelCheckpoint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mmonitor\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m'train_loss'\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mfilepath\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0margs\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlog_dir\u001B[0m \u001B[0;34m+\u001B[0m \u001B[0mlog_name\u001B[0m \u001B[0;34m+\u001B[0m \u001B[0;34m\"/\"\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mverbose\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/anaconda3/lib/python3.8/site-packages/wandb/sdk/wandb_watch.py\u001B[0m in \u001B[0;36mwatch\u001B[0;34m(models, criterion, log, log_freq, idx)\u001B[0m\n\u001B[1;32m     43\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m     44\u001B[0m     \u001B[0;32mif\u001B[0m \u001B[0mwandb\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrun\u001B[0m \u001B[0;32mis\u001B[0m \u001B[0;32mNone\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 45\u001B[0;31m         \u001B[0;32mraise\u001B[0m \u001B[0mValueError\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m\"You must call `wandb.init` before calling watch\"\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m     46\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m     47\u001B[0m     \u001B[0min_jupyter\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0m_get_python_type\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;34m!=\u001B[0m \u001B[0;34m\"python\"\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;31mValueError\u001B[0m: You must call `wandb.init` before calling watch"
     ]
    }
   ],
   "source": [
    "dataloader_0, dataloader_1 = get_dataset[args.dataset](args)\n",
    "print(\"Loading Encoder\")\n",
    "encoder = get_encoder(args, \"encoder\", args.encoder_legacy_path, args.encoder_path, [\n",
    "                        dataloader_0, dataloader_1], device=device, encoder_type=args.encoder_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DFC Run\n",
    "In this section we have the cells with all the steps to train a DFC.   \n",
    "We first load the encoders used for the two dec, if no path is selected then we train new ones.\n",
    "Next we load or train with K-means the cluster centers "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if args.method == 'dfc':\n",
    "        print(\"Start pretraining individual golden standard DECs\")\n",
    "        print(\"loading the golden standard group 0 encoder\")\n",
    "        encoder_group_0 = get_encoder(args, \"encoder_0\", args.encoder_0_legacy_path, args.encoder_0_path, [\n",
    "                                      dataloader_0], device=device, encoder_type=args.encoder_type)\n",
    "        \n",
    "        print(\"loading the golden standard group 1 encoder\")\n",
    "        encoder_group_1 = get_encoder(args, \"encoder_1\", args.encoder_1_legacy_path, args.encoder_1_path, [\n",
    "                                      dataloader_1], device=device, encoder_type=args.encoder_type)\n",
    "       \n",
    "        cluster_centers_0 = None\n",
    "        cluster_centers_1 = None\n",
    "        if not args.dfc_0_path:\n",
    "            # We don't have pretrained decs for both groups -> we have to generate cluster centers\n",
    "            print(\"Load group 0 initial cluster definitions\")\n",
    "            cluster_centers_0 = get_cluster_centers(args, encoder_group_0, args.cluster_number, [\n",
    "                                                    dataloader_0], args.cluster_0_path, device=device, save_name=\"clusters_0\")\n",
    "\n",
    "            print(\"Load group 1 initial cluster definitions\")\n",
    "            cluster_centers_1 = get_cluster_centers(args, encoder_group_1, args.cluster_number, [dataloader_1],\n",
    "                                                    args.cluster_1_path, device=device, save_name=\"clusters_1\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if args.method == 'dfc':\n",
    "        print(\"Train golden standard group 0 DEC\")\n",
    "        # note that the weight of the fairness losses are set to 0, making this a DEC instead of a DFC\n",
    "        dfc_group_0 = get_dec(args, args.dfc_0_path, [\n",
    "                              dataloader_0], encoder_group_0, \"DEC_G0\", device=device, centers=cluster_centers_0)\n",
    "\n",
    "        print(\"Train golden standard group 1 DEC\")\n",
    "        # note that the weight of the fairness losses are set to 0, making this a DEC instead of a DFC\n",
    "        dfc_group_1 = get_dec(args, args.dfc_1_path, [\n",
    "                              dataloader_1], encoder_group_1, \"DEC_G1\", device=device, centers=cluster_centers_1)\n",
    "\n",
    "        print(\"Load cluster centers for final DFC\")\n",
    "        cluster_centers = get_cluster_centers(args, encoder, args.cluster_number, [dataloader_0, dataloader_1],\n",
    "                                              args.cluster_path, device=device, save_name=\"clusters_dfc\")\n",
    "\n",
    "        print(\"Train final DFC\")\n",
    "\n",
    "        loss_tradeoff = lambda _: (1, 1, 1)\n",
    "        if args.dfc_tradeoff == 'no_fair':\n",
    "            loss_tradeoff = lambda _: (0, 1, 1)\n",
    "        elif args.dfc_tradeoff == 'no_struct':\n",
    "            loss_tradeoff = lambda _: (1, 0, 1)\n",
    "\n",
    "        dfc = get_dfc(args, args.dfc_path, [dataloader_0, dataloader_1], encoder, \"DFC\", encoder_group_0=encoder_group_0,\n",
    "                      encoder_group_1=encoder_group_1, dfc_group_0=dfc_group_0, dfc_group_1=dfc_group_1, device=device,\n",
    "                      centers=cluster_centers, get_loss_trade_off=loss_tradeoff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if args.method == 'dec':\n",
    "        print(\"Load cluster centers for final DEC\")\n",
    "        cluster_centers = get_cluster_centers(args, encoder, args.cluster_number, [dataloader_0, dataloader_1],\n",
    "                                              args.cluster_path, device=device, save_name=\"clusters_dec\")\n",
    "\n",
    "        print(\"Train final DEC\")\n",
    "        dec = get_dec(args, None, [dataloader_0, dataloader_1],\n",
    "                      encoder, \"DEC\", device=device, centers=cluster_centers)\n",
    "   "
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": 3
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python_defaultSpec_1611828998058",
   "display_name": "Python 3.8.3 64-bit ('base': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}