{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "model_path = Path(\"../models/\")\n",
    "dir_list = [path for path in model_path.iterdir() if path.is_dir()]\n",
    "processed_dir_list = [path for path in dir_list if (path / \"input.tif\").exists()]\n",
    "\n",
    "default_root_dir = (\n",
    "    model_path\n",
    "    / \"20240417_095627_0-0.999999_1000x32x256x256_skip=0_l=10_d=7_sf=32_ds=2at10_f=10.0_z=3_g=8_sd=0_b=tri_a=gelu_b=4_e=500_p=32\"\n",
    ")\n",
    "ckpt_path = default_root_dir / \"version_0/checkpoints/epoch=79-step=20000.ckpt\"\n",
    "input_path = default_root_dir / \"input.tif\"\n",
    "output_path = default_root_dir / \"inference.tif\"\n",
    "data_dir = Path(\"../../tempdata/Simulated_data/\")\n",
    "data_path = data_dir / \"RES_4341_p=0.06250_m=0.05898.tif\"\n",
    "ground_truth_path = data_dir / \"ref.tif\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run a gpu inference\n",
    "from inference import gpu_patch_inference\n",
    "from spadgapmodels import SPADGAP\n",
    "import torch\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "from tifffile import imread, imwrite\n",
    "\n",
    "model = SPADGAP.load_from_checkpoint(ckpt_path)\n",
    "\n",
    "input = imread(data_path)\n",
    "output = gpu_patch_inference(\n",
    "    model,\n",
    "    input[:512].astype(np.float32),\n",
    "    min_overlap=32,\n",
    "    device=1,\n",
    ")\n",
    "\n",
    "imwrite(default_root_dir / \"inference.tif\", output)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import group_metrics\n",
    "\n",
    "input = input[:512].astype(float)\n",
    "image = output[:512]\n",
    "ground_truth = imread(ground_truth_path)[:512].astype(float)\n",
    "\n",
    "group_metrics(input, image, ground_truth, default_root_dir, device=torch.device(\"cuda:0\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.get_device_properties(torch.device(\"cuda:0\")).total_memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from metrics import ImageMetrics, StackMetrics, StackMetricsGroups\n",
    "import torch\n",
    "import seaborn as sns\n",
    "\n",
    "property_list = [\"mse\", \"psnr\", \"ssim\", \"ncc\", \"brisque\", \"niqe\"]\n",
    "\n",
    "image_size = 256\n",
    "frame_count = 16\n",
    "ref = torch.randn(frame_count, image_size, image_size) * 100\n",
    "image = ref + torch.randn(frame_count, image_size, image_size) / 2\n",
    "metrics = StackMetrics(image, ref, property_list=property_list)\n",
    "\n",
    "ref2 = torch.randn(frame_count, image_size, image_size) * 100\n",
    "image2 = ref2 + torch.randn(frame_count, image_size, image_size) * 2\n",
    "metrics2 = StackMetrics(image2, ref2, property_list=property_list)\n",
    "\n",
    "stack_group = StackMetricsGroups([metrics, metrics2], property_list=property_list)\n",
    "\n",
    "\n",
    "stack_group.plot_group_trends()\n",
    "stack_group.plot_group_stats(kind=\"bar\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import clean_directories\n",
    "clean_directories()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from metrics import StackMetrics, StackMetricsGroups\n",
    "import torch\n",
    "from pathlib import Path\n",
    "from tifffile import imread\n",
    "\n",
    "length = 512\n",
    "device = 0\n",
    "\n",
    "input = imread(input_path)[:length].astype(float)\n",
    "output = imread(output_path)[:length].astype(float)\n",
    "ground_truth = imread(ground_truth_path)[:length].astype(float)\n",
    "\n",
    "for i in range(length):\n",
    "    input[i] = (input[i] - input[i].min()) / (input[i].max() - input[i].min())\n",
    "    output[i] = (output[i] - output[i].min()) / (output[i].max() - output[i].min())\n",
    "    ground_truth[i] = (ground_truth[i] - ground_truth[i].min()) / (\n",
    "        ground_truth[i].max() - ground_truth[i].min()\n",
    "    )\n",
    "\n",
    "metric_list = [\"mse\", \"psnr\", \"ssim\"]\n",
    "metric1 = StackMetrics(\n",
    "    output,\n",
    "    ground_truth,\n",
    "    metric_list=metric_list,\n",
    "    device=torch.device(f\"cuda:{device}\"),\n",
    ")\n",
    "metric2 = StackMetrics(\n",
    "    input,\n",
    "    ground_truth,\n",
    "    metric_list=metric_list,\n",
    "    device=torch.device(f\"cuda:{device}\"),\n",
    ")\n",
    "metric_group = StackMetricsGroups([metric1, metric2], [\"processed\", \"raw\"], metric_list)\n",
    "metric_group.plot_group_stats(\n",
    "    save=True, save_dir=default_root_dir, save_name=\"group_stats\"\n",
    ")\n",
    "metric_group.plot_group_trends(\n",
    "    save=True, save_dir=default_root_dir, save_name=\"group_trends\"\n",
    ")\n",
    "metric1.stats_df.to_csv(default_root_dir / \"processed_stats.csv\")\n",
    "metric2.stats_df.to_csv(default_root_dir / \"raw_stats.csv\")\n",
    "metric1.values_df.to_csv(default_root_dir / \"processed_values.csv\")\n",
    "metric2.values_df.to_csv(default_root_dir / \"raw_values.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import yaml\n",
    "import pandas as pd\n",
    "\n",
    "meta_list = {}\n",
    "for model_dir in processed_dir_list:\n",
    "    meta_dir = model_dir / \"metadata.yml\"\n",
    "    if meta_dir.exists():\n",
    "        with open(meta_dir, \"r\") as file:\n",
    "            metadata = yaml.safe_load(file)\n",
    "        data = metadata.pop(\"data\", {})\n",
    "        model = metadata.pop(\"model\", {})\n",
    "        optimizer = model.pop(\"optimizer\", {})\n",
    "        optimizer_config = model.pop(\"optimizer_config\", {})\n",
    "        output = {\n",
    "            \"dir\": str(meta_dir),\n",
    "            **data,\n",
    "            **model,\n",
    "            **optimizer,\n",
    "            **optimizer_config,\n",
    "            **metadata,\n",
    "        }\n",
    "        time_stamp = output.pop(\"time_stamp\")\n",
    "        meta_list[time_stamp] = output\n",
    "    metric_dir = model_dir / \"processed_stats.csv\"\n",
    "    if metric_dir.exists():\n",
    "        metrics = pd.read_csv(metric_dir)\n",
    "        metrics_long = pd.melt(\n",
    "            metrics, id_vars=[\"Stat\"], var_name=\"metric\", value_name=\"value\"\n",
    "        )\n",
    "        result_dict = {\n",
    "            f\"{stat}_{metric}\": value for stat, metric, value in metrics_long.values\n",
    "        }\n",
    "        meta_list[time_stamp] = {**meta_list[time_stamp], **result_dict}\n",
    "\n",
    "metadf = pd.DataFrame.from_dict(meta_list).T.to_csv(model_path / \"metadata_summary.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tifffile import imread, imwrite\n",
    "import numpy as np\n",
    "from utils import group_metrics\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import sys\n",
    "\n",
    "ground_truth = imread(ground_truth_path)[:512].astype(float)\n",
    "for dir in tqdm(processed_dir_list, file=sys.stdout):\n",
    "    input = imread(dir / \"input.tif\")[:512].astype(float)\n",
    "    image = imread(dir / \"inference.tif\")[:512].astype(float)\n",
    "    out = group_metrics(input, image, ground_truth, dir, device=torch.device(\"cuda:1\"))\n",
    "    print(f\"{dir.stem}: {out}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "dir = processed_dir_list[3]\n",
    "ground_truth = imread(ground_truth_path)[:512].astype(float)\n",
    "input = imread(dir / \"input.tif\")[:512].astype(float)\n",
    "image = imread(dir / \"inference.tif\")[:512].astype(float)\n",
    "group_metrics(input, image, ground_truth, dir, device=torch.device(\"cuda:1\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from tifffile import imread, imwrite\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "path = Path(\"../video_models/20240422_223901_Gated_pball_1_1000x32x256x256_skip=0_l=10_d=5_sf=32_ds=2at10_f=10.0_z=3_g=8_sd=0_b=tri_a=gelu_b=4_e=250_p=32/800x500x500_infernece.tif\")\n",
    "data = imread(path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_means = data.mean(axis=(1, 2))\n",
    "data = data / data_means[:, None, None]\n",
    "ewma = pd.Series(data_means).ewm(span=400).mean().values\n",
    "output = data * ewma[:, None, None]\n",
    "output_path = path.parent / \"800x500x500_infernece_normalzied.tif\"\n",
    "imwrite(output_path, output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ewma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from spadio import SPADFolder, SPADData  # noqa\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "data_dir = Path(\"../../tempdata/\")\n",
    "data_path = data_dir / \"acq00005\"\n",
    "\n",
    "\n",
    "try:\n",
    "    input_folder = SPADFolder(data_path)\n",
    "except FileNotFoundError:\n",
    "    raise FileNotFoundError(\"Input folder does not exist\")\n",
    "input = input_folder.spadstack[:100]\n",
    "data = input.stack\n",
    "data_means = np.sum(data, axis=(1,2 ))\n",
    "plt.plot(data_means)\n",
    "plt.savefig(data_path / \"data_means.png\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spadgap",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
