{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0b119b7e",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "activeView": "grid_default",
      "views": {
       "grid_default": {
        "col": null,
        "height": 2,
        "hidden": true,
        "row": null,
        "width": 2
       }
      }
     }
    }
   },
   "source": [
    "### Siamese Network Training with Betti Vectorization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47d6ddeb",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "activeView": "grid_default",
      "views": {
       "grid_default": {
        "col": null,
        "height": 2,
        "hidden": true,
        "row": null,
        "width": 2
       }
      }
     }
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import optim\n",
    "import pickle\n",
    "from src.ml_models.siamese import (Siamese, ContrastiveLoss, TripletMarginLoss,\n",
    "                                   get_anchor_samples, generate_data_pairs, generate_data_triplets, \n",
    "                                   split_into_batches, train_model, produce_results, produce_results_alternative)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83b3e195-a436-4c4d-959e-90b72754c22a",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "activeView": "grid_default",
      "views": {
       "grid_default": {
        "col": null,
        "height": 2,
        "hidden": true,
        "row": null,
        "width": 2
       }
      }
     }
    }
   },
   "outputs": [],
   "source": [
    "device = 'cuda:1' if torch.cuda.is_available() else 'cpu'\n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a050d06-044e-4813-9e01-42a6d71b74d9",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "activeView": "grid_default",
      "views": {
       "grid_default": {
        "col": null,
        "height": 2,
        "hidden": true,
        "row": null,
        "width": 2
       }
      }
     }
    }
   },
   "source": [
    "#### Atom Weight filtration with homology dim. 0 (superlevel), Partial Charge filtration with homology dim. 0 (superlevel) and Bond Type Filtration with homology dim. 0 (superlevel), Atom Radius filtration with homology dim. 0 (superlevel), Electron Affinity filtration with homology dim. 0 (sublevel), Node Degree filtration with homology dim. 0 (superlevel) Followed by Betti Vectorization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dec9177-d070-482c-a682-88023de04e12",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "activeView": "grid_default",
      "views": {
       "grid_default": {
        "col": null,
        "height": 2,
        "hidden": true,
        "row": null,
        "width": 2
       }
      }
     }
    }
   },
   "outputs": [],
   "source": [
    "with open('../data/ClevesJain_TopologyFeatures/atom_weight_superlevel_partial_charge_superlevel_bond_strength_superlevel_node_degree_sublevel_atom_radius_superlevel_electron_affinity_superlevel_betti.pickle', 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "\n",
    "train_x, train_y, test_x, test_y = data[\"train_x\"], data[\"train_y\"], data[\"test_x\"], data[\"test_y\"]\n",
    "anchor_x, anchor_y = get_anchor_samples(train_x, train_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c988bb2-9fa4-47d4-b811-deabb086a55b",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "activeView": "grid_default",
      "views": {
       "grid_default": {
        "col": null,
        "height": 2,
        "hidden": true,
        "row": null,
        "width": 2
       }
      }
     }
    }
   },
   "outputs": [],
   "source": [
    "print(len(train_x), len(test_x))\n",
    "print(train_x[0].shape, test_x[0].shape) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87fcb8a3",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "activeView": "grid_default",
      "views": {
       "grid_default": {
        "col": null,
        "height": 2,
        "hidden": true,
        "row": null,
        "width": 2
       }
      }
     }
    }
   },
   "source": [
    "#### Model Training - ViT Backbone & Triplet Margin Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f08b5f2f",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "activeView": "grid_default",
      "views": {
       "grid_default": {
        "col": null,
        "height": 2,
        "hidden": true,
        "row": null,
        "width": 2
       }
      }
     }
    }
   },
   "outputs": [],
   "source": [
    "loss_fn = \"triplet_margin\"\n",
    "batch_size = 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67176479-2c93-49d8-8d9a-c9df1f03f811",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "activeView": "grid_default",
      "views": {
       "grid_default": {
        "col": null,
        "height": 2,
        "hidden": true,
        "row": null,
        "width": 2
       }
      }
     }
    }
   },
   "outputs": [],
   "source": [
    "train_set, test_set = generate_data_triplets(train_x, train_y, test_x, test_y, anchor_x, anchor_y)\n",
    "dataloaders = split_into_batches(train_set, test_set, batch_size, loss_fn=loss_fn)\n",
    "del train_set, test_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dfa6aa4-83a4-48cf-9709-d05f03c57861",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "activeView": "grid_default",
      "views": {
       "grid_default": {
        "col": null,
        "height": 2,
        "hidden": true,
        "row": null,
        "width": 2
       }
      }
     }
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# model = Siamese(base_model=\"convnext\", embedding_size=1000)\n",
    "model = Siamese(base_model=\"vision_transformer\", embedding_size=1000)\n",
    "# model = Siamese(base_model=\"resnet\", embedding_size=1000)\n",
    "\n",
    "optimizer = optim.Adam(\n",
    "    filter(lambda p: p.requires_grad, model.parameters()),\n",
    "    lr=0.0005,\n",
    "    eps=1e-8,\n",
    "    weight_decay=0.0005,\n",
    ")\n",
    "model = train_model(model, dataloaders, device, optimizer, lr_decay=False, num_epochs=1, loss_fn=loss_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc743508",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "activeView": "grid_default",
      "views": {
       "grid_default": {
        "col": null,
        "height": 2,
        "hidden": true,
        "row": null,
        "width": 2
       }
      }
     }
    }
   },
   "source": [
    "#### Enrichment Factor Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46536e6b-65ad-4b8e-9923-47e65d0c95c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a875b98",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "activeView": "grid_default",
      "views": {
       "grid_default": {
        "col": null,
        "height": 2,
        "hidden": true,
        "row": null,
        "width": 2
       }
      }
     }
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "factors = [0.02, 0.05, 0.1, 0.15, 0.2]\n",
    "produce_results(model, test_x, test_y, anchor_x, device, loss_fn, factors, batch_size=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1465252-0b81-4ff5-b3e8-0397784e03de",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "activeView": "grid_default",
      "views": {
       "grid_default": {
        "col": null,
        "height": 2,
        "hidden": true,
        "row": null,
        "width": 2
       }
      }
     }
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "extensions": {
   "jupyter_dashboards": {
    "activeView": "grid_default",
    "version": 1,
    "views": {
     "grid_default": {
      "cellMargin": 2,
      "defaultCellHeight": 60,
      "maxColumns": 12,
      "name": "grid",
      "type": "grid"
     }
    }
   }
  },
  "kernelspec": {
   "display_name": "tda",
   "language": "python",
   "name": "tda"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
