{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from spadio import SPADFolder, SPADData  # noqa\n",
    "from spadclean import GenerateTestData, SPADHotpixelTool  # noqa\n",
    "from pathlib import Path\n",
    "from utils import clean_hotpixels\n",
    "from inference import cpu_inference\n",
    "from metadata import TrainData, ModelConfig, TrainConfig, load_config\n",
    "from dataset import (\n",
    "    BernoulliDataset3D,\n",
    "    ValidationDataset3D,\n",
    "    PairedDataset,\n",
    "    BinomDataset3D,\n",
    "    N2NDataset3D,\n",
    ")  # noqa\n",
    "from spadgapmodels import SPADGAP\n",
    "import torch\n",
    "import torch.utils.data as dt\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.callbacks import (\n",
    "    LearningRateMonitor,\n",
    "    ModelCheckpoint,\n",
    "    EarlyStopping,\n",
    "    DeviceStatsMonitor,\n",
    ")\n",
    "from lightning.pytorch.loggers import TensorBoardLogger\n",
    "from tifffile import imwrite, imread\n",
    "import numpy as np\n",
    "import logging\n",
    "import sys\n",
    "import shutil\n",
    "\n",
    "configure_path = Path(\"./config.yml\")\n",
    "config = load_config(path=configure_path)  # CLI argument\n",
    "\n",
    "# logging.basicConfig(\n",
    "#     # filename=config[\"PATH\"][\"logger\"],\n",
    "#     level=logging.DEBUG,\n",
    "#     format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\",\n",
    "#     stream=sys.stdout,\n",
    "# )\n",
    "\n",
    "data_type = config[\"PATH\"][\"data_type\"]\n",
    "if data_type not in [\"raw\", \"processed\"]:\n",
    "    raise ValueError(\"Data type must be RAW or CLEAN\")\n",
    "\n",
    "dir_path = Path(config[\"PATH\"][\"dir_path\"])\n",
    "num_of_files = config[\"PATH\"][\"num_of_files\"]\n",
    "data_dir = Path(config[\"PATH\"][\"data_dir\"])\n",
    "data_path = config[\"PATH\"][\"data_path\"]\n",
    "data_file = config[\"PATH\"][\"data_file\"]\n",
    "ground_truth_path = config[\"PATH\"][\"ground_truth_path\"]\n",
    "ground_truth_file = config[\"PATH\"][\"ground_truth_file\"]\n",
    "model_path = Path(config[\"PATH\"][\"model_path\"])\n",
    "\n",
    "data_path = data_dir / data_file if data_path == \"\" else Path(data_path)\n",
    "ground_truth_path = (\n",
    "    data_dir / ground_truth_file if ground_truth_path == \"\" else Path(ground_truth_path)\n",
    ")\n",
    "if data_type == \"raw\":\n",
    "    try:\n",
    "        input_folder = SPADFolder(dir_path)\n",
    "    except FileNotFoundError:\n",
    "        logging.error(\"Folder not found\")\n",
    "        # sys.exit(1)\n",
    "    input = input_folder.spadstack[:num_of_files]\n",
    "    data = input.process(clean_hotpixels)\n",
    "    del input\n",
    "elif data_type == \"processed\":\n",
    "    try:\n",
    "        data = imread(data_path)\n",
    "        ground_truth_file = imread(ground_truth_path)\n",
    "    except FileNotFoundError:\n",
    "        logging.error(\"File not found\")\n",
    "        # sys.exit(1)\n",
    "\n",
    "\n",
    "data_config = TrainData.from_config(config[\"DATA\"], data[:80000].astype(np.float32))\n",
    "model_config = ModelConfig.from_config(config[\"MODEL\"])\n",
    "train_config = TrainConfig.from_config(config[\"TRAINING\"], data_config, model_config)\n",
    "val_data_config = TrainData.from_config_validation(\n",
    "    config[\"DATA\"], (data[80000:].astype(np.float32))\n",
    ")\n",
    "print(train_config.metadata())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = BernoulliDataset3D.from_dataclass(data_config)\n",
    "val_data = ValidationDataset3D.from_dataclass(val_data_config)\n",
    "\n",
    "loader_config = {\n",
    "    \"batch_size\": train_config.batch_size,\n",
    "    \"shuffle\": train_config.shuffle,\n",
    "    \"pin_memory\": train_config.pin_memory,\n",
    "    \"drop_last\": train_config.drop_last,\n",
    "    \"num_workers\": train_config.num_workers,\n",
    "    \"persistent_workers\": True,\n",
    "}\n",
    "\n",
    "train_loader = dt.DataLoader(train_data, **loader_config)\n",
    "loader_config[\"shuffle\"] = False\n",
    "val_loader = dt.DataLoader(val_data, **loader_config)\n",
    "\n",
    "test_name = train_config.name\n",
    "default_root_dir = model_path / test_name\n",
    "if not default_root_dir.exists():\n",
    "    default_root_dir.mkdir(parents=True)\n",
    "\n",
    "model = SPADGAP.from_dataclass(model_config)\n",
    "model.train()\n",
    "\n",
    "logger = TensorBoardLogger(save_dir=model_path, name=test_name)\n",
    "\n",
    "trainer = pl.Trainer(\n",
    "    default_root_dir=default_root_dir,\n",
    "    accelerator=\"cuda\",\n",
    "    gradient_clip_val=1,\n",
    "    precision=train_config.precision,  # type: ignore\n",
    "    devices=[train_config.device_number],\n",
    "    max_epochs=train_config.epochs,\n",
    "    callbacks=[\n",
    "        ModelCheckpoint(\n",
    "            save_weights_only=True,\n",
    "            mode=\"min\",\n",
    "            monitor=\"val_loss\",\n",
    "            save_top_k=2,\n",
    "        ),\n",
    "        LearningRateMonitor(\"epoch\"),\n",
    "        # EarlyStopping(\"val_loss\", patience=25),\n",
    "        # DeviceStatsMonitor(),\n",
    "    ],\n",
    "    logger=logger,  # type: ignore\n",
    "    profiler=\"simple\",\n",
    "    limit_val_batches=20,\n",
    "    enable_model_summary=True,\n",
    "    enable_checkpointing=True,\n",
    ")\n",
    "print(f\"input_size: {tuple(next(iter(train_loader))[0].shape)}\")\n",
    "print(f\"file: {test_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import torchinfo\n",
    "# torchinfo.summary(model, input_size=(train_config.batch_size, 1, 64, 64, 64))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.train()\n",
    "train_config.to_yaml(default_root_dir / \"metadata.yml\")\n",
    "shutil.copyfile(configure_path, default_root_dir / \"config.yml\")\n",
    "trainer.fit(model, train_loader, val_loader)\n",
    "trainer.save_checkpoint(default_root_dir / \"final_model.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_checkpoint(default_root_dir / \"final_model.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import clear_vram  \n",
    "clear_vram()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_split = np.random.binomial(data, 0.9)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "splitted_inference = []\n",
    "for i in range(0, 10):\n",
    "    data_split = np.random.binomial(data, 0.9)\n",
    "    output = gpu_patch_inference(\n",
    "        model,\n",
    "        data_split[:512].astype(np.float32),\n",
    "        initial_patch_depth=48,\n",
    "        min_overlap=40,\n",
    "        device=train_config.device_number,\n",
    "    )\n",
    "    splitted_inference.append(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from inference import gpu_patch_inference\n",
    "output = gpu_patch_inference(\n",
    "    model,\n",
    "    data[:512].astype(np.float32),\n",
    "    initial_patch_depth=48,\n",
    "    min_overlap=40,\n",
    "    device=train_config.device_number,\n",
    ")\n",
    "\n",
    "imwrite(default_root_dir / \"input.tif\", data[:512])\n",
    "imwrite(default_root_dir / \"inference.tif\", output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import group_metrics\n",
    "\n",
    "input = data[:512].astype(float)\n",
    "image = output\n",
    "ground_truth = ground_truth_file[:512].astype(float)\n",
    "group_metrics(input, image, ground_truth, default_root_dir, device=train_config.device)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gap3d",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
