{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "94520ed3-45be-4735-a0f2-b8f380838c5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import collections\n",
    "import numpy as np\n",
    "import pyro\n",
    "import torch\n",
    "import data_loader.data_loaders as module_data\n",
    "import model.loss as module_loss\n",
    "import model.metric as module_metric\n",
    "import model.model as module_arch\n",
    "from parse_config import ConfigParser\n",
    "import trainer.trainer as module_trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c2ac4be4-19eb-4713-a88d-9bec118c7303",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pyro.enable_validation(True)\n",
    "# torch.autograd.set_detect_anomaly(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c75ea7b9-6fe5-4f8a-a825-3572e421242a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fix random seeds for reproducibility\n",
    "SEED = 123\n",
    "torch.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False\n",
    "np.random.seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4889350c-1457-4dee-9e93-fab34f951a85",
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(config):\n",
    "    logger = config.get_logger('train')\n",
    "\n",
    "    # setup data_loader instances\n",
    "    data_loader = config.init_obj('data_loader', module_data)\n",
    "    valid_data_loader = data_loader.split_validation()\n",
    "\n",
    "    # build model architecture, then print to console\n",
    "    model = config.init_obj('arch', module_arch)\n",
    "    logger.info(model)\n",
    "\n",
    "    # get function handles of metrics\n",
    "    metrics = [getattr(module_metric, met) for met in config['metrics']]\n",
    "\n",
    "    # build optimizer.\n",
    "    optimizer = config.init_obj('optimizer', pyro.optim)\n",
    "\n",
    "    # build trainer\n",
    "    # kwargs = config['trainer'].pop('args')\n",
    "    trainer = config.init_obj('trainer', module_trainer, model, metrics, optimizer,\n",
    "                              config=config, data_loader=data_loader,\n",
    "                              valid_data_loader=valid_data_loader,\n",
    "                              lr_scheduler=None)\n",
    "\n",
    "    trainer.train()\n",
    "    return trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab1445d1-45bb-43f9-893d-7a2b2bad1a0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import read_json\n",
    "\n",
    "config = read_json(\"asvi_config.json\")\n",
    "config = ConfigParser(config)\n",
    "trained = main(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faa5b900-0121-4cf6-91b6-2c9b0b12716f",
   "metadata": {},
   "outputs": [],
   "source": [
    "trained.model.eval()\n",
    "trained.model.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08b487d4-288a-40d7-8f4d-c79e2e15385e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:ppc] *",
   "language": "python",
   "name": "conda-env-ppc-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
