{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "source": "# Eval code",
   "id": "132b5c48bc506bea",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ],
   "id": "7d863c223cc1fbe1",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import tqdm\n",
    "\n",
    "from utils.CMAES_utils import load_model\n",
    "from utils.monitor_utils import *\n",
    "from model.CMAES import CMAES\n",
    "from utils.test_utils import generate_testset, visualize_tsne, visualize_umap\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "from scipy import linalg\n",
    "from sklearn.metrics import accuracy_score, precision_score, recall_score\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "import time\n",
    "\n",
    "random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)"
   ],
   "id": "initial_id",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Those dataset are bigger than 5GB, due to the limit of storage we cant offer them :(\n",
    "train_labels = None\n",
    "test_labels = None\n",
    "train_dataset = None\n",
    "test_dataset = None\n",
    "\n",
    "train_dataset = torch.tensor(train_dataset, dtype=torch.float)\n",
    "train_dataset = train_dataset.permute(0, 2, 1)\n",
    "\n",
    "test_dataset = torch.tensor(test_dataset, dtype=torch.float)\n",
    "test_dataset = test_dataset.permute(0, 2, 1)\n",
    "\n",
    "\n",
    "print(test_dataset.shape)\n",
    "print(train_dataset.shape)\n",
    "print(np.unique(test_labels).shape)\n",
    "print(np.unique(test_labels))\n",
    "# test_labels = np.concatenate([test_labels, test_labels2 + 1])\n",
    "print(test_labels)"
   ],
   "id": "594a82269b329a0a",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# load model\n",
    "huidu = CMAES(T=0.2,\n",
    "              mask_ratio=0.0,\n",
    "              use_embedding=True,\n",
    "              n_heads=4,\n",
    "              m=0.99,\n",
    "              use_avg_pool=True,\n",
    "              K=4096,\n",
    "              embedding_dim=64,\n",
    "              ff_dim=128,\n",
    "              num_layers=4,\n",
    "              dropout=0.1,\n",
    "              moco_v3=True).to('cuda')\n",
    "\n",
    "huidu = load_model(huidu, './resources/checkpoint/HuiduRep.pt').to('cuda')"
   ],
   "id": "d8e5b024b36432f2",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# print total activity params\n",
    "def count_named_parameters(model, name_filter):\n",
    "    total = 0\n",
    "    for name, param in model.named_parameters():\n",
    "        print(name)\n",
    "        if 'conv_embedding' in name or 'encoder_q' in name or 'reduce_q' in name:\n",
    "            total += param.numel()\n",
    "    return total\n",
    "\n",
    "print(\"Encoder params:\", count_named_parameters(huidu, \"encoder\"))"
   ],
   "id": "cb57949965bb628b",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# select 10 units randomly\n",
    "test_data, test_units, labels = generate_testset(test_dataset, test_labels, num_units=10)"
   ],
   "id": "c18048607c5dd161",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "test_units = torch.tensor(test_units, dtype=torch.float)\n",
    "test_units = test_units.to('cpu')\n",
    "test_data = test_data.to('cuda')\n",
    "start = time.perf_counter()\n",
    "test_data_denoise = huidu.denoise(test_data).cpu().detach()\n",
    "# calculate ari\n",
    "res, gmm_test, test_spikes = gmm_monitor(huidu,\n",
    "                                         None,\n",
    "                                         test_units,\n",
    "                                         test_data_denoise,\n",
    "                                         test_units,\n",
    "                                         labels,\n",
    "                                         device='cuda',\n",
    "                                         epochs=20,\n",
    "                                         use_pca=False,\n",
    "                                         use_scaler=False,\n",
    "                                         covariance_type='full',\n",
    "                                         use_iso=False,\n",
    "                                         score=True,\n",
    "                                         test_data_origin=test_data)\n",
    "\n",
    "end = time.perf_counter()\n",
    "print(end - start)\n",
    "print(np.mean(res))\n",
    "print(np.std(res))\n",
    "print(np.max(res))\n",
    "print(np.min(res))"
   ],
   "id": "d4b94b0a25f7713",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# calculate silhouette score\n",
    "from sklearn.metrics import silhouette_score\n",
    "\n",
    "scaler = StandardScaler()\n",
    "scores = []\n",
    "for labels in gmm_test:\n",
    "    sil_score = silhouette_score(test_spikes, labels)\n",
    "    scores.append(sil_score)\n",
    "\n",
    "print(np.mean(scores))\n",
    "print(np.std(scores))"
   ],
   "id": "c826a95418aa45eb",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def match_labels(y_true, y_pred):\n",
    "    # 构建混淆矩阵\n",
    "    cm = confusion_matrix(y_true, y_pred)\n",
    "\n",
    "    # 匈牙利算法求最大匹配（负号是因为 linear_sum_assignment 是最小化）\n",
    "    row_ind, col_ind = linear_sum_assignment(-cm)\n",
    "\n",
    "    # 创建一个新的标签映射\n",
    "    label_mapping = {col: row for row, col in zip(row_ind, col_ind)}\n",
    "    # 重新映射 y_pred\n",
    "    y_pred_aligned = np.array([label_mapping[label] for label in y_pred])\n",
    "\n",
    "    return y_pred_aligned\n",
    "\n",
    "total_acc = []\n",
    "total_precision = []\n",
    "total_recall = []\n",
    "for data_point in gmm_test:\n",
    "\n",
    "    true_encoder = LabelEncoder()\n",
    "    pred_encoder = LabelEncoder()\n",
    "\n",
    "    y_true_encoded = true_encoder.fit_transform(test_labels)\n",
    "    y_pred_encoded = pred_encoder.fit_transform(data_point)\n",
    "    y_pred_aligned = match_labels(y_pred=y_pred_encoded, y_true=y_true_encoded)\n",
    "    # 分类评估\n",
    "    acc = accuracy_score(y_true_encoded, y_pred_aligned)\n",
    "    precision = precision_score(y_true_encoded, y_pred_aligned, average='macro')\n",
    "    recall = recall_score(y_true_encoded, y_pred_aligned, average='macro')\n",
    "\n",
    "    total_acc.append(acc)\n",
    "    total_precision.append(precision)\n",
    "    total_recall.append(recall)\n",
    "    print(acc)\n",
    "print(f\"Accuracy: {np.mean(total_acc):.4f}\")\n",
    "print(f\"Precision (macro): {np.mean(total_precision):.4f}\")\n",
    "print(f\"Recall (macro): {np.mean(total_recall):.4f}\")"
   ],
   "id": "66d84eee8b1c2253",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# run test for 100 times with different seed\n",
    "score = []\n",
    "times = []\n",
    "for i in tqdm(range(100)):\n",
    "    start = time.perf_counter()\n",
    "    random.seed(i)\n",
    "    test_data, test_units, labels = generate_testset(test_dataset, test_labels, num_units=10)\n",
    "    res, gmm_test, test_spikes = gmm_monitor(huidu,\n",
    "                                         None,\n",
    "                                         None,\n",
    "                                         test_data,\n",
    "                                         test_units,\n",
    "                                         labels,\n",
    "                                         verbose=False,\n",
    "                                         use_iso=False,\n",
    "                                         score=True,\n",
    "                                         max_iter=100,\n",
    "                                         covariance_type='full',\n",
    "                                         device='gpu', epochs=50,)\n",
    "    score.append(np.mean(res))\n",
    "    print(np.mean(score))\n",
    "    print(labels)\n",
    "    end = time.perf_counter()\n",
    "    times.append(end - start)\n",
    "\n",
    "print(np.mean(times))\n",
    "print(np.std(times))\n",
    "print(np.mean(score))\n",
    "print(np.std(score))\n",
    "print(np.max(score))\n",
    "print(np.min(score))"
   ],
   "id": "44ddf18da64aada2",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "print(np.mean(score))\n",
    "print(np.std(score, ddof=1) / np.sqrt(len(res)))\n",
    "print(np.mean(times))\n",
    "print(np.std(times, ddof=1) / np.sqrt(len(res)))\n",
    "print(np.max(score))\n",
    "print(np.min(score))"
   ],
   "id": "9d879f32cd889ebd",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
