{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Rotating MNIST dataset\n",
    "Domain: anti-clockwise rotating angle (theta)\n",
    "Source domains: theta in (0, 45 degree)\n",
    "    Amount: 60000\n",
    "Target domains: theta in (45, 360 degree)\n",
    "    Amount: 60000 * 7\n",
    "\"\"\"\n",
    "\n",
    "\"\"\"\n",
    "Download data if needed\n",
    "\"\"\"\n",
    "from model import download\n",
    "download()\n",
    "\n",
    "\"\"\"\n",
    "Visualize the data\n",
    "\"\"\"\n",
    "import matplotlib.pyplot as plt\n",
    "from model import RotateMNIST\n",
    "\n",
    "for i in range(8):\n",
    "    dataset = RotateMNIST(rotate_angle=(i*45,i*45+45))\n",
    "    if i == 0:\n",
    "        dname = 'Source'\n",
    "    else:\n",
    "        dname = f'Sub Target #{i}'\n",
    "    print(dname)\n",
    "    fig, ax = plt.subplots(1, 10, figsize=(18,1.5))\n",
    "    for j in range(10):\n",
    "        img, label, angle, _ = dataset[j]\n",
    "        angle = angle[0] * 360\n",
    "        ax[j].imshow(img[0])\n",
    "        ax[j].set_title(f'Label: {label}\\nRot: {angle:.0f}')\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Configurations\n",
    "\"\"\"\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from easydict import EasyDict\n",
    "from model import set_default_args, print_args\n",
    "from model import SO, ADDA, DANN, CUA, CIDA, PCIDA\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "opt = EasyDict()\n",
    "# choose the method from [\"CIDA\", \"PCIDA\", \"SO\", \"ADDA\", \"DANN\" \"CUA\"]\n",
    "opt.model = \"SO\"\n",
    "# choose run on which device [\"cuda\", \"cpu\"]\n",
    "opt.device = \"cuda\"\n",
    "set_default_args(opt)\n",
    "print_args(opt)\n",
    "# build dataset and data loader\n",
    "dataset = RotateMNIST(rotate_angle=(0, 360))\n",
    "train_dataloader = DataLoader(\n",
    "    dataset=dataset,\n",
    "    shuffle=True,\n",
    "    batch_size=opt.batch_size,\n",
    "    num_workers=4,\n",
    ")\n",
    "test_dataloader = DataLoader(\n",
    "    dataset=dataset,\n",
    "    shuffle=True,\n",
    "    batch_size=opt.batch_size,\n",
    "    num_workers=4,\n",
    ")\n",
    "# build model\n",
    "model_pool = {\n",
    "    'SO': SO,\n",
    "    'CIDA': CIDA,\n",
    "    'PCIDA': PCIDA,\n",
    "    'ADDA': ADDA,\n",
    "    'DANN': DANN,\n",
    "    'CUA': CUA,\n",
    "}\n",
    "model = model_pool[opt.model](opt)\n",
    "model = model.to(opt.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Training the model from the scratch\n",
    "\"\"\"\n",
    "best_acc_target = 0\n",
    "if not opt.continual_da:\n",
    "    # Single Step Domain Adaptation\n",
    "    for epoch in range(opt.num_epoch):\n",
    "        model.learn(epoch, train_dataloader)\n",
    "        if (epoch + 1) % 10 == 0:\n",
    "            acc_target = model.eval_mnist(test_dataloader)\n",
    "            if acc_target > best_acc_target:\n",
    "                print('Best acc target. saved.')\n",
    "                model.save()\n",
    "else:\n",
    "    # continual DA training\n",
    "    continual_dataset = ContinousRotateMNIST()\n",
    "\n",
    "    print('===> pretrain the classifer')\n",
    "    model.prepare_trainer(init=True)\n",
    "    for epoch in range(opt.num_epoch_pre):\n",
    "        model.learn(epoch, train_dataloader, init=True)\n",
    "        if (epoch + 1) % 10 == 0:\n",
    "            model.eval_mnist(test_dataloader)\n",
    "    print('===> start continual DA')\n",
    "    model.prepare_trainer(init=False)\n",
    "    for phase in range(opt.num_da_step):\n",
    "        continual_dataset.set_phase(phase)\n",
    "        print(f'Phase {phase}/{opt.num_da_step}')\n",
    "        print(f'#source {len(continual_dataset.ds_source)} #target {len(continual_dataset.ds_target[phase])} #replay {len(continual_dataset.ds_replay)}')\n",
    "        continual_dataloader = DataLoader(\n",
    "            dataset=continual_dataset,\n",
    "            shuffle=True,\n",
    "            batch_size=opt.batch_size,\n",
    "            num_workers=4,\n",
    "        )\n",
    "        for epoch in range(opt.num_epoch_sub):\n",
    "            model.learn(epoch, continual_dataloader, init=False)\n",
    "            if (epoch + 1) % 10 == 0:\n",
    "                model.eval_mnist(test_dataloader)\n",
    "\n",
    "        target_dataloader = DataLoader(\n",
    "            dataset=continual_dataset.ds_target[phase],\n",
    "            shuffle=True,\n",
    "            batch_size=opt.batch_size,\n",
    "            num_workers=4,\n",
    "        )\n",
    "        acc_target = model.eval_mnist(test_dataloader)\n",
    "        if acc_target > best_acc_target:\n",
    "            print('Best acc target. saved.')\n",
    "            model.save()\n",
    "        data_tuple = model.gen_data_tuple(target_dataloader)\n",
    "        continual_dataset.ds_replay.update(data_tuple)  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Load the pretrained model\n",
    "\"\"\"\n",
    "model.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Print the model result\n",
    "\"\"\"\n",
    "model.gen_result_table(test_dataloader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
