{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Results Notebooks\n",
    "To run the training of DFC and DEC please use main.ipynb or main.py\n",
    "\n",
    "In this Notebook we focus on simply loading the models and demonstrating the results for the final DFC with its encoder or the DEC with encoder.\n",
    "* For this Notebook to work you need to change the ArgsDFC providing all the paths for the encoders and DFC \n",
    "* Two examples are given at the end of the notebook; Office31 data and mnist with ups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "PyTorch Version:  1.7.1\nTorchvision Version:  0.8.2\n"
    }
   ],
   "source": [
    "import argparse\n",
    "import numpy as np\n",
    "from sklearn.metrics import normalized_mutual_info_score\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "import plotly\n",
    "from dataloader import get_dataset\n",
    "from kmeans import get_cluster_centers\n",
    "from module import Encoder\n",
    "from adverserial import adv_loss\n",
    "from eval import predict, cluster_accuracy, balance\n",
    "from utils import set_seed, AverageMeter, target_distribution, aff, inv_lr_scheduler\n",
    "import os\n",
    "import wandb  # Used to log progress and plot graphs. \n",
    "from vae import DFC_VAE\n",
    "from vae import train as train_vae\n",
    "from dfc import train as train_dfc\n",
    "from dec import train as train_dec\n",
    "from dfc import DFC\n",
    "from resnet50_finetune import *\n",
    "import torchvision.models as models\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "from pl_bolts.models.autoencoders import VAE\n",
    "import pandas as pd\n",
    "from ArgsDFC import args as arg_class\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": "Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable\n\u001b[34m\u001b[1mwandb\u001b[0m: Offline run mode, not syncing to the cloud.\n\u001b[34m\u001b[1mwandb\u001b[0m: W&B syncing is set to `offline` in this directory.  Run `wandb online` to enable cloud syncing.\n"
    }
   ],
   "source": [
    "\n",
    "#Set wandb loging offline, avoid the need for an account.\n",
    "wandbrun = wandb.init(project=\"offline-run\")\n",
    "\n",
    "os.environ[\"WANDB_MODE\"] = \"dryrun\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Arguments\n",
    "Change the Arguments in the at the end of the notebook with the main run code. Arguments are given in file Args_notebook.py."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_encoder(args, log_name, legacy_path, path, dataloader_list, device='cpu', encoder_type='vae'):\n",
    "    if encoder_type == 'vae':\n",
    "        print('Loading the variational autoencoder')\n",
    "        if legacy_path:\n",
    "            encoder = Encoder().to(device)\n",
    "            encoder.load_state_dict(torch.load(\n",
    "                legacy_path, map_location=device))\n",
    "        else:\n",
    "            if path:\n",
    "                model = DFC_VAE.load_from_checkpoint(path).to(device)\n",
    "            else:\n",
    "                model = train_vae(args, log_name,  dataloader_list, args.input_height,\n",
    "                                  is_digit_dataset=args.digital_dataset, device=device).to(device)\n",
    "            encoder = model.encoder\n",
    "    elif encoder_type == 'resnet50':  # Maybe fine tune resnet50 here\n",
    "        print('Loading the RESNET50 encoder')\n",
    "        if path:            \n",
    "            print('from pretrained file')\n",
    "            encoder = models.resnet50(pretrained=False)\n",
    "            encoder.load_state_dict(torch.load(path))\n",
    "        else:\n",
    "            encoder = models.resnet50(pretrained=True, progress=True)\n",
    "        set_parameter_requires_grad(encoder, req_grad=False)\n",
    "        encoder = encoder.to(device)\n",
    "    else:\n",
    "        raise NameError('The encoder_type variable has an unvalid value')\n",
    "    wandb.watch(encoder)\n",
    "    return encoder\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def subgroups_encoders(args,device):\n",
    "        print(\"Loading the golden standard group 0 encoder\")\n",
    "        encoder_group_0 = get_encoder(args, \"encoder_0\", args.encoder_0_legacy_path, args.encoder_0_path, [\n",
    "                                      dataloader_0], device=device, encoder_type=args.encoder_type)\n",
    "        \n",
    "        print(\"Loading the golden standard group 1 encoder\")\n",
    "        encoder_group_1 = get_encoder(args, \"encoder_1\", args.encoder_1_legacy_path, args.encoder_1_path, [\n",
    "                                      dataloader_1], device=device, encoder_type=args.encoder_type)\n",
    "       \n",
    "       \n",
    "        return encoder_group_0, encoder_group_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dec_groups(args, device):   \n",
    "        print(\"Load group 0 initial cluster definitions\")\n",
    "        cluster_centers_0 = get_cluster_centers(file_path=args.cluster_0_path, device=device)\n",
    "\n",
    "        print(\"Load group 1 initial cluster definitions\")\n",
    "        cluster_centers_1 = get_cluster_centers(file_path=args.cluster_number, device=device)\n",
    "\n",
    "        #Load DEC pretrained with the weight of the fairness losses are set to 0.\n",
    "        # making this a DEC instead of a DFC \n",
    "        print(\"Load golden standard group 0 DEC\")        \n",
    "        dfc_group_0 = DFC(cluster_number=args.cluster_number, hidden_dimension=args.dfc_hidden_dim).to(device)\n",
    "        dec.load_state_dict(torch.load(args.dfc_0_path, map_location=device))\n",
    "        print(\"Load golden standard group 1 DEC\")        \n",
    "        dfc_group_0 = DFC(cluster_number=args.cluster_number,hidden_dimension=args.dfc_hidden_dim).to(device)\n",
    "        dec.load_state_dict(torch.load(args.dfc_1_path, map_location=device))\n",
    "\n",
    "        return cluster_centers_0, cluster_centers_1, dfc_group_0, dfc_group_1 \n",
    "       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dfc_module(args,device):\n",
    "    print(\"Load DFC\")\n",
    "    dfc = DFC(cluster_number=args.cluster_number,\n",
    "                  hidden_dimension=args.dfc_hidden_dim).to(device)\n",
    "    dfc.load_state_dict(torch.load(args.dfc_path, map_location=device))\n",
    "\n",
    "   \n",
    "    return dfc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_results(args, dataloader_list, encoder, dfc, device):\n",
    "    encoder.eval()\n",
    "    dfc.eval()\n",
    "    print(\"Evaluate model\")\n",
    "    predicted, labels = predict(dataloader_list, encoder, dfc, device=device,encoder_type = args.encoder_type)\n",
    "    predicted, labels = predicted.cpu().numpy(), labels.numpy()\n",
    "    print(\"Calculating cluster accuracy\")\n",
    "    _, accuracy = cluster_accuracy(predicted, labels, args.cluster_number)\n",
    "    nmi = normalized_mutual_info_score(labels, predicted, average_method=\"arithmetic\")\n",
    "    len_image_0 = len(dataloader_list[0])\n",
    "    print(\"Calculating balance\")\n",
    "    bal, en_0, en_1 = balance(predicted, len_image_0, k =args.cluster_number)\n",
    "    save_name = args.dataset \n",
    "    print(f\"{save_name} Train Accuracy:\", accuracy, f\"{save_name} Train NMI:\", nmi, f\"{save_name} Train Bal:\", bal,'\\n',\n",
    "            f\"{save_name} Train Entropy 0:\", en_0, '\\n',\n",
    "            f\"{save_name} Train Entropy 1:\", en_1)\n",
    "  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main Code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main_pipeline(args):    \n",
    "    set_seed(args.seed)    \n",
    "    os.makedirs(args.log_dir, exist_ok=True)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.set_device(args.gpu)\n",
    "    print(f\"Using {device}\")\n",
    "\n",
    "\n",
    "    dataloader_0, dataloader_1 = get_dataset[args.dataset](args)\n",
    "    dataloader_list = [dataloader_0, dataloader_1]\n",
    "    print(\"Loading Encoder type:\",args.encoder_type)\n",
    "    encoder = get_encoder(args, \"encoder\", args.encoder_legacy_path, args.encoder_path, dataloader_list, device=device, encoder_type=args.encoder_type)\n",
    "    print(\"Running method\", args.method)\n",
    "    if args.method == 'dfc':\n",
    "        # encoder_group_0, encoder_group_1, = subgroups_encoders(args,device)\n",
    "        # cluster_centers_0, cluster_centers_1, dfc_group_0, dfc_group_1 = get_dec_groups(args,device)\n",
    "        dfc = get_dfc_module(args,device)\n",
    "        eval_results(args,dataloader_list, encoder, dfc, device=device)\n",
    "    if args.method == 'dec':\n",
    "        print(\"Load cluster centers for final DEC\")\n",
    "        cluster_centers = get_cluster_centers(args, encoder, args.cluster_number, [dataloader_0, dataloader_1],\n",
    "                                              args.cluster_path, device=device, save_name=\"clusters_dec\")\n",
    "\n",
    "        print(\"Train final DEC\")\n",
    "        dec = get_dec(args,\n",
    "                      encoder, \"DEC\", device=device, centers=cluster_centers)\n",
    "    del encoder\n",
    "    del dfc\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DFC Run\n",
    "In this section we have the cells with all the steps to train a DFC.   \n",
    "We first load the encoders used for for the two dec, if no path is selected then we train new ones.\n",
    "Next we load or train with K-means the cluster centers "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# # Example run.\n",
    "# mnist_ups_args = arg_class()\n",
    "# mnist_ups_args.set_mnist_ups()\n",
    "# main_pipeline(mnist_ups_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "Using cuda\nLoading Encoder type: resnet50\nLoading the RESNET50 encoder\nfrom pretrained file\nRunning method dfc\nLoad DFC\nEvaluate model\nCalculating cluster accuracy\nCalculating balance\noffice_31 Train Accuracy: 0.6763392857142857 office_31 Train NMI: 0.716612082292925 office_31 Train Bal: 6.369426751592358e-08 \n office_31 Train Entropy 0: 2.6890848577221345 \n office_31 Train Entropy 1: 3.3934453185999707\n"
    }
   ],
   "source": [
    "office_args = arg_class()\n",
    "office_args.set_office31_load_models()\n",
    "main_pipeline(office_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "Using cuda\nLoading Encoder type: resnet50\nLoading the RESNET50 encoder\nfrom pretrained file\nRunning method dfc\nLoad DFC\nEvaluate model\nCalculating cluster accuracy\nCalculating balance\nmtfl Train Accuracy: 0.7485977564102564 mtfl Train NMI: 0.21403839260800497 mtfl Train Bal: 0.0011003117549972493 \n mtfl Train Entropy 0:0.6889892386783225 \n mtfl Train Entropy 1: 0.6887447313219119\n"
    }
   ],
   "source": [
    "mtfl_args = arg_class()\n",
    "mtfl_args.set_mtfl_load_models()\n",
    "main_pipeline(mtfl_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "<IPython.core.display.HTML object>",
      "text/html": "<br/>Waiting for W&B process to finish, PID 23385<br/>Program ended successfully."
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "<IPython.core.display.HTML object>",
      "text/html": "Find user logs for this run at: <code>/mnt/f/Documents/Amsterdam/UvA/Block 3 -  Fairness, Accountability, Confidentiality and Transparency in AI/FACT-2021/wandb/offline-run-20210129_235358-336tc8dx/logs/debug.log</code>"
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "<IPython.core.display.HTML object>",
      "text/html": "Find internal logs for this run at: <code>/mnt/f/Documents/Amsterdam/UvA/Block 3 -  Fairness, Accountability, Confidentiality and Transparency in AI/FACT-2021/wandb/offline-run-20210129_235358-336tc8dx/logs/debug-internal.log</code>"
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": "\u001b[34m\u001b[1mwandb\u001b[0m: You can sync this run to the cloud by running:\n\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mwandb sync /mnt/f/Documents/Amsterdam/UvA/Block 3 -  Fairness, Accountability, Confidentiality and Transparency in AI/FACT-2021/wandb/offline-run-20210129_235358-336tc8dx\u001b[0m\n"
    }
   ],
   "source": [
    "#Finish loggin in wandb\n",
    "wandbrun.finish()"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": 3
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python_defaultSpec_1611924171085",
   "display_name": "Python 3.8.3 64-bit ('base': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}