{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cd229016",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from __future__ import annotations\n",
    "import os\n",
    "os.environ['CUDA_LAUNCH_BLOCKING'] = '0,1,2,3,4,5,6,7'\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'\n",
    "\n",
    "import torch\n",
    "import json\n",
    "import random\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from datetime import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4ffd21da",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/peizhengqi/anaconda3/lib/python3.9/site-packages/scipy/__init__.py:155: UserWarning: A NumPy version >=1.18.5 and <1.25.0 is required for this version of SciPy (detected version 1.26.4\n",
      "  warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n",
      "/home/peizhengqi/anaconda3/lib/python3.9/site-packages/pandas/core/computation/expressions.py:21: UserWarning: Pandas requires version '2.8.4' or newer of 'numexpr' (version '2.8.3' currently installed).\n",
      "  from pandas.core.computation.check import NUMEXPR_INSTALLED\n",
      "/home/peizhengqi/anaconda3/lib/python3.9/site-packages/pandas/core/arrays/masked.py:60: UserWarning: Pandas requires version '1.3.6' or newer of 'bottleneck' (version '1.3.5' currently installed).\n",
      "  from pandas.core import (\n"
     ]
    }
   ],
   "source": [
    "\n",
    "import llm_utils.eval_benchmarks as llm_eval\n",
    "import llm_utils.load_llm as load_llm\n",
    "import llm_utils.load_datasets as load_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3e2271cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "maze_dim, metric_dim = 48, 96\n",
    "\n",
    "choose_llm = 'qwen3-4b'\n",
    "#choose_llm = 'qwen3-8b'\n",
    "#choose_llm = 'mistral-7b'\n",
    "#choose_llm = 'llama3-8b'\n",
    "\n",
    "eval_benchmark, TEST_OFFSET, SKIP_STEP, num_train = 'mmluPro', 0, 5, 60\n",
    "#eval_benchmark, TEST_OFFSET, SKIP_STEP, num_train = 'gpqa-main', 0, 1, 100\n",
    "#eval_benchmark, TEST_OFFSET, SKIP_STEP, num_train = 'gsm8k', 0, 1, 100\n",
    "#eval_benchmark, TEST_OFFSET, SKIP_STEP, num_train = 'math-500', 0, 1, 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9a9cdb4e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MMLUPro_configNames:['default']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/peizhengqi/anaconda3/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/peizhengqi/anaconda3/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n",
      "  warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9647180647e849a48955755dc407b44e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.\n"
     ]
    }
   ],
   "source": [
    "\n",
    "cur_DS, llm_testPromptIds, llm_trainPromptIds = load_ds.get_dataset(eval_benchmark, TEST_OFFSET, SKIP_STEP)\n",
    "cur_model, cur_tokenizer, layers_range = load_llm.get_llm(choose_llm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "beb8c034",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "q4_filePath1 = \"stat_records/qwen3-4b_mmluPro_F64_T280_I3_R9859.json\"\n",
    "with open(q4_filePath1, \"r\") as file: loaded_q4MP = json.load(file)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f7f03af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class Maze_metric(nn.Module):\n",
    "    def __init__(self, llm_dm, metric_dim, metric_hdim):\n",
    "        super().__init__()\n",
    "        self.mass_proj = nn.Linear(llm_dm, metric_dim)\n",
    "        self.metric_proj = nn.Linear(llm_dm, metric_dim)\n",
    "        self.fusion_proj1 = nn.Linear(metric_dim*2, metric_hdim)\n",
    "        self.fusion_proj2 = nn.Linear(metric_hdim, metric_dim)\n",
    "        self.tanh = nn.Tanh()\n",
    "\n",
    "    def compute_layerEp(self, layer_id, raw_e_hidS_all, raw_q_hidS):\n",
    "\n",
    "        e_hidS_all = [raw_e_hidS[layer_id] for raw_e_hidS in raw_e_hidS_all]\n",
    "        q_hidS = torch.mean(raw_q_hidS[layer_id], dim=1)\n",
    "\n",
    "        all_e_mass = [self.mass_proj(e_hidS.to(torch.float32)) for e_hidS in e_hidS_all]\n",
    "        q_mass = self.mass_proj(q_hidS.to(torch.float32))\n",
    "\n",
    "        d_values = [self.metric_proj((e_hidS_all[e_id+1] - e_hidS_all[e_id]).to(torch.float32)) for e_id in range(len(e_hidS_all)-1)]\n",
    "        d_values += [self.metric_proj((q_hidS - e_hidS_all[-1]).to(torch.float32))]\n",
    "\n",
    "        md_values = [self.tanh(all_e_mass[_id] * d_values[_id]) for _id in range(len(all_e_mass))]\n",
    "        Ep_values = sum(md_values)/len(md_values)\n",
    "        return q_mass, Ep_values\n",
    "\n",
    "    def compute_fusion(self, q_mass, Ep_values, cur_Ep_values):\n",
    "        return self.fusion_proj2(self.tanh(self.fusion_proj1(torch.cat([q_mass, Ep_values], dim=1)))) *  cur_Ep_values\n",
    "    \n",
    "    def fusion_layerEp_com0(self, layer_id, raw_e_hidS_all, raw_q_hidS):\n",
    "        q_mass, Ep_values = self.compute_layerEp(layer_id, raw_e_hidS_all, raw_q_hidS)\n",
    "        return self.compute_fusion(q_mass, Ep_values, Ep_values)\n",
    "    \n",
    "    def fusion_layerEp_com1(self, layer_id, raw_e_hidS_all, raw_q_hidS):\n",
    "        if layer_id == 0 or layer_id == len(raw_e_hidS_all[0])-1:\n",
    "            q_mass, Ep_values = self.compute_layerEp(layer_id, raw_e_hidS_all, raw_q_hidS)\n",
    "            return self.compute_fusion(q_mass, Ep_values, Ep_values)\n",
    "        else:\n",
    "            q_mass1, Ep_values1 = self.compute_layerEp(layer_id-1, raw_e_hidS_all, raw_q_hidS)\n",
    "            q_mass2, Ep_values2 = self.compute_layerEp(layer_id, raw_e_hidS_all, raw_q_hidS)\n",
    "            q_mass3, Ep_values3 = self.compute_layerEp(layer_id+1, raw_e_hidS_all, raw_q_hidS)\n",
    "            localFusion_Ep = 0.5*self.compute_fusion(q_mass1, Ep_values2, Ep_values2) + self.compute_fusion(q_mass2, Ep_values2, Ep_values2) + 0.5*self.compute_fusion(q_mass3, Ep_values2, Ep_values2)\n",
    "            return localFusion_Ep/2\n",
    "        \n",
    "    def fusion_layerEp_com2(self, layer_id, raw_e_hidS_all, raw_q_hidS):\n",
    "        if layer_id == 0 or layer_id == len(raw_e_hidS_all[0])-1:\n",
    "            q_mass, Ep_values = self.compute_layerEp(layer_id, raw_e_hidS_all, raw_q_hidS)\n",
    "            return self.compute_fusion(q_mass, Ep_values, Ep_values)\n",
    "        else:\n",
    "            q_mass1, Ep_values1 = self.compute_layerEp(layer_id-1, raw_e_hidS_all, raw_q_hidS)\n",
    "            q_mass2, Ep_values2 = self.compute_layerEp(layer_id, raw_e_hidS_all, raw_q_hidS)\n",
    "            q_mass3, Ep_values3 = self.compute_layerEp(layer_id+1, raw_e_hidS_all, raw_q_hidS)\n",
    "            localFusion_Ep = 0.5*self.compute_fusion(q_mass2, Ep_values1, Ep_values2) + self.compute_fusion(q_mass2, Ep_values2, Ep_values2) + 0.5*self.compute_fusion(q_mass2, Ep_values3, Ep_values2)\n",
    "            return localFusion_Ep/2\n",
    "        \n",
    "    def fusion_layerEp_com3(self, layer_id, raw_e_hidS_all, raw_q_hidS):\n",
    "        if layer_id == 0 or layer_id == len(raw_e_hidS_all[0])-1:\n",
    "            q_mass, Ep_values = self.compute_layerEp(layer_id, raw_e_hidS_all, raw_q_hidS)\n",
    "            return self.compute_fusion(q_mass, Ep_values, Ep_values)\n",
    "        else:\n",
    "            q_mass1, Ep_values1 = self.compute_layerEp(layer_id-1, raw_e_hidS_all, raw_q_hidS)\n",
    "            q_mass2, Ep_values2 = self.compute_layerEp(layer_id, raw_e_hidS_all, raw_q_hidS)\n",
    "            q_mass3, Ep_values3 = self.compute_layerEp(layer_id+1, raw_e_hidS_all, raw_q_hidS)\n",
    "            #localFusion_Ep = 0.5*self.tanh(self.compute_fusion(q_mass2, Ep_values1, Ep_values1)) + self.tanh(self.compute_fusion(q_mass2, Ep_values2, Ep_values2)) + 0.5*self.tanh(self.compute_fusion(q_mass2, Ep_values3, Ep_values3))\n",
    "            #return localFusion_Ep/2\n",
    "            localFusion_Ep = 0.5*self.compute_fusion(q_mass2, Ep_values1, Ep_values1) + self.compute_fusion(q_mass2, Ep_values2, Ep_values2) + 0.5*self.compute_fusion(q_mass2, Ep_values3, Ep_values3)\n",
    "            return self.tanh(localFusion_Ep/2)\n",
    "            #localFusion_Ep = self.compute_fusion(q_mass2, Ep_values1, Ep_values1) + self.compute_fusion(q_mass2, Ep_values3, Ep_values3)\n",
    "            #return self.compute_fusion(q_mass2, Ep_values2, Ep_values2) + self.tanh(localFusion_Ep)\n",
    "\n",
    "    def compute_sumEp(self, raw_e_hidS_all, raw_q_hidS):\n",
    "        num_layers = len(raw_e_hidS_all[0])\n",
    "        sum_Ep_vecs = sum([self.fusion_layerEp_com3(_layerId, raw_e_hidS_all, raw_q_hidS) for _layerId in range(num_layers)])\n",
    "        return sum_Ep_vecs/num_layers\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a3cba4ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "cached_e_hidS_all = []\n",
    "for _id in llm_trainPromptIds:\n",
    "    with torch.no_grad():\n",
    "        cached_e_hidS_all.append([torch.mean(_item, dim=1) for _item in llm_eval.access_exemplar_hiddenStates(cur_model, cur_tokenizer, eval_benchmark, cur_DS, _id)])\n",
    "    torch.cuda.empty_cache()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "40036f2b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- D48H96 --- No.Params: 259824\n"
     ]
    }
   ],
   "source": [
    "\n",
    "cur_Maze = Maze_metric(2560, maze_dim, metric_dim).to('cuda:0')\n",
    "loss_fn = nn.MSELoss()\n",
    "_optimizer = optim.Adam(cur_Maze.parameters(), lr = 2e-4)\n",
    "\n",
    "noParam = sum(p.numel() for p in cur_Maze.parameters())\n",
    "print(f\"---- D{maze_dim}H{metric_dim} --- No.Params: {noParam}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0a7e247b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "75\n",
      "[2, 6, 10, 13, 21, 22, 23, 25, 38, 46, 53, 54, 56, 57, 59, 65, 75, 83, 86, 88, 89, 90, 95, 96, 101, 104, 112, 115, 116, 119, 120, 123, 126, 129, 139, 150, 160, 161, 164, 167, 174, 177, 181, 184, 185, 186, 187, 189, 192, 195, 199, 204, 205, 206, 215, 218, 220, 222, 233, 236, 237, 238, 239, 246, 251, 259, 267, 269, 271, 275, 276, 284, 285, 289, 291]\n",
      "[80, 240, 400, 520, 840, 880, 920, 1000, 1520, 1840, 2120, 2160, 2240, 2280, 2360, 2600, 3000, 3320, 3440, 3520, 3560, 3600, 3800, 3840, 4040, 4160, 4480, 4600, 4640, 4760, 4800, 4920, 5040, 5160, 5560, 6000, 6400, 6440, 6560, 6680, 6960, 7080, 7240, 7360, 7400, 7440, 7480, 7560, 7680, 7800, 7960, 8160, 8200, 8240, 8600, 8720, 8800, 8880, 9320, 9440, 9480, 9520, 9560, 9840, 10040, 10360, 10680, 10760, 10840, 11000, 11040, 11360, 11400, 11560, 11640]\n",
      "[4, 45, 79, 97, 99, 108, 110, 122, 165, 176, 191, 193, 196, 203, 214, 219, 235, 280, 288, 292, 293, 297]\n",
      "22\n"
     ]
    }
   ],
   "source": [
    "\n",
    "abs_idList, test_idList = [], []\n",
    "for _id, _item in enumerate(loaded_q4MP[:]):\n",
    "    ave_conf = sum([sub_item['conf'] for sub_item in _item])/len(_item)\n",
    "    if len(_item) == 64 and ave_conf > 0.1 and ave_conf < 0.6:\n",
    "        abs_idList.append(_id)\n",
    "        test_idList.append(_item[0]['test_id'])\n",
    "\n",
    "extra_trainIds = []\n",
    "for _id, _item in enumerate(loaded_q4MP[:]):\n",
    "    ave_conf = sum([sub_item['conf'] for sub_item in _item])/len(_item)\n",
    "    if ave_conf > 0.2 and ave_conf < 0.7 and _id not in abs_idList:\n",
    "        extra_trainIds.append(_id)\n",
    "\n",
    "\n",
    "print(len(abs_idList))\n",
    "print(abs_idList)\n",
    "print(test_idList)\n",
    "print(extra_trainIds)\n",
    "print(len(extra_trainIds))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8f78329f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "75\n",
      "1302 4800\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "cur_items = []\n",
    "for abs_id in abs_idList: cur_items.append(loaded_q4MP[abs_id])\n",
    "print(len(cur_items))\n",
    "\n",
    "_total, _corr = 0, 0\n",
    "for cur_id, cur_item in enumerate(cur_items):\n",
    "    for _item in cur_item:\n",
    "        if _item['conf'] == 1: _corr += 1\n",
    "        _total += 1\n",
    "print(_corr, _total)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "95d3b768",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "cached_q_hidS_all = {}\n",
    "for _id in test_idList:\n",
    "    with torch.no_grad():\n",
    "        raw_q_hidS = llm_eval.access_testQuery_hiddenStates(cur_model, cur_tokenizer, eval_benchmark, cur_DS, _id)\n",
    "        cached_q_hidS_all[_id] = raw_q_hidS\n",
    "    torch.cuda.empty_cache()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "73857ca7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.2281, 0.2163, 0.2456, 0.2689, 0.2456, 0.2105, 0.1988, 0.1812, 0.2105, 0.2105, 0.2105, 0.2398, 0.2398, 0.2574, 0.2574, 0.2514, 0.2865, 0.2865, 0.2865, 0.2923, 0.2747, 0.2747, 0.2747, 0.2514, 0.2514, 0.2339, 0.2339, 0.2456, 0.2456, 0.2632]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def tmp_eval(cur_items, num_check):\n",
    "    eval_stat = []\n",
    "    for cur_id, cur_item in enumerate(cur_items):\n",
    "        raw_e_hidS_all = [cached_e_hidS_all[_id] for _id in cur_item['fs_ids']]\n",
    "        raw_q_hidS = cached_q_hidS_all[cur_item['test_id']]\n",
    "\n",
    "        sum_Ep_vecs = cur_Maze.compute_sumEp(raw_e_hidS_all, raw_q_hidS)\n",
    "        res_energy = torch.norm(sum_Ep_vecs).item()\n",
    "        eval_stat.append([cur_id, res_energy, cur_item['conf']])\n",
    "\n",
    "    sorted_stat = sorted(eval_stat[:num_check], key=lambda x: x[1], reverse=True)\n",
    "    corr_prob = sum([_item[2] for _item in sorted_stat[:3]])/3\n",
    "    return corr_prob\n",
    "\n",
    "res_stat = []\n",
    "for num_check in range(4, 64, 2):\n",
    "    tmp_corr_probs = [tmp_eval(_item, num_check) for _item in cur_items[::4]]\n",
    "    res_stat.append(round(sum(tmp_corr_probs)/len(tmp_corr_probs),4))\n",
    "\n",
    "print(res_stat)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "21538b24",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "09:47:00\n",
      "0 715.1450250668447 09:56:33\n",
      "[0.2221, 0.2161, 0.2163, 0.2339, 0.2747, 0.2396, 0.2161, 0.1928, 0.1811, 0.1811, 0.1868, 0.2044, 0.2044, 0.1868, 0.1753, 0.1402, 0.1226, 0.1226, 0.1226, 0.1109, 0.0933, 0.0933, 0.0991, 0.0991, 0.0816, 0.0816, 0.0991, 0.0991, 0.1167, 0.1167]\n",
      "1 696.7979282009449 10:50:51\n",
      "2 686.1381960644525 11:00:24\n",
      "[0.2456, 0.2396, 0.2514, 0.2688, 0.3098, 0.2923, 0.2982, 0.3393, 0.3568, 0.3511, 0.3628, 0.3335, 0.3335, 0.3335, 0.3042, 0.31, 0.31, 0.2982, 0.3158, 0.3391, 0.3274, 0.3274, 0.3274, 0.3098, 0.3098, 0.3098, 0.3098, 0.2923, 0.2923, 0.2923]\n",
      "3 663.8642972334856 11:54:31\n",
      "4 644.4201144260542 12:04:02\n",
      "[0.1754, 0.1812, 0.2281, 0.2456, 0.2689, 0.2396, 0.2512, 0.263, 0.2454, 0.2454, 0.2865, 0.2747, 0.3098, 0.3098, 0.304, 0.304, 0.3216, 0.3216, 0.3216, 0.304, 0.3509, 0.3333, 0.3333, 0.3158, 0.3158, 0.3042, 0.3042, 0.3042, 0.3218, 0.3218]\n",
      "5 620.1183864723823 12:58:40\n",
      "6 603.5780191762271 13:08:12\n",
      "[0.2456, 0.2514, 0.2747, 0.3274, 0.3156, 0.3156, 0.2688, 0.2921, 0.3509, 0.3509, 0.3509, 0.3509, 0.3918, 0.3742, 0.3625, 0.38, 0.3918, 0.4035, 0.3802, 0.3684, 0.3684, 0.3333, 0.3216, 0.3216, 0.304, 0.304, 0.304, 0.3158, 0.3333, 0.3391]\n",
      "7 590.5070708499911 14:02:18\n",
      "8 577.6343354553152 14:11:50\n",
      "[0.2163, 0.2104, 0.2104, 0.1986, 0.2046, 0.2046, 0.1928, 0.1811, 0.2221, 0.2454, 0.2514, 0.2339, 0.2865, 0.3158, 0.2807, 0.304, 0.3216, 0.3391, 0.3216, 0.3216, 0.3216, 0.3098, 0.3156, 0.3214, 0.3272, 0.3507, 0.3507, 0.3389, 0.3389, 0.3332]\n",
      "9 556.816070568536 15:05:54\n",
      "10 537.605157158868 15:15:25\n",
      "[0.2281, 0.2046, 0.2163, 0.304, 0.2863, 0.2688, 0.2277, 0.2042, 0.2861, 0.2979, 0.3096, 0.3039, 0.3039, 0.3156, 0.3156, 0.3156, 0.3156, 0.3214, 0.3039, 0.3039, 0.2981, 0.2863, 0.2688, 0.2688, 0.2746, 0.2746, 0.2746, 0.2746, 0.257, 0.2395]\n",
      "11 520.0144936614225 16:09:44\n",
      "12 498.44382194431637 16:19:17\n",
      "[0.2281, 0.2396, 0.2572, 0.3274, 0.3391, 0.3391, 0.3098, 0.3039, 0.3214, 0.3389, 0.3332, 0.3332, 0.4093, 0.4211, 0.4211, 0.4093, 0.3918, 0.3918, 0.3918, 0.38, 0.38, 0.3507, 0.3625, 0.3625, 0.3625, 0.3625, 0.38, 0.3567, 0.3567, 0.3567]\n",
      "13 482.8455402526678 17:13:37\n",
      "14 465.9515938254731 17:23:08\n",
      "[0.2339, 0.2514, 0.2396, 0.2981, 0.3272, 0.3272, 0.3154, 0.2979, 0.3447, 0.3858, 0.3858, 0.4033, 0.4561, 0.4854, 0.4854, 0.4679, 0.4504, 0.4328, 0.4153, 0.4035, 0.3802, 0.3744, 0.3626, 0.3626, 0.3626, 0.3626, 0.3626, 0.3451, 0.3626, 0.3451]\n",
      "15 449.36313566875856 18:17:33\n",
      "16 432.2590401792471 18:27:05\n",
      "[0.2046, 0.2396, 0.2104, 0.2981, 0.2863, 0.3214, 0.3039, 0.2804, 0.3095, 0.3681, 0.3623, 0.3916, 0.4267, 0.4384, 0.4384, 0.4209, 0.4033, 0.3916, 0.3975, 0.3975, 0.4151, 0.3975, 0.4151, 0.4209, 0.4209, 0.4093, 0.4093, 0.4035, 0.4211, 0.4211]\n",
      "17 418.24676769624057 19:21:01\n",
      "18 398.8068936949102 19:30:32\n",
      "[0.2456, 0.2279, 0.2279, 0.2746, 0.3039, 0.3214, 0.3332, 0.3272, 0.374, 0.3975, 0.3975, 0.4093, 0.4268, 0.4268, 0.4268, 0.4093, 0.4268, 0.4268, 0.3918, 0.38, 0.3682, 0.3623, 0.3623, 0.3974, 0.3974, 0.4209, 0.4209, 0.4209, 0.4384, 0.4384]\n",
      "19 385.3752183247856 20:24:43\n",
      "20 370.1546652865111 20:34:14\n",
      "[0.2339, 0.2221, 0.2221, 0.2923, 0.3156, 0.3332, 0.3332, 0.3389, 0.3214, 0.3742, 0.3742, 0.386, 0.427, 0.4621, 0.4504, 0.4504, 0.4504, 0.4444, 0.4561, 0.4268, 0.4268, 0.3918, 0.4093, 0.4444, 0.4444, 0.427, 0.427, 0.427, 0.427, 0.427]\n",
      "21 351.0795792926094 21:28:16\n",
      "22 341.1926943841806 21:37:48\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_1826542/2760720635.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     22\u001b[0m         \u001b[0mres_stat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mnum_check\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m64\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m             \u001b[0mtmp_corr_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtmp_eval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_item\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_check\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_item\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcur_items\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     25\u001b[0m             \u001b[0mres_stat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtmp_corr_probs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtmp_corr_probs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres_stat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_1826542/2760720635.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     22\u001b[0m         \u001b[0mres_stat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mnum_check\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m64\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m             \u001b[0mtmp_corr_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtmp_eval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_item\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_check\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_item\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcur_items\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     25\u001b[0m             \u001b[0mres_stat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtmp_corr_probs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtmp_corr_probs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres_stat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_1826542/2003723734.py\u001b[0m in \u001b[0;36mtmp_eval\u001b[0;34m(cur_items, num_check)\u001b[0m\n\u001b[1;32m      5\u001b[0m         \u001b[0mraw_q_hidS\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcached_q_hidS_all\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcur_item\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'test_id'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m         \u001b[0msum_Ep_vecs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcur_Maze\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_sumEp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mraw_e_hidS_all\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_q_hidS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m         \u001b[0mres_energy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msum_Ep_vecs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m         \u001b[0meval_stat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcur_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mres_energy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcur_item\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'conf'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_1826542/2101813558.py\u001b[0m in \u001b[0;36mcompute_sumEp\u001b[0;34m(self, raw_e_hidS_all, raw_q_hidS)\u001b[0m\n\u001b[1;32m     69\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mcompute_sumEp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_e_hidS_all\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_q_hidS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     70\u001b[0m         \u001b[0mnum_layers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mraw_e_hidS_all\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 71\u001b[0;31m         \u001b[0msum_Ep_vecs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfusion_layerEp_com3\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_layerId\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_e_hidS_all\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_q_hidS\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_layerId\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_layers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     72\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0msum_Ep_vecs\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mnum_layers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_1826542/2101813558.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     69\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mcompute_sumEp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_e_hidS_all\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_q_hidS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     70\u001b[0m         \u001b[0mnum_layers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mraw_e_hidS_all\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 71\u001b[0;31m         \u001b[0msum_Ep_vecs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfusion_layerEp_com3\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_layerId\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_e_hidS_all\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_q_hidS\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_layerId\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_layers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     72\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0msum_Ep_vecs\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mnum_layers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_1826542/2101813558.py\u001b[0m in \u001b[0;36mfusion_layerEp_com3\u001b[0;34m(self, layer_id, raw_e_hidS_all, raw_q_hidS)\u001b[0m\n\u001b[1;32m     58\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     59\u001b[0m             \u001b[0mq_mass1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mEp_values1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_layerEp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlayer_id\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_e_hidS_all\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_q_hidS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m             \u001b[0mq_mass2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mEp_values2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_layerEp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlayer_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_e_hidS_all\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_q_hidS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     61\u001b[0m             \u001b[0mq_mass3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mEp_values3\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_layerEp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlayer_id\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_e_hidS_all\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_q_hidS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     62\u001b[0m             \u001b[0;31m#localFusion_Ep = 0.5*self.tanh(self.compute_fusion(q_mass2, Ep_values1, Ep_values1)) + self.tanh(self.compute_fusion(q_mass2, Ep_values2, Ep_values2)) + 0.5*self.tanh(self.compute_fusion(q_mass2, Ep_values3, Ep_values3))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_1826542/2101813558.py\u001b[0m in \u001b[0;36mcompute_layerEp\u001b[0;34m(self, layer_id, raw_e_hidS_all, raw_q_hidS)\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0mq_mass\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmass_proj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq_hidS\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m         \u001b[0md_values\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric_proj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me_hidS_all\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0me_id\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0me_hidS_all\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0me_id\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0me_id\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me_hidS_all\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     19\u001b[0m         \u001b[0md_values\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric_proj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq_hidS\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0me_hidS_all\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_1826542/2101813558.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0mq_mass\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmass_proj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq_hidS\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m         \u001b[0md_values\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric_proj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me_hidS_all\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0me_id\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0me_hidS_all\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0me_id\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0me_id\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me_hidS_all\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     19\u001b[0m         \u001b[0md_values\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric_proj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq_hidS\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0me_hidS_all\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1551\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1552\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1553\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1554\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1555\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1560\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1561\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1563\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1564\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    118\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    119\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "max_epochs = 100\n",
    "\n",
    "print(datetime.now().strftime(\"%H:%M:%S\"))\n",
    "for _ep in range(max_epochs):\n",
    "    W_loss = 0\n",
    "    for item_id, cur_item in enumerate(cur_items):\n",
    "        if item_id % 4 == 0: continue\n",
    "        #print(f\"--- train absId: {abs_id}\")\n",
    "        for _item in cur_item:\n",
    "            raw_e_hidS_all = [cached_e_hidS_all[_id] for _id in _item['fs_ids']]\n",
    "            raw_q_hidS = cached_q_hidS_all[_item['test_id']]\n",
    "\n",
    "            sum_Ep_vecs = cur_Maze.compute_sumEp(raw_e_hidS_all, raw_q_hidS)\n",
    "\n",
    "            cur_loss = (torch.norm(sum_Ep_vecs) - _item['conf'])**2\n",
    "            _optimizer.zero_grad(); cur_loss.backward(); _optimizer.step()\n",
    "            W_loss += cur_loss.item()\n",
    "\n",
    "    print(_ep, W_loss, datetime.now().strftime(\"%H:%M:%S\"))\n",
    "    #if _ep % 3 == 0 and _ep != 0:\n",
    "    if _ep % 2 == 0:\n",
    "        res_stat = []\n",
    "        for num_check in range(4, 64, 2):\n",
    "            tmp_corr_probs = [tmp_eval(_item, num_check) for _item in cur_items[::4]]\n",
    "            res_stat.append(round(sum(tmp_corr_probs)/len(tmp_corr_probs),4))\n",
    "        print(res_stat)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46c8cffb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ddc7fac",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ef8a5f2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
