{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/workspace/repositories/MoralVisionModels')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version:  1.7.1+cu110\n",
      "Torchvision Version:  0.8.2+cu110\n"
     ]
    }
   ],
   "source": [
    "from src.tune_resnet import setup_dataset, setup_model, train\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model, transform_train, transform_test = setup_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2941/2941 [00:07<00:00, 375.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "label_weights [0.2420945256715403, 0.7579054743284597]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2941/2941 [00:07<00:00, 373.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "label_weights [0.2420945256715403, 0.7579054743284597]\n"
     ]
    }
   ],
   "source": [
    "train_dataloader, test_dataloader = setup_dataset(model, transform_train, transform_test, verbose=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval before training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:14<00:00,  4.86it/s]\n",
      "  2%|▏         | 1/46 [00:00<00:08,  5.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test: | Loss: 0.49428 | Acc: 75.963\n",
      "[[1683, 546], [161, 551]]\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:08<00:00,  4.34it/s]\n",
      "  2%|▏         | 1/46 [00:00<00:08,  5.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 001: | Loss: 0.49161 | Acc: 76.130\n",
      "[[1693, 536], [166, 546]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:16<00:00,  2.36it/s]\n",
      "  2%|▏         | 1/46 [00:00<00:07,  5.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 002: | Loss: 0.49423 | Acc: 75.851\n",
      "[[1691, 538], [172, 540]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:08<00:00,  5.59it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 003: | Loss: 0.49667 | Acc: 75.653\n",
      "[[1706, 523], [193, 519]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:11<00:00,  6.59it/s]\n",
      "  2%|▏         | 1/46 [00:00<00:06,  6.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 004: | Loss: 0.48658 | Acc: 77.071\n",
      "[[1715, 514], [160, 552]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:11<00:00,  6.03it/s]\n",
      "  2%|▏         | 1/46 [00:00<00:07,  6.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 005: | Loss: 0.49331 | Acc: 76.196\n",
      "[[1692, 537], [163, 549]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  2.46it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 006: | Loss: 0.49646 | Acc: 76.164\n",
      "[[1704, 525], [176, 536]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:09<00:00,  5.53it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 007: | Loss: 0.50022 | Acc: 75.620\n",
      "[[1706, 523], [194, 518]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  6.24it/s]\n",
      "  2%|▏         | 1/46 [00:00<00:07,  6.11it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 008: | Loss: 0.49093 | Acc: 75.963\n",
      "[[1690, 539], [168, 544]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  2.47it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 009: | Loss: 0.49404 | Acc: 75.889\n",
      "[[1689, 540], [169, 543]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:17<00:00,  1.85it/s]\n",
      "  0%|          | 0/46 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 010: | Loss: 0.49301 | Acc: 76.597\n",
      "[[1703, 526], [162, 550]]\n",
      "Eval after training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 46/46 [00:10<00:00,  4.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test: | Loss: 0.49351 | Acc: 76.165\n",
      "[[1697, 532], [169, 543]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "train(train_dataloader, test_dataloader, model, optimizer, epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "save_path = '/workspace/datasets/results/normativity/clip_stuff/results/SMIR/Resnet/fine_tuning/model.pt'\n",
    "import os\n",
    "os.makedirs(os.path.dirname(save_path), exist_ok=True)\n",
    "torch.save(model.state_dict(), save_path)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
