{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6133ed20",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import copy\n",
    "import numpy as np\n",
    "\n",
    "import torch.optim as optim\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else : \n",
    "    device = torch.device('cpu')\n",
    "print(device)\n",
    "\n",
    "from torch.nn import functional as F\n",
    "\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import os\n",
    "import random\n",
    "import dataset.aug_index as ai # data augmentation and some index for reconstruction\n",
    "import NSC_module.modules_decoder as dec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6618ef63",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightning.pytorch import Trainer, seed_everything # used for seed setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "148405c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def quantize_tensor(x, num_frac_bits, min_val, max_val):\n",
    "    scale = 2.0 ** (-num_frac_bits)\n",
    "    x_clamped = x.clamp(min_val, max_val)\n",
    "\n",
    "    return torch.round(x_clamped / scale) * scale\n",
    "\n",
    "class QuantLinear(nn.Module):\n",
    "    def __init__(self, in_features, out_features, bias=True):\n",
    "        super().__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))\n",
    "        if bias:\n",
    "            self.bias = nn.Parameter(torch.Tensor(out_features))\n",
    "        else:\n",
    "            self.register_parameter('bias', None)\n",
    "        \n",
    "        # Q1.4 4 bits for frac, 1 bit for sign\n",
    "        self.min_val = -1.0\n",
    "        self.max_val = +(1.0 - 2.0 ** -4)\n",
    "        self.reset_parameters()\n",
    "\n",
    "    # init\n",
    "    def reset_parameters(self):\n",
    "        nn.init.uniform_(self.weight, self.min_val, self.max_val)\n",
    "        if self.bias is not None:\n",
    "            nn.init.uniform_(self.bias, self.min_val, self.max_val)\n",
    "            \n",
    "    # actually, the paper use the bias, so here we write two version\n",
    "    def forward(self, x):\n",
    "        w = quantize_tensor(self.weight, num_frac_bits=4, min_val=self.min_val, max_val=self.max_val)\n",
    "        if self.bias is not None:\n",
    "            b = quantize_tensor(self.bias, num_frac_bits=4, min_val=self.min_val, max_val=self.max_val)\n",
    "        else:\n",
    "            b = None\n",
    "        return F.linear(x, w, b)\n",
    "\n",
    "# the paper use Q2.8, while the original hardtanh is -1 to 1, here we change it to -2 to 2 - 2^-8\n",
    "class QuantHardtanh(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.min_val = -2.0\n",
    "        self.max_val = +(2.0 - 2.0 ** -8)\n",
    "        self.num_frac_bits = 8\n",
    "\n",
    "    def forward(self, x):\n",
    "        return quantize_tensor(x, self.num_frac_bits, self.min_val, self.max_val)\n",
    "\n",
    "class AE_Q1p4(nn.Module):\n",
    "    def __init__(self, input_dim=128, latent_dim=4):\n",
    "        super().__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            QuantLinear(input_dim, latent_dim),\n",
    "            QuantHardtanh()\n",
    "        )\n",
    "        self.decoder = dec.Base_decoder(latent_dim=latent_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        z = self.encoder(x)\n",
    "        x = self.decoder(z)\n",
    "        return z, x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8476ab4d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20170622\n",
      "20170623\n",
      "20170629\n",
      "seed  1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "Seed set to 2\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 3\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 4\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 5\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 1\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "d533101\n",
      "d561106\n",
      "seed  1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 2\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 3\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 4\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 5\n",
      "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed  5\n"
     ]
    }
   ],
   "source": [
    "# training part\n",
    "\n",
    "cr = 32\n",
    "seed_list = [1, 2, 3, 4, 5]\n",
    "file_list = [[\"20170622\", \"20170623\", \"20170629\"], [\"d533101\", \"d561106\"], [\"NP_T1\"], [\"NP_T2\"], [\"NP_T3\"], [\"NP_T4\"], [\"NP_T5\"]]\n",
    "\n",
    "results = {}\n",
    "\n",
    "train_epochs = 100\n",
    "\n",
    "for cho in range(0, 2):\n",
    "    choice = cho\n",
    "\n",
    "    results[f'{choice}'] = {}\n",
    "\n",
    "    psnr_valt = []\n",
    "    r2_valt = []\n",
    "    sndr_valt = []\n",
    "    nrmse_valt = []\n",
    "    psnr_mid_valt = []\n",
    "    r2_mid_valt = []\n",
    "    sndr_mid_valt = []\n",
    "    nrmse_mid_valt = []\n",
    "\n",
    "    files = file_list[choice]\n",
    "    X_list = []\n",
    "    y_list = []\n",
    "    cl_list = []\n",
    "    for i in range(len(files)):\n",
    "        print(files[i])\n",
    "\n",
    "        tX = np.load('./dataset/' + files[i] + '_wave.npy')\n",
    "        ty = np.load('./dataset/' + files[i] + '_neo.npy')\n",
    "\n",
    "        tcl = np.full(len(tX), i) \n",
    "        X_list.append(tX)  # Append the entire array\n",
    "        y_list.append(ty)\n",
    "        cl_list.append(tcl)\n",
    "\n",
    "    X = np.concatenate(X_list, axis=0)\n",
    "    y = np.concatenate(y_list, axis=0)\n",
    "    CL = np.concatenate(cl_list, axis=0)\n",
    "\n",
    "    for se in range (len(seed_list)):\n",
    "\n",
    "        save_name = \"./compare_pt/AE_Q1p4_SEED_\" + str(se) + \"_CHOICE_\" + str(choice) + \".pt\"\n",
    "        psnr_val = []\n",
    "        r2_val = []\n",
    "        sndr_val = []\n",
    "        nrmse_val = []\n",
    "        psnr_mid_val = []\n",
    "        r2_mid_val = []\n",
    "        sndr_mid_val = []\n",
    "        nrmse_mid_val = []\n",
    "        \n",
    "        seed = seed_list[se]\n",
    "        seed_everything(seed, workers=True)\n",
    "\n",
    "        # Datasets\n",
    "        bs = int(min(256, max(16, 2 ** np.round(np.log2(len(X) / 100)))))\n",
    "\n",
    "        X_tmp, X_test, y_tmp, y_test, _, cl_test = train_test_split(X, y, CL, test_size=0.2)\n",
    "        X_train, X_valid, y_train, y_valid = train_test_split(X_tmp, y_tmp, test_size=0.125)\n",
    "\n",
    "        np.save(str(choice) + str(se) + \"aeq1p4_cl.npy\", np.array(cl_test))\n",
    "\n",
    "        X_train = np.array(X_train)\n",
    "\n",
    "        X_train = torch.tensor(X_train, dtype=torch.float32)\n",
    "        X_valid = torch.tensor(X_valid, dtype=torch.float32)\n",
    "        X_test  = torch.tensor(X_test , dtype=torch.float32)\n",
    "\n",
    "        y_train = torch.tensor(y_train, dtype=torch.float32)\n",
    "        y_valid = torch.tensor(y_valid, dtype=torch.float32)\n",
    "        y_test  = torch.tensor(y_test , dtype=torch.float32)\n",
    "\n",
    "        train_dataset = TensorDataset(X_train, y_train)\n",
    "        valid_dataset = TensorDataset(X_valid, y_valid)\n",
    "        test_dataset  = TensorDataset(X_test, y_test)\n",
    "\n",
    "        train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)\n",
    "        valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False)\n",
    "        test_loader  = DataLoader(test_dataset, batch_size=bs, shuffle=False)\n",
    "\n",
    "        print(\"seed \", seed)\n",
    "\n",
    "        best_psnr = -np.inf\n",
    "\n",
    "        model = AE_Q1p4(latent_dim=int(128/cr)).to(device)\n",
    "\n",
    "        trainer = Trainer(deterministic=True)\n",
    "        \n",
    "        optimizer = optim.AdamW([\n",
    "            {'params': model.encoder.parameters(), 'lr': 1e-3},  \n",
    "            {'params': model.decoder.parameters(), 'lr': 1e-3, 'weight_decay': 1e-4},\n",
    "        ])\n",
    "        \n",
    "        for epoch in range(train_epochs):\n",
    "            model.train()\n",
    "            running_loss = 0.0\n",
    "            for (inputs, inputs_neo) in train_loader:\n",
    "                inputs = inputs.to(device)\n",
    "\n",
    "                optimizer.zero_grad()\n",
    "                _, outputs = model(inputs)\n",
    "\n",
    "                loss = F.mse_loss(outputs, inputs)\n",
    "                \n",
    "                loss.backward()\n",
    "                clip_norm = 10\n",
    "                total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm)\n",
    "                optimizer.step()\n",
    "                running_loss += loss.item()\n",
    "\n",
    "            #print(running_loss/len(train_loader))\n",
    "\n",
    "            model.eval()\n",
    "            all_recons = []\n",
    "            all_targets = []\n",
    "\n",
    "            with torch.no_grad():\n",
    "                for (inputs, inputs_neo) in valid_loader:\n",
    "                    inputs = inputs.to(device)\n",
    "\n",
    "                    _, outputs = model(inputs)\n",
    "                    for i in inputs:\n",
    "                        all_targets.append(i.cpu().numpy())\n",
    "                    for o in outputs:\n",
    "                        all_recons.append(o.cpu().numpy())\n",
    "\n",
    "            tmp_psnr_mid = 0\n",
    "            for i in range (len(all_recons)):\n",
    "                tmp_psnr_mid += ai.psnr(all_recons[i][32:96], all_targets[i][32:96])\n",
    "            \n",
    "            psnr_mid_valv = tmp_psnr_mid / len(all_recons)\n",
    "            if psnr_mid_valv > best_psnr:\n",
    "                best_psnr = psnr_mid_valv\n",
    "                best_model_state = copy.deepcopy(model.state_dict())\n",
    "\n",
    "        torch.save(best_model_state, save_name)\n",
    "\n",
    "        o_list = []\n",
    "        i_list = []\n",
    "        z_list = []\n",
    "\n",
    "        best_model_state = torch.load(save_name)\n",
    "        model.load_state_dict(best_model_state)\n",
    "        \n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            for (inputs, inputs_neo) in test_loader:\n",
    "                inputs = inputs.to(device)\n",
    "                z, outputs = model(inputs)\n",
    "                for i in inputs:\n",
    "                    i_list.append(i.cpu().numpy())\n",
    "                for o in outputs:\n",
    "                    o_list.append(o.cpu().numpy())\n",
    "                for z in z:\n",
    "                    z_list.append(z.cpu().numpy())\n",
    "        \n",
    "        np.save(str(choice) + str(se) + \"aeq1p4_target.npy\", np.array(i_list))\n",
    "        np.save(str(choice) + str(se) + \"aeq1p4_recon.npy\", np.array(o_list))\n",
    "        np.save(str(choice) + str(se) + \"aeq1p4_latent.npy\", np.array(z_list))\n",
    "\n",
    "        for i in range(len(i_list)):\n",
    "            psnr_val.append(ai.psnr(o_list[i], i_list[i]))\n",
    "            r2_val.append(ai.r2_score(o_list[i], i_list[i]))\n",
    "            sndr_val.append(ai.sndr(o_list[i], i_list[i]))\n",
    "            nrmse_val.append(ai.nrmse(o_list[i], i_list[i]))\n",
    "            \n",
    "            psnr_mid_val.append(ai.psnr(o_list[i][32:96], i_list[i][32:96]))\n",
    "            r2_mid_val.append(ai.r2_score(o_list[i][32:96], i_list[i][32:96]))\n",
    "            sndr_mid_val.append(ai.sndr(o_list[i][32:96], i_list[i][32:96]))\n",
    "            nrmse_mid_val.append(ai.nrmse(o_list[i][32:96], i_list[i][32:96]))\n",
    "        \n",
    "        psnr_val = np.array(psnr_val)\n",
    "        r2_val = np.array(r2_val)\n",
    "        sndr_val = np.array(sndr_val)\n",
    "        nrmse_val = np.array(nrmse_val)\n",
    "        psnr_mid_val = np.array(psnr_mid_val)\n",
    "        r2_mid_val = np.array(r2_mid_val)\n",
    "        sndr_mid_val = np.array(sndr_mid_val)\n",
    "        nrmse_mid_val = np.array(nrmse_mid_val)\n",
    "\n",
    "        psnr_valt.append(psnr_val.mean())\n",
    "        r2_valt.append(r2_val.mean())\n",
    "        sndr_valt.append(sndr_val.mean())\n",
    "        nrmse_valt.append(nrmse_val.mean())\n",
    "        psnr_mid_valt.append(psnr_mid_val.mean())\n",
    "        r2_mid_valt.append(r2_mid_val.mean())\n",
    "        sndr_mid_valt.append(sndr_mid_val.mean())\n",
    "        nrmse_mid_valt.append(nrmse_mid_val.mean())\n",
    "\n",
    "    psnr_valt = np.array(psnr_valt)\n",
    "    r2_valt = np.array(r2_valt)\n",
    "    sndr_valt = np.array(sndr_valt)\n",
    "    nrmse_valt = np.array(nrmse_valt)\n",
    "    psnr_mid_valt = np.array(psnr_mid_valt)\n",
    "    r2_mid_valt = np.array(r2_mid_valt)\n",
    "    sndr_mid_valt = np.array(sndr_mid_valt)\n",
    "    nrmse_mid_valt = np.array(nrmse_mid_valt)\n",
    "\n",
    "    results[f'{choice}']['psnr_mean'] = psnr_valt.mean()\n",
    "    results[f'{choice}']['r2_mean'] = r2_valt.mean()\n",
    "    results[f'{choice}']['sndr_mean'] = sndr_valt.mean()\n",
    "    results[f'{choice}']['nrmse_mean'] = nrmse_valt.mean()\n",
    "    results[f'{choice}']['psnr_mid_mean'] = psnr_mid_valt.mean()\n",
    "    results[f'{choice}']['r2_mid_mean'] = r2_mid_valt.mean()\n",
    "    results[f'{choice}']['sndr_mid_mean'] = sndr_mid_valt.mean()\n",
    "    results[f'{choice}']['nrmse_mid_mean'] = nrmse_mid_valt.mean()\n",
    "    \n",
    "    results[f'{choice}']['psnr_std'] = psnr_valt.std()\n",
    "    results[f'{choice}']['r2_std'] = r2_valt.std()\n",
    "    results[f'{choice}']['sndr_std'] = sndr_valt.std()\n",
    "    results[f'{choice}']['nrmse_std'] = nrmse_valt.std()\n",
    "    results[f'{choice}']['psnr_mid_std'] = psnr_mid_valt.std()\n",
    "    results[f'{choice}']['r2_mid_std'] = r2_mid_valt.std()\n",
    "    results[f'{choice}']['sndr_mid_std'] = sndr_mid_valt.std()\n",
    "    results[f'{choice}']['nrmse_mid_std'] = nrmse_mid_valt.std()\n",
    "    \n",
    "    results[f'{choice}']['psnr_best'] = psnr_valt.max()\n",
    "    results[f'{choice}']['r2_best'] = r2_valt.max()\n",
    "    results[f'{choice}']['sndr_best'] = sndr_valt.max()\n",
    "    results[f'{choice}']['nrmse_best'] = nrmse_valt.min()\n",
    "    results[f'{choice}']['psnr_mid_best'] = psnr_mid_valt.max()\n",
    "    results[f'{choice}']['r2_mid_best'] = r2_mid_valt.max()\n",
    "    results[f'{choice}']['sndr_mid_best'] = sndr_mid_valt.max()\n",
    "    results[f'{choice}']['nrmse_mid_best'] = nrmse_mid_valt.min()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d4049196",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'0': {'psnr_mean': np.float32(18.958801), 'r2_mean': np.float32(0.21699426), 'sndr_mean': np.float32(3.6696038), 'nrmse_mean': np.float32(0.13366234), 'psnr_mid_mean': np.float32(17.951262), 'r2_mid_mean': np.float32(0.3286658), 'sndr_mid_mean': np.float32(5.4832735), 'nrmse_mid_mean': np.float32(0.15981326), 'psnr_std': np.float32(1.995826), 'r2_std': np.float32(0.21772249), 'sndr_std': np.float32(1.9788599), 'nrmse_std': np.float32(0.023303902), 'psnr_mid_std': np.float32(3.3542635), 'r2_mid_std': np.float32(0.2972446), 'sndr_mid_std': np.float32(3.3374157), 'nrmse_mid_std': np.float32(0.047455013), 'psnr_best': np.float32(21.122513), 'r2_best': np.float32(0.50031143), 'sndr_best': np.float32(5.8056855), 'nrmse_best': np.float32(0.109426074), 'psnr_mid_best': np.float32(22.627247), 'r2_mid_best': np.float32(0.6998036), 'sndr_mid_best': np.float32(10.131347), 'nrmse_mid_best': np.float32(0.097868234)}, '1': {'psnr_mean': np.float32(16.956146), 'r2_mean': np.float32(-0.030826813), 'sndr_mean': np.float32(0.8999265), 'nrmse_mean': np.float32(0.15397593), 'psnr_mid_mean': np.float32(17.609337), 'r2_mid_mean': np.float32(0.39100814), 'sndr_mid_mean': np.float32(3.6756454), 'nrmse_mid_mean': np.float32(0.14526191), 'psnr_std': np.float32(0.60880476), 'r2_std': np.float32(0.20203784), 'sndr_std': np.float32(0.61917937), 'nrmse_std': np.float32(0.011995264), 'psnr_mid_std': np.float32(1.0024331), 'r2_mid_std': np.float32(0.13220108), 'sndr_mid_std': np.float32(1.0092998), 'nrmse_mid_std': np.float32(0.016806984), 'psnr_best': np.float32(17.521385), 'r2_best': np.float32(0.1546733), 'sndr_best': np.float32(1.4857041), 'nrmse_best': np.float32(0.14193079), 'psnr_mid_best': np.float32(18.852877), 'r2_mid_best': np.float32(0.5531462), 'sndr_mid_best': np.float32(4.937402), 'nrmse_mid_best': np.float32(0.12431763)}}\n"
     ]
    }
   ],
   "source": [
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "42da5147-a028-419b-a3cb-714cf9f2e9e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./compare_result/aeq1p4.txt\", \"w\") as f:\n",
    "    f.write(str(results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fcdd01f-1db4-42bc-adaf-d6183c6d41f7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a015c70c-a9f8-4e31-9a52-521d74c36f31",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
