{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/workspace/repositories/MoralVisionModels')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version:  1.7.1+cu110\n",
      "Torchvision Version:  0.8.2+cu110\n"
     ]
    }
   ],
   "source": [
    "from src.tune_resnetCLIP import setup_dataset, setup_model, train\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model, transform_train, transform_test = setup_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2941/2941 [00:07<00:00, 375.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "label_weights [0.2420945256715403, 0.7579054743284597]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2941/2941 [00:07<00:00, 383.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "label_weights [0.2420945256715403, 0.7579054743284597]\n"
     ]
    }
   ],
   "source": [
    "train_dataloader, test_dataloader = setup_dataset(model, transform_train, transform_test, verbose=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval before training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:18<00:00,  2.55it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test: | Loss: 0.69558 | Acc: 57.295\n",
      "[[1522, 707], [549, 163]]\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:21<00:00,  2.10it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 001: | Loss: 0.69236 | Acc: 63.998\n",
      "[[1671, 558], [501, 211]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:12<00:00,  3.78it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 002: | Loss: 0.68187 | Acc: 73.345\n",
      "[[1789, 440], [344, 368]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:14<00:00,  2.91it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 003: | Loss: 0.67177 | Acc: 76.537\n",
      "[[1751, 478], [212, 500]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:13<00:00,  3.28it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 004: | Loss: 0.66183 | Acc: 77.556\n",
      "[[1749, 480], [180, 532]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.38it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 005: | Loss: 0.65286 | Acc: 77.057\n",
      "[[1716, 513], [162, 550]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.45it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 006: | Loss: 0.64422 | Acc: 78.550\n",
      "[[1767, 462], [169, 543]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.39it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 007: | Loss: 0.63616 | Acc: 78.710\n",
      "[[1765, 464], [162, 550]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.43it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 008: | Loss: 0.62828 | Acc: 78.983\n",
      "[[1781, 448], [170, 542]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.41it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 009: | Loss: 0.62102 | Acc: 78.952\n",
      "[[1773, 456], [163, 549]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.41it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 010: | Loss: 0.61402 | Acc: 78.822\n",
      "[[1761, 468], [155, 557]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.36it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 011: | Loss: 0.60737 | Acc: 78.575\n",
      "[[1742, 487], [143, 569]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.56it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 012: | Loss: 0.60112 | Acc: 78.472\n",
      "[[1736, 493], [140, 572]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.45it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 013: | Loss: 0.59509 | Acc: 78.414\n",
      "[[1736, 493], [142, 570]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.52it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 014: | Loss: 0.58950 | Acc: 78.713\n",
      "[[1745, 484], [142, 570]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.29it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 015: | Loss: 0.58416 | Acc: 79.049\n",
      "[[1763, 466], [150, 562]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.54it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 016: | Loss: 0.57877 | Acc: 78.742\n",
      "[[1744, 485], [140, 572]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.43it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 017: | Loss: 0.57366 | Acc: 78.407\n",
      "[[1731, 498], [137, 575]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.40it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 018: | Loss: 0.56866 | Acc: 78.856\n",
      "[[1745, 484], [138, 574]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.35it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 019: | Loss: 0.56452 | Acc: 79.226\n",
      "[[1758, 471], [140, 572]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.46it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 020: | Loss: 0.56025 | Acc: 79.255\n",
      "[[1757, 472], [138, 574]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.35it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 021: | Loss: 0.55584 | Acc: 79.671\n",
      "[[1770, 459], [139, 573]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.47it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 022: | Loss: 0.55131 | Acc: 79.155\n",
      "[[1752, 477], [136, 576]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.53it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 023: | Loss: 0.54755 | Acc: 78.991\n",
      "[[1743, 486], [132, 580]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.35it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 024: | Loss: 0.54384 | Acc: 79.194\n",
      "[[1750, 479], [133, 579]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.49it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 025: | Loss: 0.53977 | Acc: 79.087\n",
      "[[1746, 483], [132, 580]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 13/46 [00:02<00:07,  4.46it/s]"
     ]
    }
   ],
   "source": [
    "train(train_dataloader, test_dataloader, model, optimizer, epochs=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "save_path = '/workspace/datasets/results/normativity/clip_stuff/results/SMIR/CLIPResnet/fine_tuning/model.pt'\n",
    "import os\n",
    "os.makedirs(os.path.dirname(save_path), exist_ok=True)\n",
    "torch.save(model.fc.state_dict(), save_path)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
