{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "1135eb0f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "import torch\n",
    "import torch_geometric\n",
    "from pathlib import Path\n",
    "from model import GNNPolicy\n",
    "from data_type import GraphDataset\n",
    "from utils import process\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "982bad16",
   "metadata": {},
   "outputs": [],
   "source": [
    "problem = 'FCMCNF'\n",
    "lr = 0.005\n",
    "n_epoch = 5\n",
    "patience = 10\n",
    "early_stopping = 20\n",
    "normalize = True\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "batch_train = 1\n",
    "batch_valid  = 256\n",
    "\n",
    "loss_fn = torch.nn.BCELoss()\n",
    "optimizer_fn = torch.optim.Adam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "3b1c3669",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------------\n",
      "GNN for problem FCMCNF\n",
      "Training on:          16285 samples\n",
      "Validating on:        3019 samples\n",
      "Batch Size Train:     1\n",
      "Batch Size Valid      256\n",
      "Learning rate:        0.005 \n",
      "Number of epochs:     5\n",
      "Normalize:            True\n",
      "Device:               cuda\n",
      "Loss fct:             BCELoss()\n",
      "Optimizer:            <class 'torch.optim.adam.Adam'>\n",
      "Model's Size:         1050 parameters \n",
      "-------------------------\n"
     ]
    }
   ],
   "source": [
    "train_losses = []\n",
    "valid_losses = []\n",
    "train_accs = []\n",
    "valid_accs = []\n",
    "\n",
    "train_files = [ str(path) for path in Path(os.path.join(os.path.abspath(''), \n",
    "                                                        f\"../node_selection/data/{problem}/train\")).glob(\"*.pt\") ]\n",
    "\n",
    "valid_files = [ str(path) for path in Path(os.path.join(os.path.abspath(''), \n",
    "                                                        f\"../node_selection/data/{problem}/valid\")).glob(\"*.pt\") ]\n",
    "\n",
    "\n",
    "train_files += valid_files[:7000]\n",
    "valid_files = valid_files[7000:]\n",
    "\n",
    "train_data = GraphDataset(train_files)\n",
    "valid_data = GraphDataset(valid_files)\n",
    "\n",
    "#inspect(train_data[:100])\n",
    "\n",
    "# TO DO : learn something from the data\n",
    "train_loader = torch_geometric.loader.DataLoader(train_data, \n",
    "                                                 batch_size=batch_train, \n",
    "                                                 shuffle=True, \n",
    "                                                 follow_batch=['constraint_features_s', \n",
    "                                                               'constraint_features_t',\n",
    "                                                               'variable_features_s',\n",
    "                                                               'variable_features_t'])\n",
    "\n",
    "valid_loader = torch_geometric.loader.DataLoader(valid_data, \n",
    "                                                 batch_size=batch_valid, \n",
    "                                                 shuffle=False, \n",
    "                                                 follow_batch=['constraint_features_s',\n",
    "                                                               'constraint_features_t',\n",
    "                                                               'variable_features_s',\n",
    "                                                               'variable_features_t'])\n",
    "\n",
    "policy = GNNPolicy().to(device)\n",
    "optimizer = optimizer_fn(policy.parameters(), lr=lr) #ADAM is the best\n",
    "\n",
    "print(\"-------------------------\")\n",
    "print(f\"GNN for problem {problem}\")\n",
    "print(f\"Training on:          {len(train_data)} samples\")\n",
    "print(f\"Validating on:        {len(valid_data)} samples\")\n",
    "print(f\"Batch Size Train:     {batch_train}\")\n",
    "print(f\"Batch Size Valid      {batch_valid}\")\n",
    "print(f\"Learning rate:        {lr} \")\n",
    "print(f\"Number of epochs:     {n_epoch}\")\n",
    "print(f\"Normalize:            {normalize}\")\n",
    "print(f\"Device:               {device}\")\n",
    "print(f\"Loss fct:             {loss_fn}\")\n",
    "print(f\"Optimizer:            {optimizer_fn}\")  \n",
    "print(f\"Model's Size:         {sum(p.numel() for p in policy.parameters())} parameters \")\n",
    "print(\"-------------------------\") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c79d8a0e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1\n",
      "Train loss: 0.072, accuracy 0.953\n",
      "Valid loss: 0.061, accuracy 0.957\n",
      "Epoch 2\n",
      "Train loss: 0.068, accuracy 0.951\n",
      "Valid loss: 0.061, accuracy 0.957\n",
      "Epoch 3\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_61678/2522883233.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Epoch {epoch + 1}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m     train_loss, train_acc = process(policy, \n\u001b[0m\u001b[1;32m      5\u001b[0m                                     \u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m                                     \u001b[0mloss_fn\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/local_workspace/labaabde/learn2selectnodes/learning/utils.py\u001b[0m in \u001b[0;36mprocess\u001b[0;34m(policy, data_loader, loss_fct, device, optimizer, normalize)\u001b[0m\n\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m             \u001b[0my_true\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.5\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m0.5\u001b[0m \u001b[0;31m#0,1 label from -1,1 label\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m             \u001b[0my_proba\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpolicy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     97\u001b[0m             \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_proba\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     98\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/l2sn/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1049\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1050\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1052\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1053\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/local_workspace/labaabde/learn2selectnodes/learning/model.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, batch, inv, epsilon)\u001b[0m\n\u001b[1;32m    122\u001b[0m             \u001b[0mgraph0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgraph1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgraph1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgraph0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m         \u001b[0mscore0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mgraph0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#concatenation of averages variable/constraint features after conv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    125\u001b[0m         \u001b[0mscore1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mgraph1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    126\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/local_workspace/labaabde/learn2selectnodes/learning/model.py\u001b[0m in \u001b[0;36mforward_graph\u001b[0;34m(self, constraint_features, edge_indices, edge_features, variable_features, bbounds, constraint_batch, variable_batch)\u001b[0m\n\u001b[1;32m    158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    159\u001b[0m             \u001b[0;31m#cons to var\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 160\u001b[0;31m             variable_features = F.relu(conv((constraint_features, variable_features), \n\u001b[0m\u001b[1;32m    161\u001b[0m                                       \u001b[0medge_indices_reversed\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    162\u001b[0m                                       \u001b[0medge_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0medge_features\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/l2sn/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1049\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1050\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1052\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1053\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/l2sn/lib/python3.9/site-packages/torch_geometric/nn/conv/graph_conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, edge_index, edge_weight, size)\u001b[0m\n\u001b[1;32m     67\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     68\u001b[0m         \u001b[0;31m# propagate_type: (x: OptPairTensor, edge_weight: OptTensor)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 69\u001b[0;31m         out = self.propagate(edge_index, x=x, edge_weight=edge_weight,\n\u001b[0m\u001b[1;32m     70\u001b[0m                              size=size)\n\u001b[1;32m     71\u001b[0m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlin_rel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/l2sn/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py\u001b[0m in \u001b[0;36mpropagate\u001b[0;34m(self, edge_index, size, **kwargs)\u001b[0m\n\u001b[1;32m    315\u001b[0m                     \u001b[0;32mif\u001b[0m \u001b[0mres\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    316\u001b[0m                         \u001b[0mmsg_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 317\u001b[0;31m                 \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mmsg_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    318\u001b[0m                 \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_message_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    319\u001b[0m                     \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmsg_kwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/l2sn/lib/python3.9/site-packages/torch_geometric/nn/conv/graph_conv.py\u001b[0m in \u001b[0;36mmessage\u001b[0;34m(self, x_j, edge_weight)\u001b[0m\n\u001b[1;32m     77\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     78\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m     \u001b[0;32mdef\u001b[0m \u001b[0mmessage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_j\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0medge_weight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     80\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mx_j\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0medge_weight\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0medge_weight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mx_j\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for epoch in range(n_epoch):\n",
    "    print(f\"Epoch {epoch + 1}\")\n",
    "\n",
    "    train_loss, train_acc = process(policy, \n",
    "                                    train_loader, \n",
    "                                    loss_fn,\n",
    "                                    device,\n",
    "                                    optimizer=optimizer, \n",
    "                                    normalize=normalize)\n",
    "    train_losses.append(train_loss)\n",
    "    train_accs.append(train_acc)\n",
    "    print(f\"Train loss: {train_loss:0.3f}, accuracy {train_acc:0.3f}\" )\n",
    "\n",
    "    valid_loss, valid_acc = process(policy, \n",
    "                                    valid_loader, \n",
    "                                    loss_fn, \n",
    "                                    device,\n",
    "                                    optimizer=None,\n",
    "                                    normalize=normalize)\n",
    "    valid_losses.append(valid_loss)\n",
    "    valid_accs.append(valid_acc)\n",
    "\n",
    "    print(f\"Valid loss: {valid_loss:0.3f}, accuracy {valid_acc:0.3f}\" )\n",
    "torch.save(policy.state_dict(),f'policy_{problem}.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "edd0910e",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(policy.state_dict(),f'policy_{problem}.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a658ab77",
   "metadata": {},
   "outputs": [],
   "source": [
    "decisions = [ policy(dvalid.to(device)).item() for dvalid in valid_data ]\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure(0)\n",
    "plt.hist(decisions)\n",
    "plt.title('decisions histogramme for valid set')\n",
    "plt.savefig(\"./hist.png\")\n",
    "\n",
    "plt.figure(1)\n",
    "plt.plot(train_losses, label='train')\n",
    "plt.plot(valid_losses, label='valid')\n",
    "plt.title('losses')\n",
    "plt.xlabel('epoch')\n",
    "plt.legend()\n",
    "plt.savefig(\"./losses.png\")\n",
    "\n",
    "\n",
    "plt.figure(2)\n",
    "plt.plot(train_accs, label='train')\n",
    "plt.plot(valid_accs, label='valid')\n",
    "plt.title('accuracies')\n",
    "plt.xlabel('epoch')\n",
    "plt.legend()\n",
    "plt.savefig(\"./accuracies.png\")\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
