{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6ee95b57",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import Linear\n",
    "from torch_geometric.nn import GCNConv,GATConv,SAGEConv\n",
    "from torch.utils.data import DataLoader,Dataset,TensorDataset\n",
    "from torch_geometric.nn import MessagePassing\n",
    "from torch_geometric.utils import add_self_loops, degree\n",
    "import numpy as np\n",
    "from scipy.sparse import coo_matrix\n",
    "import pandas as pd\n",
    "import tqdm\n",
    "import random\n",
    "import datetime\n",
    "import sys\n",
    "from scipy.stats import pearsonr\n",
    "# from  torch.utils.data import TensorDataset,DataLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22a0a635",
   "metadata": {},
   "source": [
    "## 参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0d24c635",
   "metadata": {},
   "outputs": [],
   "source": [
    "TIME_LEN = 12\n",
    "BATCH_SIZE = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0313c1d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(12345)\n",
    "torch.cuda.manual_seed(12345)\n",
    "np.random.seed(12345)\n",
    "random.seed(12345)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d855d042",
   "metadata": {},
   "source": [
    "### 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "de4a8427",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(34272, 207)\n"
     ]
    }
   ],
   "source": [
    "data = np.load('/home/hqh/DataSetFile/metrla/STdata.npy')#.astype(np.float16)# 时空\n",
    "print(data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "751bb924",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  0,   0,   0,  ..., 206, 206, 206],\n",
      "        [  0,  13,  37,  ..., 187, 198, 206]])\n"
     ]
    }
   ],
   "source": [
    "A = np.load('/home/hqh/DataSetFile/metrla/adj01.npy')\n",
    "adj = coo_matrix(A)\n",
    "values = adj.data  \n",
    "indices = np.vstack((adj.row, adj.col))  # 我们真正需要的coo形式\n",
    "edge_index  = torch.LongTensor(indices)  # PyG框架需要的coo形式\n",
    "print(edge_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "dfc6e543",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1722 207\n"
     ]
    }
   ],
   "source": [
    "n_V,n_E = len(A),len(edge_index[0])\n",
    "print(n_E,n_V)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c882a844",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "92daa9b9",
   "metadata": {},
   "source": [
    "### 同配异配"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "33463079",
   "metadata": {},
   "outputs": [],
   "source": [
    "# start_time = datetime.datetime.now()\n",
    "# DATA_ARRAY =  np.repeat(np.expand_dims(data,2),len(adj),2) # time,n,n,每一行的元素一样\n",
    "# print(datetime.datetime.now()-start_time)\n",
    "# DIFF = DATA_ARRAY - DATA_ARRAY.transpose(0,2,1) #time,n,n， i,j表示d当前时间步的data[i]-data[j]\n",
    "# print(datetime.datetime.now()-start_time)\n",
    "# DIFF = DIFF/DATA_ARRAY\n",
    "# print(datetime.datetime.now()-start_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "029700cd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0:00:01.090931\n",
      "0:00:02.302211\n",
      "0:00:02.626245\n"
     ]
    }
   ],
   "source": [
    "start_time = datetime.datetime.now()\n",
    "DIFF = data[:,edge_index[0]]-data[:,edge_index[1]]\n",
    "print(datetime.datetime.now()-start_time)\n",
    "DIFF = (DIFF+1)/(data[:,edge_index[0]]+1)\n",
    "print(datetime.datetime.now()-start_time)\n",
    "DIFF_ABS = np.absolute(DIFF)\n",
    "print(datetime.datetime.now()-start_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e38931ed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4397064447402954 GB\n",
      "(34272, 1722)\n",
      "0.9486729411882304 6.329240266700626\n",
      "同配比例： 0.2273149097037189\n",
      "异配正比例： 0.4731735173744294\n",
      "异配负比例： 0.29951157292185165\n"
     ]
    }
   ],
   "source": [
    "eplision=0.02\n",
    "print(sys.getsizeof(DIFF) / 1024 / 1024/1024, 'GB')\n",
    "print(DIFF.shape)\n",
    "number = DIFF.shape[0]*DIFF.shape[1]\n",
    "# print( len(DIFF[np.isnan(DIFF)])/number)\n",
    "# print( len(DIFF[np.isinf(DIFF)])/number)\n",
    "print(DIFF_ABS.mean(),DIFF_ABS.std())\n",
    "print('同配比例：',((DIFF_ABS<=eplision).sum())/number)\n",
    "print('异配正比例：',((DIFF>eplision).sum())/number)\n",
    "print('异配负比例：',((DIFF<-eplision).sum())/number)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "472398dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "DIFF[DIFF>eplision]=1\n",
    "DIFF[DIFF_ABS<=eplision]=0.5\n",
    "DIFF[DIFF<-eplision]=-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "227a9e2e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(34272, 1722)\n",
      "0.4731735173744294 0.29951157292185165\n",
      "[-1.   0.5  0.5  1.   1.   1.   1.   1.   1.   0.5  0.5  0.5  0.5  1.\n",
      "  1.   0.5  1.   1.   0.5  0.5 -1.  -1.  -1.   0.5 -1.  -1.   1.  -1.\n",
      " -1.   0.5 -1.   1.  -1.   1.   0.5  1.   0.5  1.   1.   0.5 -1.   1.\n",
      "  1.   1.   0.5  1.   1.   0.5  0.5  1.   1.   0.5  1.   1.   1.   0.5\n",
      "  1.   1.   1.   1.   1.   0.5  0.5  1.   1.   0.5  1.   0.5  1.   0.5\n",
      "  1.   1.   0.5  0.5  1.   0.5  0.5  1.   0.5  1.   1.   1.   0.5  0.5\n",
      "  0.5  1.   0.5  1.  -1.  -1.   0.5  0.5  0.5 -1.  -1.   0.5 -1.  -1.\n",
      "  1.  -1.   0.5  1.   1.   0.5  1.   1.   1.   1.   1.  -1.   0.5 -1.\n",
      "  1.  -1.  -1.   1.  -1.  -1.  -1.   1.  -1.   1.  -1.   1.   0.5  1.\n",
      "  1.   1.   1.   1.   1.   1.   1.   0.5  0.5  1.   0.5 -1.  -1.  -1.\n",
      "  1.   1.   0.5  1.   1.   1.   1.   1.   0.5  1.  -1.  -1.   1.  -1.\n",
      " -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.  -1.\n",
      " -1.  -1.   0.5  0.5 -1.   0.5  0.5 -1.  -1.  -1. ]\n"
     ]
    }
   ],
   "source": [
    "print(DIFF.shape)\n",
    "print((DIFF==1).sum()/number,(DIFF==-1).sum()/number)\n",
    "print(DIFF[0][10:188])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "144289a3",
   "metadata": {},
   "source": [
    "### 变化方向"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91ba9228",
   "metadata": {},
   "source": [
    "### 1.直接求出sign"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "47ce42a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(34271, 1722)\n"
     ]
    }
   ],
   "source": [
    "EXPAND_DATA = data[:,edge_index[0]]\n",
    "sign = EXPAND_DATA[1:]-EXPAND_DATA[:-1]\n",
    "sign = (sign+1)/(EXPAND_DATA[:-1]+1)\n",
    "sign_abs = np.absolute(sign)\n",
    "print(sign.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b662b010",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.3516844111856842 3.646938386908529\n",
      "时序同配比例： 0.33384646005428276\n",
      "时序异配同向比例： 0.49182567545671957\n",
      "时序异配负向比例： 0.17432786448899767\n"
     ]
    }
   ],
   "source": [
    "number = sign.shape[0]*sign.shape[1]\n",
    "print(sign_abs.mean(),sign_abs.std())\n",
    "print('时序同配比例：',((sign_abs<=eplision).sum())/number)\n",
    "print('时序异配同向比例：',((sign>eplision).sum())/number)\n",
    "print('时序异配负向比例：',((sign<-eplision).sum())/number)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "517962ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "def show_month(sign,sign_abs,DIFF,DIFF_ABS,eplision):\n",
    "    number = sign.shape[1]\n",
    "    T_homo = (sign_abs<=eplision).sum(-1)/number\n",
    "    T_hete1 = (sign>eplision).sum(-1)/number\n",
    "    T_hete2 = (sign<-eplision).sum(-1)/number\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b0ec4117",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sign[sign_abs<=eplision]=0\n",
    "# sign[sign>eplision]=1\n",
    "# sign[sign<-eplision]=-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "9e107948",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(sign.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a888052",
   "metadata": {},
   "source": [
    "### 2.但是直接求sign会有标签暴露之嫌疑，还是需要用一个网络来学习sign"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a539469c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "76f00dab",
   "metadata": {},
   "source": [
    "### 2.1加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ccb5e227",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(34271, 207)\n"
     ]
    }
   ],
   "source": [
    "sig = data[1:]-data[0:-1]\n",
    "sig = (sig+1)/(data[0:-1]+1)\n",
    "sig_abs = np.absolute(sig)\n",
    "print(sig.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8333adf6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.35385622143903406 3.6673369425688302\n",
      "时序同配比例： 0.3325436627099968\n",
      "时序异配同向比例： 0.4919509840364461\n",
      "时序异配负向比例： 0.17550535325355715\n"
     ]
    }
   ],
   "source": [
    "eplision=0.02\n",
    "number = sig.shape[0]*sig.shape[1]\n",
    "print(sig_abs.mean(),sig_abs.std())\n",
    "print('时序同配比例：',((sig_abs<=eplision).sum())/number)\n",
    "print('时序异配同向比例：',((sig>eplision).sum())/number)\n",
    "print('时序异配负向比例：',((sig<-eplision).sum())/number)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "241e8aa7",
   "metadata": {},
   "outputs": [],
   "source": [
    "sig[sig_abs<=eplision]=0\n",
    "sig[sig>eplision]=1\n",
    "sig[sig<-eplision]=2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "eff5f6dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(34271, 207)\n"
     ]
    }
   ],
   "source": [
    "print(sig.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "fb23e282",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([34259, 207, 12])\n",
      "torch.Size([34259, 207, 12])\n",
      "y.shape: torch.Size([34259, 207])\n",
      "input_x.shape: torch.Size([34259, 207, 24])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_25543/1059340229.py:13: DeprecationWarning: an integer is required (got type numpy.float64).  Implicit conversion to integers using __int__ is deprecated, and may be removed in a future version of Python.\n",
      "  y = torch.LongTensor(y)\n"
     ]
    }
   ],
   "source": [
    "x = []\n",
    "for i in range(1,len(data)-TIME_LEN): #从第2个开始，和sig保持一致\n",
    "    x.append(data[i:i+TIME_LEN])\n",
    "x = torch.FloatTensor(x).transpose(2,1)#np的transpose和torch的不一样？\n",
    "print(x.shape)# batch  空间 时间\n",
    "\n",
    "sig_x = []\n",
    "y = []\n",
    "for i in range(0,len(sig)-TIME_LEN):\n",
    "    sig_x.append(sig[i:i+TIME_LEN])\n",
    "    y.append(sig[i+TIME_LEN])\n",
    "sig_x = torch.FloatTensor(sig_x).transpose(2,1)\n",
    "y = torch.LongTensor(y)\n",
    "print(sig_x.shape)\n",
    "\n",
    "input_x = torch.cat((x,sig_x),dim=-1)\n",
    "print('y.shape:',y.shape)\n",
    "print('input_x.shape:',input_x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "ea4e5a74",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.int64 tensor([2, 2, 1, 2, 1, 2, 1, 2, 1, 0, 0, 0, 0, 1, 0, 1, 0, 2, 2, 1, 2, 2, 2, 2,\n",
      "        1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 1, 1, 1, 2, 2, 1, 0, 1, 0, 1, 2, 0,\n",
      "        1, 2, 2, 2, 1, 0, 2, 0, 1, 0, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 0, 0, 2, 0,\n",
      "        2, 1, 0, 1, 1, 2, 1, 1, 1, 0, 2, 2, 1, 1, 0, 2, 2, 2, 1, 2, 0, 1, 0, 0,\n",
      "        0, 1, 2, 1, 0, 1, 2, 2, 2, 1, 1, 2, 1, 2, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0,\n",
      "        2, 1, 2, 0, 0, 2, 0, 0, 1, 1, 2, 2, 1, 1, 0, 1, 1, 1, 0, 2, 0, 2, 1, 1,\n",
      "        2, 1, 2, 2, 1, 0, 1, 2, 1, 1, 0, 1, 1, 1, 2, 1, 0, 0, 0, 1, 1, 1, 0, 0,\n",
      "        1, 1, 1, 1, 1, 1, 2, 0, 2, 2, 2, 1, 1, 0, 0, 0, 0, 2, 1, 0, 0, 1, 1, 1,\n",
      "        1, 1, 1, 0, 0, 0, 2, 1, 1, 2, 1, 1, 0, 2, 0])\n"
     ]
    }
   ],
   "source": [
    "print(y.dtype,y[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "26fc9b06",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE=128\n",
    "def construct_dataloader(x,y,loader_type):\n",
    "    dataset = TensorDataset(x,y)\n",
    "    if loader_type=='train':\n",
    "        dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=False)# 不能shuffle\n",
    "    else:\n",
    "        dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=False)# \n",
    "    return dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "9b656701",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples = len(input_x)\n",
    "train_ratio = 0.8\n",
    "test_ratio = 0.1\n",
    "\n",
    "num_test = round(num_samples * 0.2)\n",
    "num_train = round(num_samples * 0.7)\n",
    "num_val = num_samples - num_test - num_train\n",
    "x_train, y_train = input_x[:num_train], y[:num_train]\n",
    "x_val, y_val =  input_x[num_train: num_train + num_val],y[num_train: num_train + num_val]\n",
    "x_test, y_test = input_x[-num_test:], y[-num_test:]\n",
    "\n",
    "train_dataloader = construct_dataloader(x_train,y_train,'train')\n",
    "val_dataloader = construct_dataloader(x_val,y_val,'valid')\n",
    "test_dataloader = construct_dataloader(x_test,y_test,'valid')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "a57aa19a",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_dataloader =  DataLoader(dataset=input_x, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "e3864cf6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23981\n",
      "3426\n",
      "6852\n"
     ]
    }
   ],
   "source": [
    "print(train_dataloader.dataset.__len__())\n",
    "print(val_dataloader.dataset.__len__())\n",
    "print(test_dataloader.dataset.__len__())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ea6ddb2",
   "metadata": {},
   "source": [
    "### 2.2模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "d4ddef4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM_layer(torch.nn.Module):\n",
    "    def __init__(self,output_size):\n",
    "        super(LSTM_layer, self).__init__()\n",
    "        batch_first=True\n",
    "        hidden_size=64\n",
    "        self.lstm1 = torch.nn.LSTM(num_layers = 1,batch_first=batch_first,\n",
    "                               input_size=1,hidden_size = hidden_size)\n",
    "        self.lstm2 = torch.nn.LSTM(num_layers = 1,batch_first=batch_first,\n",
    "                               input_size=1,hidden_size = hidden_size)\n",
    "        self.mlp = Linear(2*hidden_size,output_size)\n",
    "    def forward(self, x):\n",
    "        #x: batch numnodes seqlen\n",
    "        x=x.reshape(x.shape[0]*x.shape[1],x.shape[2],1)\n",
    "        out1,(hn,cn) = self.lstm1(x[:,:TIME_LEN])\n",
    "        out2,(hn,cn) = self.lstm2(x[:,TIME_LEN:])# batchsize*n timestep hidden_size\n",
    "        out1 = out1.relu()\n",
    "        out2 = out2.relu()#[:,-1,:]\n",
    "        out = self.mlp(torch.cat((out1,out2),dim=-1))\n",
    "        out.relu()\n",
    "        return out.transpose(1,2) #->batch*n hiddensize(channel) timestep \n",
    "    \n",
    "class ResNet_Block(torch.nn.Module):\n",
    "    def __init__(self,input_channels,output_channels,ks):\n",
    "        super(ResNet_Block, self).__init__()\n",
    "        self.align=torch.nn.Conv1d(in_channels=input_channels,out_channels=output_channels,kernel_size=ks)\n",
    "        self.conv1=torch.nn.Conv1d(in_channels=input_channels,out_channels=output_channels,kernel_size=ks)\n",
    "        self.bn1 = nn.BatchNorm2d(output_channels)\n",
    "        self.conv2=torch.nn.Conv1d(in_channels=output_channels,out_channels=output_channels,kernel_size=3,padding=1)\n",
    "        self.bn2 = nn.BatchNorm2d(output_channels)\n",
    "    def forward(self,x):# x: batch channel height\n",
    "        x_align=self.align(x)\n",
    "        out=self.conv1(x)\n",
    "        #out=self.bn1(out)\n",
    "        out=out.relu()\n",
    "        out=self.conv2(out)\n",
    "        #out=self.bn2(out)\n",
    "        #print(out.shape,x.shape)\n",
    "        out=out+x_align\n",
    "        return out.relu()\n",
    "\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model, self).__init__()\n",
    "        self.lstm = LSTM_layer(output_size=64)\n",
    "        res_list=[]\n",
    "        for i in range(3):\n",
    "            res_list.append(ResNet_Block(input_channels=64,output_channels=64,ks=4))\n",
    "        res_list.append(ResNet_Block(input_channels=64,output_channels=64,ks=3)) #3*3+2\n",
    "        self.res = nn.Sequential(*res_list)\n",
    "        self.mlp = Linear(64,3)\n",
    "    def forward(self, x):\n",
    "        #x: batch numnodes seqlen\n",
    "        out = self.lstm(x)\n",
    "        out = self.res(out).squeeze(-1)\n",
    "        out = self.mlp(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "683cac59",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model(\n",
      "  (lstm): LSTM_layer(\n",
      "    (lstm1): LSTM(1, 64, batch_first=True)\n",
      "    (lstm2): LSTM(1, 64, batch_first=True)\n",
      "    (mlp): Linear(in_features=128, out_features=64, bias=True)\n",
      "  )\n",
      "  (res): Sequential(\n",
      "    (0): ResNet_Block(\n",
      "      (align): Conv1d(64, 64, kernel_size=(4,), stride=(1,))\n",
      "      (conv1): Conv1d(64, 64, kernel_size=(4,), stride=(1,))\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): ResNet_Block(\n",
      "      (align): Conv1d(64, 64, kernel_size=(4,), stride=(1,))\n",
      "      (conv1): Conv1d(64, 64, kernel_size=(4,), stride=(1,))\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (2): ResNet_Block(\n",
      "      (align): Conv1d(64, 64, kernel_size=(4,), stride=(1,))\n",
      "      (conv1): Conv1d(64, 64, kernel_size=(4,), stride=(1,))\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (3): ResNet_Block(\n",
      "      (align): Conv1d(64, 64, kernel_size=(3,), stride=(1,))\n",
      "      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(1,))\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (mlp): Linear(in_features=64, out_features=3, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "model=Model()\n",
    "print(model)\n",
    "optimizer =  torch.optim.Adam(model.parameters(), lr=3e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f05cecee",
   "metadata": {},
   "source": [
    "### 2.3训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "d226ef76",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\"\n",
    "model = model.to(device)\n",
    "loss_fn = torch.nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "c1e99679",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_acc(out,y):\n",
    "    predict = torch.argmax(out, dim=-1)\n",
    "    acc_number = (predict == y).sum()\n",
    "    return acc_number"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "bca39fe7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(optimizer,dataloader):\n",
    "    acc = 0\n",
    "    number = 0\n",
    "    total_loss = 0\n",
    "    total_batch = 0\n",
    "    for batch in tqdm.tqdm(dataloader):\n",
    "        x,y=batch\n",
    "        x=x.to(device)\n",
    "        y=y.to(device)\n",
    "        out = model(x)\n",
    "        loss = loss_fn(out,y.reshape(-1))\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        acc+=get_acc(out,y.reshape(-1))\n",
    "        number+=len(out)\n",
    "        total_loss+=loss.detach()\n",
    "        total_batch+=1\n",
    "    return acc/number,total_loss/total_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "28b4366e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(dataloader):\n",
    "    with torch.no_grad():\n",
    "        acc = 0\n",
    "        number = 0\n",
    "        total_loss = 0\n",
    "        total_batch = 0\n",
    "        y_hat = []\n",
    "        y_real = []\n",
    "        for batch in tqdm.tqdm(dataloader):\n",
    "            x,y=batch\n",
    "            x=x.to(device)\n",
    "            y=y.to(device)\n",
    "            out = model(x)\n",
    "            loss = loss_fn(out,y.reshape(-1))\n",
    "            acc += get_acc(out,y.reshape(-1))\n",
    "            number+=len(out)\n",
    "            total_loss+=loss\n",
    "            total_batch+=1\n",
    "            \n",
    "            predict = torch.argmax(out, dim=-1)\n",
    "            y_hat.append(predict.reshape(-1))\n",
    "            y_real.append(y.reshape(-1))\n",
    "        y_hat = torch.cat(y_hat,dim=-1).cpu().numpy()\n",
    "        y_real = torch.cat(y_real,dim=-1).cpu().numpy()\n",
    "        print(y_hat.shape,y_real.shape)\n",
    "        print(y_hat.mean(),y_hat.std())\n",
    "        print(y_real.mean(),y_real.std())\n",
    "        print('皮尔逊系数 = ',pearsonr(y_hat,y_real))\n",
    "    return acc/number,total_loss/total_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "ae322884",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 188/188 [00:37<00:00,  5.03it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 59.07it/s]\n",
      "/home/hqh/anaconda3/envs/torch1.8_second/lib/python3.9/site-packages/scipy/stats/stats.py:4023: PearsonRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(PearsonRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(709182,) (709182,)\n",
      "1.0 0.0\n",
      "0.8212207867655975 0.6975936393948747\n",
      "皮尔逊系数 =  (nan, nan)\n",
      "0-th epoch | trian: acc=0.4790164828300476 loss=0.996900737285614 | val:  acc=0.48140108585357666 loss=0.9691302180290222\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 188/188 [00:37<00:00,  5.04it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 59.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(709182,) (709182,)\n",
      "0.9999168055590807 0.009120719248187425\n",
      "0.8212207867655975 0.6975936393948747\n",
      "皮尔逊系数 =  (0.004310975423586499, 0.0002829715788930834)\n",
      "1-th epoch | trian: acc=0.4865548610687256 loss=0.9688684940338135 | val:  acc=0.48142364621162415 loss=0.9555522799491882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 188/188 [00:37<00:00,  5.04it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 59.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(709182,) (709182,)\n",
      "0.664829338590094 0.6104220872487242\n",
      "0.8212207867655975 0.6975936393948747\n",
      "皮尔逊系数 =  (0.30072571218055316, 0.0)\n",
      "2-th epoch | trian: acc=0.5337391495704651 loss=0.914862334728241 | val:  acc=0.6139326691627502 loss=0.8444697260856628\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 188/188 [00:37<00:00,  5.04it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 59.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(709182,) (709182,)\n",
      "0.8359518431093852 0.5235103898322786\n",
      "0.8212207867655975 0.6975936393948747\n",
      "皮尔逊系数 =  (0.27603540887507794, 0.0)\n",
      "3-th epoch | trian: acc=0.604658842086792 loss=0.8484348058700562 | val:  acc=0.6196871995925903 loss=0.8172434568405151\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 188/188 [00:37<00:00,  5.04it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 58.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(709182,) (709182,)\n",
      "0.7121951769785471 0.5681913398417135\n",
      "0.8212207867655975 0.6975936393948747\n",
      "皮尔逊系数 =  (0.31283638821558013, 0.0)\n",
      "4-th epoch | trian: acc=0.6248815655708313 loss=0.8192561864852905 | val:  acc=0.6378461122512817 loss=0.7927519083023071\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 188/188 [00:37<00:00,  5.04it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 59.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(709182,) (709182,)\n",
      "0.7344673158653209 0.5877491705009427\n",
      "0.8212207867655975 0.6975936393948747\n",
      "皮尔逊系数 =  (0.3206222783608332, 0.0)\n",
      "5-th epoch | trian: acc=0.6297562122344971 loss=0.8110337257385254 | val:  acc=0.6415757536888123 loss=0.786170482635498\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 188/188 [00:37<00:00,  5.04it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 59.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(709182,) (709182,)\n",
      "0.733946998090758 0.5831305428774306\n",
      "0.8212207867655975 0.6975936393948747\n",
      "皮尔逊系数 =  (0.32177876224373675, 0.0)\n",
      "6-th epoch | trian: acc=0.6310577392578125 loss=0.8082811832427979 | val:  acc=0.6416096091270447 loss=0.7857165932655334\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 188/188 [00:37<00:00,  5.05it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 59.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(709182,) (709182,)\n",
      "0.6931521104596563 0.5945535452903948\n",
      "0.8212207867655975 0.6975936393948747\n",
      "皮尔逊系数 =  (0.32731800928058513, 0.0)\n",
      "7-th epoch | trian: acc=0.6323778033256531 loss=0.8060788512229919 | val:  acc=0.642961859703064 loss=0.7845543026924133\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 188/188 [00:37<00:00,  5.04it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 59.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(709182,) (709182,)\n",
      "0.7157499767337581 0.5922068288334794\n",
      "0.8212207867655975 0.6975936393948747\n",
      "皮尔逊系数 =  (0.32916839956440364, 0.0)\n",
      "8-th epoch | trian: acc=0.633335292339325 loss=0.8040560483932495 | val:  acc=0.6438079476356506 loss=0.7816556096076965\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 188/188 [00:37<00:00,  5.04it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 59.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(709182,) (709182,)\n",
      "0.7189564878973239 0.6065893106140313\n",
      "0.8212207867655975 0.6975936393948747\n",
      "皮尔逊系数 =  (0.33214217558429093, 0.0)\n",
      "9-th epoch | trian: acc=0.6332198977470398 loss=0.8037219643592834 | val:  acc=0.6431183815002441 loss=0.7830707430839539\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "EPOCH = 10\n",
    "for epoch in range(EPOCH):\n",
    "    train_acc,train_loss = train(optimizer,train_dataloader)\n",
    "    val_acc,val_loss = evaluate(val_dataloader)\n",
    "    log = '{}-th epoch | trian: acc={} loss={} | val:  acc={} loss={}'\n",
    "    print(log.format(epoch,train_acc,train_loss,val_acc,val_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "b79aa07e",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model,'./model/pre.th')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "b3a90b31",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.37001546401198776\n"
     ]
    }
   ],
   "source": [
    "n=(sig.shape[0]-1)*sig.shape[1]\n",
    "print((sig[:-1]==sig[1:]).sum()/n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "86667d9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sign(data_loader):\n",
    "    sign = []\n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm.tqdm(data_loader):\n",
    "            x=batch\n",
    "            x=x.to(device)\n",
    "            out = model(x)\n",
    "            predict = torch.argmax(out, dim=-1)\n",
    "            sign.append(predict.reshape(x.shape[0],x.shape[1]))\n",
    "    sign=torch.cat(sign,dim=0).float()\n",
    "#     sign=torch.where(sign==0,0.5,sign)\n",
    "#     sign=torch.where(sign==2,-1,sign)\n",
    "    print(sign.shape)\n",
    "    return sign"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "1b39dc01",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 268/268 [00:04<00:00, 56.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([34259, 207])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "SIGN = get_sign(all_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "bb047d53",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0., 2., 2., 1., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1., 2., 1., 1., 1.,\n",
      "        1., 2., 2., 0., 0., 0., 0., 2., 1., 2., 1., 0., 1., 1., 0., 1., 1., 1.,\n",
      "        1., 1., 2., 1., 2., 1., 1., 2., 0., 0., 1., 1., 2., 2., 2., 1., 1., 1.,\n",
      "        1., 2., 2., 0., 1., 0., 2., 2., 0., 0., 1., 0., 1., 1., 0., 0., 1., 2.,\n",
      "        1., 1., 1., 2., 0., 2., 0.], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "print(SIGN[50][20:99])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "f7ba0fc0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2391368])\n",
      "tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "print(torch.where(SIGN==0)[0].shape)\n",
    "print(SIGN[torch.where(SIGN==0)])\n",
    "SIGN[torch.where(SIGN==0)]=0.5\n",
    "SIGN[torch.where(SIGN==2)]=-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "bc815417",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0.5000, -1.0000, -1.0000,  1.0000,  1.0000,  0.5000,  0.5000,  1.0000,\n",
      "         0.5000,  1.0000,  0.5000,  0.5000,  0.5000,  1.0000, -1.0000,  1.0000,\n",
      "         1.0000,  1.0000,  1.0000, -1.0000, -1.0000,  0.5000,  0.5000,  0.5000,\n",
      "         0.5000, -1.0000,  1.0000, -1.0000,  1.0000,  0.5000,  1.0000,  1.0000,\n",
      "         0.5000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000, -1.0000,  1.0000,\n",
      "        -1.0000,  1.0000,  1.0000, -1.0000,  0.5000,  0.5000,  1.0000,  1.0000,\n",
      "        -1.0000, -1.0000, -1.0000,  1.0000,  1.0000,  1.0000,  1.0000, -1.0000,\n",
      "        -1.0000,  0.5000,  1.0000,  0.5000, -1.0000, -1.0000,  0.5000,  0.5000,\n",
      "         1.0000,  0.5000,  1.0000,  1.0000,  0.5000,  0.5000,  1.0000, -1.0000,\n",
      "         1.0000,  1.0000,  1.0000, -1.0000,  0.5000, -1.0000,  0.5000],\n",
      "       device='cuda:0')\n",
      "torch.Size([34259, 207])\n"
     ]
    }
   ],
   "source": [
    "print(SIGN[50][20:99])\n",
    "print(SIGN.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0305c916",
   "metadata": {},
   "source": [
    "## 计算最终的M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "7a57cf38",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([34259, 1722])\n",
      "torch.Size([34259, 1722])\n"
     ]
    }
   ],
   "source": [
    "edge_start_sign=SIGN[:,edge_index[0]].to('cpu')\n",
    "edge_c=torch.FloatTensor(DIFF[TIME_LEN+1:])\n",
    "print(edge_start_sign.shape)\n",
    "print(edge_c.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "d5af40bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "M=edge_start_sign/edge_c #改成除法看看效果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "eb3d0548",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 2., -1., -1.,  ...,  1.,  1., -2.])\n"
     ]
    }
   ],
   "source": [
    "print(M[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "2578161d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([2., 0., 0.,  ..., 1., 1., 0.])\n"
     ]
    }
   ],
   "source": [
    "M_pos=M.clone()\n",
    "M_pos[torch.where(M<0)]=0\n",
    "print(M_pos[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "d3321f6d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0., 1., 1.,  ..., 0., 0., 2.])\n"
     ]
    }
   ],
   "source": [
    "M_neg=M.clone()\n",
    "M_neg[torch.where(M>0)]=0\n",
    "#x\n",
    "M_neg=-M_neg#尴尬了一开始没加这个\n",
    "M_neg[torch.where(M_neg<=0)]=0\n",
    "print(M_neg[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "1c6f5740",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0625 kB\n",
      "torch.Size([34259, 1722])\n"
     ]
    }
   ],
   "source": [
    "print(sys.getsizeof(M) / 1024 , 'kB')\n",
    "print(M_neg.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "add32902",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(M_pos,'/home/hqh/DataSetFile/metrla/M_pos.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "0b6cfbfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(M_neg,'/home/hqh/DataSetFile/metrla/M_neg.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "6ac3ee09",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 2.0000, -1.0000, -1.0000,  ...,  1.0000,  1.0000, -2.0000],\n",
      "        [ 2.0000,  1.0000,  2.0000,  ..., -0.5000, -0.5000,  1.0000],\n",
      "        [ 2.0000,  2.0000,  1.0000,  ...,  2.0000,  2.0000,  2.0000],\n",
      "        ...,\n",
      "        [ 2.0000,  1.0000,  1.0000,  ..., -0.5000, -0.5000,  1.0000],\n",
      "        [-2.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000,  2.0000],\n",
      "        [ 1.0000,  1.0000,  1.0000,  ...,  1.0000, -0.5000,  1.0000]])\n"
     ]
    }
   ],
   "source": [
    "print(M)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20251e1b",
   "metadata": {},
   "source": [
    "# 计算ND_h sig ND_T"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b53d586a",
   "metadata": {},
   "source": [
    "#### ND_sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "57fe8110",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_NDsig(data_loader):\n",
    "    sign = []\n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm.tqdm(data_loader):\n",
    "            x=batch\n",
    "            x=x.to(device)\n",
    "            out = model(x).softmax(dim=-1).cpu()\n",
    "            sign.append(out.reshape(x.shape[0],x.shape[1],3))\n",
    "    sign=torch.cat(sign,dim=0).float()\n",
    "    print(sign.shape)\n",
    "    return sign"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "213a53e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 268/268 [00:05<00:00, 51.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([34259, 207, 3])\n"
     ]
    }
   ],
   "source": [
    "ND_sig=get_NDsig(all_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "dfb74540",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1.6379e-01, 7.0144e-01, 1.3478e-01],\n",
      "        [4.5572e-01, 2.0645e-01, 3.3783e-01],\n",
      "        [5.5948e-02, 8.5607e-01, 8.7987e-02],\n",
      "        [3.9971e-01, 2.5025e-01, 3.5004e-01],\n",
      "        [5.0247e-01, 2.5922e-01, 2.3831e-01],\n",
      "        [6.3247e-01, 1.3343e-01, 2.3410e-01],\n",
      "        [3.4690e-01, 5.1856e-01, 1.3454e-01],\n",
      "        [2.9286e-01, 4.4679e-02, 6.6246e-01],\n",
      "        [1.6743e-01, 4.6367e-01, 3.6890e-01],\n",
      "        [2.6607e-01, 1.3483e-01, 5.9910e-01],\n",
      "        [3.0172e-01, 4.8378e-01, 2.1451e-01],\n",
      "        [6.1474e-01, 9.2751e-02, 2.9251e-01],\n",
      "        [3.8399e-01, 1.0411e-01, 5.1189e-01],\n",
      "        [4.2671e-01, 3.9013e-01, 1.8316e-01],\n",
      "        [2.5180e-01, 2.9837e-01, 4.4983e-01],\n",
      "        [5.0576e-01, 3.4452e-01, 1.4972e-01],\n",
      "        [1.5691e-01, 6.0039e-01, 2.4270e-01],\n",
      "        [3.3980e-01, 1.1243e-01, 5.4777e-01],\n",
      "        [2.0924e-01, 4.4655e-01, 3.4421e-01],\n",
      "        [2.4599e-01, 5.9985e-01, 1.5416e-01],\n",
      "        [5.6565e-01, 5.9508e-02, 3.7484e-01],\n",
      "        [1.5042e-01, 2.3513e-01, 6.1446e-01],\n",
      "        [3.0262e-01, 2.8798e-01, 4.0940e-01],\n",
      "        [2.1337e-01, 5.8078e-01, 2.0585e-01],\n",
      "        [4.9884e-01, 3.9405e-01, 1.0711e-01],\n",
      "        [3.2000e-01, 6.2363e-02, 6.1763e-01],\n",
      "        [1.4360e-01, 7.2756e-01, 1.2884e-01],\n",
      "        [3.3461e-01, 9.6662e-02, 5.6873e-01],\n",
      "        [3.4158e-02, 8.6570e-01, 1.0014e-01],\n",
      "        [1.1367e-01, 7.4036e-01, 1.4597e-01],\n",
      "        [1.9742e-02, 6.5370e-01, 3.2656e-01],\n",
      "        [2.4517e-01, 6.8336e-01, 7.1472e-02],\n",
      "        [3.8247e-01, 6.1926e-02, 5.5560e-01],\n",
      "        [5.0704e-01, 2.4750e-01, 2.4547e-01],\n",
      "        [4.3449e-01, 1.0334e-01, 4.6216e-01],\n",
      "        [2.4168e-01, 3.0361e-01, 4.5471e-01],\n",
      "        [4.9611e-02, 8.0209e-01, 1.4830e-01],\n",
      "        [1.2312e-01, 6.5566e-01, 2.2122e-01],\n",
      "        [9.0370e-02, 7.4168e-01, 1.6795e-01],\n",
      "        [4.9797e-01, 2.8954e-01, 2.1248e-01],\n",
      "        [3.0679e-01, 5.7287e-01, 1.2034e-01],\n",
      "        [2.0539e-01, 7.0732e-01, 8.7297e-02],\n",
      "        [2.8569e-01, 5.5712e-01, 1.5719e-01],\n",
      "        [2.1067e-01, 6.5757e-01, 1.3176e-01],\n",
      "        [4.2606e-01, 1.4665e-01, 4.2729e-01],\n",
      "        [9.2714e-02, 8.2237e-01, 8.4912e-02],\n",
      "        [3.1360e-01, 3.4443e-01, 3.4198e-01],\n",
      "        [2.2861e-01, 6.0894e-01, 1.6245e-01],\n",
      "        [1.3055e-01, 7.4103e-01, 1.2841e-01],\n",
      "        [3.0824e-01, 4.6862e-01, 2.2314e-01],\n",
      "        [1.6831e-01, 1.7970e-01, 6.5199e-01],\n",
      "        [2.3321e-01, 2.8771e-01, 4.7908e-01],\n",
      "        [1.1636e-01, 7.8338e-01, 1.0026e-01],\n",
      "        [2.6115e-01, 2.3377e-01, 5.0508e-01],\n",
      "        [3.4222e-01, 2.4937e-01, 4.0842e-01],\n",
      "        [4.1443e-01, 3.3083e-01, 2.5475e-01],\n",
      "        [1.2747e-01, 6.8811e-01, 1.8443e-01],\n",
      "        [4.3378e-01, 3.5727e-01, 2.0895e-01],\n",
      "        [3.7232e-01, 3.5705e-01, 2.7063e-01],\n",
      "        [2.5454e-01, 2.4934e-01, 4.9612e-01],\n",
      "        [3.1670e-01, 5.0155e-01, 1.8175e-01],\n",
      "        [5.3611e-02, 8.4031e-01, 1.0608e-01],\n",
      "        [1.1014e-01, 3.0721e-01, 5.8265e-01],\n",
      "        [5.2842e-01, 3.3829e-01, 1.3329e-01],\n",
      "        [5.2774e-01, 1.2702e-01, 3.4524e-01],\n",
      "        [5.1138e-01, 2.3980e-01, 2.4882e-01],\n",
      "        [2.9194e-01, 1.7598e-01, 5.3208e-01],\n",
      "        [1.7736e-01, 6.2589e-01, 1.9675e-01],\n",
      "        [4.2968e-01, 4.6585e-01, 1.0447e-01],\n",
      "        [3.7382e-01, 5.1734e-01, 1.0884e-01],\n",
      "        [3.7975e-01, 3.2636e-01, 2.9389e-01],\n",
      "        [5.3672e-01, 1.6792e-01, 2.9536e-01],\n",
      "        [3.2268e-01, 1.2090e-01, 5.5642e-01],\n",
      "        [1.2240e-01, 7.5743e-01, 1.2017e-01],\n",
      "        [6.5689e-01, 2.4234e-01, 1.0077e-01],\n",
      "        [2.5953e-01, 3.8850e-01, 3.5197e-01],\n",
      "        [6.7690e-02, 8.2765e-01, 1.0466e-01],\n",
      "        [2.8650e-01, 1.5839e-01, 5.5511e-01],\n",
      "        [8.6818e-02, 8.0717e-01, 1.0601e-01],\n",
      "        [6.4323e-02, 8.4244e-01, 9.3237e-02],\n",
      "        [2.7618e-01, 5.1458e-01, 2.0924e-01],\n",
      "        [4.2207e-01, 3.5235e-01, 2.2558e-01],\n",
      "        [3.0036e-01, 6.8619e-02, 6.3102e-01],\n",
      "        [1.5780e-01, 7.4265e-01, 9.9552e-02],\n",
      "        [2.0161e-01, 6.0283e-01, 1.9556e-01],\n",
      "        [2.7925e-01, 5.8564e-01, 1.3512e-01],\n",
      "        [4.1598e-01, 4.6213e-01, 1.2189e-01],\n",
      "        [5.3007e-01, 1.3700e-01, 3.3293e-01],\n",
      "        [2.3595e-01, 5.2398e-01, 2.4008e-01],\n",
      "        [4.9138e-01, 1.8990e-01, 3.1872e-01],\n",
      "        [3.6363e-02, 8.7185e-01, 9.1788e-02],\n",
      "        [2.4810e-01, 5.6104e-01, 1.9086e-01],\n",
      "        [5.6675e-01, 1.3797e-01, 2.9528e-01],\n",
      "        [1.1747e-01, 7.6842e-01, 1.1411e-01],\n",
      "        [4.6000e-01, 4.2665e-01, 1.1335e-01],\n",
      "        [1.8640e-01, 4.4351e-01, 3.7009e-01],\n",
      "        [7.8225e-01, 7.0684e-02, 1.4706e-01],\n",
      "        [3.6677e-01, 5.4024e-01, 9.2989e-02],\n",
      "        [3.7584e-01, 1.9675e-01, 4.2742e-01],\n",
      "        [2.4198e-01, 6.8411e-01, 7.3908e-02],\n",
      "        [3.8383e-01, 2.7236e-01, 3.4381e-01],\n",
      "        [9.4902e-02, 8.0538e-01, 9.9720e-02],\n",
      "        [3.6188e-01, 3.2628e-01, 3.1184e-01],\n",
      "        [2.1415e-01, 6.4078e-01, 1.4508e-01],\n",
      "        [4.6660e-01, 3.2297e-01, 2.1043e-01],\n",
      "        [7.1124e-05, 9.9688e-01, 3.0510e-03],\n",
      "        [7.1124e-05, 9.9688e-01, 3.0510e-03],\n",
      "        [5.2727e-01, 4.1210e-02, 4.3152e-01],\n",
      "        [8.3467e-02, 8.2540e-01, 9.1132e-02],\n",
      "        [3.2048e-01, 2.7221e-01, 4.0731e-01],\n",
      "        [3.2569e-01, 5.1982e-01, 1.5449e-01],\n",
      "        [3.4708e-01, 9.8821e-02, 5.5409e-01],\n",
      "        [1.0338e-01, 7.7083e-01, 1.2579e-01],\n",
      "        [4.3259e-01, 1.6600e-01, 4.0141e-01],\n",
      "        [3.6992e-01, 3.6882e-01, 2.6126e-01],\n",
      "        [6.8790e-02, 7.4169e-01, 1.8952e-01],\n",
      "        [9.3286e-02, 7.8389e-01, 1.2283e-01],\n",
      "        [4.0063e-01, 4.8519e-01, 1.1418e-01],\n",
      "        [1.3647e-01, 6.7666e-01, 1.8687e-01],\n",
      "        [8.5492e-02, 6.2603e-01, 2.8848e-01],\n",
      "        [2.6869e-01, 2.8060e-01, 4.5071e-01],\n",
      "        [1.3119e-01, 5.8132e-01, 2.8748e-01],\n",
      "        [8.8024e-02, 8.2788e-01, 8.4099e-02],\n",
      "        [4.9581e-01, 3.7337e-01, 1.3081e-01],\n",
      "        [2.3482e-01, 5.8364e-01, 1.8153e-01],\n",
      "        [1.9726e-01, 4.5305e-01, 3.4969e-01],\n",
      "        [1.4298e-01, 4.1707e-01, 4.3995e-01],\n",
      "        [5.1470e-01, 2.6795e-01, 2.1735e-01],\n",
      "        [1.6461e-01, 7.5621e-01, 7.9176e-02],\n",
      "        [1.7994e-01, 6.1387e-01, 2.0619e-01],\n",
      "        [4.9881e-01, 1.8762e-01, 3.1357e-01],\n",
      "        [4.9357e-01, 1.8654e-01, 3.1989e-01],\n",
      "        [9.5603e-02, 7.4205e-01, 1.6235e-01],\n",
      "        [8.6579e-02, 7.8832e-01, 1.2510e-01],\n",
      "        [2.0802e-01, 6.9558e-01, 9.6404e-02],\n",
      "        [3.0834e-01, 4.8153e-01, 2.1013e-01],\n",
      "        [3.0118e-01, 6.2295e-02, 6.3652e-01],\n",
      "        [1.1229e-01, 7.9120e-01, 9.6515e-02],\n",
      "        [5.9702e-01, 2.2743e-01, 1.7555e-01],\n",
      "        [3.8644e-01, 3.1191e-01, 3.0164e-01],\n",
      "        [2.2630e-01, 5.7412e-01, 1.9957e-01],\n",
      "        [3.2933e-01, 5.2985e-01, 1.4082e-01],\n",
      "        [2.1097e-01, 5.7208e-01, 2.1695e-01],\n",
      "        [1.6252e-01, 6.6259e-01, 1.7489e-01],\n",
      "        [2.5940e-01, 5.9276e-01, 1.4783e-01],\n",
      "        [2.1381e-01, 6.2529e-01, 1.6090e-01],\n",
      "        [4.6381e-01, 2.7883e-01, 2.5736e-01],\n",
      "        [3.0515e-01, 2.6150e-01, 4.3335e-01],\n",
      "        [2.2079e-01, 5.9313e-01, 1.8608e-01],\n",
      "        [7.5467e-01, 1.1672e-01, 1.2861e-01],\n",
      "        [1.1881e-01, 7.5815e-01, 1.2304e-01],\n",
      "        [1.9438e-01, 2.9614e-01, 5.0948e-01],\n",
      "        [2.4463e-01, 2.8390e-01, 4.7147e-01],\n",
      "        [2.7088e-01, 6.1200e-01, 1.1712e-01],\n",
      "        [3.6579e-01, 4.9272e-01, 1.4149e-01],\n",
      "        [4.2592e-01, 2.2340e-01, 3.5068e-01],\n",
      "        [2.3408e-01, 5.3009e-01, 2.3583e-01],\n",
      "        [2.6358e-01, 6.3246e-01, 1.0396e-01],\n",
      "        [4.6860e-01, 2.2677e-01, 3.0463e-01],\n",
      "        [5.1520e-01, 8.6012e-02, 3.9879e-01],\n",
      "        [5.0712e-01, 1.1270e-01, 3.8018e-01],\n",
      "        [3.5473e-01, 5.5236e-01, 9.2913e-02],\n",
      "        [3.7368e-01, 5.3841e-01, 8.7912e-02],\n",
      "        [6.9359e-01, 1.9555e-01, 1.1086e-01],\n",
      "        [2.8757e-01, 5.9555e-01, 1.1688e-01],\n",
      "        [1.1575e-01, 6.3366e-01, 2.5059e-01],\n",
      "        [7.7961e-01, 1.2057e-01, 9.9820e-02],\n",
      "        [3.9608e-01, 4.6332e-01, 1.4060e-01],\n",
      "        [3.3675e-01, 5.6951e-01, 9.3739e-02],\n",
      "        [4.6269e-01, 3.8568e-01, 1.5162e-01],\n",
      "        [6.2970e-02, 7.9887e-01, 1.3816e-01],\n",
      "        [1.2341e-01, 7.8123e-01, 9.5355e-02],\n",
      "        [2.0369e-01, 6.5286e-01, 1.4345e-01],\n",
      "        [1.1986e-01, 7.4934e-01, 1.3080e-01],\n",
      "        [3.3399e-01, 2.5186e-01, 4.1415e-01],\n",
      "        [4.1660e-01, 5.5061e-02, 5.2833e-01],\n",
      "        [2.9584e-01, 5.8197e-01, 1.2219e-01],\n",
      "        [3.8520e-01, 6.9598e-02, 5.4521e-01],\n",
      "        [3.5404e-01, 1.2055e-01, 5.2541e-01],\n",
      "        [1.2514e-01, 7.2989e-01, 1.4497e-01],\n",
      "        [7.6504e-02, 7.6014e-01, 1.6335e-01],\n",
      "        [4.1974e-01, 4.3642e-01, 1.4384e-01],\n",
      "        [2.4871e-01, 6.2175e-01, 1.2954e-01],\n",
      "        [6.7455e-01, 1.6822e-01, 1.5722e-01],\n",
      "        [2.2231e-01, 6.8280e-01, 9.4889e-02],\n",
      "        [2.1783e-01, 5.2760e-01, 2.5457e-01],\n",
      "        [1.9976e-01, 6.2067e-01, 1.7957e-01],\n",
      "        [3.9755e-01, 1.1181e-01, 4.9064e-01],\n",
      "        [6.1338e-01, 6.8088e-02, 3.1853e-01],\n",
      "        [9.4202e-02, 7.4245e-01, 1.6335e-01],\n",
      "        [1.6734e-01, 7.5678e-01, 7.5881e-02],\n",
      "        [6.0026e-01, 2.6247e-01, 1.3727e-01],\n",
      "        [2.4881e-01, 6.5586e-01, 9.5332e-02],\n",
      "        [3.4681e-01, 3.7692e-01, 2.7627e-01],\n",
      "        [1.2064e-01, 7.4870e-01, 1.3066e-01],\n",
      "        [7.2526e-01, 1.1695e-01, 1.5779e-01],\n",
      "        [1.7697e-01, 3.9835e-01, 4.2468e-01],\n",
      "        [1.5556e-01, 6.6546e-01, 1.7898e-01],\n",
      "        [6.1374e-01, 1.3239e-01, 2.5387e-01],\n",
      "        [6.9395e-02, 8.2351e-01, 1.0709e-01],\n",
      "        [1.1312e-01, 7.3825e-01, 1.4863e-01],\n",
      "        [3.0887e-01, 6.8996e-02, 6.2213e-01],\n",
      "        [2.3430e-01, 6.4455e-01, 1.2116e-01],\n",
      "        [1.0511e-01, 7.8405e-01, 1.1084e-01],\n",
      "        [1.2464e-01, 6.3498e-01, 2.4038e-01],\n",
      "        [5.2535e-01, 2.7001e-02, 4.4765e-01],\n",
      "        [4.0410e-01, 1.2039e-01, 4.7550e-01]])\n"
     ]
    }
   ],
   "source": [
    "print(ND_sig[0][:])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b18d6a08",
   "metadata": {},
   "source": [
    "#### ND_h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "e375a6da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0:00:00.549237\n",
      "0:00:01.144302\n",
      "0:00:01.320636\n"
     ]
    }
   ],
   "source": [
    "start_time = datetime.datetime.now()\n",
    "DIFF = data[:,edge_index[0]]-data[:,edge_index[1]]\n",
    "print(datetime.datetime.now()-start_time)\n",
    "DIFF = (DIFF+1)/(data[:,edge_index[0]]+1)\n",
    "print(datetime.datetime.now()-start_time)\n",
    "DIFF_ABS = np.absolute(DIFF)\n",
    "print(datetime.datetime.now()-start_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "6ad5a621",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4397064447402954 GB\n",
      "(34272, 1722)\n",
      "0.9486729411882304 6.329240266700626\n",
      "同配比例： 0.2273149097037189\n",
      "异配正比例： 0.4731735173744294\n",
      "异配负比例： 0.29951157292185165\n"
     ]
    }
   ],
   "source": [
    "eplision=0.02\n",
    "print(sys.getsizeof(DIFF) / 1024 / 1024/1024, 'GB')\n",
    "print(DIFF.shape)\n",
    "number = DIFF.shape[0]*DIFF.shape[1]\n",
    "# print( len(DIFF[np.isnan(DIFF)])/number)\n",
    "# print( len(DIFF[np.isinf(DIFF)])/number)\n",
    "print(DIFF_ABS.mean(),DIFF_ABS.std())\n",
    "print('同配比例：',((DIFF_ABS<=eplision).sum())/number)\n",
    "print('异配正比例：',((DIFF>eplision).sum())/number)\n",
    "print('异配负比例：',((DIFF<-eplision).sum())/number)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "a4b02d58",
   "metadata": {},
   "outputs": [],
   "source": [
    "DIFF[DIFF>eplision]=1\n",
    "DIFF[DIFF_ABS<=eplision]=0\n",
    "DIFF[DIFF<-eplision]=2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "f8c7d6ac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 2., 1., ..., 0., 2., 0.])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "DIFF[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "6f8826f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "edge_spatial_types=torch.LongTensor(DIFF[TIME_LEN+1:,])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "629efb0a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([34259, 1722])"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edge_spatial_types.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "612d0d75",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 0, 2, 1, 2, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1,\n",
       "        0, 1, 1, 0, 0, 1, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2,\n",
       "        2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0,\n",
       "        1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 2, 1, 1,\n",
       "        1, 2, 2, 0, 0, 2, 2, 2, 0, 2, 2, 1, 2, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,\n",
       "        0, 0, 0, 2, 1, 1, 2, 2, 0, 2, 0, 2, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1,\n",
       "        1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2,\n",
       "        2, 2, 2, 2, 2, 2, 1, 2, 1, 1, 1, 1, 0, 2, 1, 0, 1, 1, 2, 2, 0, 1, 2, 0,\n",
       "        2, 2, 0, 2, 1, 1, 1, 2, 2, 2, 0, 2, 2, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1,\n",
       "        1, 1, 0, 1, 0, 2, 2, 0, 2, 2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "        0, 0, 1, 0, 0, 1, 1, 2, 1, 1, 1])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edge_spatial_types[-1][1:252]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "3a19700c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 34259/34259 [23:23<00:00, 24.41it/s]\n"
     ]
    }
   ],
   "source": [
    "ND_h=[]\n",
    "for edge_type in tqdm.tqdm(edge_spatial_types):\n",
    "    ND_h_=torch.zeros(ND_sig.shape[1],3)\n",
    "    for i in range(len(edge_type)):#只能串行计算\n",
    "        ND_h_[edge_index[0][i],edge_type[i]]+=1\n",
    "    ND_h.append(ND_h_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "895e9b2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([34259, 207, 3])\n"
     ]
    }
   ],
   "source": [
    "ND_h_torch=torch.stack(ND_h,dim=0)\n",
    "print(ND_h_torch.shape)\n",
    "# print(ND_h_torch[2][:])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d138d2f6",
   "metadata": {},
   "source": [
    "### ND_t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "89640d93",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([34259, 207, 3])\n"
     ]
    }
   ],
   "source": [
    "degree = ND_h_torch.sum(dim=-1).unsqueeze(-1).repeat_interleave(3,dim=-1)\n",
    "print(degree.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "5117ebb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "ND_t = ND_h_torch*ND_sig/degree/degree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "6e5cfecd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1.1374e-03, 9.7422e-03, 8.4236e-03],\n",
      "        [7.1206e-03, 3.2258e-03, 3.1672e-02],\n",
      "        [1.3814e-03, 4.2275e-02, 3.2588e-03],\n",
      "        [4.0787e-03, 1.2768e-03, 1.9645e-02],\n",
      "        [1.9628e-02, 6.0754e-03, 0.0000e+00],\n",
      "        [3.7424e-03, 2.3685e-03, 1.2467e-02],\n",
      "        [2.1239e-02, 4.2331e-02, 0.0000e+00],\n",
      "        [7.4708e-03, 9.1181e-04, 1.6900e-02],\n",
      "        [3.4170e-03, 4.7313e-02, 7.5286e-03],\n",
      "        [1.0643e-02, 2.1572e-02, 0.0000e+00],\n",
      "        [2.4137e-02, 3.8702e-02, 8.5803e-03],\n",
      "        [1.5179e-02, 8.0155e-03, 0.0000e+00],\n",
      "        [2.1551e-02, 1.0624e-03, 2.6117e-03],\n",
      "        [4.2671e-03, 3.5112e-02, 0.0000e+00],\n",
      "        [1.5737e-02, 4.6620e-03, 2.1086e-02],\n",
      "        [2.4976e-02, 2.1267e-02, 0.0000e+00],\n",
      "        [0.0000e+00, 1.6631e-03, 1.2101e-02],\n",
      "        [4.0213e-03, 0.0000e+00, 3.5654e-02],\n",
      "        [2.3249e-02, 4.9616e-02, 3.8246e-02],\n",
      "        [6.0989e-03, 3.4702e-02, 1.2741e-03],\n",
      "        [1.5712e-02, 2.0662e-03, 7.8092e-03],\n",
      "        [0.0000e+00, 4.4087e-02, 3.8403e-02],\n",
      "        [4.7285e-03, 0.0000e+00, 4.4778e-02],\n",
      "        [4.0007e-02, 3.6299e-02, 0.0000e+00],\n",
      "        [1.6491e-02, 2.2796e-02, 0.0000e+00],\n",
      "        [6.5307e-03, 6.3636e-03, 1.2605e-02],\n",
      "        [1.4360e-01, 0.0000e+00, 0.0000e+00],\n",
      "        [4.1310e-03, 0.0000e+00, 5.6170e-02],\n",
      "        [4.7442e-04, 6.0118e-02, 0.0000e+00],\n",
      "        [2.3198e-03, 3.0219e-02, 1.1916e-02],\n",
      "        [2.1936e-03, 0.0000e+00, 7.2569e-02],\n",
      "        [3.0646e-02, 8.5420e-02, 0.0000e+00],\n",
      "        [1.5842e-02, 1.8321e-03, 3.2876e-03],\n",
      "        [1.0563e-02, 1.5469e-02, 0.0000e+00],\n",
      "        [1.7380e-02, 8.2674e-03, 3.6973e-02],\n",
      "        [0.0000e+00, 1.8976e-02, 8.5258e-02],\n",
      "        [1.8375e-03, 5.9414e-02, 0.0000e+00],\n",
      "        [3.0401e-03, 2.4284e-02, 1.0924e-02],\n",
      "        [4.0165e-04, 1.6482e-02, 6.7179e-03],\n",
      "        [8.2310e-03, 9.5717e-03, 8.7804e-03],\n",
      "        [3.1305e-03, 8.7684e-03, 5.5260e-03],\n",
      "        [2.0539e-03, 6.3658e-02, 0.0000e+00],\n",
      "        [7.0831e-03, 1.8417e-02, 5.1964e-03],\n",
      "        [1.6854e-02, 7.8909e-02, 0.0000e+00],\n",
      "        [1.0520e-02, 7.2422e-03, 1.5825e-02],\n",
      "        [2.3179e-02, 2.0559e-01, 0.0000e+00],\n",
      "        [3.4844e-02, 7.6539e-02, 0.0000e+00],\n",
      "        [3.5721e-03, 9.5147e-03, 1.5230e-02],\n",
      "        [2.7198e-03, 1.5438e-02, 5.3506e-03],\n",
      "        [1.9265e-02, 2.9289e-02, 2.7893e-02],\n",
      "        [6.7323e-03, 2.1564e-02, 2.6080e-02],\n",
      "        [0.0000e+00, 3.1967e-02, 1.0646e-01],\n",
      "        [1.3963e-02, 6.2671e-02, 0.0000e+00],\n",
      "        [2.6115e-03, 9.3507e-03, 2.5254e-02],\n",
      "        [4.2249e-03, 3.0786e-03, 3.5295e-02],\n",
      "        [2.4522e-03, 5.8727e-03, 1.3566e-02],\n",
      "        [3.5407e-03, 0.0000e+00, 2.5615e-02],\n",
      "        [4.3378e-03, 1.0718e-02, 1.2537e-02],\n",
      "        [4.5965e-03, 2.6448e-02, 6.6822e-03],\n",
      "        [5.1947e-03, 2.5443e-02, 1.0125e-02],\n",
      "        [2.4742e-02, 1.5673e-02, 2.8399e-03],\n",
      "        [6.7014e-03, 0.0000e+00, 1.3260e-02],\n",
      "        [0.0000e+00, 8.5336e-03, 8.0924e-02],\n",
      "        [8.2565e-03, 3.7001e-02, 0.0000e+00],\n",
      "        [1.4659e-02, 1.7642e-02, 0.0000e+00],\n",
      "        [1.2784e-01, 5.9950e-02, 0.0000e+00],\n",
      "        [3.2437e-02, 3.9107e-02, 0.0000e+00],\n",
      "        [3.6196e-03, 7.6640e-02, 0.0000e+00],\n",
      "        [2.1484e-02, 1.8634e-02, 1.0447e-03],\n",
      "        [7.6290e-03, 5.2790e-02, 2.2212e-03],\n",
      "        [5.9336e-03, 1.5298e-02, 1.8368e-02],\n",
      "        [2.5159e-02, 1.3119e-02, 0.0000e+00],\n",
      "        [2.0167e-02, 0.0000e+00, 1.0433e-01],\n",
      "        [5.7375e-03, 3.5504e-02, 3.7553e-03],\n",
      "        [1.6287e-02, 1.6023e-02, 0.0000e+00],\n",
      "        [7.2092e-03, 0.0000e+00, 4.8885e-02],\n",
      "        [2.7076e-03, 0.0000e+00, 1.6745e-02],\n",
      "        [2.3677e-03, 3.9270e-03, 3.2114e-02],\n",
      "        [1.0718e-03, 7.9721e-02, 0.0000e+00],\n",
      "        [3.9382e-03, 6.8771e-02, 0.0000e+00],\n",
      "        [9.8636e-03, 1.8378e-02, 0.0000e+00],\n",
      "        [1.1724e-02, 1.7128e-02, 1.5665e-03],\n",
      "        [1.5018e-02, 1.3724e-03, 1.8931e-02],\n",
      "        [3.2203e-03, 4.5469e-02, 6.0950e-03],\n",
      "        [4.9779e-03, 7.4424e-03, 1.4486e-02],\n",
      "        [1.3090e-02, 4.5753e-02, 0.0000e+00],\n",
      "        [1.6639e-02, 7.3940e-02, 0.0000e+00],\n",
      "        [2.6176e-02, 6.7654e-03, 4.1103e-03],\n",
      "        [2.3595e-03, 1.0480e-02, 1.6805e-02],\n",
      "        [1.5356e-02, 1.7803e-02, 0.0000e+00],\n",
      "        [2.2727e-03, 1.0898e-01, 5.7368e-03],\n",
      "        [1.2658e-03, 5.7249e-03, 1.0712e-02],\n",
      "        [1.4169e-01, 3.4493e-02, 0.0000e+00],\n",
      "        [3.5240e-03, 5.3790e-02, 0.0000e+00],\n",
      "        [9.2001e-03, 1.7066e-02, 4.5340e-03],\n",
      "        [1.1650e-02, 2.7719e-02, 4.6261e-02],\n",
      "        [1.3886e-02, 4.1825e-03, 0.0000e+00],\n",
      "        [1.2125e-02, 2.6789e-02, 7.6851e-04],\n",
      "        [7.6702e-03, 4.0152e-03, 4.3614e-02],\n",
      "        [1.2099e-01, 0.0000e+00, 0.0000e+00],\n",
      "        [3.1986e-02, 7.5656e-03, 1.9100e-02],\n",
      "        [2.3433e-03, 2.9829e-02, 4.9245e-03],\n",
      "        [4.2826e-03, 1.9306e-03, 1.8452e-02],\n",
      "        [4.3703e-03, 2.6154e-02, 1.1843e-02],\n",
      "        [2.5922e-02, 3.5885e-02, 0.0000e+00],\n",
      "        [0.0000e+00, 1.6477e-02, 2.2693e-04],\n",
      "        [0.0000e+00, 2.0344e-02, 3.7359e-04],\n",
      "        [6.5909e-02, 5.1512e-03, 0.0000e+00],\n",
      "        [1.7034e-03, 1.0107e-01, 0.0000e+00],\n",
      "        [2.6486e-03, 4.4993e-03, 2.6930e-02],\n",
      "        [2.0104e-02, 1.9253e-02, 1.9073e-03],\n",
      "        [1.3883e-02, 4.9410e-03, 5.5409e-03],\n",
      "        [5.1690e-02, 0.0000e+00, 0.0000e+00],\n",
      "        [2.6485e-02, 1.3551e-02, 0.0000e+00],\n",
      "        [4.5669e-03, 3.6427e-02, 0.0000e+00],\n",
      "        [1.6985e-03, 9.1566e-03, 1.4039e-02],\n",
      "        [2.5913e-03, 8.7099e-02, 3.4118e-03],\n",
      "        [2.0440e-03, 3.2181e-02, 0.0000e+00],\n",
      "        [3.7908e-03, 3.7592e-02, 1.5573e-02],\n",
      "        [8.7237e-04, 0.0000e+00, 1.7662e-02],\n",
      "        [1.6793e-02, 0.0000e+00, 8.4509e-02],\n",
      "        [8.1995e-03, 7.2666e-02, 1.7968e-02],\n",
      "        [1.7605e-03, 1.6558e-02, 5.0460e-03],\n",
      "        [2.3241e-02, 2.9170e-02, 0.0000e+00],\n",
      "        [7.3382e-03, 2.7358e-02, 8.5094e-03],\n",
      "        [1.2329e-02, 0.0000e+00, 6.5567e-02],\n",
      "        [0.0000e+00, 4.1707e-01, 0.0000e+00],\n",
      "        [1.2182e-02, 1.2684e-02, 1.2861e-03],\n",
      "        [3.8961e-03, 4.0272e-02, 0.0000e+00],\n",
      "        [2.2492e-02, 7.6734e-02, 0.0000e+00],\n",
      "        [5.9857e-02, 1.5010e-02, 0.0000e+00],\n",
      "        [1.8280e-02, 4.6060e-03, 1.5797e-02],\n",
      "        [2.9876e-03, 4.6378e-02, 5.0734e-03],\n",
      "        [1.3528e-03, 3.6953e-02, 7.8188e-03],\n",
      "        [5.7783e-03, 7.7286e-02, 2.6779e-03],\n",
      "        [1.2334e-02, 2.8892e-02, 0.0000e+00],\n",
      "        [4.1831e-03, 4.3260e-03, 0.0000e+00],\n",
      "        [4.1588e-03, 4.8839e-02, 1.1915e-03],\n",
      "        [1.1940e-02, 1.8194e-02, 0.0000e+00],\n",
      "        [4.7709e-03, 1.5403e-02, 1.4896e-02],\n",
      "        [8.3817e-03, 1.4176e-02, 9.8555e-03],\n",
      "        [1.2197e-02, 1.3083e-02, 6.9543e-03],\n",
      "        [3.4872e-03, 1.4184e-02, 1.0758e-02],\n",
      "        [6.6334e-03, 6.7611e-02, 0.0000e+00],\n",
      "        [3.0699e-03, 1.0522e-02, 6.9979e-03],\n",
      "        [3.5341e-03, 3.1006e-02, 3.9892e-03],\n",
      "        [2.8396e-02, 1.1381e-02, 1.0505e-02],\n",
      "        [7.6287e-02, 6.5376e-02, 0.0000e+00],\n",
      "        [1.2266e-02, 6.5903e-02, 0.0000e+00],\n",
      "        [1.8867e-01, 2.9179e-02, 0.0000e+00],\n",
      "        [3.3003e-03, 4.2120e-02, 1.0253e-02],\n",
      "        [0.0000e+00, 2.9614e-03, 4.5853e-02],\n",
      "        [3.8224e-03, 8.8718e-03, 3.6834e-02],\n",
      "        [6.0195e-02, 6.8000e-02, 0.0000e+00],\n",
      "        [4.8772e-03, 0.0000e+00, 7.5463e-03],\n",
      "        [1.3251e-02, 6.9501e-03, 1.5586e-03],\n",
      "        [6.5023e-03, 0.0000e+00, 3.2754e-02],\n",
      "        [9.3579e-03, 2.2454e-02, 6.1514e-04],\n",
      "        [1.9127e-02, 1.3884e-02, 1.2434e-02],\n",
      "        [7.1555e-03, 5.9730e-03, 0.0000e+00],\n",
      "        [3.8738e-02, 0.0000e+00, 2.6402e-03],\n",
      "        [1.8919e-02, 2.4549e-03, 8.2589e-04],\n",
      "        [1.9900e-02, 6.3717e-03, 1.0404e-03],\n",
      "        [3.5387e-03, 1.2970e-02, 0.0000e+00],\n",
      "        [8.9865e-03, 4.6527e-02, 1.8263e-03],\n",
      "        [1.4468e-02, 7.9208e-02, 0.0000e+00],\n",
      "        [5.4140e-03, 9.2100e-03, 0.0000e+00],\n",
      "        [9.7797e-03, 2.8600e-02, 3.4716e-03],\n",
      "        [5.2618e-03, 6.2290e-02, 0.0000e+00],\n",
      "        [1.8508e-02, 6.1710e-02, 0.0000e+00],\n",
      "        [6.2970e-04, 3.1955e-02, 6.9079e-03],\n",
      "        [3.0598e-03, 2.5826e-02, 3.1522e-03],\n",
      "        [5.6581e-03, 9.0675e-02, 0.0000e+00],\n",
      "        [4.9527e-03, 1.8579e-02, 3.2430e-03],\n",
      "        [2.0874e-02, 1.5741e-02, 5.1769e-02],\n",
      "        [4.9302e-03, 3.5839e-03, 0.0000e+00],\n",
      "        [5.2516e-03, 2.7549e-02, 1.4461e-03],\n",
      "        [1.0700e-02, 0.0000e+00, 7.5723e-02],\n",
      "        [3.1470e-03, 2.1431e-03, 2.1016e-02],\n",
      "        [1.3905e-02, 0.0000e+00, 3.2215e-02],\n",
      "        [0.0000e+00, 8.4460e-02, 3.6301e-02],\n",
      "        [4.1974e-01, 0.0000e+00, 0.0000e+00],\n",
      "        [2.4871e-01, 0.0000e+00, 0.0000e+00],\n",
      "        [1.6864e-01, 4.2056e-02, 0.0000e+00],\n",
      "        [6.9472e-03, 2.1337e-02, 5.9306e-03],\n",
      "        [8.8909e-03, 4.3070e-02, 5.1954e-03],\n",
      "        [4.4392e-02, 6.8963e-02, 0.0000e+00],\n",
      "        [2.0188e-02, 1.3103e-03, 0.0000e+00],\n",
      "        [1.8777e-02, 2.4317e-03, 1.6252e-03],\n",
      "        [1.0467e-02, 1.6499e-01, 0.0000e+00],\n",
      "        [6.6935e-03, 1.2108e-01, 0.0000e+00],\n",
      "        [2.9179e-02, 9.1134e-03, 0.0000e+00],\n",
      "        [1.3250e-02, 1.5523e-02, 0.0000e+00],\n",
      "        [5.3084e-03, 2.1154e-02, 0.0000e+00],\n",
      "        [7.5401e-03, 4.6794e-02, 1.6333e-02],\n",
      "        [1.1332e-02, 5.4820e-03, 0.0000e+00],\n",
      "        [0.0000e+00, 6.5842e-03, 3.1588e-02],\n",
      "        [0.0000e+00, 6.6546e-03, 1.6108e-02],\n",
      "        [2.1919e-02, 2.7018e-03, 3.8857e-03],\n",
      "        [2.5702e-03, 2.0334e-02, 5.2886e-03],\n",
      "        [1.2569e-02, 1.6405e-01, 0.0000e+00],\n",
      "        [1.2607e-02, 1.4081e-03, 5.0786e-02],\n",
      "        [5.8574e-02, 0.0000e+00, 3.0289e-02],\n",
      "        [1.0511e-01, 0.0000e+00, 0.0000e+00],\n",
      "        [1.2464e-03, 6.3498e-03, 1.9231e-02],\n",
      "        [1.4593e-02, 3.7501e-03, 0.0000e+00],\n",
      "        [7.1734e-03, 0.0000e+00, 2.8136e-02]])\n"
     ]
    }
   ],
   "source": [
    "print(ND_t[0][:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "2ef40127",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(ND_t,'/home/hqh/DataSetFile/metrla/ND_t.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c098d5ad",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
