{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The line_profiler extension is already loaded. To reload it, use:\n",
      "  %reload_ext line_profiler\n",
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n",
      "{'name': 'pubmed', 'K': 4, 'dict_size': 500, 'weighted': True, 'epoch': 3, 'batch_size': 512, 'eval_step': 20, 'eval_batch_size': -1, 'a_method': 'max_a_nonzero', 'lam': 0.003, 'n_a_nonzero': 10, 'shuffle': True, 'O_Q_ST_accurate': True, 'O_loop_cnt': 1, 'O_init_method': 'random_select', 'O_resample_method': 'greedy', 'O_resample_step': 5, 'O_resample_warmup': 5, 'O_thresh': 0.05, 'feature_normalize': 'row_sum', 'O_sparsify_conf': {'max_nonzero': 20, 'signal_ratio': 0.5}}\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Dictionary learning for GCN\n",
    "\n",
    "Created by: Yaochen Hu\n",
    "Created on: Aug. 17, 2022\n",
    "\n",
    "\"\"\"\n",
    "# pylint: disable=anomalous-backslash-in-string\n",
    "# pylint: disable=invalid-name\n",
    "# pylint: disable=import-error\n",
    "# pylint: disable=missing-function-docstring\n",
    "import os\n",
    "import sys\n",
    "sys.path.extend([\"../\"]) # pylint: disable=wrong-import-position\n",
    "import random\n",
    "from time import time\n",
    "import warnings\n",
    "import pickle\n",
    "import datetime\n",
    "import socket\n",
    "\n",
    "import shutil\n",
    "import yaml\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from sklearn import linear_model\n",
    "\n",
    "from data_utils import *\n",
    "from graph_dict import *\n",
    "from utils import *\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ[\"CURL_CA_BUNDLE\"] = \"\"\n",
    "DEVICE = 'cuda:2'\n",
    "\n",
    "%load_ext line_profiler\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# %lprun -f\n",
    "\n",
    "# Parameter settings and data loading\n",
    "DATA_ROOT_FOLDER = \"../dataset\"\n",
    "CONF_ROOT_FOLDER = \"../config\"\n",
    "RES_ROOT_FOLDER = \"../result\"\n",
    "\n",
    "\n",
    "#############################################\n",
    "# corasgc\n",
    "# CONF_NAME = \"cora/dict/cora_gen_dict_base.yml\"\n",
    "# # pubmed\n",
    "CONF_NAME = \"pubmed/dict/pubmed_gen_dict_base.yml\"\n",
    "# # reddit\n",
    "# CONF_NAME = \"reddit/dicttest/reddit_gen_dict_size_500.yml\"\n",
    "\n",
    "# # OGBN-products\n",
    "# CONF_NAME = \"\"\n",
    "\n",
    "#######################################\n",
    "# Initialization and hyper-parameter setting\n",
    "_TIME_ZONE = 0\n",
    "TIMESTAMP = time()\n",
    "TIMESTAMP_FORMATTED = datetime.datetime.fromtimestamp(\n",
    "    int(TIMESTAMP)+_TIME_ZONE*3600).strftime('%Y%m%d-%H%M%S')\n",
    "HOST_NAME = socket.gethostname()\n",
    "\n",
    "conf_path = os.path.join(CONF_ROOT_FOLDER, CONF_NAME)\n",
    "train_conf = load_train_conf(conf_path)\n",
    "\n",
    "DATA_FOLDER = os.path.join(DATA_ROOT_FOLDER, train_conf[\"name\"])\n",
    "if not os.path.exists(DATA_FOLDER):\n",
    "    os.makedirs(DATA_FOLDER)\n",
    "\n",
    "# sys.stdout = Logger(os.path.join(RES_FOLDER, \"log.txt\"))\n",
    "\n",
    "print(train_conf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "Loading features from ../dataset/pubmed/k4_row_sum.pkl\n"
     ]
    }
   ],
   "source": [
    "# Data loading\n",
    "dataset = load_data(train_conf[\"name\"])\n",
    "\n",
    "# Preprocessing the features\n",
    "feature_gen = SGCFeatureGen(dataset, k=train_conf[\"K\"], device=\"cpu\", cache_path=DATA_FOLDER, use_cache=True, return_torch=False,\n",
    "                            need_k_adj=True, compute_device=DEVICE, feature_normalize=train_conf['feature_normalize'])\n",
    "A = feature_gen.k_adj_cpu\n",
    "features = feature_gen.normed_features\n",
    "AF = feature_gen.prop_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initializing the model...\n",
      "Weighted version.\n",
      "starting to fit...\n",
      "Initializing...\n",
      "---- Evaluating...\n",
      "1/1 finished in 2.9s, ETA 0.0\n",
      "Training started...\n",
      "Weighted regret: 115.163239, | Rel weighted regret: 1.000000 | Regret: 576.237122 | Rel Regret: 1.000000\n",
      "  a row nonzeros: 0.0000 ± 0.0000 | O row nonzeros: mean 19717.0000 ± 0.0000\n",
      "---------------------------\n",
      "\n",
      "*** Time a 1.8 | Time O 0.5 | Time O update 0.3\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 2.0 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.7 | Time O 0.2 | Time O update 0.1\n",
      "*** Time a 1.2 | Time O 0.2 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.2 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.2 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.2 | Time O 0.2 | Time O update 0.1\n",
      "*** Time a 1.2 | Time O 0.2 | Time O update 0.1\n",
      "*** Time a 1.2 | Time O 0.2 | Time O update 0.1\n",
      "Resampling O...\n",
      "---- Evaluating...\n",
      "1/1 finished in 2.3s, ETA 0.0\n",
      "Epoch: 0 | Iter: 19 | Elapsed time: 43.5 | ETA: 211.0\n",
      "Weighted regret: 70.099342, | Rel weighted regret: 0.637629 | Regret: 659.053528 | Rel Regret: 2.110193\n",
      "  a row nonzeros: 5.0965 ± 4.9414 | O row nonzeros: mean 11783.7880 ± 8988.9395\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.4 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.9 | Time O 0.5 | Time O update 0.2\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.4 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.0 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.5 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "---- Evaluating...\n",
      "1/1 finished in 2.2s, ETA 0.0\n",
      "Epoch: 1 | Iter: 0 | Elapsed time: 90.5 | ETA: 174.2\n",
      "Weighted regret: 30.433746, | Rel weighted regret: 0.365800 | Regret: 774.398315 | Rel Regret: 3.969830\n",
      "  a row nonzeros: 9.8212 ± 0.9668 | O row nonzeros: mean 14500.2080 ± 8185.5887\n",
      "*** Time a 1.9 | Time O 0.4 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.9 | Time O 0.4 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.4 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "---- Evaluating...\n",
      "1/1 finished in 2.3s, ETA 0.0\n",
      "Epoch: 1 | Iter: 20 | Elapsed time: 138.8 | ETA: 131.9\n",
      "Weighted regret: 29.644436, | Rel weighted regret: 0.362059 | Regret: 773.559448 | Rel Regret: 3.969467\n",
      "  a row nonzeros: 9.8448 ± 0.9003 | O row nonzeros: mean 14722.4260 ± 8093.4518\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.4 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.4 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.9 | Time O 0.4 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.4 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 0.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "---- Evaluating...\n",
      "1/1 finished in 2.3s, ETA 0.0\n",
      "Epoch: 2 | Iter: 1 | Elapsed time: 186.5 | ETA: 86.2\n",
      "Weighted regret: 29.215723, | Rel weighted regret: 0.356221 | Regret: 785.690735 | Rel Regret: 3.982300\n",
      "  a row nonzeros: 9.8652 ± 0.8556 | O row nonzeros: mean 15237.0120 ± 7847.3514\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.7 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.4 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.4 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.2 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.4 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "---- Evaluating...\n",
      "1/1 finished in 2.3s, ETA 0.0\n",
      "Epoch: 2 | Iter: 21 | Elapsed time: 234.8 | ETA: 39.9\n",
      "Weighted regret: 27.980646, | Rel weighted regret: 0.352006 | Regret: 759.054443 | Rel Regret: 3.973679\n",
      "  a row nonzeros: 9.8698 ± 0.8520 | O row nonzeros: mean 15059.3500 ± 7980.3628\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 1.8 | Time O 0.3 | Time O update 0.1\n",
      "Resampling O...\n",
      "*** Time a 1.9 | Time O 0.3 | Time O update 0.1\n",
      "*** Time a 0.9 | Time O 0.3 | Time O update 0.1\n",
      "---------------------------\n",
      "\n",
      "Started post processing...\n",
      "Recompute a...\n",
      "0.974 finished in 70.9 s. ETA: 0.0 ss\n",
      "---- Evaluating...\n",
      "1/1 finished in 2.2s, ETA 0.0\n",
      "Training finished in 352.0 s.\n",
      "Weighted regret: 17.037174, | Rel weighted regret: 0.149532 | Regret: 484.826416 | Rel Regret: 0.850721\n",
      "  a row nonzeros: 9.9079 ± 0.4267 | O row nonzeros: mean 10528.2920 ± 9465.8023\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<graph_dict.SparseDictionaryLearning at 0x7f3670733c90>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"Initializing the model...\")\n",
    "test = SparseDictionaryLearning(A,\n",
    "                                train_conf[\"dict_size\"],\n",
    "                                AF=AF,\n",
    "                                F=features,\n",
    "                                weighted=train_conf[\"weighted\"],\n",
    "                                epoch=train_conf[\"epoch\"],\n",
    "                                batch_size=train_conf[\"batch_size\"],\n",
    "                                eval_step=train_conf[\"eval_step\"],\n",
    "                                eval_batch_size=train_conf[\"eval_batch_size\"],\n",
    "                                a_method=train_conf[\"a_method\"],\n",
    "                                lam=train_conf[\"lam\"],\n",
    "                                n_a_nonzero=train_conf[\"n_a_nonzero\"],\n",
    "                                shuffle=train_conf[\"shuffle\"],\n",
    "                                num_worker=1,\n",
    "                                O_Q_ST_accurate=train_conf[\"O_Q_ST_accurate\"],\n",
    "                                O_loop_cnt=train_conf[\"O_loop_cnt\"],\n",
    "                                O_init_method=train_conf[\"O_init_method\"],\n",
    "                                O_resample_method=train_conf[\"O_resample_method\"],\n",
    "                                O_resample_warmup=train_conf[\"O_resample_warmup\"],\n",
    "                                O_resample_step=train_conf[\"O_resample_step\"],\n",
    "                                O_is_sparsify=False,\n",
    "                                device=DEVICE,\n",
    "                                verbose=True)\n",
    "print(\"starting to fit...\")\n",
    "test.fit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results are saved in ../result/test.\n"
     ]
    }
   ],
   "source": [
    "test.save(\"../result/test\")\n",
    "export_train_conf(os.path.join(\"../result/test/conf.yml\"), train_conf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test the performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = SparseDictionaryLearning(A,\n",
    "                                train_conf[\"dict_size\"],\n",
    "                                AF=AF,\n",
    "                                F=features,\n",
    "                                weighted=train_conf[\"weighted\"],\n",
    "                                epoch=train_conf[\"epoch\"],\n",
    "                                batch_size=train_conf[\"batch_size\"],\n",
    "                                eval_step=train_conf[\"eval_step\"],\n",
    "                                eval_batch_size=train_conf[\"eval_batch_size\"],\n",
    "                                a_method=train_conf[\"a_method\"],\n",
    "                                lam=train_conf[\"lam\"],\n",
    "                                n_a_nonzero=train_conf[\"n_a_nonzero\"],\n",
    "                                shuffle=train_conf[\"shuffle\"],\n",
    "                                num_worker=1,\n",
    "                                O_Q_ST_accurate=train_conf[\"O_Q_ST_accurate\"],\n",
    "                                O_loop_cnt=train_conf[\"O_loop_cnt\"],\n",
    "                                O_init_method=train_conf[\"O_init_method\"],\n",
    "                                O_resample_method=train_conf[\"O_resample_method\"],\n",
    "                                O_resample_warmup=train_conf[\"O_resample_warmup\"],\n",
    "                                O_thresh=train_conf[\"O_thresh\"],\n",
    "                                O_is_sparsify=False,\n",
    "                                device=DEVICE,\n",
    "                                verbose=True)\n",
    "test.load(\"../result/test\")\n",
    "test.post_update(O_is_sparsify=True, O_sparsify_conf=train_conf[\"O_sparsify_conf\"])\n",
    "test.eval_and_log()\n",
    "test._display_stat()\n",
    "test.save(\"../result/test_1\")\n",
    "export_train_conf(os.path.join(\"../result/test_1/conf.yml\"), train_conf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test.log[\"signal_list\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# visualize the results for one test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## initialize and load data\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "plt.plot(sorted(test.per_sample_regret))\n",
    "plt.title('per_sample_regret')\n",
    "plt.show()\n",
    "\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "plt.plot(sorted(np.sum(np.abs(test.a), axis=0)))\n",
    "plt.title('dict signal abs sum')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "# for i in range(8):\n",
    "#     plt.plot(test.a[i,:])\n",
    "# plt.plot(np.sum(test.a!=0, axis=1)[:100])\n",
    "plt.hist(np.sum(test.a!=0, axis=1))\n",
    "plt.title('a nonzeros hist')\n",
    "plt.show()\n",
    "\n",
    "test1 = np.linalg.norm(test.a, axis=0)\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "plt.plot(test1)\n",
    "plt.title('dict wise signal norm')\n",
    "plt.show()\n",
    "\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "plt.hist(test1)\n",
    "plt.title('dict wise signal norm hist')\n",
    "plt.show()\n",
    "\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "plt.plot(test.per_sample_regret)\n",
    "plt.title('per sample regret')\n",
    "plt.show()\n",
    "\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "plt.hist(test.per_sample_regret)\n",
    "plt.title('per sample regret hist')\n",
    "plt.show()\n",
    "\n",
    "print(f\"max regret pos {np.argmax(test.per_sample_regret)}\")\n",
    "print(f\"min regret pos {np.argmin(test.per_sample_regret)}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "plt.plot(np.sum(np.abs(test.a), axis=0))\n",
    "plt.title('dict signal abs sum')\n",
    "plt.show()\n",
    "\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "plt.hist(np.sum(np.abs(test.a), axis=0))\n",
    "plt.title('dict signal abs sum hist')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "O = test.O.flatten()\n",
    "plt.hist(O[np.abs(O)<0.05])\n",
    "plt.title('dict value less than 0.05 hist')\n",
    "plt.show()\n",
    "\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "O = test.O.flatten()\n",
    "plt.hist(O)\n",
    "plt.title('dict all hist')\n",
    "plt.show()\n",
    "\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "plt.hist(np.sum(np.abs(test.O)>0.05, axis=1))\n",
    "plt.title('dict larger than 0.05 hist')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 772\n",
    "\n",
    "print(np.sum(test.a[i,:]!=0))\n",
    "\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "plt.plot(test.a[i,:])\n",
    "plt.title('a')\n",
    "plt.show()\n",
    "\n",
    "y_h = 0.6\n",
    "T = A.toarray()\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_ylim([0, y_h])\n",
    "plt.plot(T[i,:])\n",
    "plt.title('Target signal')\n",
    "plt.show()\n",
    "\n",
    "\n",
    "S = test.a.dot(test.O)\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_ylim([0, y_h])\n",
    "plt.plot(S[i,:])\n",
    "plt.title('Approximated signal')\n",
    "plt.show()\n",
    "\n",
    "\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_ylim([0, y_h])\n",
    "plt.plot(S[i,:]-T[i,:])\n",
    "plt.title('Diff')\n",
    "plt.show()\n",
    "\n",
    "y_h1 = 0.07\n",
    "WT = T@features\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_ylim([0, y_h1])\n",
    "plt.plot(WT[i,:])\n",
    "plt.title('weighted target signal')\n",
    "plt.show()\n",
    "\n",
    "WS = S@features\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_ylim([0, y_h1])\n",
    "plt.plot(WS[i,:])\n",
    "plt.title('weighted approximated signal')\n",
    "plt.show()\n",
    "\n",
    "WS = S@features\n",
    "fig = plt.figure(figsize=(10, 5))\n",
    "ax = fig.add_subplot(111)\n",
    "ax.set_ylim([0, y_h1])\n",
    "plt.plot(WS[i,:]-WT[i,:])\n",
    "plt.title('Diff')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualization for the data distribution\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Other anaylis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_size_of_mem(num_nodes, feature_size, dict_rate):\n",
    "    data_size = 4\n",
    "    dict_size = int(num_nodes * dict_rate)\n",
    "    size = {}\n",
    "    size[\"A\"] = num_nodes ** 2\n",
    "    size[\"F\"] = num_nodes * feature_size\n",
    "    size[\"O\"] = dict_size * num_nodes\n",
    "    size[\"a\"] = num_nodes * dict_size\n",
    "    size[\"a_old\"] = size[\"a\"]\n",
    "    size[\"Q\"] = dict_size ** 2\n",
    "    size[\"ST\"] = dict_size * num_nodes\n",
    "    \n",
    "    return sum(list(size.values()))*data_size\n",
    "\n",
    "res = get_size_of_mem(230000, 2000, 0.01)\n",
    "print(res / 1024 ** 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "toc-autonumbering": false,
  "toc-showmarkdowntxt": true
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
