{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2ad058da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.8.0+cu128\n",
      "cuda\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import scipy.io\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import copy\n",
    "import datetime\n",
    "import numpy as np\n",
    "import random\n",
    "import json\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.nn import functional as F\n",
    "\n",
    "print(torch.__version__)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else : \n",
    "    device = torch.device('cpu')\n",
    "print(device)\n",
    "\n",
    "# self defined package\n",
    "import dataset.aug_index as ai # data augmentation and some index for reconstruction\n",
    "import NSC_module.modules_encoder as enc\n",
    "import NSC_module.modules_decoder as dec\n",
    "import NSC_module.custom_loss as c_loss # not used in single ablation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c04aa8c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightning.pytorch import Trainer, seed_everything # used for seed setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b39a4c76",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.load('./dataset/D1N2_wave.npy')\n",
    "y = np.load('./dataset/D1N2_neo.npy')\n",
    "ds = \"D1N2\"\n",
    "\n",
    "# X = np.load('./dataset/D2N2_wave.npy')\n",
    "# y = np.load('./dataset/D2N2_neo.npy')\n",
    "# ds = \"D2N2\"\n",
    "\n",
    "# X = np.load('./dataset/NP_T1_wave.npy')\n",
    "# y = np.load('./dataset/NP_T1_neo.npy')\n",
    "# ds = \"NPT1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e46e0de2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NEAC(nn.Module):\n",
    "    def __init__(self, latent_dim=4, num_windows=3, require_lr_window=True, require_quant=True):\n",
    "        super().__init__()\n",
    "        # Abalation\n",
    "        self.encoder = enc.NeuralSignalCodec(latent_dim=latent_dim, num_windows=num_windows, required_quant=require_quant, required_lr_window=require_lr_window)\n",
    "        self.decoder = dec.Base_decoder(latent_dim=latent_dim)\n",
    "\n",
    "    def forward(self, x, x_neo, quant_force=None):\n",
    "        z = self.encoder(x, x_neo, quant_force)\n",
    "        x_recon = self.decoder(z.float())\n",
    "        return z, x_recon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e6e4575d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "Seed set to 2\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 3\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 4\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 5\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 1\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSNR :  18.496363 / 18.37494\n",
      "R2 :  0.34553713 / 0.4577774\n",
      "NRMSE :  0.12610164 / 0.13013941\n",
      "SNDR :  3.318254 / 4.8517823\n",
      "seed  1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 2\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 3\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 4\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 5\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSNR :  17.429623 / 17.626965\n",
      "R2 :  0.09551958 / 0.22525072\n",
      "NRMSE :  0.14394683 / 0.14257644\n",
      "SNDR :  2.2515154 / 4.103808\n",
      "seed  1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "Seed set to 2\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 3\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 4\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 5\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 1\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSNR :  17.897932 / 19.3832\n",
      "R2 :  -5.621179 / 0.6488017\n",
      "NRMSE :  0.18399289 / 0.1117212\n",
      "SNDR :  2.719824 / 5.8600416\n",
      "seed  1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 2\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 3\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 4\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 5\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 1\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PSNR :  17.959738 / 19.477411\n",
      "R2 :  -0.09175882 / 0.65669674\n",
      "NRMSE :  0.14183117 / 0.11058762\n",
      "SNDR :  2.7816293 / 5.954253\n",
      "seed  1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 2\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 3\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 4\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 5\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  5\n",
      "PSNR :  17.439709 / 19.433949\n",
      "R2 :  -0.15131316 / 0.651242\n",
      "NRMSE :  0.15083238 / 0.11119658\n",
      "SNDR :  2.2615998 / 5.9107904\n"
     ]
    }
   ],
   "source": [
    "# training part\n",
    "\n",
    "cr = 32 \n",
    "seed_list = [1, 2, 3, 4, 5]\n",
    "\n",
    "results = {}\n",
    "\n",
    "train_epochs = 100\n",
    "\n",
    "for cho in range(1, 6):\n",
    "    choice = cho\n",
    "    results[f'{choice}'] = {}\n",
    "\n",
    "    psnr_valt = []\n",
    "    r2_valt = []\n",
    "    sndr_valt = []\n",
    "    nrmse_valt = []\n",
    "    psnr_mid_valt = []\n",
    "    r2_mid_valt = []\n",
    "    sndr_mid_valt = []\n",
    "    nrmse_mid_valt = []\n",
    "    all_train_loss = []\n",
    "    all_valid_loss = []\n",
    "\n",
    "    for se in range (len(seed_list)):\n",
    "        train_loss_list = []\n",
    "        valid_loss_list = []\n",
    "        \n",
    "        psnr_val = []\n",
    "        r2_val = []\n",
    "        sndr_val = []\n",
    "        nrmse_val = []\n",
    "        psnr_mid_val = []\n",
    "        r2_mid_val = []\n",
    "        sndr_mid_val = []\n",
    "        nrmse_mid_val = []\n",
    "        \n",
    "        seed = seed_list[se]\n",
    "        print(\"seed \", seed)\n",
    "        seed_everything(seed, workers=True)\n",
    "\n",
    "        model = NEAC(latent_dim=int(128/cr), num_windows = 3, require_lr_window = False, require_quant = False).to(device)\n",
    "\n",
    "        trainer = Trainer(deterministic=True)\n",
    "        optimizer = optim.AdamW([\n",
    "            {'params': [model.encoder.alpha, model.encoder.beta, model.encoder.gamma], 'lr': 1e-3},  \n",
    "            {'params': model.encoder.boundary_params, 'lr': 1e-3, 'weight_decay': 1e-4},\n",
    "            {'params': model.decoder.parameters(), 'lr': 1e-3, 'weight_decay': 1e-4},\n",
    "        ])\n",
    "\n",
    "        # Datasets\n",
    "        bs = 64\n",
    "\n",
    "        # class not involve in training, here we hold for valid incase the other demand\n",
    "        X_tmp, X_test, y_tmp, y_test, cl_tmp, cl_test = train_test_split(X, y, cl, test_size=0.2)\n",
    "        X_train, X_valid, y_train, y_valid, _, cl_valid = train_test_split(X_tmp, y_tmp, cl_tmp, test_size=0.125)\n",
    "        \n",
    "        X_train = np.array(X_train)\n",
    "\n",
    "        X_train = torch.tensor(X_train, dtype=torch.float32)\n",
    "        X_valid = torch.tensor(X_valid, dtype=torch.float32)\n",
    "        X_test  = torch.tensor(X_test , dtype=torch.float32)\n",
    "\n",
    "        y_train = torch.tensor(y_train, dtype=torch.float32)\n",
    "        y_valid = torch.tensor(y_valid, dtype=torch.float32)\n",
    "        y_test  = torch.tensor(y_test , dtype=torch.float32)\n",
    "\n",
    "        train_dataset = TensorDataset(X_train, y_train)\n",
    "        valid_dataset = TensorDataset(X_valid, y_valid)\n",
    "        test_dataset  = TensorDataset(X_test, y_test)\n",
    "\n",
    "        train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)\n",
    "        valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False)\n",
    "        test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)\n",
    "\n",
    "        best_psnr = -np.inf\n",
    "\n",
    "        sr = 20000\n",
    "\n",
    "        for epoch in range(train_epochs):\n",
    "            model.train()\n",
    "            running_loss = 0.0\n",
    "            valid_loss = 0.0\n",
    "            for (inputs, inputs_neo) in train_loader:\n",
    "                inputs = inputs.to(device)\n",
    "                inputs_neo = inputs_neo.to(device)\n",
    "\n",
    "                optimizer.zero_grad()\n",
    "                _, outputs = model(inputs, inputs_neo)\n",
    "\n",
    "                if (choice == 1) :\n",
    "                    loss = F.mse_loss(outputs, inputs)\n",
    "                elif (choice == 2) :\n",
    "                    loss = c_loss.window_energy_direct_loss(outputs, inputs, inputs_neo, 'mse')\n",
    "                elif (choice == 3) :\n",
    "                    loss = c_loss.window_energy_dist_loss(outputs, inputs, inputs_neo, 'laplace', 'mse', sampling_rate=sr)\n",
    "                elif (choice == 4) :\n",
    "                    loss = c_loss.window_energy_dist_loss(outputs, inputs, inputs_neo, 'gaussian', 'mse', sampling_rate=sr)\n",
    "                elif (choice == 5) :\n",
    "                    loss = c_loss.window_energy_dist_loss(outputs, inputs, inputs_neo, 'cauchy', 'mse', sampling_rate=sr)\n",
    "\n",
    "                loss.backward()\n",
    "                clip_norm = 10\n",
    "                total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm)\n",
    "                #print(f\"Gradient norm before clipping: {total_norm.item()}\")\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()\n",
    "\n",
    "            train_loss_list.append(running_loss/len(train_loader))\n",
    "\n",
    "            model.eval()\n",
    "            all_recons = []\n",
    "            all_targets = []\n",
    "\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                for (inputs, inputs_neo) in valid_loader:\n",
    "                    inputs = inputs.to(device)\n",
    "                    inputs_neo = inputs_neo.to(device)\n",
    "\n",
    "                    _, outputs = model(inputs, inputs_neo)\n",
    "                    for i in inputs:\n",
    "                        all_targets.append(i.cpu().numpy())\n",
    "                    for o in outputs:\n",
    "                        all_recons.append(o.cpu().numpy())\n",
    "\n",
    "                    if (choice == 1) :\n",
    "                        loss = F.mse_loss(outputs, inputs)\n",
    "                    elif (choice == 2) :\n",
    "                        loss = c_loss.window_energy_direct_loss(outputs, inputs, inputs_neo, 'mse')\n",
    "                    elif (choice == 3) :\n",
    "                        loss = c_loss.window_energy_dist_loss(outputs, inputs, inputs_neo, 'laplace', 'mse', sampling_rate=sr)\n",
    "                    elif (choice == 4) :\n",
    "                        loss = c_loss.window_energy_dist_loss(outputs, inputs, inputs_neo, 'gaussian', 'mse', sampling_rate=sr)\n",
    "                    elif (choice == 5) :\n",
    "                        loss = c_loss.window_energy_dist_loss(outputs, inputs, inputs_neo, 'cauchy', 'mse', sampling_rate=sr)\n",
    "\n",
    "                    valid_loss += loss.item()\n",
    "\n",
    "            valid_loss_list.append(valid_loss/len(valid_loader))\n",
    "\n",
    "            tmp_psnr_mid = 0\n",
    "            for i in range (len(all_recons)):\n",
    "                tmp_psnr_mid += ai.psnr(all_recons[i][32:96], all_targets[i][32:96])\n",
    "            \n",
    "            psnr_mid_valv = tmp_psnr_mid / len(all_recons)\n",
    "            if psnr_mid_valv > best_psnr:\n",
    "                best_psnr = psnr_mid_valv\n",
    "                best_model_state = copy.deepcopy(model.state_dict())\n",
    "\n",
    "        save_name = \"./ablation_model_loss/\" + ds + 'ablation_loss_' + str(choice) + '_' + str(seed) + '.pt'\n",
    "        torch.save(best_model_state, save_name)\n",
    "\n",
    "        o_list = []\n",
    "        i_list = []\n",
    "        z_list = []\n",
    "\n",
    "        best_model_state = torch.load(save_name)\n",
    "        model.load_state_dict(best_model_state)\n",
    "        model.eval()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for (inputs, inputs_neo) in test_loader:\n",
    "                inputs = inputs.to(device)\n",
    "                inputs_neo = inputs_neo.to(device)\n",
    "\n",
    "                z, outputs = model(inputs, inputs_neo)\n",
    "                \n",
    "                for i in inputs:\n",
    "                    i_list.append(i.cpu().numpy())\n",
    "                for o in outputs:\n",
    "                    o_list.append(o.cpu().numpy())\n",
    "                for z in z:\n",
    "                    z_list.append(z.cpu().numpy())\n",
    "\n",
    "        for i in range(len(i_list)):\n",
    "            psnr_val.append(ai.psnr(o_list[i], i_list[i]))\n",
    "            r2_val.append(ai.r2_score(o_list[i], i_list[i]))\n",
    "            sndr_val.append(ai.sndr(o_list[i], i_list[i]))\n",
    "            nrmse_val.append(ai.nrmse(o_list[i], i_list[i]))\n",
    "            \n",
    "            psnr_mid_val.append(ai.psnr(o_list[i][32:96], i_list[i][32:96]))\n",
    "            r2_mid_val.append(ai.r2_score(o_list[i][32:96], i_list[i][32:96]))\n",
    "            sndr_mid_val.append(ai.sndr(o_list[i][32:96], i_list[i][32:96]))\n",
    "            nrmse_mid_val.append(ai.nrmse(o_list[i][32:96], i_list[i][32:96]))\n",
    "        \n",
    "        psnr_val = np.array(psnr_val)\n",
    "        r2_val = np.array(r2_val)\n",
    "        sndr_val = np.array(sndr_val)\n",
    "        nrmse_val = np.array(nrmse_val)\n",
    "        psnr_mid_val = np.array(psnr_mid_val)\n",
    "        r2_mid_val = np.array(r2_mid_val)\n",
    "        sndr_mid_val = np.array(sndr_mid_val)\n",
    "        nrmse_mid_val = np.array(nrmse_mid_val)\n",
    "\n",
    "        psnr_valt.append(psnr_val.mean())\n",
    "        r2_valt.append(r2_val.mean())\n",
    "        sndr_valt.append(sndr_val.mean())\n",
    "        nrmse_valt.append(nrmse_val.mean())\n",
    "        psnr_mid_valt.append(psnr_mid_val.mean())\n",
    "        r2_mid_valt.append(r2_mid_val.mean())\n",
    "        sndr_mid_valt.append(sndr_mid_val.mean())\n",
    "        nrmse_mid_valt.append(nrmse_mid_val.mean())\n",
    "\n",
    "        # print(psnr_val / len(i_list))\n",
    "        # print(r2_val / len(i_list))\n",
    "        # print(sndr_val/ len(i_list))\n",
    "        # print(nrmse_val/ len(i_list))\n",
    "        # print(psnr_mid_val/ len(i_list))\n",
    "        # print(r2_mid_val/ len(i_list))\n",
    "        # print(sndr_mid_val/ len(i_list))\n",
    "        # print(nrmse_mid_val/ len(i_list))\n",
    "\n",
    "        all_train_loss.append(train_loss_list)\n",
    "        all_valid_loss.append(valid_loss_list)\n",
    "\n",
    "    all_train_loss = np.array(all_train_loss)\n",
    "    all_valid_loss = np.array(all_valid_loss)\n",
    "\n",
    "    np.save(\"./ablation_loss/\" + str(choice) + ds + \"_train_loss.npy\", all_train_loss)\n",
    "    np.save(\"./ablation_loss/\" + str(choice) + ds + \"_valid_loss.npy\", all_valid_loss)\n",
    "\n",
    "    psnr_valt = np.array(psnr_valt)\n",
    "    r2_valt = np.array(r2_valt)\n",
    "    sndr_valt = np.array(sndr_valt)\n",
    "    nrmse_valt = np.array(nrmse_valt)\n",
    "    psnr_mid_valt = np.array(psnr_mid_valt)\n",
    "    r2_mid_valt = np.array(r2_mid_valt)\n",
    "    sndr_mid_valt = np.array(sndr_mid_valt)\n",
    "    nrmse_mid_valt = np.array(nrmse_mid_valt)\n",
    "\n",
    "    results[f'{choice}']['psnr_mean'] = psnr_valt.mean()\n",
    "    results[f'{choice}']['r2_mean'] = r2_valt.mean()\n",
    "    results[f'{choice}']['sndr_mean'] = sndr_valt.mean()\n",
    "    results[f'{choice}']['nrmse_mean'] = nrmse_valt.mean()\n",
    "    results[f'{choice}']['psnr_mid_mean'] = psnr_mid_valt.mean()\n",
    "    results[f'{choice}']['r2_mid_mean'] = r2_mid_valt.mean()\n",
    "    results[f'{choice}']['sndr_mid_mean'] = sndr_mid_valt.mean()\n",
    "    results[f'{choice}']['nrmse_mid_mean'] = nrmse_mid_valt.mean()\n",
    "    \n",
    "    results[f'{choice}']['psnr_std'] = psnr_valt.std()\n",
    "    results[f'{choice}']['r2_std'] = r2_valt.std()\n",
    "    results[f'{choice}']['sndr_std'] = sndr_valt.std()\n",
    "    results[f'{choice}']['nrmse_std'] = nrmse_valt.std()\n",
    "    results[f'{choice}']['psnr_mid_std'] = psnr_mid_valt.std()\n",
    "    results[f'{choice}']['r2_mid_std'] = r2_mid_valt.std()\n",
    "    results[f'{choice}']['sndr_mid_std'] = sndr_mid_valt.std()\n",
    "    results[f'{choice}']['nrmse_mid_std'] = nrmse_mid_valt.std()\n",
    "\n",
    "    results[f'{choice}']['psnr_best'] = psnr_valt.max()\n",
    "    results[f'{choice}']['r2_best'] = r2_valt.max()\n",
    "    results[f'{choice}']['sndr_best'] = sndr_valt.max()\n",
    "    results[f'{choice}']['nrmse_best'] = nrmse_valt.min()\n",
    "    results[f'{choice}']['psnr_mid_best'] = psnr_mid_valt.max()\n",
    "    results[f'{choice}']['r2_mid_best'] = r2_mid_valt.max()\n",
    "    results[f'{choice}']['sndr_mid_best'] = sndr_mid_valt.max()\n",
    "    results[f'{choice}']['nrmse_mid_best'] = nrmse_mid_valt.min()\n",
    "\n",
    "    print(\"PSNR : \", psnr_valt.mean(), '/', psnr_mid_valt.mean())\n",
    "    print(\"R2 : \", r2_valt.mean(), '/', r2_mid_valt.mean())\n",
    "    print(\"NRMSE : \", nrmse_valt.mean(), '/', nrmse_mid_valt.mean())\n",
    "    print(\"SNDR : \", sndr_valt.mean(), '/', sndr_mid_valt.mean())\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5e0798f9-0f5f-4f4d-a713-a1e6de54a52d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'1': {'psnr_mean': np.float32(18.496363), 'r2_mean': np.float32(0.34553713), 'sndr_mean': np.float32(3.318254), 'nrmse_mean': np.float32(0.12610164), 'psnr_mid_mean': np.float32(18.37494), 'r2_mid_mean': np.float32(0.4577774), 'sndr_mid_mean': np.float32(4.8517823), 'nrmse_mid_mean': np.float32(0.13013941), 'psnr_std': np.float32(0.36408308), 'r2_std': np.float32(0.08188677), 'sndr_std': np.float32(0.34952167), 'nrmse_std': np.float32(0.005974775), 'psnr_mid_std': np.float32(0.5410891), 'r2_mid_std': np.float32(0.14979976), 'sndr_mid_std': np.float32(0.5148348), 'nrmse_mid_std': np.float32(0.011006453), 'psnr_best': np.float32(19.165808), 'r2_best': np.float32(0.47850603), 'sndr_best': np.float32(3.957757), 'nrmse_best': np.float32(0.11511946), 'psnr_mid_best': np.float32(19.238148), 'r2_mid_best': np.float32(0.617986), 'sndr_mid_best': np.float32(5.641719), 'nrmse_mid_best': np.float32(0.1142839)}, '2': {'psnr_mean': np.float32(17.429623), 'r2_mean': np.float32(0.09551958), 'sndr_mean': np.float32(2.2515154), 'nrmse_mean': np.float32(0.14394683), 'psnr_mid_mean': np.float32(17.626965), 'r2_mid_mean': np.float32(0.22525072), 'sndr_mid_mean': np.float32(4.103808), 'nrmse_mid_mean': np.float32(0.14257644), 'psnr_std': np.float32(0.823859), 'r2_std': np.float32(0.21129094), 'sndr_std': np.float32(0.8089826), 'nrmse_std': np.float32(0.014990251), 'psnr_mid_std': np.float32(0.22617407), 'r2_mid_std': np.float32(0.2607651), 'sndr_mid_std': np.float32(0.23761679), 'nrmse_mid_std': np.float32(0.008115027), 'psnr_best': np.float32(18.15559), 'r2_best': np.float32(0.3135323), 'sndr_best': np.float32(2.9334764), 'nrmse_best': np.float32(0.13051432), 'psnr_mid_best': np.float32(17.909317), 'r2_mid_best': np.float32(0.46503723), 'sndr_mid_best': np.float32(4.365382), 'nrmse_mid_best': np.float32(0.13606279)}, '3': {'psnr_mean': np.float32(17.897932), 'r2_mean': np.float32(-5.621179), 'sndr_mean': np.float32(2.719824), 'nrmse_mean': np.float32(0.18399289), 'psnr_mid_mean': np.float32(19.3832), 'r2_mid_mean': np.float32(0.6488017), 'sndr_mid_mean': np.float32(5.8600416), 'nrmse_mid_mean': np.float32(0.1117212), 'psnr_std': np.float32(0.51088643), 'r2_std': np.float32(5.5136094), 'sndr_std': np.float32(0.5065787), 'nrmse_std': np.float32(0.04678374), 'psnr_mid_std': np.float32(0.11089778), 'r2_mid_std': np.float32(0.009165793), 'sndr_mid_std': np.float32(0.08589815), 'nrmse_mid_std': np.float32(0.0015585191), 'psnr_best': np.float32(18.576618), 'r2_best': np.float32(0.40238628), 'sndr_best': np.float32(3.4251714), 'nrmse_best': np.float32(0.12284344), 'psnr_mid_best': np.float32(19.51599), 'r2_mid_best': np.float32(0.6616379), 'sndr_mid_best': np.float32(5.9720545), 'nrmse_mid_best': np.float32(0.109995686)}, '4': {'psnr_mean': np.float32(17.959738), 'r2_mean': np.float32(-0.09175882), 'sndr_mean': np.float32(2.7816293), 'nrmse_mean': np.float32(0.14183117), 'psnr_mid_mean': np.float32(19.477411), 'r2_mid_mean': np.float32(0.65669674), 'sndr_mid_mean': np.float32(5.954253), 'nrmse_mid_mean': np.float32(0.11058762), 'psnr_std': np.float32(1.3714604), 'r2_std': np.float32(0.4860828), 'sndr_std': np.float32(1.3550396), 'nrmse_std': np.float32(0.02833418), 'psnr_mid_std': np.float32(0.13720842), 'r2_mid_std': np.float32(0.010349659), 'sndr_mid_std': np.float32(0.1217582), 'nrmse_mid_std': np.float32(0.0018031411), 'psnr_best': np.float32(18.884129), 'r2_best': np.float32(0.48672605), 'sndr_best': np.float32(3.6760771), 'nrmse_best': np.float32(0.11661845), 'psnr_mid_best': np.float32(19.62638), 'r2_mid_best': np.float32(0.6693505), 'sndr_mid_best': np.float32(6.0824475), 'nrmse_mid_best': np.float32(0.10850684)}, '5': {'psnr_mean': np.float32(17.439709), 'r2_mean': np.float32(-0.15131316), 'sndr_mean': np.float32(2.2615998), 'nrmse_mean': np.float32(0.15083238), 'psnr_mid_mean': np.float32(19.433949), 'r2_mid_mean': np.float32(0.651242), 'sndr_mid_mean': np.float32(5.9107904), 'nrmse_mid_mean': np.float32(0.11119658), 'psnr_std': np.float32(1.6330217), 'r2_std': np.float32(0.38947022), 'sndr_std': np.float32(1.6294477), 'nrmse_std': np.float32(0.02764235), 'psnr_mid_std': np.float32(0.0857986), 'r2_mid_std': np.float32(0.0072216196), 'sndr_mid_std': np.float32(0.09450473), 'nrmse_mid_std': np.float32(0.0011379601), 'psnr_best': np.float32(18.678555), 'r2_best': np.float32(0.30740955), 'sndr_best': np.float32(3.4564426), 'nrmse_best': np.float32(0.125074), 'psnr_mid_best': np.float32(19.530098), 'r2_mid_best': np.float32(0.6557467), 'sndr_mid_best': np.float32(5.994422), 'nrmse_mid_best': np.float32(0.110066004)}}\n"
     ]
    }
   ],
   "source": [
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "da359930-8a88-4eeb-9bff-a3d2739bf071",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./ablation_result/abl_loss\" + ds + \".txt\", \"w\") as f:\n",
    "    f.write(str(results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5192e967-b1c2-45c5-9b81-b7d221353880",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
