{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import glob\n",
    "import os\n",
    "\n",
    "DATA_NAME = \"50salads\"\n",
    "DATA_RATIO = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_boundary_labels(label_list, mapping_dict):\n",
    "    boundary_list = []\n",
    "    segment_len_list = []\n",
    "    label_seg_list = []\n",
    "\n",
    "    for video_label in label_list:\n",
    "        for class_label, class_name in mapping_dict.items():\n",
    "            video_label[video_label == class_name] = int(class_label) # change class name into class integer\n",
    "\n",
    "        label_seg_list.append(np.zeros(len(video_label)))\n",
    "        boundaries = []\n",
    "        segment_len = []\n",
    "        length = 0\n",
    "        for ind, (prev_label, curr_label) in enumerate(zip(video_label, video_label[1:])):\n",
    "            length += 1\n",
    "            if prev_label != curr_label:\n",
    "                boundaries.append(ind)\n",
    "                segment_len.append(length)\n",
    "                length = 0\n",
    "        if length != 0:\n",
    "            segment_len.append(length)  # put last segment(no boundary at the last of file)\n",
    "        if len(boundaries) != len(segment_len)-1:\n",
    "            segment_len.append(1)\n",
    "        boundary_list.append(boundaries)\n",
    "        segment_len_list.append(segment_len)\n",
    "\n",
    "    for i in range(len(boundary_list)):\n",
    "        for j in range(len(boundary_list[i])):\n",
    "            label_seg_list[i][boundary_list[i][j]] = 1\n",
    "    return label_seg_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def patchwork(feature_list, label_list, label_seg_list):\n",
    "        num_file = len(feature_list)\n",
    "        permuted_file_indices = np.arange(num_file)\n",
    "        length = 0\n",
    "        X_long = []\n",
    "        y_long = []\n",
    "        y_seg_long = []\n",
    "        file_boundaries = []\n",
    "        for i in permuted_file_indices:\n",
    "            length += len(feature_list[i])\n",
    "            X_long.append(feature_list[i])\n",
    "            y_long.append(label_list[i])\n",
    "            y_seg_long.append(label_seg_list[i])\n",
    "            file_boundaries.append(length)\n",
    "        return np.concatenate(X_long, axis=0), np.concatenate(y_long, axis=0), np.concatenate(y_seg_long, axis=0), np.array(file_boundaries, dtype=np.int64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = 'features'\n",
    "label_path = 'groundTruth'\n",
    "label_map_file_name = 'mapping.txt'\n",
    "\n",
    "feature_file_names = sorted(glob.glob(os.path.join(data_path, \"*.npy\")))\n",
    "label_file_names = sorted(glob.glob(os.path.join(label_path, \"*.txt\")))\n",
    "mapping_dict = pd.read_csv(label_map_file_name, sep=\" \", index_col=None, header=None)[1].to_dict()\n",
    "feature_list = [np.load(f).transpose() for f in feature_file_names]\n",
    "label_list = [np.array(pd.read_csv(f, sep=\" \", index_col=None, header=None)[0].to_numpy()) for f in\n",
    "                label_file_names]\n",
    "\n",
    "label_seg_list = generate_boundary_labels(label_list, mapping_dict)\n",
    "X_long, y_long, y_seg_long, file_boundaries_indice = patchwork(feature_list, label_list, label_seg_list)\n",
    "y_seg_long = np.array(generate_boundary_labels([y_long], {})).flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_long = y_long.astype(np.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,\n",
       "        17, 18], dtype=int32),\n",
       " array([61363, 14293, 47458, 10495, 47137, 13668, 11309, 22871, 27383,\n",
       "        13500, 20784, 61567, 48413, 15777, 21230, 24853, 34149, 30001,\n",
       "        51344]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(y_long, return_counts=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Change action start/ action end label as unlabeled region (c=0)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,\n",
      "       17], dtype=int32), array([81345, 61363, 14293, 47458, 10495, 47137, 13668, 11309, 22871,\n",
      "       27383, 13500, 20784, 61567, 48413, 15777, 21230, 24853, 34149]))\n"
     ]
    }
   ],
   "source": [
    "y_long += 1\n",
    "y_long[y_long==18] = 0\n",
    "y_long[y_long==19] = 0\n",
    "print(np.unique(y_long, return_counts=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[604, 0], [1595, 1], [350, 2], [698, 1], [244, 2], [1837, 3], [218, 4], [978, 5], [834, 6], [289, 7], [767, 8], [507, 9], [304, 10], [1118, 11], [2095, 0], [471, 7], [306, 10], [446, 8], [472, 9], [316, 11], [1558, 3], [1051, 1], [2543, 12], [1034, 13], [484, 14], [99, 2], [97, 14], [697, 5], [268, 6], [478, 15], [422, 16], [592, 17], [1871, 0], [1512, 3], [241, 4], [995, 1], [126, 2], [38, 1], [75, 2], [67, 16], [90, 2], [662, 5], [178, 6], [341, 5], [302, 6], [1624, 12], [681, 13], [81, 14], [239, 13], [310, 14], [648, 8], [310, 9], [169, 7], [303, 10], [831, 16], [587, 17], [2199, 0], [860, 1], [396, 2], [1156, 12], [811, 13], [305, 14], [1019, 3], [415, 4], [839, 5], [299, 6], [419, 16], [212, 8], [266, 9], [181, 7], [265, 10], [441, 11], [460, 17], [965, 0], [3095, 12], [378, 14], [1128, 5], [180, 6], [582, 3], [142, 4], [924, 1], [132, 2], [685, 16], [567, 9], [490, 8], [172, 10], [240, 7], [613, 11], [701, 17], [402, 15], [1036, 0], [570, 9], [404, 8], [181, 10], [225, 7], [553, 11], [1039, 12], [116, 14], [948, 13], [209, 14], [1243, 1], [224, 2], [971, 1], [333, 2], [569, 3], [324, 4], [913, 5], [186, 6], [550, 16], [603, 17], [482, 15], [1090, 0], [578, 8], [585, 9], [182, 10], [124, 7], [722, 11], [1486, 12], [569, 13], [516, 12], [651, 13], [250, 14], [1332, 3], [311, 4], [1437, 1], [287, 2], [490, 5], [156, 6], [128, 5], [128, 6], [726, 16], [1114, 17], [575, 15], [675, 0], [1251, 12], [1204, 13], [539, 14], [549, 1], [312, 2], [645, 5], [645, 6], [1170, 3], [408, 4], [275, 16], [310, 1], [44, 2], [471, 16], [1166, 9], [340, 8], [186, 11], [29, 10], [62, 11], [106, 10], [284, 11], [1022, 17], [451, 15], [960, 0], [474, 8], [452, 9], [344, 10], [136, 11], [195, 7], [219, 11], [1075, 5], [265, 6], [665, 1], [157, 2], [937, 12], [1341, 13], [544, 14], [790, 3], [68, 4], [603, 16], [201, 15], [351, 16], [1001, 17], [1440, 0], [495, 8], [587, 9], [329, 10], [311, 7], [379, 11], [1018, 1], [172, 2], [1192, 1], [244, 2], [657, 5], [51, 6], [285, 2], [782, 3], [150, 4], [427, 16], [1405, 12], [758, 13], [199, 14], [461, 16], [209, 15], [339, 16], [527, 17], [1090, 0], [691, 1], [248, 2], [835, 12], [461, 13], [266, 14], [701, 5], [216, 6], [525, 3], [243, 4], [352, 16], [919, 9], [277, 7], [292, 10], [566, 11], [284, 8], [660, 11], [870, 17], [1014, 0], [779, 1], [253, 2], [505, 3], [236, 4], [598, 5], [235, 6], [796, 12], [365, 13], [264, 14], [573, 16], [466, 9], [624, 8], [268, 7], [175, 10], [325, 11], [502, 17], [467, 15], [1313, 0], [1904, 3], [2211, 12], [1090, 13], [139, 4], [703, 14], [2332, 1], [412, 2], [1501, 5], [324, 6], [1293, 16], [643, 8], [1128, 9], [276, 10], [197, 7], [498, 11], [691, 15], [1093, 17], [975, 0], [2013, 5], [294, 6], [1845, 3], [454, 4], [2130, 1], [294, 2], [1906, 12], [883, 13], [587, 14], [740, 16], [431, 8], [561, 9], [200, 10], [145, 7], [511, 11], [430, 15], [732, 17], [1088, 0], [1008, 3], [826, 5], [141, 3], [412, 6], [1159, 12], [198, 13], [242, 14], [2214, 1], [264, 2], [381, 16], [710, 8], [697, 9], [227, 10], [318, 7], [338, 11], [280, 15], [200, 16], [754, 17], [887, 0], [389, 9], [451, 8], [356, 10], [361, 7], [423, 11], [1222, 5], [536, 6], [1943, 1], [457, 2], [1156, 3], [513, 4], [1412, 12], [1865, 13], [313, 14], [442, 16], [392, 15], [257, 16], [523, 17], [1680, 0], [839, 9], [547, 8], [149, 10], [167, 11], [101, 10], [279, 7], [537, 11], [832, 5], [180, 6], [157, 5], [258, 6], [917, 1], [121, 2], [978, 3], [135, 4], [843, 12], [1025, 13], [415, 14], [637, 16], [1089, 17], [558, 15], [1272, 0], [638, 9], [372, 8], [141, 10], [172, 7], [429, 11], [828, 5], [210, 6], [644, 1], [152, 2], [519, 3], [240, 4], [438, 12], [237, 13], [139, 14], [436, 16], [447, 17], [431, 15], [2734, 0], [415, 9], [350, 8], [172, 7], [219, 10], [377, 11], [811, 12], [1060, 13], [557, 5], [859, 3], [1364, 1], [120, 6], [224, 14], [110, 2], [129, 4], [171, 16], [361, 15], [339, 17], [1667, 0], [634, 5], [835, 3], [901, 1], [1410, 12], [1058, 13], [70, 6], [102, 2], [120, 4], [42, 14], [171, 16], [317, 9], [390, 8], [161, 10], [224, 7], [464, 11], [301, 17], [171, 15], [1125, 0], [630, 9], [449, 8], [180, 7], [253, 10], [301, 11], [1416, 5], [262, 6], [52, 1], [198, 6], [1060, 1], [160, 2], [1645, 12], [1019, 13], [503, 12], [657, 13], [402, 14], [1272, 3], [269, 4], [423, 16], [988, 15], [608, 17], [1048, 0], [757, 9], [683, 8], [225, 7], [226, 10], [272, 11], [2084, 1], [288, 2], [1123, 3], [318, 4], [1518, 5], [215, 6], [1183, 12], [1399, 13], [292, 14], [382, 16], [701, 17], [441, 15], [1051, 0], [691, 10], [172, 7], [291, 8], [334, 9], [396, 11], [601, 3], [151, 4], [888, 5], [245, 6], [580, 12], [776, 13], [203, 14], [826, 1], [149, 2], [420, 16], [571, 15], [443, 17], [944, 0], [774, 5], [172, 6], [482, 3], [138, 4], [620, 12], [673, 13], [260, 14], [1145, 1], [233, 2], [241, 16], [269, 7], [338, 10], [346, 8], [298, 9], [277, 11], [440, 17], [557, 15], [1430, 0], [1155, 12], [762, 13], [319, 14], [378, 13], [298, 14], [625, 1], [168, 2], [295, 1], [193, 2], [977, 3], [288, 4], [724, 5], [535, 6], [688, 9], [300, 10], [253, 7], [306, 8], [548, 11], [18, 16], [844, 17], [452, 15], [816, 0], [1094, 13], [235, 14], [1064, 13], [197, 14], [1459, 5], [314, 6], [943, 3], [638, 4], [1192, 1], [154, 2], [1023, 1], [516, 2], [335, 16], [370, 7], [447, 9], [271, 10], [464, 8], [731, 11], [168, 15], [831, 17], [1763, 0], [322, 10], [279, 7], [609, 9], [509, 8], [329, 11], [936, 3], [1314, 5], [1916, 12], [1621, 13], [1534, 1], [188, 14], [74, 4], [188, 2], [328, 15], [480, 16], [486, 17], [1402, 0], [293, 7], [240, 10], [306, 9], [338, 8], [247, 11], [823, 3], [1514, 12], [1296, 13], [1312, 1], [1268, 5], [535, 2], [297, 16], [619, 15], [343, 17], [1434, 0], [656, 9], [270, 10], [313, 7], [605, 8], [737, 11], [588, 5], [218, 6], [186, 5], [719, 3], [173, 4], [1235, 12], [1059, 13], [208, 14], [910, 1], [71, 2], [143, 1], [51, 2], [242, 16], [397, 15], [410, 16], [636, 17], [1008, 0], [1310, 12], [1572, 13], [389, 14], [728, 1], [286, 2], [993, 3], [333, 4], [730, 5], [330, 6], [445, 16], [517, 9], [207, 10], [264, 7], [493, 8], [246, 11], [1073, 17], [1852, 15], [1306, 0], [290, 7], [215, 10], [369, 8], [366, 9], [935, 3], [890, 1], [1047, 12], [683, 13], [594, 2], [199, 16], [547, 15], [412, 16], [445, 5], [98, 6], [353, 16], [484, 17], [4022, 0], [265, 7], [197, 10], [282, 9], [625, 8], [558, 5], [732, 12], [413, 13], [607, 1], [553, 3], [1252, 2], [277, 16], [431, 17], [493, 15], [1117, 0], [1072, 1], [225, 2], [1409, 3], [320, 4], [1561, 12], [976, 13], [327, 14], [1426, 5], [267, 6], [558, 5], [175, 6], [490, 16], [211, 7], [350, 9], [211, 10], [214, 8], [456, 11], [1429, 17], [790, 0], [351, 9], [339, 8], [191, 7], [211, 10], [248, 11], [256, 9], [165, 11], [792, 5], [264, 6], [175, 5], [82, 6], [1408, 1], [220, 2], [956, 3], [192, 4], [1077, 12], [1175, 13], [253, 14], [86, 13], [178, 14], [215, 16], [522, 15], [977, 17], [1658, 0], [254, 10], [162, 7], [248, 8], [360, 9], [256, 11], [811, 1], [126, 2], [560, 3], [191, 4], [1479, 12], [362, 13], [190, 14], [392, 5], [281, 6], [425, 16], [408, 17], [509, 15], [1336, 0], [804, 12], [452, 13], [211, 14], [357, 5], [311, 6], [647, 1], [168, 2], [646, 3], [142, 4], [270, 16], [238, 10], [100, 7], [249, 8], [334, 9], [345, 11], [856, 15], [380, 17], [1920, 0], [592, 8], [411, 10], [81, 11], [218, 7], [115, 11], [458, 9], [646, 11], [1409, 12], [985, 13], [264, 14], [180, 13], [42, 14], [103, 13], [164, 14], [1165, 3], [945, 5], [1274, 1], [70, 2], [35, 4], [67, 2], [119, 6], [473, 16], [787, 17], [585, 15], [1433, 0], [503, 8], [308, 9], [261, 10], [151, 7], [379, 11], [656, 3], [167, 4], [1099, 12], [1986, 13], [869, 5], [106, 13], [1552, 1], [113, 2], [164, 14], [68, 2], [267, 6], [207, 16], [770, 15], [761, 17], [1398, 0], [943, 9], [468, 8], [419, 10], [251, 7], [418, 11], [2309, 12], [2518, 13], [499, 14], [1066, 5], [308, 6], [1564, 1], [377, 2], [1770, 3], [520, 4], [568, 16], [1180, 15], [1368, 17], [1847, 0], [589, 9], [402, 8], [237, 7], [273, 10], [282, 11], [1041, 1], [218, 2], [1631, 5], [358, 6], [826, 3], [304, 4], [1793, 12], [882, 13], [348, 14], [519, 16], [461, 17], [532, 15], [1836, 0], [516, 9], [354, 8], [237, 7], [198, 10], [1388, 5], [276, 6], [1966, 1], [292, 2], [932, 3], [188, 4], [963, 12], [1524, 13], [944, 14], [726, 16], [545, 17], [3317, 0], [1037, 5], [229, 6], [2007, 1], [354, 2], [773, 12], [1214, 13], [735, 14], [687, 3], [321, 4], [474, 16], [504, 9], [417, 8], [170, 7], [153, 10], [432, 11], [407, 17], [3381, 0], [1250, 12], [981, 13], [807, 5], [188, 6], [20, 5], [282, 14], [847, 1], [257, 2], [945, 3], [172, 4], [314, 16], [748, 8], [1022, 9], [239, 10], [187, 7], [195, 11], [206, 15], [138, 16], [1320, 17], [2198, 0], [867, 12], [455, 13], [259, 14], [310, 5], [129, 6], [885, 1], [160, 2], [615, 3], [143, 4], [274, 16], [529, 8], [513, 9], [62, 11], [165, 10], [230, 7], [99, 11], [357, 17], [285, 15], [1583, 0], [1625, 12], [617, 13], [465, 14], [815, 3], [180, 4], [301, 5], [71, 6], [465, 5], [306, 6], [455, 1], [140, 2], [330, 16], [363, 8], [316, 9], [154, 7], [673, 10], [731, 17], [410, 15], [2599, 0], [624, 1], [113, 2], [800, 3], [145, 4], [535, 5], [182, 6], [1404, 12], [432, 13], [320, 14], [388, 16], [398, 9], [390, 8], [443, 10], [175, 7], [597, 11], [435, 17], [253, 15], [921, 0], [719, 5], [190, 6], [972, 12], [708, 13], [153, 14], [673, 3], [168, 4], [629, 1], [162, 2], [259, 16], [163, 7], [177, 10], [94, 11], [470, 9], [399, 8], [249, 11], [836, 17], [1976, 0], [404, 9], [259, 8], [105, 11], [130, 7], [282, 10], [117, 11], [653, 1], [156, 2], [817, 5], [220, 6], [522, 3], [167, 4], [509, 12], [344, 13], [180, 14], [274, 16], [1130, 15], [943, 17], [2357, 0], [628, 12], [693, 13], [1172, 5], [312, 6], [1281, 1], [1000, 3], [146, 4], [83, 2], [44, 14], [5, 16], [10, 14], [706, 16], [349, 7], [317, 10], [655, 8], [999, 9], [666, 11], [850, 17], [1526, 0], [760, 1], [732, 12], [690, 13], [698, 3], [1065, 5], [199, 6], [47, 14], [103, 2], [126, 4], [584, 16], [202, 7], [227, 10], [608, 8], [551, 9], [369, 11], [577, 17], [4148, 0]]\n",
      "number of gradual change points: 849\n",
      "number of abrupt change points: 49\n"
     ]
    }
   ],
   "source": [
    "def stitch_with_change_points(X, y_long):\n",
    "    # y_long: an class label array with same length of data array\n",
    "    # generates class segments' start end pair list\n",
    "    # there could be no class change before and after a un-annotated segment\n",
    "    # we only treat class change as a change point\n",
    "    # if two labeled segments surrounding an unlabeled segment have same class then those segments are merged as one segment\n",
    "    boundary_indice = np.where(y_long[1:]!=y_long[:-1])[0]+1\n",
    "    boundary_indice = np.concatenate([[0],boundary_indice,[len(y_long)]])\n",
    "    class_segment_list = [[j-i,int(y_long[i])] for i,j in zip(boundary_indice[:-1], boundary_indice[1:])]\n",
    "    print(class_segment_list)    \n",
    "    \n",
    "    X_annot_stitched = X[y_long!=0]\n",
    "    y_annot_stitched = y_long[y_long!=0]\n",
    "    y_annot_stitched -= 1\n",
    "    boundary_labels_per_ts = np.zeros_like(y_annot_stitched)\n",
    "\n",
    "    total_length = 0\n",
    "    prev_c = 0\n",
    "    for i, (length, c) in enumerate(class_segment_list):\n",
    "        if c==0 and (0 < i < len(class_segment_list) - 1) and class_segment_list[i-1][1] != class_segment_list[i+1][1]: # if first/last segment is not annotated skip the segment\n",
    "            boundary_labels_per_ts[total_length-1] = 2\n",
    "            # print(total_length-1)\n",
    "        elif c!=0: # class must change comparing to the class of previous segment\n",
    "            if prev_c != c and boundary_labels_per_ts[total_length-1]==0: # if already labeled as abrupt, then do not label\n",
    "                boundary_labels_per_ts[total_length-1] = 1\n",
    "            total_length += length\n",
    "            prev_c = c\n",
    "        else:\n",
    "            # print(i)\n",
    "            continue\n",
    "    boundary_labels_per_ts[0]=0 # no change point at the start timestamp\n",
    "    boundary_labels_per_ts[-1]=0 # no change point at the last timestamp\n",
    "    \n",
    "    print(f\"number of gradual change points: {np.sum(boundary_labels_per_ts==1)}\\nnumber of abrupt change points: {np.sum(boundary_labels_per_ts==2)}\")\n",
    "    return X_annot_stitched, y_annot_stitched, boundary_labels_per_ts\n",
    "X_annot_stitched, y_annot_stitched, boundary_labels_per_ts = stitch_with_change_points(X_long, y_long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler, RobustScaler\n",
    "scaler = StandardScaler()\n",
    "X_annot_stitched = scaler.fit_transform(X_annot_stitched)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "49\n",
      "(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16],\n",
      "      dtype=int32), array([61363, 14293, 47458, 10495, 47137, 13668, 11309, 22871, 27383,\n",
      "       13500, 20784, 61567, 48413, 15777, 21230, 24853, 34149]))\n"
     ]
    }
   ],
   "source": [
    "print(np.sum(boundary_labels_per_ts==2))\n",
    "print(np.unique(y_annot_stitched, return_counts=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_length = int(DATA_RATIO*len(X_annot_stitched))\n",
    "np.save(f\"{DATA_NAME}_X_long.npy\", X_annot_stitched[:data_length])\n",
    "np.save(f\"{DATA_NAME}_y_long.npy\", y_annot_stitched[:data_length])\n",
    "np.save(f\"{DATA_NAME}_cp_long.npy\", boundary_labels_per_ts[:data_length]) # gradual change points: 1 / abrupt change points: 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(496250, 2048) (496250,) (496250,)\n",
      "(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16],\n",
      "      dtype=int32), array([61363, 14293, 47458, 10495, 47137, 13668, 11309, 22871, 27383,\n",
      "       13500, 20784, 61567, 48413, 15777, 21230, 24853, 34149]))\n"
     ]
    }
   ],
   "source": [
    "X_load = np.load(f\"{DATA_NAME}_X_long.npy\")\n",
    "y_load = np.load(f\"{DATA_NAME}_y_long.npy\")\n",
    "boundary_load = np.load(f\"{DATA_NAME}_cp_long.npy\")\n",
    "print(X_load.shape, y_load.shape,boundary_load.shape)\n",
    "print(np.unique(y_load, return_counts=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
