{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "35413445",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.8.0+cu128\n",
      "cpu\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import scipy.io\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import copy\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.nn import functional as F\n",
    "\n",
    "print(torch.__version__)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else : \n",
    "    device = torch.device('cpu')\n",
    "print(device)\n",
    "\n",
    "import os\n",
    "import random\n",
    "import dataset.aug_index as ai "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "98cc97a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightning.pytorch import Trainer, seed_everything # used for seed setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "868c91e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# x: [B, 128]\n",
    "# latent: [B, 4]\n",
    "class PCAModel(nn.Module):\n",
    "    def __init__(self, input_dim=128, latent_dim=4, pca_components=None, pca_mean=None):\n",
    "        super().__init__()\n",
    "        self.input_dim = input_dim\n",
    "        self.latent_dim = latent_dim\n",
    "        \n",
    "        self.register_buffer('components', torch.tensor(pca_components, dtype=torch.float32))\n",
    "        self.register_buffer('mean', torch.tensor(pca_mean, dtype=torch.float32))\n",
    "\n",
    "    def encode(self, x):\n",
    "        x_centered = x - self.mean\n",
    "        latent = torch.matmul(x_centered, self.components.T)\n",
    "        return latent\n",
    "\n",
    "    def decode(self, latent):\n",
    "        recon = torch.matmul(latent, self.components) + self.mean\n",
    "        return recon\n",
    "\n",
    "    def forward(self, x):\n",
    "        z = self.encode(x)\n",
    "        x = self.decode(z)\n",
    "        return z, x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d5ae9a60",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20170622\n",
      "20170623\n",
      "20170629\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 2\n",
      "Seed set to 3\n",
      "Seed set to 4\n",
      "Seed set to 5\n",
      "Seed set to 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "d533101\n",
      "d561106\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 2\n",
      "Seed set to 3\n",
      "Seed set to 4\n",
      "Seed set to 5\n"
     ]
    }
   ],
   "source": [
    "# training part\n",
    "\n",
    "cr = 32\n",
    "seed_list = [1, 2, 3, 4, 5]\n",
    "file_list = [[\"20170622\", \"20170623\", \"20170629\"], [\"d533101\", \"d561106\"]]\n",
    "\n",
    "results = {}\n",
    "\n",
    "for cho in range(0, 2):\n",
    "    \n",
    "    choice = cho\n",
    "\n",
    "    results[f'{choice}'] = {}\n",
    "\n",
    "    psnr_valt = []\n",
    "    r2_valt = []\n",
    "    sndr_valt = []\n",
    "    nrmse_valt = []\n",
    "    psnr_mid_valt = []\n",
    "    r2_mid_valt = []\n",
    "    sndr_mid_valt = []\n",
    "    nrmse_mid_valt = []\n",
    "\n",
    "    files = file_list[choice]\n",
    "    X_list = []\n",
    "    y_list = []\n",
    "    cl_list = []\n",
    "    for i in range(len(files)):\n",
    "        print(files[i])\n",
    "\n",
    "        tX = np.load('./dataset/' + files[i] + '_wave.npy')\n",
    "        ty = np.load('./dataset/' + files[i] + '_neo.npy')\n",
    "\n",
    "        tcl = np.full(len(tX), i) \n",
    "        X_list.append(tX)  # Append the entire array\n",
    "        y_list.append(ty)\n",
    "        cl_list.append(tcl)\n",
    "\n",
    "    X = np.concatenate(X_list, axis=0)\n",
    "    y = np.concatenate(y_list, axis=0)\n",
    "    CL = np.concatenate(cl_list, axis=0)\n",
    "\n",
    "    for se in range (len(seed_list)):\n",
    "\n",
    "        save_name = \"./compare_pt/PCA_SEED_\" + str(se) + \"_CHOICE_\" + str(choice) + \".pt\"\n",
    "        psnr_val = []\n",
    "        r2_val = []\n",
    "        sndr_val = []\n",
    "        nrmse_val = []\n",
    "        psnr_mid_val = []\n",
    "        r2_mid_val = []\n",
    "        sndr_mid_val = []\n",
    "        nrmse_mid_val = []\n",
    "        \n",
    "        seed = seed_list[se]\n",
    "        seed_everything(seed, workers=True)\n",
    "\n",
    "        # Datasets\n",
    "        bs = 64\n",
    "\n",
    "        X_train, X_test, y_train, y_test, _, cl_test = train_test_split(X, y, CL, test_size=0.2)\n",
    "\n",
    "        np.save(str(choice) + str(se) + \"pca_cl.npy\", np.array(cl_test))\n",
    "\n",
    "        X_train = np.array(X_train) \n",
    "        X_test  = torch.tensor(X_test , dtype=torch.float32)\n",
    "        y_test  = torch.tensor(y_test , dtype=torch.float32)\n",
    "\n",
    "        test_dataset = TensorDataset(X_test, y_test)\n",
    "        test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)\n",
    "        best_r2 = -np.inf\n",
    "\n",
    "        pca = PCA(n_components=4)\n",
    "        pca.fit(X_train)\n",
    "\n",
    "        model = PCAModel(input_dim=128, latent_dim=4, pca_components=pca.components_, pca_mean=pca.mean_)\n",
    "        model.to(device)\n",
    "        model.eval()\n",
    "\n",
    "        save_name = 'PCA_' + str(choice) + '_' + str(seed) + '.pt'\n",
    "        torch.save(copy.deepcopy(model.state_dict()), save_name)\n",
    "\n",
    "        all_recons = []\n",
    "        all_targets = []\n",
    "\n",
    "        o_list = []\n",
    "        i_list = []\n",
    "        z_list = []\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for (inputs, _) in test_loader:\n",
    "                inputs = inputs.to(device)\n",
    "                z, outputs = model(inputs)\n",
    "                for i in inputs:\n",
    "                    i_list.append(i.cpu().numpy())\n",
    "                for o in outputs:\n",
    "                    o_list.append(o.cpu().numpy())\n",
    "                for z in z:\n",
    "                    z_list.append(z.cpu().numpy())\n",
    "        \n",
    "        np.save(str(choice) + str(se) + \"pca_target.npy\", np.array(i_list))\n",
    "        np.save(str(choice) + str(se) + \"pca_recon.npy\", np.array(o_list))\n",
    "        np.save(str(choice) + str(se) + \"pca_latent.npy\", np.array(z_list))\n",
    "\n",
    "\n",
    "        for i in range(len(i_list)):\n",
    "            psnr_val.append(ai.psnr(o_list[i], i_list[i]))\n",
    "            r2_val.append(ai.r2_score(o_list[i], i_list[i]))\n",
    "            sndr_val.append(ai.sndr(o_list[i], i_list[i]))\n",
    "            nrmse_val.append(ai.nrmse(o_list[i], i_list[i]))\n",
    "            \n",
    "            psnr_mid_val.append(ai.psnr(o_list[i][32:96], i_list[i][32:96]))\n",
    "            r2_mid_val.append(ai.r2_score(o_list[i][32:96], i_list[i][32:96]))\n",
    "            sndr_mid_val.append(ai.sndr(o_list[i][32:96], i_list[i][32:96]))\n",
    "            nrmse_mid_val.append(ai.nrmse(o_list[i][32:96], i_list[i][32:96]))\n",
    "        \n",
    "        psnr_val = np.array(psnr_val)\n",
    "        r2_val = np.array(r2_val)\n",
    "        sndr_val = np.array(sndr_val)\n",
    "        nrmse_val = np.array(nrmse_val)\n",
    "        psnr_mid_val = np.array(psnr_mid_val)\n",
    "        r2_mid_val = np.array(r2_mid_val)\n",
    "        sndr_mid_val = np.array(sndr_mid_val)\n",
    "        nrmse_mid_val = np.array(nrmse_mid_val)\n",
    "\n",
    "        psnr_valt.append(psnr_val.mean())\n",
    "        r2_valt.append(r2_val.mean())\n",
    "        sndr_valt.append(sndr_val.mean())\n",
    "        nrmse_valt.append(nrmse_val.mean())\n",
    "        psnr_mid_valt.append(psnr_mid_val.mean())\n",
    "        r2_mid_valt.append(r2_mid_val.mean())\n",
    "        sndr_mid_valt.append(sndr_mid_val.mean())\n",
    "        nrmse_mid_valt.append(nrmse_mid_val.mean())\n",
    "\n",
    "    psnr_valt = np.array(psnr_valt)\n",
    "    r2_valt = np.array(r2_valt)\n",
    "    sndr_valt = np.array(sndr_valt)\n",
    "    nrmse_valt = np.array(nrmse_valt)\n",
    "    psnr_mid_valt = np.array(psnr_mid_valt)\n",
    "    r2_mid_valt = np.array(r2_mid_valt)\n",
    "    sndr_mid_valt = np.array(sndr_mid_valt)\n",
    "    nrmse_mid_valt = np.array(nrmse_mid_valt)\n",
    "\n",
    "    results[f'{choice}']['psnr_mean'] = psnr_valt.mean()\n",
    "    results[f'{choice}']['r2_mean'] = r2_valt.mean()\n",
    "    results[f'{choice}']['sndr_mean'] = sndr_valt.mean()\n",
    "    results[f'{choice}']['nrmse_mean'] = nrmse_valt.mean()\n",
    "    results[f'{choice}']['psnr_mid_mean'] = psnr_mid_valt.mean()\n",
    "    results[f'{choice}']['r2_mid_mean'] = r2_mid_valt.mean()\n",
    "    results[f'{choice}']['sndr_mid_mean'] = sndr_mid_valt.mean()\n",
    "    results[f'{choice}']['nrmse_mid_mean'] = nrmse_mid_valt.mean()\n",
    "    \n",
    "    results[f'{choice}']['psnr_std'] = psnr_valt.std()\n",
    "    results[f'{choice}']['r2_std'] = r2_valt.std()\n",
    "    results[f'{choice}']['sndr_std'] = sndr_valt.std()\n",
    "    results[f'{choice}']['nrmse_std'] = nrmse_valt.std()\n",
    "    results[f'{choice}']['psnr_mid_std'] = psnr_mid_valt.std()\n",
    "    results[f'{choice}']['r2_mid_std'] = r2_mid_valt.std()\n",
    "    results[f'{choice}']['sndr_mid_std'] = sndr_mid_valt.std()\n",
    "    results[f'{choice}']['nrmse_mid_std'] = nrmse_mid_valt.std()\n",
    "\n",
    "    results[f'{choice}']['psnr_best'] = psnr_valt.max()\n",
    "    results[f'{choice}']['r2_best'] = r2_valt.max()\n",
    "    results[f'{choice}']['sndr_best'] = sndr_valt.max()\n",
    "    results[f'{choice}']['nrmse_best'] = nrmse_valt.min()\n",
    "    results[f'{choice}']['psnr_mid_best'] = psnr_mid_valt.max()\n",
    "    results[f'{choice}']['r2_mid_best'] = r2_mid_valt.max()\n",
    "    results[f'{choice}']['sndr_mid_best'] = sndr_mid_valt.max()\n",
    "    results[f'{choice}']['nrmse_mid_best'] = nrmse_mid_valt.min()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "08a978ff-239e-4882-9538-6cc8ba8a74c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'0': {'psnr_mean': np.float32(29.802902), 'r2_mean': np.float32(0.94847584), 'sndr_mean': np.float32(14.513702), 'nrmse_mean': np.float32(0.035302617), 'psnr_mid_mean': np.float32(30.217524), 'r2_mid_mean': np.float32(0.97298366), 'sndr_mid_mean': np.float32(17.749537), 'nrmse_mid_mean': np.float32(0.034004282), 'psnr_std': np.float32(0.06568025), 'r2_std': np.float32(0.0012779932), 'sndr_std': np.float32(0.06492793), 'nrmse_std': np.float32(0.00032022537), 'psnr_mid_std': np.float32(0.05622261), 'r2_mid_std': np.float32(0.0015729076), 'sndr_mid_std': np.float32(0.05727865), 'nrmse_mid_std': np.float32(0.00037430468), 'psnr_best': np.float32(29.866858), 'r2_best': np.float32(0.9496814), 'sndr_best': np.float32(14.573268), 'nrmse_best': np.float32(0.034977097), 'psnr_mid_best': np.float32(30.296484), 'r2_mid_best': np.float32(0.97441673), 'sndr_mid_best': np.float32(17.826246), 'nrmse_mid_best': np.float32(0.033484384)}, '1': {'psnr_mean': np.float32(21.032955), 'r2_mean': np.float32(0.646902), 'sndr_mean': np.float32(4.976735), 'nrmse_mean': np.float32(0.09333463), 'psnr_mid_mean': np.float32(21.386082), 'r2_mid_mean': np.float32(0.78477824), 'sndr_mid_mean': np.float32(7.4523897), 'nrmse_mid_mean': np.float32(0.09021516), 'psnr_std': np.float32(0.014687313), 'r2_std': np.float32(0.0012713986), 'sndr_std': np.float32(0.010854884), 'nrmse_std': np.float32(0.00017573946), 'psnr_mid_std': np.float32(0.059915427), 'r2_mid_std': np.float32(0.002989055), 'sndr_mid_std': np.float32(0.04169822), 'nrmse_mid_std': np.float32(0.00072629505), 'psnr_best': np.float32(21.061518), 'r2_best': np.float32(0.6486607), 'sndr_best': np.float32(4.987965), 'nrmse_best': np.float32(0.093049094), 'psnr_mid_best': np.float32(21.493176), 'r2_mid_best': np.float32(0.7897901), 'sndr_mid_best': np.float32(7.525787), 'nrmse_mid_best': np.float32(0.08892611)}}\n"
     ]
    }
   ],
   "source": [
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7d19bd3c-f4d0-4346-a508-e167cbb73f7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./compare_result/pca.txt\", \"w\") as f:\n",
    "    f.write(str(results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fa214559",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.float32\n",
      "torch.float32\n",
      "torch.Size([4, 128])\n",
      "torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "print(model.components.dtype)\n",
    "print(model.mean.dtype)\n",
    "\n",
    "print(model.components.shape)\n",
    "print(model.mean.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fab99e99-951d-44d1-aa87-0487eb5111c3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
