{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a9129cab-3252-4d7f-b8ab-3cf509c71612",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T21:38:58.030425Z",
     "iopub.status.busy": "2024-11-03T21:38:58.030208Z",
     "iopub.status.idle": "2024-11-03T21:38:58.264832Z",
     "shell.execute_reply": "2024-11-03T21:38:58.259683Z",
     "shell.execute_reply.started": "2024-11-03T21:38:58.030409Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import scipy.io as scio\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0a42c7b4-8416-44e4-9781-60f816613bb4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T21:38:58.541086Z",
     "iopub.status.busy": "2024-11-03T21:38:58.540755Z",
     "iopub.status.idle": "2024-11-03T21:38:58.765537Z",
     "shell.execute_reply": "2024-11-03T21:38:58.765038Z",
     "shell.execute_reply.started": "2024-11-03T21:38:58.541064Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(12602, 256)\n"
     ]
    }
   ],
   "source": [
    "dat = scio.loadmat('./speech_data/t12.2022.06.16_diagnosticBlocks.mat')\n",
    "data = dat['tx1']\n",
    "print(data.shape)\n",
    "label = dat['trialCues']\n",
    "time = dat['goTrialEpochs']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "00754050-d033-428b-af86-623a9c95c170",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T21:39:06.679384Z",
     "iopub.status.busy": "2024-11-03T21:39:06.679142Z",
     "iopub.status.idle": "2024-11-03T21:39:06.699982Z",
     "shell.execute_reply": "2024-11-03T21:39:06.699469Z",
     "shell.execute_reply.started": "2024-11-03T21:39:06.679364Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "64\n",
      "(64, 50, 128)\n",
      "64\n"
     ]
    }
   ],
   "source": [
    "label_list = []\n",
    "data_list = []\n",
    "for i in range(1,9):\n",
    "    # if i == 3:\n",
    "    #     continue\n",
    "    tot = 0\n",
    "    for l in range(len(label)):\n",
    "        if label[l][0] == i:\n",
    "            # if tot == 16:\n",
    "            #     continue\n",
    "            label_list.append(i)\n",
    "            temp = data[time[l][0]:time[l][0]+50,:128]\n",
    "            # data_list.append(np.mean(temp, axis = 0))\n",
    "            data_list.append(temp)\n",
    "            tot+=1\n",
    "            \n",
    "print(len(data_list))\n",
    "datas = np.array(data_list)\n",
    "print(datas.shape)\n",
    "\n",
    "print(len(label_list))\n",
    "labels = np.array(label_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "34c31694-d01d-4d44-85e6-582eb1db0a1a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T21:39:13.200575Z",
     "iopub.status.busy": "2024-11-03T21:39:13.200352Z",
     "iopub.status.idle": "2024-11-03T21:39:23.306974Z",
     "shell.execute_reply": "2024-11-03T21:39:23.306390Z",
     "shell.execute_reply.started": "2024-11-03T21:39:13.200557Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.859375\n"
     ]
    }
   ],
   "source": [
    "from sklearn.svm import LinearSVC\n",
    "tot = 0\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "nums = 64\n",
    "for i in range(nums):\n",
    "    test_data_loo, test_label_loo = datas.reshape(nums,-1)[i], labels[i]\n",
    "    traidx = [j for j in range(nums) if j != i]\n",
    "    train_data_loo, train_label_loo =  datas.reshape(nums,-1)[traidx], labels[traidx]\n",
    "    linearsvc = LinearSVC(C=1e8)\n",
    "    linearsvc.fit(train_data_loo, train_label_loo)\n",
    "    pred = linearsvc.predict(test_data_loo.reshape(1,-1))\n",
    "    if pred == test_label_loo:\n",
    "        tot += 1\n",
    "print(tot/(i+1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bc7d6c5a-ab82-4f22-b868-646ae6108609",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T21:39:33.706678Z",
     "iopub.status.busy": "2024-11-03T21:39:33.702470Z",
     "iopub.status.idle": "2024-11-03T21:39:33.709824Z",
     "shell.execute_reply": "2024-11-03T21:39:33.709353Z",
     "shell.execute_reply.started": "2024-11-03T21:39:33.706641Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "file_list = os.listdir('./speech_data')\n",
    "file_list = sorted(file_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "24cdfe57-19b3-4fac-9e50-bdc9c366995d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T21:39:39.998439Z",
     "iopub.status.busy": "2024-11-03T21:39:39.996531Z",
     "iopub.status.idle": "2024-11-03T21:39:40.004273Z",
     "shell.execute_reply": "2024-11-03T21:39:40.003595Z",
     "shell.execute_reply.started": "2024-11-03T21:39:39.998364Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['t12.2022.06.16_diagnosticBlocks.mat', 't12.2022.06.21_diagnosticBlocks.mat', 't12.2022.06.23_diagnosticBlocks.mat', 't12.2022.06.28_diagnosticBlocks.mat', 't12.2022.07.05_diagnosticBlocks.mat', 't12.2022.07.07_diagnosticBlocks.mat', 't12.2022.07.14_diagnosticBlocks.mat', 't12.2022.07.21_diagnosticBlocks.mat', 't12.2022.07.27_diagnosticBlocks.mat']\n"
     ]
    }
   ],
   "source": [
    "# file_list.remove('cross.ipynb')\n",
    "# file_list.remove('.ipynb_checkpoints')\n",
    "print(file_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0d4e62bd-a98a-4582-a0e7-5c43c3c900dc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T21:40:18.619010Z",
     "iopub.status.busy": "2024-11-03T21:40:18.618778Z",
     "iopub.status.idle": "2024-11-03T21:40:25.085792Z",
     "shell.execute_reply": "2024-11-03T21:40:25.084636Z",
     "shell.execute_reply.started": "2024-11-03T21:40:18.618991Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t12.2022.06.16_diagnosticBlocks.mat\n",
      "1.0\n",
      "t12.2022.06.21_diagnosticBlocks.mat\n",
      "0.671875\n",
      "t12.2022.06.23_diagnosticBlocks.mat\n",
      "0.578125\n",
      "t12.2022.06.28_diagnosticBlocks.mat\n",
      "0.65625\n",
      "t12.2022.07.05_diagnosticBlocks.mat\n",
      "0.421875\n",
      "t12.2022.07.07_diagnosticBlocks.mat\n",
      "0.359375\n",
      "t12.2022.07.14_diagnosticBlocks.mat\n",
      "0.34375\n",
      "t12.2022.07.21_diagnosticBlocks.mat\n",
      "0.28125\n",
      "t12.2022.07.27_diagnosticBlocks.mat\n",
      "0.140625\n"
     ]
    }
   ],
   "source": [
    "per_day_data = []\n",
    "per_day_label = []\n",
    "for f in sorted(file_list):\n",
    "    dat = scio.loadmat('./speech_data/' + f)\n",
    "    data = dat['tx1']\n",
    "    \n",
    "    label_list = []\n",
    "    data_list = []\n",
    "    for i in range(1,9):\n",
    "        # if i == 3:\n",
    "        #     continue\n",
    "        tot = 0\n",
    "        for l in range(len(label)):\n",
    "            if label[l][0] == i:\n",
    "                # if tot == 16:\n",
    "                #     continue\n",
    "                label_list.append(i)\n",
    "                temp = data[time[l][0]:time[l][0]+50,:128]\n",
    "                # data_list.append(np.mean(temp, axis = 0))\n",
    "                data_list.append(temp)\n",
    "                tot+=1\n",
    "                \n",
    "    datas2 = np.array(data_list)\n",
    "    per_day_data.append(datas2)\n",
    "\n",
    "    labels2 = np.array(label_list)\n",
    "    per_day_label.append(labels2)\n",
    "    tot = 0\n",
    "    for i in range(64):\n",
    "        pred = linearsvc.predict(datas2[i].reshape(1,-1))\n",
    "        if (pred == labels2[i]):\n",
    "            tot+=1\n",
    "    print(f)\n",
    "    print(tot/64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2c41e7a2-2ca3-4e78-9504-e6313f92b161",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T21:40:26.193536Z",
     "iopub.status.busy": "2024-11-03T21:40:26.193279Z",
     "iopub.status.idle": "2024-11-03T21:40:26.200842Z",
     "shell.execute_reply": "2024-11-03T21:40:26.199502Z",
     "shell.execute_reply.started": "2024-11-03T21:40:26.193518Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_data(f):\n",
    "\n",
    "    dat = scio.loadmat('./speech_data/' + f)\n",
    "    data = dat['tx1']\n",
    "    \n",
    "    label_list = []\n",
    "    data_list = []\n",
    "    for i in range(1,9):\n",
    "        # if i == 3:\n",
    "        #     continue\n",
    "        tot = 0\n",
    "        for l in range(len(label)):\n",
    "            if label[l][0] == i:\n",
    "                # if tot == 16:\n",
    "                #     continue\n",
    "                label_list.append(i)\n",
    "                temp = data[time[l][0]:time[l][0]+50,:128]\n",
    "                # data_list.append(np.mean(temp, axis = 0))\n",
    "                data_list.append(temp)\n",
    "                tot+=1\n",
    "                \n",
    "    datas2 = np.array(data_list)\n",
    "    \n",
    "    labels2 = np.array(label_list)\n",
    "    return datas2, labels2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6620f233-b917-42c2-8748-ced22ec75ea8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T21:40:35.868905Z",
     "iopub.status.busy": "2024-11-03T21:40:35.868494Z",
     "iopub.status.idle": "2024-11-03T21:40:36.289276Z",
     "shell.execute_reply": "2024-11-03T21:40:36.288475Z",
     "shell.execute_reply.started": "2024-11-03T21:40:35.868884Z"
    }
   },
   "outputs": [],
   "source": [
    "one, onel = get_data(file_list[0])\n",
    "two, twol = get_data(file_list[5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "676af4c0-1df5-415b-a45f-77f991703b4e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T21:40:51.461773Z",
     "iopub.status.busy": "2024-11-03T21:40:51.461546Z",
     "iopub.status.idle": "2024-11-03T21:40:51.501622Z",
     "shell.execute_reply": "2024-11-03T21:40:51.500753Z",
     "shell.execute_reply.started": "2024-11-03T21:40:51.461755Z"
    }
   },
   "outputs": [],
   "source": [
    "from scipy.signal import welch\n",
    "def compute_frequency_features(data):\n",
    "    fft_features = np.abs(np.fft.rfft(data, axis=-1))\n",
    "    angle_features = np.angle(np.fft.fft(data, axis=-1))\n",
    "    # psd_features = []\n",
    "    # for channel in data:\n",
    "    #     _, Pxx = welch(channel, nperseg=64)\n",
    "    #     psd_features.append(Pxx)\n",
    "\n",
    "    # psd_features = np.array(psd_features)\n",
    "    freq_features = np.concatenate((fft_features, angle_features), axis=-1)\n",
    "    return freq_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e3049187-ea94-4bba-ae97-597b0baba8c8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T21:40:55.943814Z",
     "iopub.status.busy": "2024-11-03T21:40:55.943143Z",
     "iopub.status.idle": "2024-11-03T21:40:55.954273Z",
     "shell.execute_reply": "2024-11-03T21:40:55.952948Z",
     "shell.execute_reply.started": "2024-11-03T21:40:55.943750Z"
    }
   },
   "outputs": [],
   "source": [
    "def squeeze(data):\n",
    "    data = data.reshape([64, 50])\n",
    "    temp = np.zeros([64,5])\n",
    "    for i in range(64):\n",
    "        for j in range(5):\n",
    "            temp[i,j] = np.sum(data[i,j*10:(j+1)*10])\n",
    "    return temp.reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "id": "3c793106-74ff-428d-a3cc-8b92cd42201f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-06T04:16:36.327256Z",
     "iopub.status.busy": "2024-11-06T04:16:36.327026Z",
     "iopub.status.idle": "2024-11-06T04:16:36.347676Z",
     "shell.execute_reply": "2024-11-06T04:16:36.346768Z",
     "shell.execute_reply.started": "2024-11-06T04:16:36.327239Z"
    }
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'pmd2' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[178], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m np\u001b[38;5;241m.\u001b[39mcount_nonzero(np\u001b[38;5;241m.\u001b[39misnan(pmd2))\n",
      "\u001b[0;31mNameError\u001b[0m: name 'pmd2' is not defined"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "af1b5c1f-0c14-4996-b3f2-4242f6db5555",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T23:27:43.877394Z",
     "iopub.status.busy": "2024-11-03T23:27:43.876819Z",
     "iopub.status.idle": "2024-11-03T23:31:37.315272Z",
     "shell.execute_reply": "2024-11-03T23:31:37.314543Z",
     "shell.execute_reply.started": "2024-11-03T23:27:43.877372Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t12.2022.06.16_diagnosticBlocks.mat\n",
      "t12.2022.06.21_diagnosticBlocks.mat\n",
      "raw:  0.71875\n",
      "begin:  0.671875\n",
      "after:  0.671875 49\n",
      "t12.2022.06.23_diagnosticBlocks.mat\n",
      "raw:  0.875\n",
      "begin:  0.578125\n",
      "after:  0.671875 49\n",
      "t12.2022.06.28_diagnosticBlocks.mat\n",
      "raw:  0.859375\n",
      "begin:  0.65625\n",
      "after:  0.65625 37\n",
      "t12.2022.07.05_diagnosticBlocks.mat\n",
      "raw:  0.84375\n",
      "begin:  0.421875\n",
      "after:  0.5 45\n",
      "t12.2022.07.07_diagnosticBlocks.mat\n",
      "raw:  0.734375\n",
      "begin:  0.359375\n",
      "after:  0.5625 53\n"
     ]
    }
   ],
   "source": [
    "print(file_list[0])\n",
    "idx_list = []\n",
    "false_id_list = []\n",
    "data_list = []\n",
    "false_data_list = []\n",
    "for d in range(1, 6):\n",
    "    print(file_list[d])\n",
    "    three, threel = get_data(file_list[d])\n",
    "\n",
    "    from sklearn.svm import LinearSVC\n",
    "    tot = 0\n",
    "    import warnings\n",
    "    warnings.filterwarnings(\"ignore\")\n",
    "    nums = 64\n",
    "    for i in range(nums):\n",
    "        test_data_loo, test_label_loo = three.reshape(nums,-1)[i], threel[i]\n",
    "        traidx = [j for j in range(nums) if j != i]\n",
    "        train_data_loo, train_label_loo =  three.reshape(nums,-1)[traidx], threel[traidx]\n",
    "        linearsvc0 = LinearSVC(C=1e8)\n",
    "        linearsvc0.fit(train_data_loo, train_label_loo)\n",
    "        pred = linearsvc0.predict(test_data_loo.reshape(1,-1))\n",
    "        if pred == test_label_loo:\n",
    "            tot += 1\n",
    "    print('raw: ',tot/(i+1))\n",
    "\n",
    "    tot = 0\n",
    "    for i in range(64):\n",
    "        pred = linearsvc.predict(three[i].reshape(1,-1))\n",
    "        if (pred == threel[i]):\n",
    "            tot+=1\n",
    "    \n",
    "    print('begin: ', tot/64)\n",
    "    \n",
    "    from scipy.stats import pearsonr\n",
    "    idx = [i for i in range(128)]\n",
    "    new_idx = []\n",
    "    false_idx = []\n",
    "    \n",
    "    idx = [i for i in range(128)]\n",
    "    for j in range(128):\n",
    "        temp = squeeze(one[:,:,j].reshape(-1))\n",
    "        data_list.append(squeeze(three[:,:,j]))\n",
    "        # print(temp.shape)\n",
    "        sim = []\n",
    "        for k in idx:\n",
    "            sim.append(pearsonr(squeeze(three[:,:,k].reshape(-1)), temp)[0])\n",
    "        sim_sorted = sorted(sim, reverse=True)\n",
    "        sim_id = sim.index(max(sim))\n",
    "        # if random.random() > 0.7:\n",
    "        #     sim_id = sim.index(sim_sorted[0])\n",
    "        # else:\n",
    "        #     sim_id = sim.index(sim_sorted[1])\n",
    "        # else:\n",
    "        #     sim_id = sim.index(sim_sorted[2])\n",
    "        # sim_id = sim.index(sim_sorted[0])\n",
    "        \n",
    "        new_idx.append(idx[sim_id])\n",
    "        sim_id = sim.index(sim_sorted[-1])\n",
    "        false_idx.append(idx[sim_id])\n",
    "        false_data_list.append(squeeze(three[:,:,idx[sim_id]]))\n",
    "        \n",
    "        # idx.remove(idx[sim_id])\n",
    "\n",
    "    changes = np.array(new_idx).reshape(128)\n",
    "    idx_list.append(changes)\n",
    "    new_three = three[:,:,changes]\n",
    "\n",
    "    changes = np.array(false_idx).reshape(128)\n",
    "    false_id_list.append(changes)\n",
    "    \n",
    "    tot = 0\n",
    "    for i in range(64):\n",
    "        pred = linearsvc.predict(new_three[i].reshape(1,-1))\n",
    "        if (pred == threel[i]):\n",
    "            tot+=1\n",
    "    \n",
    "    print('after: ',tot/64, len(np.unique(changes)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "3fd66573-077c-4b63-af6a-d9ffd8c93371",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T23:33:20.549481Z",
     "iopub.status.busy": "2024-11-03T23:33:20.548701Z",
     "iopub.status.idle": "2024-11-03T23:35:40.757212Z",
     "shell.execute_reply": "2024-11-03T23:35:40.756561Z",
     "shell.execute_reply.started": "2024-11-03T23:33:20.549413Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t12.2022.06.16_diagnosticBlocks.mat\n",
      "t12.2022.07.14_diagnosticBlocks.mat\n",
      "raw:  0.75\n",
      "begin:  0.34375\n",
      "after:  0.46875 48\n",
      "t12.2022.07.21_diagnosticBlocks.mat\n",
      "raw:  0.8125\n",
      "begin:  0.28125\n",
      "after:  0.6875 49\n",
      "t12.2022.07.27_diagnosticBlocks.mat\n",
      "raw:  0.640625\n",
      "begin:  0.140625\n",
      "after:  0.453125 50\n"
     ]
    }
   ],
   "source": [
    "print(file_list[0])\n",
    "test_idx_list = []\n",
    "test_false_id_list = []\n",
    "test_data_list = []\n",
    "test_false_data_list = []\n",
    "for d in range(6, len(file_list)):\n",
    "    print(file_list[d])\n",
    "    three, threel = get_data(file_list[d])\n",
    "\n",
    "    from sklearn.svm import LinearSVC\n",
    "    tot = 0\n",
    "    import warnings\n",
    "    warnings.filterwarnings(\"ignore\")\n",
    "    nums = 64\n",
    "    for i in range(nums):\n",
    "        test_data_loo, test_label_loo = three.reshape(nums,-1)[i], threel[i]\n",
    "        traidx = [j for j in range(nums) if j != i]\n",
    "        train_data_loo, train_label_loo =  three.reshape(nums,-1)[traidx], threel[traidx]\n",
    "        linearsvc0 = LinearSVC(C=1e8)\n",
    "        linearsvc0.fit(train_data_loo, train_label_loo)\n",
    "        pred = linearsvc0.predict(test_data_loo.reshape(1,-1))\n",
    "        if pred == test_label_loo:\n",
    "            tot += 1\n",
    "    print('raw: ',tot/(i+1))\n",
    "\n",
    "    tot = 0\n",
    "    for i in range(64):\n",
    "        pred = linearsvc.predict(three[i].reshape(1,-1))\n",
    "        if (pred == threel[i]):\n",
    "            tot+=1\n",
    "    \n",
    "    print('begin: ', tot/64)\n",
    "    \n",
    "    from scipy.stats import pearsonr\n",
    "    idx = [i for i in range(128)]\n",
    "    new_idx = []\n",
    "    false_idx = []\n",
    "    \n",
    "    idx = [i for i in range(128)]\n",
    "    for j in range(128):\n",
    "        temp = squeeze(one[:,:,j].reshape(-1))\n",
    "        test_data_list.append(squeeze(three[:,:,j]))\n",
    "        # print(temp.shape)\n",
    "        sim = []\n",
    "        for k in idx:\n",
    "            sim.append(pearsonr(squeeze(three[:,:,k].reshape(-1)), temp)[0])\n",
    "        sim_sorted = sorted(sim, reverse=True)\n",
    "        sim_id = sim.index(max(sim))\n",
    "        # if random.random() > 0.7:\n",
    "        #     sim_id = sim.index(sim_sorted[0])\n",
    "        # else:\n",
    "        #     sim_id = sim.index(sim_sorted[1])\n",
    "        # else:\n",
    "        #     sim_id = sim.index(sim_sorted[2])\n",
    "        # sim_id = sim.index(sim_sorted[0])\n",
    "        \n",
    "        new_idx.append(idx[sim_id])\n",
    "        sim_id = sim.index(sim_sorted[-1])\n",
    "        false_idx.append(idx[sim_id])\n",
    "        test_false_data_list.append(squeeze(three[:,:,idx[sim_id]]))\n",
    "        \n",
    "        # idx.remove(idx[sim_id])\n",
    "\n",
    "    changes = np.array(new_idx).reshape(128)\n",
    "    test_idx_list.append(changes)\n",
    "    new_three = three[:,:,changes]\n",
    "\n",
    "    changes = np.array(false_idx).reshape(128)\n",
    "    test_false_id_list.append(changes)\n",
    "    \n",
    "    tot = 0\n",
    "    for i in range(64):\n",
    "        pred = linearsvc.predict(new_three[i].reshape(1,-1))\n",
    "        if (pred == threel[i]):\n",
    "            tot+=1\n",
    "    \n",
    "    print('after: ',tot/64, len(np.unique(changes)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "e9cd5ea9-9ef8-4f7a-acd1-4107db54fb33",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T22:30:04.680714Z",
     "iopub.status.busy": "2024-11-03T22:30:04.673708Z",
     "iopub.status.idle": "2024-11-03T22:30:09.793810Z",
     "shell.execute_reply": "2024-11-03T22:30:09.793219Z",
     "shell.execute_reply.started": "2024-11-03T22:30:04.680503Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t12.2022.06.16_diagnosticBlocks.mat\n",
      "t12.2022.06.21_diagnosticBlocks.mat\n",
      "t12.2022.06.23_diagnosticBlocks.mat\n",
      "t12.2022.06.28_diagnosticBlocks.mat\n",
      "t12.2022.07.05_diagnosticBlocks.mat\n",
      "t12.2022.07.07_diagnosticBlocks.mat\n",
      "t12.2022.07.14_diagnosticBlocks.mat\n",
      "t12.2022.07.21_diagnosticBlocks.mat\n",
      "t12.2022.07.27_diagnosticBlocks.mat\n"
     ]
    }
   ],
   "source": [
    "print(file_list[0])\n",
    "data_list = []\n",
    "for d in range(1, len(file_list)):\n",
    "    print(file_list[d])\n",
    "    three, threel = get_data(file_list[d])\n",
    "    \n",
    "    from scipy.stats import pearsonr\n",
    "    idx = [i for i in range(128)]\n",
    "    new_idx = []\n",
    "    \n",
    "    idx = [i for i in range(128)]\n",
    "    for j in range(128):\n",
    "        temp = squeeze(three[:,:,j])\n",
    "\n",
    "        data_list.append(temp)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "5b136a63-f57a-498e-bb2e-a9d696a436d3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T23:38:44.521301Z",
     "iopub.status.busy": "2024-11-03T23:38:44.520231Z",
     "iopub.status.idle": "2024-11-03T23:38:44.529809Z",
     "shell.execute_reply": "2024-11-03T23:38:44.529186Z",
     "shell.execute_reply.started": "2024-11-03T23:38:44.521275Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1280, 320)"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(data_list).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "2078310a-e45e-460a-8da8-20f0c331ca66",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T23:35:51.367372Z",
     "iopub.status.busy": "2024-11-03T23:35:51.367106Z",
     "iopub.status.idle": "2024-11-03T23:35:51.370487Z",
     "shell.execute_reply": "2024-11-03T23:35:51.369963Z",
     "shell.execute_reply.started": "2024-11-03T23:35:51.367355Z"
    }
   },
   "outputs": [],
   "source": [
    "data_list.extend(false_data_list)\n",
    "\n",
    "idx_list.extend(false_id_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "1eec6dbb-1a8b-4c1d-8954-b0f9f11303fc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T23:35:51.942704Z",
     "iopub.status.busy": "2024-11-03T23:35:51.942375Z",
     "iopub.status.idle": "2024-11-03T23:35:51.946120Z",
     "shell.execute_reply": "2024-11-03T23:35:51.945672Z",
     "shell.execute_reply.started": "2024-11-03T23:35:51.942680Z"
    }
   },
   "outputs": [],
   "source": [
    "test_data_list.extend(test_false_data_list)\n",
    "\n",
    "test_idx_list.extend(test_false_id_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "27bde380-2b23-407f-86f4-3244ae758e81",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T23:37:53.583713Z",
     "iopub.status.busy": "2024-11-03T23:37:53.583466Z",
     "iopub.status.idle": "2024-11-03T23:37:53.587111Z",
     "shell.execute_reply": "2024-11-03T23:37:53.586616Z",
     "shell.execute_reply.started": "2024-11-03T23:37:53.583697Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6, 128)"
      ]
     },
     "execution_count": 77,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(test_idx_list).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "b89d65d9-53f4-47a2-9791-dd09dca179ac",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T23:38:50.781193Z",
     "iopub.status.busy": "2024-11-03T23:38:50.780965Z",
     "iopub.status.idle": "2024-11-03T23:38:50.784478Z",
     "shell.execute_reply": "2024-11-03T23:38:50.784081Z",
     "shell.execute_reply.started": "2024-11-03T23:38:50.781177Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "640.0"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1280/2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "51963fce-2227-4822-a92b-002068325925",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T23:38:08.352943Z",
     "iopub.status.busy": "2024-11-03T23:38:08.352654Z",
     "iopub.status.idle": "2024-11-03T23:38:08.359901Z",
     "shell.execute_reply": "2024-11-03T23:38:08.359388Z",
     "shell.execute_reply.started": "2024-11-03T23:38:08.352924Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(320,)"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_data_list[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "0dc7c90f-bf7c-4ce8-ba06-1464c0195d92",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T22:26:48.230058Z",
     "iopub.status.busy": "2024-11-03T22:26:48.229829Z",
     "iopub.status.idle": "2024-11-03T22:26:48.235835Z",
     "shell.execute_reply": "2024-11-03T22:26:48.235416Z",
     "shell.execute_reply.started": "2024-11-03T22:26:48.230042Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 41,  43,  87, ..., 103, 124,   3])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(idx_list).reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "75ef34fe-e513-4074-8869-2fe9b08b3e9c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T23:36:16.440637Z",
     "iopub.status.busy": "2024-11-03T23:36:16.439825Z",
     "iopub.status.idle": "2024-11-03T23:36:16.445652Z",
     "shell.execute_reply": "2024-11-03T23:36:16.445006Z",
     "shell.execute_reply.started": "2024-11-03T23:36:16.440595Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(idx_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "id": "032a9585-defe-46fd-aede-aa91c4979c9c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-04T00:44:19.774842Z",
     "iopub.status.busy": "2024-11-04T00:44:19.774498Z",
     "iopub.status.idle": "2024-11-04T00:44:20.755027Z",
     "shell.execute_reply": "2024-11-04T00:44:20.754467Z",
     "shell.execute_reply.started": "2024-11-04T00:44:19.774815Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/50, Loss: 0.7891\n",
      "Epoch 12/50, Loss: 0.4880\n",
      "Epoch 22/50, Loss: 0.4797\n",
      "Epoch 32/50, Loss: 0.4282\n",
      "Epoch 42/50, Loss: 0.3419\n",
      "Average Loss: 0.5834\n",
      "Average acc: 0.7422\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset, random_split\n",
    "import random\n",
    "\n",
    "def set_seed(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "# 设置一个固定的随机种子\n",
    "set_seed(42)\n",
    "\n",
    "class ChannelAlignmentModel(nn.Module):\n",
    "    def __init__(self, channel_count, signal_feature_size, embedding_dim):\n",
    "        super(ChannelAlignmentModel, self).__init__()\n",
    "        self.embedding = nn.Embedding(channel_count, embedding_dim)\n",
    "        self.signal_fc = nn.Linear(signal_feature_size, embedding_dim)\n",
    "        self.interaction_fc = nn.Linear(embedding_dim * 2, 1)\n",
    "\n",
    "    def forward(self, channel_idx, signal_features):\n",
    "        channel_embed = self.embedding(channel_idx)\n",
    "        signal_embed = torch.relu(self.signal_fc(signal_features))\n",
    "        # print(channel_embed.shape)\n",
    "        # print(signal_embed.shape)\n",
    "        interaction = torch.cat((channel_embed, signal_embed), dim=1)\n",
    "        match_score = torch.sigmoid(self.interaction_fc(interaction))\n",
    "        return match_score\n",
    "\n",
    "channel_count = 128 # 通道数量\n",
    "signal_feature_size = 320  # 信号特征维度\n",
    "embedding_dim = 128  # 嵌入维度\n",
    "\n",
    "channel_indices = torch.tensor(np.array(idx_list).reshape(-1))\n",
    "signal_features = torch.tensor(np.array(data_list), dtype=torch.float32)\n",
    "true = [1 for i in range(640)]\n",
    "false = [0 for i in range(640)]\n",
    "true.extend(false)\n",
    "y = torch.tensor(np.array(true), dtype=torch.float32)  # 匹配度\n",
    "\n",
    "train_dataset = TensorDataset(channel_indices, signal_features, y)\n",
    "\n",
    "channel_indices = torch.tensor(np.array(test_idx_list).reshape(-1))\n",
    "signal_features = torch.tensor(np.array(test_data_list), dtype=torch.float32)\n",
    "true = [1 for i in range(384)]\n",
    "false = [0 for i in range(384)]\n",
    "true.extend(false)\n",
    "test_y = torch.tensor(np.array(true), dtype=torch.float32)  # 匹配度\n",
    "\n",
    "test_dataset = TensorDataset(channel_indices, signal_features, test_y)\n",
    "# train_size = int(0.8 * len(dataset))\n",
    "# test_size = len(dataset) - train_size\n",
    "# train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
    "\n",
    "model = ChannelAlignmentModel(channel_count, signal_feature_size, embedding_dim)\n",
    "criterion = nn.BCELoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "def train(model, train_loader, criterion, optimizer, epochs=50):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        for channel_idx, signal_features, labels in train_loader:\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(channel_idx, signal_features)\n",
    "            loss = criterion(outputs.reshape(-1), labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        if epoch % 10 == 1:\n",
    "            print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}')\n",
    "\n",
    "# 测试模型\n",
    "def test(model, test_loader):\n",
    "    model.eval()\n",
    "    total_loss = 0\n",
    "    tot = 0\n",
    "    tot_all = 0\n",
    "    with torch.no_grad():\n",
    "        for channel_idx, signal_features, labels in test_loader:\n",
    "            outputs = model(channel_idx, signal_features)\n",
    "            loss = criterion(outputs.reshape(-1), labels)\n",
    "            total_loss += loss.item()\n",
    "            tot_all += len(labels)\n",
    "            for i in range(len(labels)):\n",
    "                if np.round(outputs.reshape(-1)[i]) == labels[i]:\n",
    "                    # print(np.round(outputs.reshape(-1)[i]), labels[i])\n",
    "                    tot  += 1\n",
    "    print(f'Average Loss: {total_loss / len(test_loader):.4f}')\n",
    "    print(f'Average acc: {tot / tot_all:.4f}')\n",
    "\n",
    "\n",
    "# 运行训练和测试\n",
    "train(model, train_loader, criterion, optimizer)\n",
    "test(model, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "178191b7-40b5-4314-adcd-eb9962bc4e6a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-03T22:22:55.346728Z",
     "iopub.status.busy": "2024-11-03T22:22:55.346391Z",
     "iopub.status.idle": "2024-11-03T22:22:55.358870Z",
     "shell.execute_reply": "2024-11-03T22:22:55.358325Z",
     "shell.execute_reply.started": "2024-11-03T22:22:55.346700Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "channel_indices = torch.randint(0, 128, (1000,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "id": "1c7144a6-137f-4340-a448-fbe69133e80c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-04T01:01:32.693235Z",
     "iopub.status.busy": "2024-11-04T01:01:32.693003Z",
     "iopub.status.idle": "2024-11-04T01:01:32.749654Z",
     "shell.execute_reply": "2024-11-04T01:01:32.749203Z",
     "shell.execute_reply.started": "2024-11-04T01:01:32.693218Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Loss: 0.5834\n",
      "Average acc: 0.7422\n"
     ]
    }
   ],
   "source": [
    "def test(model, test_loader):\n",
    "    model.eval()\n",
    "    total_loss = 0\n",
    "    tot = 0\n",
    "    tot_all = 0\n",
    "    with torch.no_grad():\n",
    "        for channel_idx, signal_features, labels in test_loader:\n",
    "            outputs = model(channel_idx, signal_features)\n",
    "            loss = criterion(outputs.reshape(-1), labels)\n",
    "            total_loss += loss.item()\n",
    "            tot_all += len(labels)\n",
    "            for i in range(len(labels)):\n",
    "                if np.round(outputs.reshape(-1)[i]) == labels[i]:\n",
    "                    # print(np.round(outputs.reshape(-1)[i]), labels[i])\n",
    "                    tot  += 1\n",
    "    print(f'Average Loss: {total_loss / len(test_loader):.4f}')\n",
    "    print(f'Average acc: {tot / tot_all:.4f}')\n",
    "\n",
    "\n",
    "test(model, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "ec699b92-7c21-4d96-ae36-325de845f64f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-04T00:28:53.217830Z",
     "iopub.status.busy": "2024-11-04T00:28:53.217533Z",
     "iopub.status.idle": "2024-11-04T00:28:53.224195Z",
     "shell.execute_reply": "2024-11-04T00:28:53.223487Z",
     "shell.execute_reply.started": "2024-11-04T00:28:53.217804Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(768, 320)"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(test_data_list).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "7fa8786c-6731-4c8f-9589-fcf72a4230e4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-04T00:32:31.042226Z",
     "iopub.status.busy": "2024-11-04T00:32:31.042037Z",
     "iopub.status.idle": "2024-11-04T00:32:31.045483Z",
     "shell.execute_reply": "2024-11-04T00:32:31.045086Z",
     "shell.execute_reply.started": "2024-11-04T00:32:31.042201Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6.0"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "768/128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "id": "f4ddf0d9-b685-4be5-9218-4ff303e10cc0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-04T11:21:36.630396Z",
     "iopub.status.busy": "2024-11-04T11:21:36.630042Z",
     "iopub.status.idle": "2024-11-04T11:21:41.762492Z",
     "shell.execute_reply": "2024-11-04T11:21:41.762009Z",
     "shell.execute_reply.started": "2024-11-04T11:21:36.630370Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor([[0.7800]]), tensor([[0.3916]]), tensor([[0.7331]]), tensor([[0.9738]]), tensor([[0.9579]]), tensor([[0.7606]]), tensor([[0.7900]]), tensor([[0.8009]]), tensor([[0.8128]]), tensor([[0.6912]])]\n",
      "[tensor([[0.8432]]), tensor([[0.4940]]), tensor([[0.8064]]), tensor([[0.9826]]), tensor([[0.9719]]), tensor([[0.8282]]), tensor([[0.8509]]), tensor([[0.8592]]), tensor([[0.8682]]), tensor([[0.7724]])]\n",
      "[tensor([[0.3487]]), tensor([[0.0886]]), tensor([[0.2932]]), tensor([[0.8490]]), tensor([[0.7746]]), tensor([[0.3243]]), tensor([[0.3624]]), tensor([[0.3780]]), tensor([[0.3960]]), tensor([[0.2526]])]\n",
      "[tensor([[0.8002]]), tensor([[0.4210]]), tensor([[0.7563]]), tensor([[0.9768]]), tensor([[0.9626]]), tensor([[0.7821]]), tensor([[0.8095]]), tensor([[0.8197]]), tensor([[0.8307]]), tensor([[0.7166]])]\n",
      "[tensor([[0.9161]]), tensor([[0.6648]]), tensor([[0.8943]]), tensor([[0.9914]]), tensor([[0.9859]]), tensor([[0.9073]]), tensor([[0.9206]]), tensor([[0.9254]]), tensor([[0.9305]]), tensor([[0.8733]])]\n",
      "[tensor([[0.7679]]), tensor([[0.3753]]), tensor([[0.7194]]), tensor([[0.9720]]), tensor([[0.9550]]), tensor([[0.7478]]), tensor([[0.7784]]), tensor([[0.7897]]), tensor([[0.8021]]), tensor([[0.6762]])]\n",
      "[tensor([[0.5542]]), tensor([[0.1841]]), tensor([[0.4906]]), tensor([[0.9289]]), tensor([[0.8887]]), tensor([[0.5270]]), tensor([[0.5689]]), tensor([[0.5852]]), tensor([[0.6036]]), tensor([[0.4397]])]\n",
      "[tensor([[0.9098]]), tensor([[0.6468]]), tensor([[0.8866]]), tensor([[0.9907]]), tensor([[0.9848]]), tensor([[0.9004]]), tensor([[0.9146]]), tensor([[0.9197]]), tensor([[0.9251]]), tensor([[0.8643]])]\n",
      "[tensor([[0.1605]]), tensor([[0.0335]]), tensor([[0.1290]]), tensor([[0.6675]]), tensor([[0.5510]]), tensor([[0.1463]]), tensor([[0.1686]]), tensor([[0.1782]]), tensor([[0.1897]]), tensor([[0.1077]])]\n",
      "[tensor([[0.4325]]), tensor([[0.1215]]), tensor([[0.3712]]), tensor([[0.8889]]), tensor([[0.8303]]), tensor([[0.4058]]), tensor([[0.4472]]), tensor([[0.4638]]), tensor([[0.4828]]), tensor([[0.3248]])]\n",
      "[tensor([[0.8871]]), tensor([[0.5880]]), tensor([[0.8590]]), tensor([[0.9880]]), tensor([[0.9806]]), tensor([[0.8757]]), tensor([[0.8930]]), tensor([[0.8992]]), tensor([[0.9059]]), tensor([[0.8323]])]\n",
      "[tensor([[0.9573]]), tensor([[0.8026]]), tensor([[0.9455]]), tensor([[0.9958]]), tensor([[0.9931]]), tensor([[0.9526]]), tensor([[0.9596]]), tensor([[0.9622]]), tensor([[0.9648]]), tensor([[0.9340]])]\n",
      "[tensor([[0.7692]]), tensor([[0.3769]]), tensor([[0.7208]]), tensor([[0.9722]]), tensor([[0.9553]]), tensor([[0.7492]]), tensor([[0.7796]]), tensor([[0.7909]]), tensor([[0.8032]]), tensor([[0.6778]])]\n",
      "[tensor([[0.9496]]), tensor([[0.7739]]), tensor([[0.9359]]), tensor([[0.9950]]), tensor([[0.9918]]), tensor([[0.9441]]), tensor([[0.9524]]), tensor([[0.9554]]), tensor([[0.9585]]), tensor([[0.9225]])]\n",
      "[tensor([[0.8134]]), tensor([[0.4417]]), tensor([[0.7715]]), tensor([[0.9786]]), tensor([[0.9655]]), tensor([[0.7962]]), tensor([[0.8222]]), tensor([[0.8318]]), tensor([[0.8422]]), tensor([[0.7334]])]\n",
      "[tensor([[0.9073]]), tensor([[0.6398]]), tensor([[0.8834]]), tensor([[0.9904]]), tensor([[0.9843]]), tensor([[0.8976]]), tensor([[0.9122]]), tensor([[0.9174]]), tensor([[0.9230]]), tensor([[0.8606]])]\n",
      "[tensor([[0.4760]]), tensor([[0.1416]]), tensor([[0.4131]]), tensor([[0.9051]]), tensor([[0.8536]]), tensor([[0.4488]]), tensor([[0.4909]]), tensor([[0.5076]]), tensor([[0.5267]]), tensor([[0.3644]])]\n",
      "[tensor([[0.2220]]), tensor([[0.0493]]), tensor([[0.1811]]), tensor([[0.7499]]), tensor([[0.6469]]), tensor([[0.2037]]), tensor([[0.2325]]), tensor([[0.2447]]), tensor([[0.2590]]), tensor([[0.1527]])]\n",
      "[tensor([[0.4762]]), tensor([[0.1417]]), tensor([[0.4133]]), tensor([[0.9052]]), tensor([[0.8537]]), tensor([[0.4490]]), tensor([[0.4911]]), tensor([[0.5078]]), tensor([[0.5269]]), tensor([[0.3646]])]\n",
      "[tensor([[0.6440]]), tensor([[0.2472]]), tensor([[0.5836]]), tensor([[0.9500]]), tensor([[0.9207]]), tensor([[0.6185]]), tensor([[0.6575]]), tensor([[0.6725]]), tensor([[0.6890]]), tensor([[0.5331]])]\n",
      "[tensor([[0.9302]]), tensor([[0.7076]]), tensor([[0.9117]]), tensor([[0.9929]]), tensor([[0.9884]]), tensor([[0.9228]]), tensor([[0.9340]]), tensor([[0.9380]]), tensor([[0.9423]]), tensor([[0.8938]])]\n",
      "[tensor([[0.8130]]), tensor([[0.4411]]), tensor([[0.7710]]), tensor([[0.9786]]), tensor([[0.9654]]), tensor([[0.7957]]), tensor([[0.8218]]), tensor([[0.8315]]), tensor([[0.8419]]), tensor([[0.7329]])]\n",
      "[tensor([[0.3802]]), tensor([[0.1002]]), tensor([[0.3221]]), tensor([[0.8656]]), tensor([[0.7975]]), tensor([[0.3548]]), tensor([[0.3943]]), tensor([[0.4104]]), tensor([[0.4290]]), tensor([[0.2791]])]\n",
      "[tensor([[0.1368]]), tensor([[0.0280]]), tensor([[0.1094]]), tensor([[0.6247]]), tensor([[0.5043]]), tensor([[0.1244]]), tensor([[0.1440]]), tensor([[0.1524]]), tensor([[0.1626]]), tensor([[0.0909]])]\n",
      "[tensor([[0.2913]]), tensor([[0.0694]]), tensor([[0.2415]]), tensor([[0.8119]]), tensor([[0.7252]]), tensor([[0.2692]]), tensor([[0.3038]]), tensor([[0.3181]]), tensor([[0.3349]]), tensor([[0.2060]])]\n",
      "[tensor([[0.5036]]), tensor([[0.1555]]), tensor([[0.4400]]), tensor([[0.9142]]), tensor([[0.8669]]), tensor([[0.4762]]), tensor([[0.5184]]), tensor([[0.5351]]), tensor([[0.5540]]), tensor([[0.3903]])]\n",
      "[tensor([[0.4670]]), tensor([[0.1372]]), tensor([[0.4043]]), tensor([[0.9020]]), tensor([[0.8490]]), tensor([[0.4398]]), tensor([[0.4818]]), tensor([[0.4986]]), tensor([[0.5176]]), tensor([[0.3561]])]\n",
      "[tensor([[0.3718]]), tensor([[0.0970]]), tensor([[0.3143]]), tensor([[0.8614]]), tensor([[0.7916]]), tensor([[0.3466]]), tensor([[0.3858]]), tensor([[0.4018]]), tensor([[0.4202]]), tensor([[0.2720]])]\n",
      "[tensor([[0.3935]]), tensor([[0.1054]]), tensor([[0.3345]]), tensor([[0.8720]]), tensor([[0.8064]]), tensor([[0.3676]]), tensor([[0.4077]]), tensor([[0.4240]]), tensor([[0.4427]]), tensor([[0.2905]])]\n",
      "[tensor([[0.8967]]), tensor([[0.6119]]), tensor([[0.8706]]), tensor([[0.9892]]), tensor([[0.9824]]), tensor([[0.8862]]), tensor([[0.9021]]), tensor([[0.9079]]), tensor([[0.9141]]), tensor([[0.8457]])]\n",
      "[tensor([[0.3682]]), tensor([[0.0957]]), tensor([[0.3111]]), tensor([[0.8596]]), tensor([[0.7891]]), tensor([[0.3431]]), tensor([[0.3822]]), tensor([[0.3981]]), tensor([[0.4165]]), tensor([[0.2689]])]\n",
      "[tensor([[0.6644]]), tensor([[0.2644]]), tensor([[0.6053]]), tensor([[0.9541]]), tensor([[0.9271]]), tensor([[0.6396]]), tensor([[0.6776]]), tensor([[0.6920]]), tensor([[0.7080]]), tensor([[0.5555]])]\n",
      "[tensor([[0.0559]]), tensor([[0.0106]]), tensor([[0.0439]]), tensor([[0.3834]]), tensor([[0.2754]]), tensor([[0.0504]]), tensor([[0.0591]]), tensor([[0.0630]]), tensor([[0.0676]]), tensor([[0.0360]])]\n",
      "[tensor([[0.6183]]), tensor([[0.2273]]), tensor([[0.5566]]), tensor([[0.9445]]), tensor([[0.9123]]), tensor([[0.5922]]), tensor([[0.6323]]), tensor([[0.6477]]), tensor([[0.6649]]), tensor([[0.5056]])]\n",
      "[tensor([[0.1567]]), tensor([[0.0326]]), tensor([[0.1259]]), tensor([[0.6613]]), tensor([[0.5441]]), tensor([[0.1428]]), tensor([[0.1648]]), tensor([[0.1742]]), tensor([[0.1854]]), tensor([[0.1050]])]\n",
      "[tensor([[0.7814]]), tensor([[0.3935]]), tensor([[0.7346]]), tensor([[0.9740]]), tensor([[0.9582]]), tensor([[0.7621]]), tensor([[0.7914]]), tensor([[0.8022]]), tensor([[0.8140]]), tensor([[0.6928]])]\n",
      "[tensor([[0.0753]]), tensor([[0.0146]]), tensor([[0.0593]]), tensor([[0.4609]]), tensor([[0.3432]]), tensor([[0.0680]]), tensor([[0.0795]]), tensor([[0.0846]]), tensor([[0.0907]]), tensor([[0.0489]])]\n",
      "[tensor([[0.4769]]), tensor([[0.1420]]), tensor([[0.4140]]), tensor([[0.9055]]), tensor([[0.8541]]), tensor([[0.4497]]), tensor([[0.4918]]), tensor([[0.5086]]), tensor([[0.5276]]), tensor([[0.3653]])]\n",
      "[tensor([[0.7124]]), tensor([[0.3102]]), tensor([[0.6574]]), tensor([[0.9630]]), tensor([[0.9408]]), tensor([[0.6894]]), tensor([[0.7244]]), tensor([[0.7376]]), tensor([[0.7521]]), tensor([[0.6099]])]\n",
      "[tensor([[0.2339]]), tensor([[0.0525]]), tensor([[0.1913]]), tensor([[0.7623]]), tensor([[0.6622]]), tensor([[0.2149]]), tensor([[0.2447]]), tensor([[0.2573]]), tensor([[0.2722]]), tensor([[0.1616]])]\n",
      "[tensor([[0.9228]]), tensor([[0.6845]]), tensor([[0.9025]]), tensor([[0.9921]]), tensor([[0.9871]]), tensor([[0.9146]]), tensor([[0.9269]]), tensor([[0.9313]]), tensor([[0.9361]]), tensor([[0.8830]])]\n",
      "[tensor([[0.6078]]), tensor([[0.2195]]), tensor([[0.5456]]), tensor([[0.9421]]), tensor([[0.9087]]), tensor([[0.5814]]), tensor([[0.6219]]), tensor([[0.6375]]), tensor([[0.6549]]), tensor([[0.4945]])]\n",
      "[tensor([[0.1743]]), tensor([[0.0369]]), tensor([[0.1405]]), tensor([[0.6892]]), tensor([[0.5754]]), tensor([[0.1591]]), tensor([[0.1830]]), tensor([[0.1933]]), tensor([[0.2054]]), tensor([[0.1176]])]\n",
      "[tensor([[0.8093]]), tensor([[0.4352]]), tensor([[0.7668]]), tensor([[0.9781]]), tensor([[0.9646]]), tensor([[0.7918]]), tensor([[0.8183]]), tensor([[0.8281]]), tensor([[0.8387]]), tensor([[0.7282]])]\n",
      "[tensor([[0.8903]]), tensor([[0.5956]]), tensor([[0.8627]]), tensor([[0.9884]]), tensor([[0.9812]]), tensor([[0.8791]]), tensor([[0.8959]]), tensor([[0.9020]]), tensor([[0.9086]]), tensor([[0.8366]])]\n",
      "[tensor([[0.4904]]), tensor([[0.1487]]), tensor([[0.4271]]), tensor([[0.9100]]), tensor([[0.8607]]), tensor([[0.4631]]), tensor([[0.5053]]), tensor([[0.5220]]), tensor([[0.5410]]), tensor([[0.3779]])]\n",
      "[tensor([[0.8467]]), tensor([[0.5007]]), tensor([[0.8106]]), tensor([[0.9831]]), tensor([[0.9726]]), tensor([[0.8320]]), tensor([[0.8543]]), tensor([[0.8624]]), tensor([[0.8712]]), tensor([[0.7771]])]\n",
      "[tensor([[0.0845]]), tensor([[0.0165]]), tensor([[0.0667]]), tensor([[0.4922]]), tensor([[0.3720]]), tensor([[0.0764]]), tensor([[0.0892]]), tensor([[0.0948]]), tensor([[0.1015]]), tensor([[0.0550]])]\n",
      "[tensor([[0.6430]]), tensor([[0.2464]]), tensor([[0.5825]]), tensor([[0.9498]]), tensor([[0.9204]]), tensor([[0.6175]]), tensor([[0.6565]]), tensor([[0.6715]]), tensor([[0.6881]]), tensor([[0.5320]])]\n",
      "[tensor([[0.5475]]), tensor([[0.1801]]), tensor([[0.4838]]), tensor([[0.9270]]), tensor([[0.8859]]), tensor([[0.5202]]), tensor([[0.5622]]), tensor([[0.5786]]), tensor([[0.5971]]), tensor([[0.4330]])]\n",
      "[tensor([[0.5176]]), tensor([[0.1630]]), tensor([[0.4539]]), tensor([[0.9185]]), tensor([[0.8732]]), tensor([[0.4902]]), tensor([[0.5324]]), tensor([[0.5491]]), tensor([[0.5679]]), tensor([[0.4038]])]\n",
      "[tensor([[0.5671]]), tensor([[0.1921]]), tensor([[0.5036]]), tensor([[0.9322]]), tensor([[0.8937]]), tensor([[0.5400]]), tensor([[0.5816]]), tensor([[0.5978]]), tensor([[0.6160]]), tensor([[0.4526]])]\n",
      "[tensor([[0.7553]]), tensor([[0.3592]]), tensor([[0.7052]]), tensor([[0.9701]]), tensor([[0.9520]]), tensor([[0.7345]]), tensor([[0.7662]]), tensor([[0.7780]]), tensor([[0.7908]]), tensor([[0.6609]])]\n",
      "[tensor([[0.6063]]), tensor([[0.2185]]), tensor([[0.5440]]), tensor([[0.9418]]), tensor([[0.9081]]), tensor([[0.5799]]), tensor([[0.6204]]), tensor([[0.6360]]), tensor([[0.6535]]), tensor([[0.4929]])]\n",
      "[tensor([[0.2807]]), tensor([[0.0662]]), tensor([[0.2321]]), tensor([[0.8039]]), tensor([[0.7147]]), tensor([[0.2591]]), tensor([[0.2929]]), tensor([[0.3070]]), tensor([[0.3234]]), tensor([[0.1976]])]\n",
      "[tensor([[0.4689]]), tensor([[0.1382]]), tensor([[0.4062]]), tensor([[0.9027]]), tensor([[0.8500]]), tensor([[0.4418]]), tensor([[0.4838]]), tensor([[0.5005]]), tensor([[0.5196]]), tensor([[0.3579]])]\n",
      "[tensor([[0.3277]]), tensor([[0.0813]]), tensor([[0.2741]]), tensor([[0.8366]]), tensor([[0.7578]]), tensor([[0.3040]]), tensor([[0.3409]]), tensor([[0.3562]]), tensor([[0.3738]]), tensor([[0.2353]])]\n",
      "[tensor([[0.3956]]), tensor([[0.1062]]), tensor([[0.3365]]), tensor([[0.8730]]), tensor([[0.8078]]), tensor([[0.3697]]), tensor([[0.4099]]), tensor([[0.4262]]), tensor([[0.4449]]), tensor([[0.2924]])]\n",
      "[tensor([[0.6716]]), tensor([[0.2708]]), tensor([[0.6131]]), tensor([[0.9555]]), tensor([[0.9292]]), tensor([[0.6470]]), tensor([[0.6846]]), tensor([[0.6989]]), tensor([[0.7147]]), tensor([[0.5635]])]\n",
      "[tensor([[0.5592]]), tensor([[0.1872]]), tensor([[0.4957]]), tensor([[0.9302]]), tensor([[0.8907]]), tensor([[0.5321]]), tensor([[0.5739]]), tensor([[0.5902]]), tensor([[0.6085]]), tensor([[0.4447]])]\n",
      "[tensor([[0.9521]]), tensor([[0.7831]]), tensor([[0.9391]]), tensor([[0.9952]]), tensor([[0.9922]]), tensor([[0.9469]]), tensor([[0.9548]]), tensor([[0.9576]]), tensor([[0.9606]]), tensor([[0.9262]])]\n",
      "[tensor([[0.0609]]), tensor([[0.0116]]), tensor([[0.0479]]), tensor([[0.4053]]), tensor([[0.2941]]), tensor([[0.0550]]), tensor([[0.0644]]), tensor([[0.0686]]), tensor([[0.0736]]), tensor([[0.0393]])]\n",
      "[tensor([[0.1498]]), tensor([[0.0310]]), tensor([[0.1201]]), tensor([[0.6492]]), tensor([[0.5308]]), tensor([[0.1364]]), tensor([[0.1575]]), tensor([[0.1666]]), tensor([[0.1775]]), tensor([[0.1001]])]\n",
      "[tensor([[0.5438]]), tensor([[0.1779]]), tensor([[0.4801]]), tensor([[0.9260]]), tensor([[0.8844]]), tensor([[0.5165]]), tensor([[0.5585]]), tensor([[0.5749]]), tensor([[0.5934]]), tensor([[0.4293]])]\n",
      "[tensor([[0.0536]]), tensor([[0.0102]]), tensor([[0.0420]]), tensor([[0.3728]]), tensor([[0.2665]]), tensor([[0.0483]]), tensor([[0.0567]]), tensor([[0.0604]]), tensor([[0.0648]]), tensor([[0.0345]])]\n",
      "[tensor([[0.3977]]), tensor([[0.1071]]), tensor([[0.3384]]), tensor([[0.8740]]), tensor([[0.8091]]), tensor([[0.3718]]), tensor([[0.4121]]), tensor([[0.4284]]), tensor([[0.4471]]), tensor([[0.2942]])]\n",
      "[tensor([[0.7941]]), tensor([[0.4119]]), tensor([[0.7493]]), tensor([[0.9759]]), tensor([[0.9612]]), tensor([[0.7757]]), tensor([[0.8037]]), tensor([[0.8141]]), tensor([[0.8253]]), tensor([[0.7089]])]\n",
      "[tensor([[0.1389]]), tensor([[0.0284]]), tensor([[0.1110]]), tensor([[0.6287]]), tensor([[0.5087]]), tensor([[0.1263]]), tensor([[0.1461]]), tensor([[0.1547]]), tensor([[0.1649]]), tensor([[0.0924]])]\n",
      "[tensor([[0.3078]]), tensor([[0.0747]]), tensor([[0.2562]]), tensor([[0.8236]]), tensor([[0.7406]]), tensor([[0.2850]]), tensor([[0.3206]]), tensor([[0.3354]]), tensor([[0.3526]]), tensor([[0.2192]])]\n",
      "[tensor([[0.6638]]), tensor([[0.2639]]), tensor([[0.6047]]), tensor([[0.9540]]), tensor([[0.9269]]), tensor([[0.6389]]), tensor([[0.6770]]), tensor([[0.6914]]), tensor([[0.7074]]), tensor([[0.5548]])]\n",
      "[tensor([[0.5799]]), tensor([[0.2004]]), tensor([[0.5168]]), tensor([[0.9355]]), tensor([[0.8986]]), tensor([[0.5530]]), tensor([[0.5943]]), tensor([[0.6104]]), tensor([[0.6283]]), tensor([[0.4656]])]\n",
      "[tensor([[0.9152]]), tensor([[0.6622]]), tensor([[0.8932]]), tensor([[0.9913]]), tensor([[0.9858]]), tensor([[0.9063]]), tensor([[0.9197]]), tensor([[0.9245]]), tensor([[0.9297]]), tensor([[0.8720]])]\n",
      "[tensor([[0.4145]]), tensor([[0.1139]]), tensor([[0.3542]]), tensor([[0.8814]]), tensor([[0.8196]]), tensor([[0.3882]]), tensor([[0.4290]]), tensor([[0.4455]]), tensor([[0.4644]]), tensor([[0.3088]])]\n",
      "[tensor([[0.6734]]), tensor([[0.2724]]), tensor([[0.6150]]), tensor([[0.9559]]), tensor([[0.9298]]), tensor([[0.6489]]), tensor([[0.6864]]), tensor([[0.7006]]), tensor([[0.7164]]), tensor([[0.5655]])]\n",
      "[tensor([[0.3975]]), tensor([[0.1070]]), tensor([[0.3383]]), tensor([[0.8739]]), tensor([[0.8090]]), tensor([[0.3716]]), tensor([[0.4119]]), tensor([[0.4282]]), tensor([[0.4469]]), tensor([[0.2940]])]\n",
      "[tensor([[0.7047]]), tensor([[0.3022]]), tensor([[0.6489]]), tensor([[0.9616]]), tensor([[0.9387]]), tensor([[0.6814]]), tensor([[0.7169]]), tensor([[0.7303]]), tensor([[0.7450]]), tensor([[0.6010]])]\n",
      "[tensor([[0.5670]]), tensor([[0.1921]]), tensor([[0.5036]]), tensor([[0.9322]]), tensor([[0.8937]]), tensor([[0.5399]]), tensor([[0.5815]]), tensor([[0.5978]]), tensor([[0.6159]]), tensor([[0.4525]])]\n",
      "[tensor([[0.4398]]), tensor([[0.1247]]), tensor([[0.3782]]), tensor([[0.8918]]), tensor([[0.8344]]), tensor([[0.4130]]), tensor([[0.4545]]), tensor([[0.4712]]), tensor([[0.4902]]), tensor([[0.3313]])]\n",
      "[tensor([[0.9372]]), tensor([[0.7304]]), tensor([[0.9204]]), tensor([[0.9937]]), tensor([[0.9897]]), tensor([[0.9304]]), tensor([[0.9406]]), tensor([[0.9442]]), tensor([[0.9481]]), tensor([[0.9040]])]\n",
      "[tensor([[0.6691]]), tensor([[0.2685]]), tensor([[0.6104]]), tensor([[0.9550]]), tensor([[0.9285]]), tensor([[0.6444]]), tensor([[0.6821]]), tensor([[0.6965]]), tensor([[0.7123]]), tensor([[0.5607]])]\n",
      "[tensor([[0.0945]]), tensor([[0.0186]]), tensor([[0.0748]]), tensor([[0.5228]]), tensor([[0.4011]]), tensor([[0.0855]]), tensor([[0.0997]]), tensor([[0.1058]]), tensor([[0.1133]]), tensor([[0.0618]])]\n",
      "[tensor([[0.1336]]), tensor([[0.0272]]), tensor([[0.1067]]), tensor([[0.6183]]), tensor([[0.4975]]), tensor([[0.1214]]), tensor([[0.1407]]), tensor([[0.1490]]), tensor([[0.1589]]), tensor([[0.0887]])]\n",
      "[tensor([[0.8562]]), tensor([[0.5195]]), tensor([[0.8219]]), tensor([[0.9843]]), tensor([[0.9745]]), tensor([[0.8422]]), tensor([[0.8634]]), tensor([[0.8711]]), tensor([[0.8794]]), tensor([[0.7899]])]\n",
      "[tensor([[0.1296]]), tensor([[0.0263]]), tensor([[0.1034]]), tensor([[0.6100]]), tensor([[0.4887]]), tensor([[0.1177]]), tensor([[0.1365]]), tensor([[0.1446]]), tensor([[0.1542]]), tensor([[0.0859]])]\n",
      "[tensor([[0.0777]]), tensor([[0.0151]]), tensor([[0.0613]]), tensor([[0.4695]]), tensor([[0.3510]]), tensor([[0.0702]]), tensor([[0.0821]]), tensor([[0.0873]]), tensor([[0.0935]]), tensor([[0.0505]])]\n",
      "[tensor([[0.8624]]), tensor([[0.5323]]), tensor([[0.8292]]), tensor([[0.9850]]), tensor([[0.9758]]), tensor([[0.8489]]), tensor([[0.8693]]), tensor([[0.8768]]), tensor([[0.8848]]), tensor([[0.7983]])]\n",
      "[tensor([[0.4971]]), tensor([[0.1521]]), tensor([[0.4336]]), tensor([[0.9121]]), tensor([[0.8638]]), tensor([[0.4697]]), tensor([[0.5119]]), tensor([[0.5287]]), tensor([[0.5476]]), tensor([[0.3842]])]\n",
      "[tensor([[0.8477]]), tensor([[0.5026]]), tensor([[0.8118]]), tensor([[0.9832]]), tensor([[0.9728]]), tensor([[0.8330]]), tensor([[0.8553]]), tensor([[0.8634]]), tensor([[0.8721]]), tensor([[0.7785]])]\n",
      "[tensor([[0.5083]]), tensor([[0.1580]]), tensor([[0.4448]]), tensor([[0.9157]]), tensor([[0.8691]]), tensor([[0.4810]]), tensor([[0.5232]]), tensor([[0.5399]]), tensor([[0.5588]]), tensor([[0.3949]])]\n",
      "[tensor([[0.3219]]), tensor([[0.0793]]), tensor([[0.2689]]), tensor([[0.8329]]), tensor([[0.7529]]), tensor([[0.2985]]), tensor([[0.3350]]), tensor([[0.3501]]), tensor([[0.3676]]), tensor([[0.2306]])]\n",
      "[tensor([[0.1117]]), tensor([[0.0223]]), tensor([[0.0888]]), tensor([[0.5692]]), tensor([[0.4468]]), tensor([[0.1013]]), tensor([[0.1178]]), tensor([[0.1249]]), tensor([[0.1335]]), tensor([[0.0736]])]\n",
      "[tensor([[0.8343]]), tensor([[0.4776]]), tensor([[0.7960]]), tensor([[0.9814]]), tensor([[0.9700]]), tensor([[0.8186]]), tensor([[0.8424]]), tensor([[0.8511]]), tensor([[0.8605]]), tensor([[0.7607]])]\n",
      "[tensor([[0.8150]]), tensor([[0.4444]]), tensor([[0.7734]]), tensor([[0.9788]]), tensor([[0.9659]]), tensor([[0.7979]]), tensor([[0.8238]]), tensor([[0.8334]]), tensor([[0.8437]]), tensor([[0.7355]])]\n",
      "[tensor([[0.2374]]), tensor([[0.0535]]), tensor([[0.1943]]), tensor([[0.7657]]), tensor([[0.6665]]), tensor([[0.2181]]), tensor([[0.2483]]), tensor([[0.2610]]), tensor([[0.2760]]), tensor([[0.1642]])]\n",
      "[tensor([[0.9783]]), tensor([[0.8909]]), tensor([[0.9721]]), tensor([[0.9979]]), tensor([[0.9966]]), tensor([[0.9758]]), tensor([[0.9795]]), tensor([[0.9808]]), tensor([[0.9822]]), tensor([[0.9660]])]\n",
      "[tensor([[0.8099]]), tensor([[0.4362]]), tensor([[0.7675]]), tensor([[0.9781]]), tensor([[0.9647]]), tensor([[0.7925]]), tensor([[0.8189]]), tensor([[0.8287]]), tensor([[0.8392]]), tensor([[0.7290]])]\n",
      "[tensor([[0.3776]]), tensor([[0.0992]]), tensor([[0.3198]]), tensor([[0.8644]]), tensor([[0.7957]]), tensor([[0.3523]]), tensor([[0.3917]]), tensor([[0.4078]]), tensor([[0.4263]]), tensor([[0.2769]])]\n",
      "[tensor([[0.3411]]), tensor([[0.0859]]), tensor([[0.2862]]), tensor([[0.8446]]), tensor([[0.7687]]), tensor([[0.3169]]), tensor([[0.3546]]), tensor([[0.3701]]), tensor([[0.3880]]), tensor([[0.2463]])]\n",
      "[tensor([[0.8797]]), tensor([[0.5705]]), tensor([[0.8500]]), tensor([[0.9872]]), tensor([[0.9792]]), tensor([[0.8677]]), tensor([[0.8859]]), tensor([[0.8925]]), tensor([[0.8996]]), tensor([[0.8220]])]\n",
      "[tensor([[0.6845]]), tensor([[0.2826]]), tensor([[0.6270]]), tensor([[0.9580]]), tensor([[0.9330]]), tensor([[0.6604]]), tensor([[0.6973]]), tensor([[0.7112]]), tensor([[0.7266]]), tensor([[0.5780]])]\n",
      "[tensor([[0.6369]]), tensor([[0.2415]]), tensor([[0.5761]]), tensor([[0.9485]]), tensor([[0.9184]]), tensor([[0.6112]]), tensor([[0.6506]]), tensor([[0.6656]]), tensor([[0.6824]]), tensor([[0.5255]])]\n",
      "[tensor([[0.5572]]), tensor([[0.1860]]), tensor([[0.4937]]), tensor([[0.9297]]), tensor([[0.8899]]), tensor([[0.5301]]), tensor([[0.5719]]), tensor([[0.5882]]), tensor([[0.6065]]), tensor([[0.4427]])]\n",
      "[tensor([[0.9291]]), tensor([[0.7039]]), tensor([[0.9103]]), tensor([[0.9928]]), tensor([[0.9882]]), tensor([[0.9215]]), tensor([[0.9329]]), tensor([[0.9370]]), tensor([[0.9413]]), tensor([[0.8921]])]\n",
      "[tensor([[0.9715]]), tensor([[0.8611]]), tensor([[0.9636]]), tensor([[0.9972]]), tensor([[0.9955]]), tensor([[0.9683]]), tensor([[0.9731]]), tensor([[0.9748]]), tensor([[0.9766]]), tensor([[0.9556]])]\n",
      "[tensor([[0.4037]]), tensor([[0.1095]]), tensor([[0.3441]]), tensor([[0.8767]]), tensor([[0.8130]]), tensor([[0.3777]]), tensor([[0.4181]]), tensor([[0.4345]]), tensor([[0.4533]]), tensor([[0.2994]])]\n",
      "[tensor([[0.1296]]), tensor([[0.0263]]), tensor([[0.1034]]), tensor([[0.6099]]), tensor([[0.4887]]), tensor([[0.1177]]), tensor([[0.1365]]), tensor([[0.1445]]), tensor([[0.1542]]), tensor([[0.0859]])]\n",
      "[tensor([[0.5201]]), tensor([[0.1644]]), tensor([[0.4564]]), tensor([[0.9192]]), tensor([[0.8743]]), tensor([[0.4927]]), tensor([[0.5349]]), tensor([[0.5516]]), tensor([[0.5703]]), tensor([[0.4062]])]\n",
      "[tensor([[0.2142]]), tensor([[0.0471]]), tensor([[0.1743]]), tensor([[0.7411]]), tensor([[0.6363]]), tensor([[0.1963]]), tensor([[0.2243]]), tensor([[0.2362]]), tensor([[0.2502]]), tensor([[0.1468]])]\n",
      "[tensor([[0.8688]]), tensor([[0.5458]]), tensor([[0.8368]]), tensor([[0.9858]]), tensor([[0.9770]]), tensor([[0.8558]]), tensor([[0.8754]]), tensor([[0.8825]]), tensor([[0.8902]]), tensor([[0.8069]])]\n",
      "[tensor([[0.0491]]), tensor([[0.0093]]), tensor([[0.0384]]), tensor([[0.3515]]), tensor([[0.2489]]), tensor([[0.0442]]), tensor([[0.0519]]), tensor([[0.0553]]), tensor([[0.0595]]), tensor([[0.0316]])]\n",
      "[tensor([[0.9296]]), tensor([[0.7057]]), tensor([[0.9110]]), tensor([[0.9928]]), tensor([[0.9883]]), tensor([[0.9221]]), tensor([[0.9334]]), tensor([[0.9375]]), tensor([[0.9418]]), tensor([[0.8929]])]\n",
      "[tensor([[0.3768]]), tensor([[0.0989]]), tensor([[0.3190]]), tensor([[0.8639]]), tensor([[0.7951]]), tensor([[0.3514]]), tensor([[0.3909]]), tensor([[0.4069]]), tensor([[0.4254]]), tensor([[0.2762]])]\n",
      "[tensor([[0.7015]]), tensor([[0.2991]]), tensor([[0.6455]]), tensor([[0.9611]]), tensor([[0.9378]]), tensor([[0.6781]]), tensor([[0.7138]]), tensor([[0.7273]]), tensor([[0.7422]]), tensor([[0.5973]])]\n",
      "[tensor([[0.1012]]), tensor([[0.0200]]), tensor([[0.0802]]), tensor([[0.5417]]), tensor([[0.4195]]), tensor([[0.0916]]), tensor([[0.1067]]), tensor([[0.1133]]), tensor([[0.1211]]), tensor([[0.0663]])]\n",
      "[tensor([[0.2426]]), tensor([[0.0550]]), tensor([[0.1988]]), tensor([[0.7709]]), tensor([[0.6728]]), tensor([[0.2231]]), tensor([[0.2537]]), tensor([[0.2666]]), tensor([[0.2818]]), tensor([[0.1682]])]\n",
      "[tensor([[0.5611]]), tensor([[0.1883]]), tensor([[0.4976]]), tensor([[0.9307]]), tensor([[0.8914]]), tensor([[0.5339]]), tensor([[0.5757]]), tensor([[0.5920]]), tensor([[0.6102]]), tensor([[0.4466]])]\n",
      "[tensor([[0.5639]]), tensor([[0.1901]]), tensor([[0.5005]]), tensor([[0.9314]]), tensor([[0.8925]]), tensor([[0.5368]]), tensor([[0.5785]]), tensor([[0.5948]]), tensor([[0.6130]]), tensor([[0.4494]])]\n",
      "[tensor([[0.6221]]), tensor([[0.2301]]), tensor([[0.5605]]), tensor([[0.9453]]), tensor([[0.9135]]), tensor([[0.5960]]), tensor([[0.6359]]), tensor([[0.6513]]), tensor([[0.6684]]), tensor([[0.5095]])]\n",
      "[tensor([[0.4364]]), tensor([[0.1233]]), tensor([[0.3750]]), tensor([[0.8905]]), tensor([[0.8325]]), tensor([[0.4097]]), tensor([[0.4511]]), tensor([[0.4678]]), tensor([[0.4868]]), tensor([[0.3283]])]\n",
      "[tensor([[0.9062]]), tensor([[0.6369]]), tensor([[0.8822]]), tensor([[0.9902]]), tensor([[0.9841]]), tensor([[0.8965]]), tensor([[0.9112]]), tensor([[0.9164]]), tensor([[0.9221]]), tensor([[0.8591]])]\n",
      "[tensor([[0.4764]]), tensor([[0.1418]]), tensor([[0.4135]]), tensor([[0.9053]]), tensor([[0.8538]]), tensor([[0.4492]]), tensor([[0.4913]]), tensor([[0.5081]]), tensor([[0.5271]]), tensor([[0.3648]])]\n",
      "[tensor([[0.8923]]), tensor([[0.6005]]), tensor([[0.8652]]), tensor([[0.9886]]), tensor([[0.9815]]), tensor([[0.8813]]), tensor([[0.8979]]), tensor([[0.9038]]), tensor([[0.9103]]), tensor([[0.8394]])]\n",
      "[tensor([[0.8822]]), tensor([[0.5761]]), tensor([[0.8529]]), tensor([[0.9874]]), tensor([[0.9796]]), tensor([[0.8703]]), tensor([[0.8882]]), tensor([[0.8947]]), tensor([[0.9017]]), tensor([[0.8254]])]\n",
      "[tensor([[0.6680]]), tensor([[0.2675]]), tensor([[0.6092]]), tensor([[0.9548]]), tensor([[0.9281]]), tensor([[0.6433]]), tensor([[0.6810]]), tensor([[0.6954]]), tensor([[0.7113]]), tensor([[0.5594]])]\n",
      "[tensor([[0.5976]]), tensor([[0.2124]]), tensor([[0.5350]]), tensor([[0.9398]]), tensor([[0.9051]]), tensor([[0.5710]]), tensor([[0.6119]]), tensor([[0.6277]]), tensor([[0.6453]]), tensor([[0.4839]])]\n",
      "[tensor([[0.4909]]), tensor([[0.1490]]), tensor([[0.4276]]), tensor([[0.9101]]), tensor([[0.8609]]), tensor([[0.4636]]), tensor([[0.5058]]), tensor([[0.5225]]), tensor([[0.5415]]), tensor([[0.3784]])]\n",
      "[tensor([[0.9256]]), tensor([[0.6930]]), tensor([[0.9060]]), tensor([[0.9924]]), tensor([[0.9876]]), tensor([[0.9177]]), tensor([[0.9296]]), tensor([[0.9338]]), tensor([[0.9384]]), tensor([[0.8870]])]\n",
      "[tensor([[0.8146]]), tensor([[0.4437]]), tensor([[0.7729]]), tensor([[0.9788]]), tensor([[0.9658]]), tensor([[0.7975]]), tensor([[0.8234]]), tensor([[0.8329]]), tensor([[0.8433]]), tensor([[0.7350]])]\n"
     ]
    }
   ],
   "source": [
    "change_id = []\n",
    "\n",
    "for i in range(128):\n",
    "    preds = []\n",
    "    for j in range(128):\n",
    "        channel_idx = i\n",
    "        signal_features = np.array(test_data_list)[j].reshape(1,-1)\n",
    "        with torch.no_grad():\n",
    "            preds.append(model(torch.tensor(channel_idx).reshape(-1), torch.tensor(signal_features, dtype = torch.float32)))\n",
    "    print(preds[:10])\n",
    "    # print(i)\n",
    "    change_id.append(preds[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "id": "92c5c242-d5a2-4e0a-8db5-f96480a8af28",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-04T11:18:53.979943Z",
     "iopub.status.busy": "2024-11-04T11:18:53.979742Z",
     "iopub.status.idle": "2024-11-04T11:18:54.006176Z",
     "shell.execute_reply": "2024-11-04T11:18:54.005701Z",
     "shell.execute_reply.started": "2024-11-04T11:18:53.979929Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[0.9738]]),\n",
       " tensor([[0.9826]]),\n",
       " tensor([[0.8490]]),\n",
       " tensor([[0.9768]]),\n",
       " tensor([[0.9914]]),\n",
       " tensor([[0.9720]]),\n",
       " tensor([[0.9289]]),\n",
       " tensor([[0.9907]]),\n",
       " tensor([[0.6675]]),\n",
       " tensor([[0.8889]]),\n",
       " tensor([[0.9880]]),\n",
       " tensor([[0.9958]]),\n",
       " tensor([[0.9722]]),\n",
       " tensor([[0.9950]]),\n",
       " tensor([[0.9786]]),\n",
       " tensor([[0.9904]]),\n",
       " tensor([[0.9051]]),\n",
       " tensor([[0.7499]]),\n",
       " tensor([[0.9052]]),\n",
       " tensor([[0.9500]]),\n",
       " tensor([[0.9929]]),\n",
       " tensor([[0.9786]]),\n",
       " tensor([[0.8656]]),\n",
       " tensor([[0.6247]]),\n",
       " tensor([[0.8119]]),\n",
       " tensor([[0.9142]]),\n",
       " tensor([[0.9020]]),\n",
       " tensor([[0.8614]]),\n",
       " tensor([[0.8720]]),\n",
       " tensor([[0.9892]]),\n",
       " tensor([[0.8596]]),\n",
       " tensor([[0.9541]]),\n",
       " tensor([[0.3834]]),\n",
       " tensor([[0.9445]]),\n",
       " tensor([[0.6613]]),\n",
       " tensor([[0.9740]]),\n",
       " tensor([[0.4609]]),\n",
       " tensor([[0.9055]]),\n",
       " tensor([[0.9630]]),\n",
       " tensor([[0.7623]]),\n",
       " tensor([[0.9921]]),\n",
       " tensor([[0.9421]]),\n",
       " tensor([[0.6892]]),\n",
       " tensor([[0.9781]]),\n",
       " tensor([[0.9884]]),\n",
       " tensor([[0.9100]]),\n",
       " tensor([[0.9831]]),\n",
       " tensor([[0.4922]]),\n",
       " tensor([[0.9498]]),\n",
       " tensor([[0.9270]]),\n",
       " tensor([[0.9185]]),\n",
       " tensor([[0.9322]]),\n",
       " tensor([[0.9701]]),\n",
       " tensor([[0.9418]]),\n",
       " tensor([[0.8039]]),\n",
       " tensor([[0.9027]]),\n",
       " tensor([[0.8366]]),\n",
       " tensor([[0.8730]]),\n",
       " tensor([[0.9555]]),\n",
       " tensor([[0.9302]]),\n",
       " tensor([[0.9952]]),\n",
       " tensor([[0.4053]]),\n",
       " tensor([[0.6492]]),\n",
       " tensor([[0.9260]]),\n",
       " tensor([[0.3728]]),\n",
       " tensor([[0.8740]]),\n",
       " tensor([[0.9759]]),\n",
       " tensor([[0.6287]]),\n",
       " tensor([[0.8236]]),\n",
       " tensor([[0.9540]]),\n",
       " tensor([[0.9355]]),\n",
       " tensor([[0.9913]]),\n",
       " tensor([[0.8814]]),\n",
       " tensor([[0.9559]]),\n",
       " tensor([[0.8739]]),\n",
       " tensor([[0.9616]]),\n",
       " tensor([[0.9322]]),\n",
       " tensor([[0.8918]]),\n",
       " tensor([[0.9937]]),\n",
       " tensor([[0.9550]]),\n",
       " tensor([[0.5228]]),\n",
       " tensor([[0.6183]]),\n",
       " tensor([[0.9843]]),\n",
       " tensor([[0.6100]]),\n",
       " tensor([[0.4695]]),\n",
       " tensor([[0.9850]]),\n",
       " tensor([[0.9121]]),\n",
       " tensor([[0.9832]]),\n",
       " tensor([[0.9157]]),\n",
       " tensor([[0.8329]]),\n",
       " tensor([[0.5692]]),\n",
       " tensor([[0.9814]]),\n",
       " tensor([[0.9788]]),\n",
       " tensor([[0.7657]]),\n",
       " tensor([[0.9979]]),\n",
       " tensor([[0.9781]]),\n",
       " tensor([[0.8644]]),\n",
       " tensor([[0.8446]]),\n",
       " tensor([[0.9872]]),\n",
       " tensor([[0.9580]]),\n",
       " tensor([[0.9485]]),\n",
       " tensor([[0.9297]]),\n",
       " tensor([[0.9928]]),\n",
       " tensor([[0.9972]]),\n",
       " tensor([[0.8767]]),\n",
       " tensor([[0.6099]]),\n",
       " tensor([[0.9192]]),\n",
       " tensor([[0.7411]]),\n",
       " tensor([[0.9858]]),\n",
       " tensor([[0.3515]]),\n",
       " tensor([[0.9928]]),\n",
       " tensor([[0.8639]]),\n",
       " tensor([[0.9611]]),\n",
       " tensor([[0.5417]]),\n",
       " tensor([[0.7709]]),\n",
       " tensor([[0.9307]]),\n",
       " tensor([[0.9314]]),\n",
       " tensor([[0.9453]]),\n",
       " tensor([[0.8905]]),\n",
       " tensor([[0.9902]]),\n",
       " tensor([[0.9053]]),\n",
       " tensor([[0.9886]]),\n",
       " tensor([[0.9874]]),\n",
       " tensor([[0.9548]]),\n",
       " tensor([[0.9398]]),\n",
       " tensor([[0.9101]]),\n",
       " tensor([[0.9924]]),\n",
       " tensor([[0.9788]])]"
      ]
     },
     "execution_count": 176,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 170,
   "id": "8749e78f-0def-43fe-a231-0695a142cbf1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-04T01:10:29.317764Z",
     "iopub.status.busy": "2024-11-04T01:10:29.317463Z",
     "iopub.status.idle": "2024-11-04T01:10:29.347851Z",
     "shell.execute_reply": "2024-11-04T01:10:29.347349Z",
     "shell.execute_reply.started": "2024-11-04T01:10:29.317738Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[0.7800]]),\n",
       " tensor([[0.3916]]),\n",
       " tensor([[0.7331]]),\n",
       " tensor([[0.9738]]),\n",
       " tensor([[0.9579]]),\n",
       " tensor([[0.7606]]),\n",
       " tensor([[0.7900]]),\n",
       " tensor([[0.8009]]),\n",
       " tensor([[0.8128]]),\n",
       " tensor([[0.6912]]),\n",
       " tensor([[0.7793]]),\n",
       " tensor([[0.7420]]),\n",
       " tensor([[0.8365]]),\n",
       " tensor([[0.9665]]),\n",
       " tensor([[0.9245]]),\n",
       " tensor([[0.8939]]),\n",
       " tensor([[0.3720]]),\n",
       " tensor([[0.6743]]),\n",
       " tensor([[0.5544]]),\n",
       " tensor([[0.7695]]),\n",
       " tensor([[0.6739]]),\n",
       " tensor([[0.9482]]),\n",
       " tensor([[0.6910]]),\n",
       " tensor([[0.7654]]),\n",
       " tensor([[0.7622]]),\n",
       " tensor([[0.7799]]),\n",
       " tensor([[0.7796]]),\n",
       " tensor([[0.8417]]),\n",
       " tensor([[0.6876]]),\n",
       " tensor([[0.8279]]),\n",
       " tensor([[0.7180]]),\n",
       " tensor([[0.7272]]),\n",
       " tensor([[0.7822]]),\n",
       " tensor([[0.8645]]),\n",
       " tensor([[0.8792]]),\n",
       " tensor([[0.9588]]),\n",
       " tensor([[0.9279]]),\n",
       " tensor([[0.6385]]),\n",
       " tensor([[0.9005]]),\n",
       " tensor([[0.8499]]),\n",
       " tensor([[0.9427]]),\n",
       " tensor([[0.9465]]),\n",
       " tensor([[0.9100]]),\n",
       " tensor([[0.9537]]),\n",
       " tensor([[0.9172]]),\n",
       " tensor([[0.9203]]),\n",
       " tensor([[0.9391]]),\n",
       " tensor([[0.9320]]),\n",
       " tensor([[0.9694]]),\n",
       " tensor([[0.8936]]),\n",
       " tensor([[0.8760]]),\n",
       " tensor([[0.8285]]),\n",
       " tensor([[0.8993]]),\n",
       " tensor([[0.9628]]),\n",
       " tensor([[0.9648]]),\n",
       " tensor([[0.9724]]),\n",
       " tensor([[0.8368]]),\n",
       " tensor([[0.8617]]),\n",
       " tensor([[0.7813]]),\n",
       " tensor([[0.7497]]),\n",
       " tensor([[0.5358]]),\n",
       " tensor([[0.9243]]),\n",
       " tensor([[0.2746]]),\n",
       " tensor([[0.7560]]),\n",
       " tensor([[0.4022]]),\n",
       " tensor([[0.3671]]),\n",
       " tensor([[0.8529]]),\n",
       " tensor([[0.5200]]),\n",
       " tensor([[0.6568]]),\n",
       " tensor([[0.7419]]),\n",
       " tensor([[0.7819]]),\n",
       " tensor([[0.9378]]),\n",
       " tensor([[0.7069]]),\n",
       " tensor([[0.7569]]),\n",
       " tensor([[0.4440]]),\n",
       " tensor([[0.7770]]),\n",
       " tensor([[0.8140]]),\n",
       " tensor([[0.7392]]),\n",
       " tensor([[0.8272]]),\n",
       " tensor([[0.2617]]),\n",
       " tensor([[0.8766]]),\n",
       " tensor([[0.7613]]),\n",
       " tensor([[0.9567]]),\n",
       " tensor([[0.6529]]),\n",
       " tensor([[0.7506]]),\n",
       " tensor([[0.9180]]),\n",
       " tensor([[0.9526]]),\n",
       " tensor([[0.0653]]),\n",
       " tensor([[0.8058]]),\n",
       " tensor([[0.9558]]),\n",
       " tensor([[0.7290]]),\n",
       " tensor([[0.8454]]),\n",
       " tensor([[0.7169]]),\n",
       " tensor([[0.9266]]),\n",
       " tensor([[0.9159]]),\n",
       " tensor([[0.9599]]),\n",
       " tensor([[0.5592]]),\n",
       " tensor([[0.8823]]),\n",
       " tensor([[0.6350]]),\n",
       " tensor([[0.8736]]),\n",
       " tensor([[0.8098]]),\n",
       " tensor([[0.2683]]),\n",
       " tensor([[0.6772]]),\n",
       " tensor([[0.6940]]),\n",
       " tensor([[0.5799]]),\n",
       " tensor([[0.8833]]),\n",
       " tensor([[0.8934]]),\n",
       " tensor([[0.4442]]),\n",
       " tensor([[0.7574]]),\n",
       " tensor([[0.7423]]),\n",
       " tensor([[0.8683]]),\n",
       " tensor([[0.7494]]),\n",
       " tensor([[0.5286]]),\n",
       " tensor([[0.7072]]),\n",
       " tensor([[0.4412]]),\n",
       " tensor([[0.7380]]),\n",
       " tensor([[0.6728]]),\n",
       " tensor([[0.6687]]),\n",
       " tensor([[0.6899]]),\n",
       " tensor([[0.7716]]),\n",
       " tensor([[0.6069]]),\n",
       " tensor([[0.8444]]),\n",
       " tensor([[0.9490]]),\n",
       " tensor([[0.8105]]),\n",
       " tensor([[0.9237]]),\n",
       " tensor([[0.6946]]),\n",
       " tensor([[0.8605]]),\n",
       " tensor([[0.7706]])]"
      ]
     },
     "execution_count": 170,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "change_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "id": "63438313-ea4f-4179-a86a-524e6de0a303",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-04T01:09:38.833529Z",
     "iopub.status.busy": "2024-11-04T01:09:38.833202Z",
     "iopub.status.idle": "2024-11-04T01:09:40.075383Z",
     "shell.execute_reply": "2024-11-04T01:09:40.074839Z",
     "shell.execute_reply.started": "2024-11-04T01:09:38.833503Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t12.2022.07.14_diagnosticBlocks.mat\n",
      "after:  0.125 1\n",
      "t12.2022.07.21_diagnosticBlocks.mat\n",
      "after:  0.125 1\n",
      "t12.2022.07.27_diagnosticBlocks.mat\n",
      "after:  0.125 1\n"
     ]
    }
   ],
   "source": [
    "for d in range(6, len(file_list)):\n",
    "    print(file_list[d])\n",
    "    three, threel = get_data(file_list[d])\n",
    "\n",
    "    changes = np.array(change_id).reshape(128)\n",
    "    new_three = three[:,:,changes]\n",
    "\n",
    "    \n",
    "    tot = 0\n",
    "    for i in range(64):\n",
    "        pred = linearsvc.predict(new_three[i].reshape(1,-1))\n",
    "        if (pred == threel[i]):\n",
    "            tot+=1\n",
    "    print('after: ',tot/64, len(np.unique(changes)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 207,
   "id": "4adf3500-459b-4d52-ab18-1d03a584c87c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-29T05:35:12.168076Z",
     "iopub.status.busy": "2024-11-29T05:35:12.167720Z",
     "iopub.status.idle": "2024-11-29T05:35:12.530982Z",
     "shell.execute_reply": "2024-11-29T05:35:12.530526Z",
     "shell.execute_reply.started": "2024-11-29T05:35:12.168050Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warping dataset: t5.2019.12.11\n",
      "(31, 27, 201, 192)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "# from sklearn import preprocessing\n",
    "# import matplotlib.pyplot as plt\n",
    "# import sklearn\n",
    "import sklearn\n",
    "from sklearn.svm import LinearSVC\n",
    "import random\n",
    "import scipy.io as scio\n",
    "import tensorflow as tf\n",
    "import tensorflow\n",
    "import keras\n",
    "import scipy\n",
    "import tensorflow as tf\n",
    "import os\n",
    "def seed_tensorflow(seed=42):\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    tf.random.set_seed(seed)\n",
    "    tf.compat.v1.set_random_seed(seed)\n",
    "#suppress all tensorflow warnings (largely related to compatability with v2)\n",
    "# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n",
    "\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 只使用第 0 个 GPU 设备\n",
    "\n",
    "data_list = []\n",
    "label_list = []\n",
    "dataDirs = ['t5.2019.05.08','t5.2019.11.25','t5.2019.12.09','t5.2019.12.11','t5.2019.12.18',\n",
    "            't5.2019.12.20','t5.2020.01.06','t5.2020.01.08','t5.2020.01.13','t5.2020.01.15']\n",
    "dataDirs = ['t5.2019.12.11']\n",
    "# dataDirs = ['t5.2019.11.25']\n",
    "charDef = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z',\n",
    "                'greaterThan','comma','apostrophe','tilde','questionMark']\n",
    "for dataDir in dataDirs:\n",
    "    \n",
    "    print('Warping dataset: ' + dataDir)\n",
    "    dat = scipy.io.loadmat('../singleLetters.mat')\n",
    "    #to normalize the units.\n",
    "    for char in charDef:\n",
    "        neuralCube = dat['neuralActivityCube_'+char].astype(np.float64)\n",
    "        data_list.append(neuralCube)\n",
    "        label_list.append(char)\n",
    "    print(np.array(data_list).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 208,
   "id": "67b71ffa-8605-4cb3-b4e6-015643c6cd08",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-29T05:35:13.296649Z",
     "iopub.status.busy": "2024-11-29T05:35:13.296333Z",
     "iopub.status.idle": "2024-11-29T05:35:16.379218Z",
     "shell.execute_reply": "2024-11-29T05:35:16.378730Z",
     "shell.execute_reply.started": "2024-11-29T05:35:13.296615Z"
    }
   },
   "outputs": [],
   "source": [
    "#### class_list: [ class1, class2, ..., class31] class1: 1*27*192*10 -> 192*270\n",
    "\n",
    "dl = np.vstack(data_list)\n",
    "dl_len = 5\n",
    "multi = int(200/dl_len)\n",
    "temp = np.zeros([837,dl_len,192])\n",
    "for i in range(837):\n",
    "    for j in range(dl_len):\n",
    "        for k in range(192):\n",
    "            temp[i,j,k] = sum(dl[i,j*multi:(j+1)*multi,k])\n",
    "tempn = temp.reshape([31,27,5,192])\n",
    "classn = np.transpose(tempn,[0,3,2,1])\n",
    "tempn = temp.reshape([837,-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 216,
   "id": "dba5036e-78b9-4f05-819f-66889baabfb4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-29T06:13:01.021474Z",
     "iopub.status.busy": "2024-11-29T06:13:01.021216Z",
     "iopub.status.idle": "2024-11-29T06:13:04.220389Z",
     "shell.execute_reply": "2024-11-29T06:13:04.219949Z",
     "shell.execute_reply.started": "2024-11-29T06:13:01.021432Z"
    }
   },
   "outputs": [],
   "source": [
    "dl = np.vstack(data_list)\n",
    "temp = np.zeros([837,5,192])\n",
    "for i in range(837):\n",
    "    for j in range(5):\n",
    "        for k in range(192):\n",
    "            temp[i,j,k] = np.mean(dl[i,j*40:(j+1)*40,k])\n",
    "\n",
    "temp = np.array(temp).reshape(837,-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 217,
   "id": "df72b73f-7331-4ef3-8e9c-d58b962792e9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-29T06:13:04.699996Z",
     "iopub.status.busy": "2024-11-29T06:13:04.699695Z",
     "iopub.status.idle": "2024-11-29T06:13:04.704728Z",
     "shell.execute_reply": "2024-11-29T06:13:04.704078Z",
     "shell.execute_reply.started": "2024-11-29T06:13:04.699956Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(837, 960)"
      ]
     },
     "execution_count": 217,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 218,
   "id": "b29e4565-c797-4654-aab0-5f2f31431206",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-29T06:13:07.960729Z",
     "iopub.status.busy": "2024-11-29T06:13:07.960407Z",
     "iopub.status.idle": "2024-11-29T06:14:42.096126Z",
     "shell.execute_reply": "2024-11-29T06:14:42.095543Z",
     "shell.execute_reply.started": "2024-11-29T06:13:07.960690Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n",
      "0.6274509803921569\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[218], line 17\u001b[0m\n\u001b[1;32m     15\u001b[0m train_data_loo, train_label_loo \u001b[38;5;241m=\u001b[39m  temp\u001b[38;5;241m.\u001b[39mreshape(nums,\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)[traidx], new_label_list[traidx]\n\u001b[1;32m     16\u001b[0m linearsvc \u001b[38;5;241m=\u001b[39m LinearSVC(C\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e8\u001b[39m)\n\u001b[0;32m---> 17\u001b[0m linearsvc\u001b[38;5;241m.\u001b[39mfit(train_data_loo, train_label_loo)\n\u001b[1;32m     18\u001b[0m pred \u001b[38;5;241m=\u001b[39m linearsvc\u001b[38;5;241m.\u001b[39mpredict(test_data_loo\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m     19\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pred \u001b[38;5;241m==\u001b[39m test_label_loo:\n",
      "File \u001b[0;32m~/anaconda3/envs/py311-torch2/lib/python3.11/site-packages/sklearn/base.py:1474\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1467\u001b[0m     estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[1;32m   1469\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m   1470\u001b[0m     skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m   1471\u001b[0m         prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m   1472\u001b[0m     )\n\u001b[1;32m   1473\u001b[0m ):\n\u001b[0;32m-> 1474\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m fit_method(estimator, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[0;32m~/anaconda3/envs/py311-torch2/lib/python3.11/site-packages/sklearn/svm/_classes.py:325\u001b[0m, in \u001b[0;36mLinearSVC.fit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m    319\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclasses_ \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39munique(y)\n\u001b[1;32m    321\u001b[0m _dual \u001b[38;5;241m=\u001b[39m _validate_dual_parameter(\n\u001b[1;32m    322\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdual, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloss, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpenalty, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmulti_class, X\n\u001b[1;32m    323\u001b[0m )\n\u001b[0;32m--> 325\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcoef_, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mintercept_, n_iter_ \u001b[38;5;241m=\u001b[39m _fit_liblinear(\n\u001b[1;32m    326\u001b[0m     X,\n\u001b[1;32m    327\u001b[0m     y,\n\u001b[1;32m    328\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mC,\n\u001b[1;32m    329\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfit_intercept,\n\u001b[1;32m    330\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mintercept_scaling,\n\u001b[1;32m    331\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclass_weight,\n\u001b[1;32m    332\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpenalty,\n\u001b[1;32m    333\u001b[0m     _dual,\n\u001b[1;32m    334\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose,\n\u001b[1;32m    335\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_iter,\n\u001b[1;32m    336\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtol,\n\u001b[1;32m    337\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrandom_state,\n\u001b[1;32m    338\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmulti_class,\n\u001b[1;32m    339\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloss,\n\u001b[1;32m    340\u001b[0m     sample_weight\u001b[38;5;241m=\u001b[39msample_weight,\n\u001b[1;32m    341\u001b[0m )\n\u001b[1;32m    342\u001b[0m \u001b[38;5;66;03m# Backward compatibility: _fit_liblinear is used both by LinearSVC/R\u001b[39;00m\n\u001b[1;32m    343\u001b[0m \u001b[38;5;66;03m# and LogisticRegression but LogisticRegression sets a structured\u001b[39;00m\n\u001b[1;32m    344\u001b[0m \u001b[38;5;66;03m# `n_iter_` attribute with information about the underlying OvR fits\u001b[39;00m\n\u001b[1;32m    345\u001b[0m \u001b[38;5;66;03m# while LinearSVC/R only reports the maximum value.\u001b[39;00m\n\u001b[1;32m    346\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_iter_ \u001b[38;5;241m=\u001b[39m n_iter_\u001b[38;5;241m.\u001b[39mmax()\u001b[38;5;241m.\u001b[39mitem()\n",
      "File \u001b[0;32m~/anaconda3/envs/py311-torch2/lib/python3.11/site-packages/sklearn/svm/_base.py:1217\u001b[0m, in \u001b[0;36m_fit_liblinear\u001b[0;34m(X, y, C, fit_intercept, intercept_scaling, class_weight, penalty, dual, verbose, max_iter, tol, random_state, multi_class, loss, epsilon, sample_weight)\u001b[0m\n\u001b[1;32m   1214\u001b[0m sample_weight \u001b[38;5;241m=\u001b[39m _check_sample_weight(sample_weight, X, dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mfloat64)\n\u001b[1;32m   1216\u001b[0m solver_type \u001b[38;5;241m=\u001b[39m _get_liblinear_solver_type(multi_class, penalty, loss, dual)\n\u001b[0;32m-> 1217\u001b[0m raw_coef_, n_iter_ \u001b[38;5;241m=\u001b[39m liblinear\u001b[38;5;241m.\u001b[39mtrain_wrap(\n\u001b[1;32m   1218\u001b[0m     X,\n\u001b[1;32m   1219\u001b[0m     y_ind,\n\u001b[1;32m   1220\u001b[0m     sp\u001b[38;5;241m.\u001b[39missparse(X),\n\u001b[1;32m   1221\u001b[0m     solver_type,\n\u001b[1;32m   1222\u001b[0m     tol,\n\u001b[1;32m   1223\u001b[0m     bias,\n\u001b[1;32m   1224\u001b[0m     C,\n\u001b[1;32m   1225\u001b[0m     class_weight_,\n\u001b[1;32m   1226\u001b[0m     max_iter,\n\u001b[1;32m   1227\u001b[0m     rnd\u001b[38;5;241m.\u001b[39mrandint(np\u001b[38;5;241m.\u001b[39miinfo(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mi\u001b[39m\u001b[38;5;124m\"\u001b[39m)\u001b[38;5;241m.\u001b[39mmax),\n\u001b[1;32m   1228\u001b[0m     epsilon,\n\u001b[1;32m   1229\u001b[0m     sample_weight,\n\u001b[1;32m   1230\u001b[0m )\n\u001b[1;32m   1231\u001b[0m \u001b[38;5;66;03m# Regarding rnd.randint(..) in the above signature:\u001b[39;00m\n\u001b[1;32m   1232\u001b[0m \u001b[38;5;66;03m# seed for srand in range [0..INT_MAX); due to limitations in Numpy\u001b[39;00m\n\u001b[1;32m   1233\u001b[0m \u001b[38;5;66;03m# on 32-bit platforms, we can't get to the UINT_MAX limit that\u001b[39;00m\n\u001b[1;32m   1234\u001b[0m \u001b[38;5;66;03m# srand supports\u001b[39;00m\n\u001b[1;32m   1235\u001b[0m n_iter_max \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(n_iter_)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "new_label_list = []\n",
    "\n",
    "for i in range(31):\n",
    "    for j in range(27):\n",
    "        new_label_list.append(label_list[j])\n",
    "new_label_list = np.array(new_label_list)\n",
    "\n",
    "tot = 0\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "nums = 31*27\n",
    "for i in range(nums):\n",
    "    test_data_loo, test_label_loo = temp.reshape(nums,-1)[i], new_label_list[i]\n",
    "    traidx = [j for j in range(nums) if j != i]\n",
    "    train_data_loo, train_label_loo =  temp.reshape(nums,-1)[traidx], new_label_list[traidx]\n",
    "    linearsvc = LinearSVC(C=1e8)\n",
    "    linearsvc.fit(train_data_loo, train_label_loo)\n",
    "    pred = linearsvc.predict(test_data_loo.reshape(1,-1))\n",
    "    if pred == test_label_loo:\n",
    "        tot += 1\n",
    "    if i % 50 == 0:\n",
    "        print(tot/(i+1))\n",
    "print(tot/(i+1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4b388f8-bbcf-4948-bc26-2d873345b70b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py311-torch2",
   "language": "python",
   "name": "py311-torch2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
