{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from spadio import SPADFolder, SPADData  # noqa\n",
    "from spadclean import GenerateTestData, SPADHotpixelTool  # noqa\n",
    "from pathlib import Path\n",
    "from utils import clean_hotpixels\n",
    "from inference import cpu_inference\n",
    "from metadata import TrainData, ModelConfig, TrainConfig, load_config\n",
    "from dataset import (\n",
    "    BernoulliDataset3D,\n",
    "    ValidationDataset3D,\n",
    "    BinomDataset3D,\n",
    "    N2NDataset3D,\n",
    ")  # noqa\n",
    "from spadgapmodels import SPADGAP\n",
    "import torch\n",
    "import torch.utils.data as dt\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.callbacks import (\n",
    "    LearningRateMonitor,\n",
    "    ModelCheckpoint,\n",
    "    EarlyStopping,\n",
    "    DeviceStatsMonitor,\n",
    ")\n",
    "from lightning.pytorch.loggers import TensorBoardLogger\n",
    "from tifffile import imwrite, imread\n",
    "import numpy as np\n",
    "import logging\n",
    "import sys\n",
    "import shutil\n",
    "\n",
    "\n",
    "default_root_dir = Path(\"../models/20240413_073549_00999999_1000x32x256x256_skip=0_l=10_d=5_sf=32_ds=2at10_f=10.0_z=2_g=8_sd=0_b=tri_a=gelu_b=4_e=500_p=32\")\n",
    "configure_path = default_root_dir / \"config.yml\"\n",
    "config = load_config(path=configure_path)  # CLI argument\n",
    "\n",
    "data_type = config[\"PATH\"][\"data_type\"]\n",
    "if data_type not in [\"raw\", \"processed\"]:\n",
    "    raise ValueError(\"Data type must be RAW or CLEAN\")\n",
    "\n",
    "dir_path = Path(config[\"PATH\"][\"dir_path\"])\n",
    "num_of_files = config[\"PATH\"][\"num_of_files\"]\n",
    "data_dir = Path(config[\"PATH\"][\"data_dir\"])\n",
    "data_path = config[\"PATH\"][\"data_path\"]\n",
    "data_file = config[\"PATH\"][\"data_file\"]\n",
    "ground_truth_path = config[\"PATH\"][\"ground_truth_path\"]\n",
    "ground_truth_file = config[\"PATH\"][\"ground_truth_file\"]\n",
    "model_path = Path(config[\"PATH\"][\"model_path\"])\n",
    "\n",
    "data_path = data_dir / data_file if data_path == \"\" else Path(data_path)\n",
    "ground_truth_path = (\n",
    "    data_dir / ground_truth_file if ground_truth_path == \"\" else Path(ground_truth_path)\n",
    ")\n",
    "if data_type == \"raw\":\n",
    "    try:\n",
    "        input_folder = SPADFolder(dir_path)\n",
    "    except FileNotFoundError:\n",
    "        logging.error(\"Folder not found\")\n",
    "        # sys.exit(1)\n",
    "    input = input_folder.spadstack[:num_of_files]\n",
    "    data = input.process(clean_hotpixels)\n",
    "    del input\n",
    "elif data_type == \"processed\":\n",
    "    try:\n",
    "        data = imread(data_path)\n",
    "        ground_truth_file = imread(ground_truth_path)\n",
    "    except FileNotFoundError:\n",
    "        logging.error(\"File not found\")\n",
    "        # sys.exit(1)\n",
    "\n",
    "\n",
    "data_config = TrainData.from_config(config[\"DATA\"], data.astype(np.float32))\n",
    "model_config = ModelConfig.from_config(config[\"MODEL\"])\n",
    "train_config = TrainConfig.from_config(config[\"TRAINING\"], data_config, model_config)\n",
    "val_data_config = TrainData.from_config_validation(\n",
    "    config[\"DATA\"], data.astype(np.float32)\n",
    ")\n",
    "print(train_config.metadata())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = BernoulliDataset3D.from_dataclass(data_config)\n",
    "val_data = ValidationDataset3D.from_dataclass(val_data_config)\n",
    "\n",
    "loader_config = {\n",
    "    \"batch_size\": train_config.batch_size,\n",
    "    \"shuffle\": train_config.shuffle,\n",
    "    \"pin_memory\": train_config.pin_memory,\n",
    "    \"drop_last\": train_config.drop_last,\n",
    "    \"num_workers\": train_config.num_workers,\n",
    "    \"persistent_workers\": True,\n",
    "}\n",
    "\n",
    "train_loader = dt.DataLoader(train_data, **loader_config)\n",
    "loader_config[\"shuffle\"] = False\n",
    "val_loader = dt.DataLoader(val_data, **loader_config)\n",
    "\n",
    "test_name = train_config.name\n",
    "ckpt_path = default_root_dir / \"final_model.ckpt\"\n",
    "if not ckpt_path.exists():\n",
    "    ckpt_path = default_root_dir / \"version_0\" / \"checkpoints\"\n",
    "    check_points = list(ckpt_path.glob(\"*.ckpt\"))\n",
    "    check_points.sort()\n",
    "    ckpt_path = check_points[-1]\n",
    "    print(f\"Using checkpoint: {ckpt_path}\")\n",
    "\n",
    "model = SPADGAP.load_from_checkpoint(ckpt_path)\n",
    "model.train()\n",
    "\n",
    "logger = TensorBoardLogger(save_dir=model_path, name=test_name)\n",
    "\n",
    "trainer = pl.Trainer(\n",
    "    default_root_dir=default_root_dir,\n",
    "    accelerator=\"cuda\",\n",
    "    gradient_clip_val=1,\n",
    "    precision=train_config.precision,  # type: ignore\n",
    "    devices=[train_config.device_number],\n",
    "    max_epochs=train_config.epochs,\n",
    "    callbacks=[\n",
    "        ModelCheckpoint(\n",
    "            save_weights_only=True,\n",
    "            mode=\"min\",\n",
    "            monitor=\"val_loss\",\n",
    "            save_top_k=2,\n",
    "        ),\n",
    "        LearningRateMonitor(\"epoch\"),\n",
    "        # EarlyStopping(\"val_loss\", patience=25),\n",
    "        # DeviceStatsMonitor(),\n",
    "    ],\n",
    "    logger=logger,  # type: ignore\n",
    "    profiler=\"simple\",\n",
    "    limit_val_batches=20,\n",
    "    enable_model_summary=True,\n",
    "    enable_checkpointing=True,\n",
    ")\n",
    "print(f\"input_size: {tuple(next(iter(train_loader))[1].shape)}\")\n",
    "print(f\"file: {test_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.fit(model, train_loader, val_loader)\n",
    "trainer.save_checkpoint(default_root_dir / \"final_model.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_checkpoint(default_root_dir / \"final_model.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import clear_vram  \n",
    "clear_vram()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_backup = data.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 10\n",
    "datalist = []\n",
    "output_list = []\n",
    "output_value = []\n",
    "for i in range(n):\n",
    "    default_root_dir_b = default_root_dir / f\"psnr096/new_inference_{i}\"\n",
    "    default_root_dir_b.mkdir(exist_ok=True, parents=True)  \n",
    "    input = np.random.binomial(data_backup, 0.96)[3450:3450+512].astype(np.float32)\n",
    "    datalist.append(input)  \n",
    "    output = gpu_patch_inference(\n",
    "        model,\n",
    "        input,\n",
    "        initial_patch_depth=48,\n",
    "        min_overlap=40,\n",
    "        device=train_config.device_number,\n",
    "    )\n",
    "    imwrite(default_root_dir_b / \"output.tif\", output)\n",
    "    output_list.append(output)\n",
    "    ground_truth = ground_truth_file[3450:3450+512].astype(float)\n",
    "    psnr = group_metrics(input, output, ground_truth, default_root_dir_b, device=train_config.device)\n",
    "    output_value.append(psnr)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = np.stack(output_list)\n",
    "output_sum = np.sum(output, axis=0)\n",
    "default_root_dir_b = default_root_dir / f\"psnr096/sum\"\n",
    "psnr_sum = group_metrics(input, output_sum, ground_truth, default_root_dir_b, device=train_config.device)\n",
    "output_median = np.median(output, axis=0)\n",
    "default_root_dir_b = default_root_dir / f\"psnr096/median\"\n",
    "psnr_median = group_metrics(input, output_median, ground_truth, default_root_dir_b, device=train_config.device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_sum_b = output_sum - np.max(output, axis=0)\n",
    "output_sum_b = output_sum_b - np.min(output, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "default_root_dir_b = default_root_dir / f\"psnr096/sum_no_outliers\"\n",
    "psnr_sum = group_metrics(input, output_sum_b, ground_truth, default_root_dir_b, device=train_config.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imwrite(default_root_dir_b / \"output.tif\", output_sum_b)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "plt.plot(output_value)\n",
    "plt.title(\"PSNR vs P\")\n",
    "plt.xlabel(\"P = n/100+0.9\")\n",
    "plt.ylabel(\"PSNR\")\n",
    "plt.savefig(default_root_dir / f\"psnr096/PSNR_vs_P.png\")\n",
    "plt.savefig(default_root_dir / f\"psnr096/PSNR_vs_P.svg\" , format='svg', dpi=1200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "default_root_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from inference import gpu_patch_inference\n",
    "default_root_dir_b = default_root_dir / \"rand_split\"\n",
    "data = datalist[1][3450:3450+512]\n",
    "output = gpu_patch_inference(\n",
    "    model,\n",
    "    data.astype(np.float32),\n",
    "    initial_patch_depth=48,\n",
    "    min_overlap=40,\n",
    "    device=train_config.device_number,\n",
    ")\n",
    "imwrite(default_root_dir_b / \"input.tif\", data)\n",
    "imwrite(default_root_dir_b / \"inference.tif\", output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputlist = []\n",
    "\n",
    "default_root_dir_b = default_root_dir / \"10rand_split\"\n",
    "default_root_dir_b.mkdir(parents=True, exist_ok=True)\n",
    "for i in range(10):\n",
    "    output = gpu_patch_inference(\n",
    "        model,\n",
    "        datalist[i][3450:3450+512].astype(np.float32),\n",
    "        initial_patch_depth=48,\n",
    "        min_overlap=40,\n",
    "        device=train_config.device_number,\n",
    "    )\n",
    "    outputlist.append(output)\n",
    "output = np.stack(outputlist)\n",
    "output = output.sum(axis=0)\n",
    "imwrite(default_root_dir_b / \"input.tif\", data[3450:3450+512])\n",
    "# imwrite(default_root_dir_b / \"input2.tif\", data2[3450:3450+512])\n",
    "imwrite(default_root_dir_b / \"inference.tif\", output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from inference import gpu_patch_inference\n",
    "output = gpu_patch_inference(\n",
    "    model,\n",
    "    data[3450:3450+512].astype(np.float32),\n",
    "    initial_patch_depth=48,\n",
    "    min_overlap=40,\n",
    "    device=train_config.device_number,\n",
    ")\n",
    "imwrite(default_root_dir / \"input.tif\", data[3450:3450+512])\n",
    "imwrite(default_root_dir / \"inference.tif\", output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import group_metrics\n",
    "input = data.astype(float)\n",
    "image = output\n",
    "ground_truth = ground_truth_file[3450:3450+512].astype(float)\n",
    "group_metrics(input, image, ground_truth, default_root_dir_b, device=train_config.device)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gap3d",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
