{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and general settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import sys\n",
    "from torch.utils.data import DataLoader\n",
    "from helpers import process_sys, parameter_check, set_seeds, set_cuda_randomness, set_optim, make_info, lr_update, run_check\n",
    "from init import init_model, init_check\n",
    "from data import init_dataset, split_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set different model parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Re-set parameters for base network script\n",
    "MODEL=\"VGG11\"  # Model architecture: ResNet18, DenseNet121, VGG11\n",
    "DSET=\"ImageNet\"  # Dataset: ImageNet, CIFAR-100 (all require setting directory correctly in data.init_dataset)\n",
    "INIT=\"same\"  # Model initialisation: \"same\" or \"different\"\n",
    "OPTIM=\"SGD\"  # Optimizer: \"Adam\" or \"SGD\"\n",
    "DATA=\"same\"  # Training data: \"same\" or \"different\"\n",
    "ORDER=\"same\"  # Order of training data: \"same\" or \"different\"\n",
    "LR=0.1  # Starting learning rate\n",
    "BATCHES=\"different\"  # Batch size: \"same\" or \"different\"\n",
    "EPOCHS=30  # Training epochs\n",
    "CUDA=0  # CUDA randomness: 0 is deterministic and 1 is random\n",
    "VERBOSE=1  # Whether to print training output\n",
    "NUM=1 # Number of models for this condition\n",
    "CONDITION=\"VGG11_Different_batchsize\" # Condition name under which results are saved\n",
    "BATCH_SIZE=256 # Batch size "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process input variables: not done here, since we set them manually above!\n",
    "# MODEL, DSET, INIT, OPTIM, DATA, ORDER, LR, BATCHES, EPOCHS, CUDA, NUM, VERBOSE, CONDITION, BATCH_SIZE = process_sys(sys.argv)\n",
    "\n",
    "# Set device by checking CUDA availability\n",
    "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# Print parameters for manual check\n",
    "parameter_check(MODEL, DSET, INIT, OPTIM, DATA, ORDER, LR, BATCHES, EPOCHS, CUDA, NUM, VERBOSE, DEVICE, CONDITION)\n",
    "\n",
    "# Prepare paths, filenames and information for output\n",
    "filename, info = make_info(MODEL, DSET, INIT, OPTIM, DATA, ORDER, LR, BATCHES, EPOCHS, CUDA, NUM, CONDITION)\n",
    "\n",
    "# Check if condition has already been run and break if that is the case\n",
    "run_check(filename, EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define global and initialization seed\n",
    "GLOBAL_SEED = 1312\n",
    "INIT_SEED = 1312\n",
    "\n",
    "# Initialize model and check whether initialization strategy has worked correctly\n",
    "model = init_model(INIT, INIT_SEED, MODEL, NUM, DEVICE)\n",
    "init_check(model, INIT, INIT_SEED, MODEL, DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set seeds and set CUDA to be deterministic or non-deterministic\n",
    "set_seeds(GLOBAL_SEED)\n",
    "set_cuda_randomness(CUDA)\n",
    "\n",
    "# Initialize datasets\n",
    "train_set = init_dataset(DSET, NUM, DATA, MODEL, CONDITION, train=True)\n",
    "\n",
    "# Split FRACTAL dataset if it is used\n",
    "if DSET == \"FRACTAL\":\n",
    "    train_set, val_set = torch.utils.data.dataset.random_split(train_set, [int(len(train_set))-50000, 50000])\n",
    "    train_set, _ = torch.utils.data.dataset.random_split(train_set, [50000, int(len(train_set))-50000])\n",
    "else:\n",
    "    val_set = init_dataset(DSET, NUM, DATA, MODEL, CONDITION, train=False)\n",
    "\n",
    "# For different data condition, split dataset into two parts\n",
    "if DATA == \"different\":\n",
    "    train_set = split_dataset(train_set, NUM, GLOBAL_SEED)\n",
    "\n",
    "# Initialise data loaders\n",
    "train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=30)\n",
    "val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=30)  # Do not shuffle to keep same order"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set loss function\n",
    "if DSET == \"ARN\":\n",
    "    criterion = nn.BCELoss().to(DEVICE)\n",
    "else:\n",
    "    criterion = nn.CrossEntropyLoss().to(DEVICE)\n",
    "\n",
    "# Set optimizer\n",
    "optimizer = set_optim(OPTIM, model, LR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pre-allocate arrays for training and results (+1 for test set before model enters training)\n",
    "loss_train = torch.zeros(EPOCHS).to(DEVICE)\n",
    "acc_train = torch.zeros(EPOCHS).to(DEVICE)\n",
    "loss_val = torch.zeros(EPOCHS+1).to(DEVICE)\n",
    "acc_val = torch.zeros(EPOCHS+1).to(DEVICE)\n",
    "\n",
    "\n",
    "# Start training model\n",
    "for epoch in range(EPOCHS + 1):\n",
    "\n",
    "    # Run test set before model has been trained at all\n",
    "    if epoch != 0:\n",
    "\n",
    "        # Run training dataset\n",
    "        model.train()\n",
    "\n",
    "        # Change seed before data is randomly drawn from train loader\n",
    "        if ORDER == \"different\":\n",
    "            set_seeds(GLOBAL_SEED + (NUM*100) + epoch)\n",
    "        else:\n",
    "            set_seeds(GLOBAL_SEED + epoch)\n",
    "\n",
    "        # Get images from dataloader\n",
    "        for i, (images, targets) in enumerate(train_loader):\n",
    "\n",
    "            # Load images and targets onto GPU\n",
    "            images = images.to(DEVICE)\n",
    "            targets = targets.to(DEVICE)\n",
    "\n",
    "            # For different batch size condition: halve sample and do two gradient updates\n",
    "            if BATCHES == \"different\":\n",
    "\n",
    "                # Loop through half batches\n",
    "                for ind in range(2):\n",
    "\n",
    "                    # Set start and end of half batch and calculate everything for this half\n",
    "                    if(targets.shape[0] == BATCH_SIZE):\n",
    "                        # start = int(ind*BATCH_SIZE/2)\n",
    "                        start = int(ind*(BATCH_SIZE/2))\n",
    "                        end = int((ind+1)*(BATCH_SIZE/2))\n",
    "                    else:\n",
    "                        start = int(ind*(targets.shape[0]/2))\n",
    "                        end = int((ind+1)*(targets.shape[0]/2))\n",
    "                        \n",
    "                    output = model(images[start:end, :, :, :])\n",
    "                    loss = criterion(output, targets[start:end])\n",
    "                    loss_train[epoch - 1] += loss.item()\n",
    "\n",
    "                    # Compute accuracy: for each batch accuracy divided by dataset size, yields mean after last batch\n",
    "                    if DSET == \"ARN\":\n",
    "                        acc_train[epoch - 1] += torch.sum(torch.eq(targets[start:end], torch.round(output))) / len(train_loader.dataset)\n",
    "                    else:\n",
    "                        acc_train[epoch - 1] += torch.sum(torch.eq(targets[start:end], torch.argmax(output, dim=1))) / len(train_loader.dataset)\n",
    "\n",
    "                    # Compute gradient and do optimizer step\n",
    "                    optimizer.zero_grad()\n",
    "                    loss.backward()\n",
    "                    optimizer.step()\n",
    "\n",
    "            else:\n",
    "\n",
    "                # Get output and loss\n",
    "                output = model(images)\n",
    "                loss = criterion(output, targets)\n",
    "                loss_train[epoch-1] += loss.item()\n",
    "\n",
    "                # Compute accuracy: for each batch accuracy divided by dataset size, this yields mean after last batch\n",
    "                if DSET == \"ARN\":\n",
    "                    acc_train[epoch-1] += torch.sum(torch.eq(targets, torch.round(output))) / len(train_loader.dataset)\n",
    "                else:\n",
    "                    acc_train[epoch-1] += torch.sum(torch.eq(targets, torch.argmax(output, dim=1))) / len(train_loader.dataset)\n",
    "\n",
    "                # Compute gradient and do optimizer step\n",
    "                optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "\n",
    "    # Disable gradient for test dataset\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "\n",
    "        # Re-set seed to global seed\n",
    "        set_seeds(GLOBAL_SEED)\n",
    "\n",
    "        # Run test set\n",
    "        for i, (images, targets) in enumerate(val_loader):\n",
    "\n",
    "            # Load images and targets onto GPU\n",
    "            images = images.to(DEVICE)\n",
    "            targets = targets.to(DEVICE)\n",
    "\n",
    "            # Get output and loss\n",
    "            output = model(images)\n",
    "            loss = criterion(output, targets)\n",
    "            loss_val[epoch] += loss.item()\n",
    "\n",
    "            # Compute accuracy\n",
    "            if DSET == \"ARN\":\n",
    "                acc_val[epoch] += torch.sum(torch.eq(targets, torch.round(output))) / len(val_loader.dataset)\n",
    "            else:\n",
    "                acc_val[epoch] += torch.sum(torch.eq(targets, torch.argmax(output, dim=1))) / len(val_loader.dataset)\n",
    "\n",
    "            # Prepare outputs to be saved for each epoch\n",
    "            if i == 0:\n",
    "                epoch_output = output\n",
    "                epoch_targets = targets\n",
    "            else:\n",
    "                epoch_output = torch.cat((epoch_output, output))\n",
    "                epoch_targets = torch.cat((epoch_targets, targets))\n",
    "\n",
    "    # Manual learning rate scheduler\n",
    "    lrTmp = lr_update(epoch, LR, EPOCHS)\n",
    "    for param_group in optimizer.param_groups:\n",
    "        param_group['lr'] = lrTmp\n",
    "    print(f\"Current learning rate: {lrTmp}\")\n",
    "\n",
    "    # Print output if wanted\n",
    "    if VERBOSE == 1:\n",
    "        print('Epoch {}/{}, Train loss: {:.4f}, Train accuracy: {:.4f}, Test loss: {:.4f}, Test accuracy:  {:.4f}'\n",
    "              .format(epoch, EPOCHS,\n",
    "                      loss_train[epoch-1], acc_train[epoch-1],\n",
    "                      loss_val[epoch], acc_val[epoch]))\n",
    "\n",
    "    # Save output after every epoch\n",
    "    result = [info, torch.argmax(epoch_output, dim=1), epoch_targets]\n",
    "\n",
    "    # Use torch save to write results and model into file\n",
    "    torch.save(result, filename + \"RESULTS_EP{}\".format(epoch) + \".txt\")\n",
    "    torch.save(model, filename + \"MODEL_EP{}\".format(epoch))\n",
    "\n",
    "# Save accuracy and loss over all epochs\n",
    "torch.save(loss_train, filename + \"TRAIN_LOSS.txt\")\n",
    "torch.save(loss_val, filename + \"VAL_LOSS.txt\")\n",
    "torch.save(acc_train, filename + \"TRAIN_ACC.txt\")\n",
    "torch.save(acc_val, filename + \"VAL_ACC.txt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
