{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b6014b8c-a1a1-4cde-ba34-d0032cb647a9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-04T12:14:20.730813Z",
     "iopub.status.busy": "2024-01-04T12:14:20.730186Z",
     "iopub.status.idle": "2024-01-04T12:14:20.748448Z",
     "shell.execute_reply": "2024-01-04T12:14:20.746639Z",
     "shell.execute_reply.started": "2024-01-04T12:14:20.730750Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import datetime\n",
    "import shutil"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "175ba47b-3de1-43a1-a8e7-ab1b190f05ca",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-06T08:49:54.727715Z",
     "iopub.status.busy": "2024-01-06T08:49:54.727088Z",
     "iopub.status.idle": "2024-01-06T08:49:57.262946Z",
     "shell.execute_reply": "2024-01-06T08:49:57.262021Z",
     "shell.execute_reply.started": "2024-01-06T08:49:54.727652Z"
    }
   },
   "outputs": [],
   "source": [
    "from ay2.torch.lightning.loggers import CustomNameCSVLogger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d1bd375-33b8-4c46-b065-78c1aa259d26",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_logger(args, root_dir):\n",
    "    from ay2.torch.lightning.loggers import CustomNameCSVLogger\n",
    "\n",
    "    # name = args.cfg if args.ablation is None else args.cfg + \"-\" + args.ablation\n",
    "    model_name = args.cfg.split('/')[0]\n",
    "    task = args.cfg.replace(model_name+'/', '')\n",
    "    name = args.cfg if args.ablation is None else f\"{model_name}/{args.ablation}/{task}\"\n",
    "\n",
    "\n",
    "    if args.test:\n",
    "        savename = 'test.csv' if not args.collect else 'collect.csv'\n",
    "        logger = CustomNameCSVLogger(\n",
    "            root_dir,\n",
    "            name=name,\n",
    "            version=args.version,\n",
    "            csv_name = savename\n",
    "        )\n",
    "    else:\n",
    "        import pytorch_lightning as pl\n",
    "        # logger = pl.loggers.CSVLogger(\n",
    "        #     root_dir,\n",
    "        #     name=name,\n",
    "        #     version=args.version,\n",
    "        # )\n",
    "        logger = CustomNameCSVLogger(\n",
    "            root_dir,\n",
    "            name=name,\n",
    "            version=args.version,\n",
    "            csv_name = 'metrics.csv'\n",
    "        )\n",
    "    return logger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "667a3bb4-73e0-48c9-853c-9086e3156430",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-06-08T12:50:55.583823Z",
     "iopub.status.busy": "2023-06-08T12:50:55.583282Z",
     "iopub.status.idle": "2023-06-08T12:50:55.592718Z",
     "shell.execute_reply": "2023-06-08T12:50:55.591365Z",
     "shell.execute_reply.started": "2023-06-08T12:50:55.583775Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def backup_logger_file(logger_version_path):\n",
    "    \n",
    "    metric_file = os.path.join(logger_version_path, 'metrics.csv')\n",
    "    if not os.path.exists(metric_file):\n",
    "        return\n",
    "    \n",
    "    m_time = os.path.getmtime(metric_file)\n",
    "    m_time = datetime.datetime.fromtimestamp(m_time)\n",
    "    m_time = m_time.strftime('%Y-%m-%d-%H:%M:%S')\n",
    "\n",
    "    backup_file = metric_file.replace('.csv', f'-{m_time}.csv')\n",
    "    if not os.path.exists(backup_file):\n",
    "        shutil.copy2(metric_file, backup_file)\n",
    "        os.remove(metric_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9761b1fc-55fe-45f9-b8f8-6d6103a30e7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def clear_old_test_file(logger_version_path):\n",
    "    \n",
    "    metric_file = os.path.join(logger_version_path, 'test.csv')\n",
    "    if os.path.exists(metric_file):\n",
    "        os.remove(metric_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6a298f6-8d37-4c69-8075-3c339ed36e94",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ckpt_path(logger_dir, theme='best'):\n",
    "    checkpoint = os.path.join(logger_dir, \"checkpoints\")\n",
    "    for path in os.listdir(checkpoint):\n",
    "        if theme in path:\n",
    "            ckpt_path = os.path.join(checkpoint, path)\n",
    "            return ckpt_path\n",
    "    raise FileNotFoundError(f'There are no {theme} ckpt in {logger_dir}')    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62941697-81ff-494c-99f3-17a68b12ca73",
   "metadata": {},
   "outputs": [],
   "source": [
    "def write_model_summary(model, log_dir):\n",
    "    \n",
    "    from pytorch_lightning.utilities.model_summary import summarize\n",
    "    \n",
    "    with open(os.path.join(log_dir, 'model.txt'), 'w') as f:\n",
    "            s = summarize(model)\n",
    "            f.write(str(s))\n",
    "\n",
    "\n",
    "    total_params = sum(p.numel() for p in model.parameters())\n",
    "    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "    with open(os.path.join(log_dir, 'parameters.txt'), 'w') as file:\n",
    "        print(total_params, total_trainable_params, file=file)\n",
    "        for p in model.named_parameters():\n",
    "            print(p[0], p[1].shape, p[1].numel(), file=file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "eaa56b66-9017-4712-9945-41987f697326",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-06-08T12:51:25.829061Z",
     "iopub.status.busy": "2023-06-08T12:51:25.828604Z",
     "iopub.status.idle": "2023-06-08T12:51:25.835445Z",
     "shell.execute_reply": "2023-06-08T12:51:25.834183Z",
     "shell.execute_reply.started": "2023-06-08T12:51:25.829018Z"
    },
    "tags": [
     "active-ipynb",
     "style-student"
    ]
   },
   "outputs": [],
   "source": [
    "path = '/usr/local/data/1-model_save/3-CS/CSNet+/coco/1/version_0'\n",
    "backup_logger_file(path)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,py:light"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
