{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "35413445",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.8.0+cu128\n",
      "cpu\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import scipy.io\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import copy\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.nn import functional as F\n",
    "\n",
    "print(torch.__version__)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else : \n",
    "    device = torch.device('cpu')\n",
    "print(device)\n",
    "\n",
    "import os\n",
    "import random\n",
    "import dataset.aug_index as ai "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "98cc97a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightning.pytorch import Trainer, seed_everything # used for seed setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "868c91e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CSModel(nn.Module):\n",
    "    def __init__(self, input_dim=128, latent_dim=4, sensing_matrix=None):\n",
    "        super().__init__()\n",
    "        self.input_dim = input_dim\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "        if sensing_matrix is None:\n",
    "            A = torch.randn(latent_dim, input_dim)\n",
    "            Q, _ = torch.linalg.qr(A.T)\n",
    "            sensing_matrix = Q[:, :latent_dim].T  # [latent_dim, input_dim]\n",
    "\n",
    "        self.register_buffer('Phi', sensing_matrix.float())\n",
    "\n",
    "        # pseudo inverse\n",
    "        Phi_pinv = torch.linalg.pinv(self.Phi)\n",
    "        self.register_buffer('Phi_pinv', Phi_pinv.float())\n",
    "\n",
    "    def encode(self, x):\n",
    "        return torch.matmul(x, self.Phi.T)  # [B, latent_dim]\n",
    "\n",
    "    def decode(self, latent):\n",
    "        return torch.matmul(latent, self.Phi_pinv.T)  # [B, input_dim]\n",
    "\n",
    "    def forward(self, x):\n",
    "        z = self.encode(x)\n",
    "        x = self.decode(z)\n",
    "        return z, x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d5ae9a60",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20170622\n",
      "20170623\n",
      "20170629\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 2\n",
      "Seed set to 3\n",
      "Seed set to 4\n",
      "Seed set to 5\n",
      "Seed set to 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "d533101\n",
      "d561106\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 2\n",
      "Seed set to 3\n",
      "Seed set to 4\n",
      "Seed set to 5\n"
     ]
    }
   ],
   "source": [
    "# training part\n",
    "\n",
    "cr = 32\n",
    "seed_list = [1, 2, 3, 4, 5]\n",
    "file_list = [[\"20170622\", \"20170623\", \"20170629\"], [\"d533101\", \"d561106\"]]\n",
    "\n",
    "results = {}\n",
    "\n",
    "for cho in range(0, 2):\n",
    "    \n",
    "    choice = cho\n",
    "\n",
    "    results[f'{choice}'] = {}\n",
    "\n",
    "    psnr_valt = []\n",
    "    r2_valt = []\n",
    "    sndr_valt = []\n",
    "    nrmse_valt = []\n",
    "    psnr_mid_valt = []\n",
    "    r2_mid_valt = []\n",
    "    sndr_mid_valt = []\n",
    "    nrmse_mid_valt = []\n",
    "\n",
    "    files = file_list[choice]\n",
    "    X_list = []\n",
    "    y_list = []\n",
    "    cl_list = []\n",
    "    for i in range(len(files)):\n",
    "        print(files[i])\n",
    "\n",
    "        tX = np.load('./dataset/' + files[i] + '_wave.npy')\n",
    "        ty = np.load('./dataset/' + files[i] + '_neo.npy')\n",
    "\n",
    "        tcl = np.full(len(tX), i) \n",
    "        X_list.append(tX)  # Append the entire array\n",
    "        y_list.append(ty)\n",
    "        cl_list.append(tcl)\n",
    "\n",
    "    X = np.concatenate(X_list, axis=0)\n",
    "    y = np.concatenate(y_list, axis=0)\n",
    "    CL = np.concatenate(cl_list, axis=0)\n",
    "\n",
    "    for se in range (len(seed_list)):\n",
    "\n",
    "        save_name = \"./compare_pt/CS_SEED_\" + str(se) + \"_CHOICE_\" + str(choice) + \".pt\"\n",
    "        psnr_val = []\n",
    "        r2_val = []\n",
    "        sndr_val = []\n",
    "        nrmse_val = []\n",
    "        psnr_mid_val = []\n",
    "        r2_mid_val = []\n",
    "        sndr_mid_val = []\n",
    "        nrmse_mid_val = []\n",
    "        \n",
    "        seed = seed_list[se]\n",
    "        seed_everything(seed, workers=True)\n",
    "\n",
    "        # Datasets\n",
    "        bs = 64\n",
    "\n",
    "        X_train, X_test, y_train, y_test, _, cl_test = train_test_split(X, y, CL, test_size=0.2)\n",
    "\n",
    "        np.save(str(choice) + str(se) + \"cs_cl.npy\", np.array(cl_test))\n",
    "\n",
    "        X_train = np.array(X_train) \n",
    "        X_test  = torch.tensor(X_test , dtype=torch.float32)\n",
    "        y_test  = torch.tensor(y_test , dtype=torch.float32)\n",
    "\n",
    "        test_dataset = TensorDataset(X_test, y_test)\n",
    "        test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)\n",
    "        best_r2 = -np.inf\n",
    "\n",
    "        model = CSModel(input_dim=128, latent_dim=4)\n",
    "        model.to(device)\n",
    "        model.eval()\n",
    "\n",
    "        save_name = 'CS_' + str(choice) + '_' + str(seed) + '.pt'\n",
    "        torch.save(copy.deepcopy(model.state_dict()), save_name)\n",
    "\n",
    "        all_recons = []\n",
    "        all_targets = []\n",
    "\n",
    "        o_list = []\n",
    "        i_list = []\n",
    "        z_list = []\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for (inputs, _) in test_loader:\n",
    "                inputs = inputs.to(device)\n",
    "                z, outputs = model(inputs)\n",
    "                for i in inputs:\n",
    "                    i_list.append(i.cpu().numpy())\n",
    "                for o in outputs:\n",
    "                    o_list.append(o.cpu().numpy())\n",
    "                for z in z:\n",
    "                    z_list.append(z.cpu().numpy())\n",
    "\n",
    "        np.save(str(choice) + str(se) + \"cs_target.npy\", np.array(i_list))\n",
    "        np.save(str(choice) + str(se) + \"cs_recon.npy\", np.array(o_list))\n",
    "        np.save(str(choice) + str(se) + \"cs_latent.npy\", np.array(z_list))\n",
    "\n",
    "        for i in range(len(i_list)):\n",
    "            psnr_val.append(ai.psnr(o_list[i], i_list[i]))\n",
    "            r2_val.append(ai.r2_score(o_list[i], i_list[i]))\n",
    "            sndr_val.append(ai.sndr(o_list[i], i_list[i]))\n",
    "            nrmse_val.append(ai.nrmse(o_list[i], i_list[i]))\n",
    "            \n",
    "            psnr_mid_val.append(ai.psnr(o_list[i][32:96], i_list[i][32:96]))\n",
    "            r2_mid_val.append(ai.r2_score(o_list[i][32:96], i_list[i][32:96]))\n",
    "            sndr_mid_val.append(ai.sndr(o_list[i][32:96], i_list[i][32:96]))\n",
    "            nrmse_mid_val.append(ai.nrmse(o_list[i][32:96], i_list[i][32:96]))\n",
    "        \n",
    "        psnr_val = np.array(psnr_val)\n",
    "        r2_val = np.array(r2_val)\n",
    "        sndr_val = np.array(sndr_val)\n",
    "        nrmse_val = np.array(nrmse_val)\n",
    "        psnr_mid_val = np.array(psnr_mid_val)\n",
    "        r2_mid_val = np.array(r2_mid_val)\n",
    "        sndr_mid_val = np.array(sndr_mid_val)\n",
    "        nrmse_mid_val = np.array(nrmse_mid_val)\n",
    "\n",
    "        psnr_valt.append(psnr_val.mean())\n",
    "        r2_valt.append(r2_val.mean())\n",
    "        sndr_valt.append(sndr_val.mean())\n",
    "        nrmse_valt.append(nrmse_val.mean())\n",
    "        psnr_mid_valt.append(psnr_mid_val.mean())\n",
    "        r2_mid_valt.append(r2_mid_val.mean())\n",
    "        sndr_mid_valt.append(sndr_mid_val.mean())\n",
    "        nrmse_mid_valt.append(nrmse_mid_val.mean())\n",
    "\n",
    "    psnr_valt = np.array(psnr_valt)\n",
    "    r2_valt = np.array(r2_valt)\n",
    "    sndr_valt = np.array(sndr_valt)\n",
    "    nrmse_valt = np.array(nrmse_valt)\n",
    "    psnr_mid_valt = np.array(psnr_mid_valt)\n",
    "    r2_mid_valt = np.array(r2_mid_valt)\n",
    "    sndr_mid_valt = np.array(sndr_mid_valt)\n",
    "    nrmse_mid_valt = np.array(nrmse_mid_valt)\n",
    "\n",
    "    results[f'{choice}']['psnr_mean'] = psnr_valt.mean()\n",
    "    results[f'{choice}']['r2_mean'] = r2_valt.mean()\n",
    "    results[f'{choice}']['sndr_mean'] = sndr_valt.mean()\n",
    "    results[f'{choice}']['nrmse_mean'] = nrmse_valt.mean()\n",
    "    results[f'{choice}']['psnr_mid_mean'] = psnr_mid_valt.mean()\n",
    "    results[f'{choice}']['r2_mid_mean'] = r2_mid_valt.mean()\n",
    "    results[f'{choice}']['sndr_mid_mean'] = sndr_mid_valt.mean()\n",
    "    results[f'{choice}']['nrmse_mid_mean'] = nrmse_mid_valt.mean()\n",
    "    \n",
    "    results[f'{choice}']['psnr_std'] = psnr_valt.std()\n",
    "    results[f'{choice}']['r2_std'] = r2_valt.std()\n",
    "    results[f'{choice}']['sndr_std'] = sndr_valt.std()\n",
    "    results[f'{choice}']['nrmse_std'] = nrmse_valt.std()\n",
    "    results[f'{choice}']['psnr_mid_std'] = psnr_mid_valt.std()\n",
    "    results[f'{choice}']['r2_mid_std'] = r2_mid_valt.std()\n",
    "    results[f'{choice}']['sndr_mid_std'] = sndr_mid_valt.std()\n",
    "    results[f'{choice}']['nrmse_mid_std'] = nrmse_mid_valt.std()\n",
    "\n",
    "    results[f'{choice}']['psnr_best'] = psnr_valt.max()\n",
    "    results[f'{choice}']['r2_best'] = r2_valt.max()\n",
    "    results[f'{choice}']['sndr_best'] = sndr_valt.max()\n",
    "    results[f'{choice}']['nrmse_best'] = nrmse_valt.min()\n",
    "    results[f'{choice}']['psnr_mid_best'] = psnr_mid_valt.max()\n",
    "    results[f'{choice}']['r2_mid_best'] = r2_mid_valt.max()\n",
    "    results[f'{choice}']['sndr_mid_best'] = sndr_mid_valt.max()\n",
    "    results[f'{choice}']['nrmse_mid_best'] = nrmse_mid_valt.min()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "08a978ff-239e-4882-9538-6cc8ba8a74c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'0': {'psnr_mean': np.float32(15.466907), 'r2_mean': np.float32(0.038573503), 'sndr_mean': np.float32(0.17770877), 'nrmse_mean': np.float32(0.16992609), 'psnr_mid_mean': np.float32(12.730919), 'r2_mid_mean': np.float32(0.04120227), 'sndr_mid_mean': np.float32(0.26292998), 'nrmse_mid_mean': np.float32(0.23280175), 'psnr_std': np.float32(0.12835835), 'r2_std': np.float32(0.030826258), 'sndr_std': np.float32(0.1417366), 'nrmse_std': np.float32(0.0025056212), 'psnr_mid_std': np.float32(0.20032464), 'r2_mid_std': np.float32(0.04613106), 'sndr_mid_std': np.float32(0.21328454), 'nrmse_mid_std': np.float32(0.0053542615), 'psnr_best': np.float32(15.702856), 'r2_best': np.float32(0.09556119), 'sndr_best': np.float32(0.4409116), 'nrmse_best': np.float32(0.16532543), 'psnr_mid_best': np.float32(13.094128), 'r2_mid_best': np.float32(0.124914944), 'sndr_mid_best': np.float32(0.65322846), 'nrmse_mid_best': np.float32(0.2231158)}, '1': {'psnr_mean': np.float32(16.222858), 'r2_mean': np.float32(0.036845878), 'sndr_mean': np.float32(0.1666396), 'nrmse_mean': np.float32(0.15601882), 'psnr_mid_mean': np.float32(14.17909), 'r2_mid_mean': np.float32(0.052666962), 'sndr_mid_mean': np.float32(0.24539664), 'nrmse_mid_mean': np.float32(0.19665048), 'psnr_std': np.float32(0.071315974), 'r2_std': np.float32(0.016394123), 'sndr_std': np.float32(0.07491478), 'nrmse_std': np.float32(0.0012337604), 'psnr_mid_std': np.float32(0.13406417), 'r2_mid_std': np.float32(0.029475003), 'sndr_mid_std': np.float32(0.13816571), 'nrmse_mid_std': np.float32(0.0029326093), 'psnr_best': np.float32(16.319887), 'r2_best': np.float32(0.057537727), 'sndr_best': np.float32(0.2612024), 'nrmse_best': np.float32(0.15436344), 'psnr_mid_best': np.float32(14.344879), 'r2_mid_best': np.float32(0.09173339), 'sndr_mid_best': np.float32(0.42940378), 'nrmse_mid_best': np.float32(0.19305602)}}\n"
     ]
    }
   ],
   "source": [
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7d19bd3c-f4d0-4346-a508-e167cbb73f7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./compare_result/cs.txt\", \"w\") as f:\n",
    "    f.write(str(results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fab99e99-951d-44d1-aa87-0487eb5111c3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
