{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name_list_128 = [\n",
    "'ACSF1',\n",
    "'Adiac',\n",
    "'AllGestureWiimoteX',\n",
    "'AllGestureWiimoteY',\n",
    "'AllGestureWiimoteZ',\n",
    "'ArrowHead',\n",
    "'Beef',\n",
    "'BeetleFly',\n",
    "'BirdChicken',\n",
    "'BME',\n",
    "'Car',\n",
    "'CBF',\n",
    "'Chinatown',\n",
    "'ChlorineConcentration',\n",
    "'CinCECGTorso',\n",
    "'Coffee',\n",
    "'Computers',\n",
    "'CricketX',\n",
    "'CricketY',\n",
    "'CricketZ',\n",
    "'Crop',\n",
    "'DiatomSizeReduction',\n",
    "'DistalPhalanxOutlineAgeGroup',\n",
    "'DistalPhalanxOutlineCorrect',\n",
    "'DistalPhalanxTW',\n",
    "'DodgerLoopDay',\n",
    "'DodgerLoopGame',\n",
    "'DodgerLoopWeekend',\n",
    "'Earthquakes',\n",
    "'ECG200',\n",
    "'ECG5000',\n",
    "'ECGFiveDays',\n",
    "'ElectricDevices',\n",
    "'EOGHorizontalSignal',\n",
    "'EOGVerticalSignal',\n",
    "'EthanolLevel',\n",
    "'FaceAll',\n",
    "'FaceFour',\n",
    "'FacesUCR',\n",
    "'FiftyWords',\n",
    "'Fish',\n",
    "'FordA',\n",
    "'FordB',\n",
    "'FreezerRegularTrain',\n",
    "'FreezerSmallTrain',\n",
    "'Fungi',\n",
    "'GestureMidAirD1',\n",
    "'GestureMidAirD2',\n",
    "'GestureMidAirD3',\n",
    "'GesturePebbleZ1',\n",
    "'GesturePebbleZ2',\n",
    "'GunPoint',\n",
    "'GunPointAgeSpan',\n",
    "'GunPointMaleVersusFemale',\n",
    "'GunPointOldVersusYoung',\n",
    "'Ham',\n",
    "'HandOutlines',\n",
    "'Haptics',\n",
    "'Herring',\n",
    "'HouseTwenty',\n",
    "'InlineSkate',\n",
    "'InsectEPGRegularTrain',\n",
    "'InsectEPGSmallTrain',\n",
    "'InsectWingbeatSound',\n",
    "'ItalyPowerDemand',\n",
    "'LargeKitchenAppliances',\n",
    "'Lightning2',\n",
    "'Lightning7',\n",
    "'Mallat',\n",
    "'Meat',\n",
    "'MedicalImages',\n",
    "'MelbournePedestrian',\n",
    "'MiddlePhalanxOutlineAgeGroup',\n",
    "'MiddlePhalanxOutlineCorrect',\n",
    "'MiddlePhalanxTW',\n",
    "'MixedShapesRegularTrain',\n",
    "'MixedShapesSmallTrain',\n",
    "'MoteStrain',\n",
    "'NonInvasiveFetalECGThorax1',\n",
    "'NonInvasiveFetalECGThorax2',\n",
    "'OliveOil',\n",
    "'OSULeaf',\n",
    "'PhalangesOutlinesCorrect',\n",
    "'Phoneme',\n",
    "'PickupGestureWiimoteZ',\n",
    "'PigAirwayPressure',\n",
    "'PigArtPressure',\n",
    "'PigCVP',\n",
    "'PLAID',\n",
    "'Plane',\n",
    "'PowerCons',\n",
    "'ProximalPhalanxOutlineAgeGroup',\n",
    "'ProximalPhalanxOutlineCorrect',\n",
    "'ProximalPhalanxTW',\n",
    "'RefrigerationDevices',\n",
    "'Rock',\n",
    "'ScreenType',\n",
    "'SemgHandGenderCh2',\n",
    "'SemgHandMovementCh2',\n",
    "'SemgHandSubjectCh2',\n",
    "'ShakeGestureWiimoteZ',\n",
    "'ShapeletSim',\n",
    "'ShapesAll',\n",
    "'SmallKitchenAppliances',\n",
    "'SmoothSubspace',\n",
    "'SonyAIBORobotSurface1',\n",
    "'SonyAIBORobotSurface2',\n",
    "'StarLightCurves',\n",
    "'Strawberry',\n",
    "'SwedishLeaf',\n",
    "'Symbols',\n",
    "'SyntheticControl',\n",
    "'ToeSegmentation1',\n",
    "'ToeSegmentation2',\n",
    "'Trace',\n",
    "'TwoLeadECG',\n",
    "'TwoPatterns',\n",
    "'UMD',\n",
    "'UWaveGestureLibraryAll',\n",
    "'UWaveGestureLibraryX',\n",
    "'UWaveGestureLibraryY',\n",
    "'UWaveGestureLibraryZ',\n",
    "'Wafer',\n",
    "'Wine',\n",
    "'WordSynonyms',\n",
    "'Worms',\n",
    "'WormsTwoClass',\n",
    "'Yoga'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ACSF1 finish\n",
      "Adiac finish\n",
      "AllGestureWiimoteX finish\n",
      "AllGestureWiimoteY finish\n",
      "AllGestureWiimoteZ finish\n",
      "ArrowHead finish\n",
      "Beef finish\n",
      "BeetleFly finish\n",
      "BirdChicken finish\n",
      "BME finish\n",
      "Car finish\n",
      "CBF finish\n",
      "Chinatown finish\n",
      "ChlorineConcentration finish\n",
      "CinCECGTorso finish\n",
      "Coffee finish\n",
      "Computers finish\n",
      "CricketX finish\n",
      "CricketY finish\n",
      "CricketZ finish\n",
      "Crop finish\n",
      "DiatomSizeReduction finish\n",
      "DistalPhalanxOutlineAgeGroup finish\n",
      "DistalPhalanxOutlineCorrect finish\n",
      "DistalPhalanxTW finish\n",
      "DodgerLoopDay finish\n",
      "DodgerLoopGame finish\n",
      "DodgerLoopWeekend finish\n",
      "Earthquakes finish\n",
      "ECG200 finish\n",
      "ECG5000 finish\n",
      "ECGFiveDays finish\n",
      "ElectricDevices finish\n",
      "EOGHorizontalSignal finish\n",
      "EOGVerticalSignal finish\n",
      "EthanolLevel finish\n",
      "FaceAll finish\n",
      "FaceFour finish\n",
      "FacesUCR finish\n",
      "FiftyWords finish\n",
      "Fish finish\n",
      "FordA finish\n",
      "FordB finish\n",
      "FreezerRegularTrain finish\n",
      "FreezerSmallTrain finish\n",
      "Fungi finish\n",
      "GestureMidAirD1 finish\n",
      "GestureMidAirD2 finish\n",
      "GestureMidAirD3 finish\n",
      "GesturePebbleZ1 finish\n",
      "GesturePebbleZ2 finish\n",
      "GunPoint finish\n",
      "GunPointAgeSpan finish\n",
      "GunPointMaleVersusFemale finish\n",
      "GunPointOldVersusYoung finish\n",
      "Ham finish\n",
      "HandOutlines finish\n",
      "Haptics finish\n",
      "Herring finish\n",
      "HouseTwenty finish\n",
      "InlineSkate finish\n",
      "InsectEPGRegularTrain finish\n",
      "InsectEPGSmallTrain finish\n",
      "InsectWingbeatSound finish\n",
      "ItalyPowerDemand finish\n",
      "LargeKitchenAppliances finish\n",
      "Lightning2 finish\n",
      "Lightning7 finish\n",
      "Mallat finish\n",
      "Meat finish\n",
      "MedicalImages finish\n",
      "MelbournePedestrian finish\n",
      "MiddlePhalanxOutlineAgeGroup finish\n",
      "MiddlePhalanxOutlineCorrect finish\n",
      "MiddlePhalanxTW finish\n",
      "MixedShapesRegularTrain finish\n",
      "MixedShapesSmallTrain finish\n",
      "MoteStrain finish\n",
      "NonInvasiveFetalECGThorax1 finish\n",
      "NonInvasiveFetalECGThorax2 finish\n",
      "OliveOil finish\n",
      "OSULeaf finish\n",
      "PhalangesOutlinesCorrect finish\n",
      "Phoneme finish\n",
      "PickupGestureWiimoteZ finish\n",
      "PigAirwayPressure finish\n",
      "PigArtPressure finish\n",
      "PigCVP finish\n",
      "PLAID finish\n",
      "Plane finish\n",
      "PowerCons finish\n",
      "ProximalPhalanxOutlineAgeGroup finish\n",
      "ProximalPhalanxOutlineCorrect finish\n",
      "ProximalPhalanxTW finish\n",
      "RefrigerationDevices finish\n",
      "Rock finish\n",
      "ScreenType finish\n",
      "SemgHandGenderCh2 finish\n",
      "SemgHandMovementCh2 finish\n",
      "SemgHandSubjectCh2 finish\n",
      "ShakeGestureWiimoteZ finish\n",
      "ShapeletSim finish\n",
      "ShapesAll finish\n",
      "SmallKitchenAppliances finish\n",
      "SmoothSubspace finish\n",
      "SonyAIBORobotSurface1 finish\n",
      "SonyAIBORobotSurface2 finish\n",
      "StarLightCurves finish\n",
      "Strawberry finish\n",
      "SwedishLeaf finish\n",
      "Symbols finish\n",
      "SyntheticControl finish\n",
      "ToeSegmentation1 finish\n",
      "ToeSegmentation2 finish\n",
      "Trace finish\n",
      "TwoLeadECG finish\n",
      "TwoPatterns finish\n",
      "UMD finish\n",
      "UWaveGestureLibraryAll finish\n",
      "UWaveGestureLibraryX finish\n",
      "UWaveGestureLibraryY finish\n",
      "UWaveGestureLibraryZ finish\n",
      "Wafer finish\n",
      "Wine finish\n",
      "WordSynonyms finish\n",
      "Worms finish\n",
      "WormsTwoClass finish\n",
      "Yoga finish\n"
     ]
    }
   ],
   "source": [
    "import _ucrdtw\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "from os.path import dirname\n",
    "from sklearn import preprocessing\n",
    "from sklearn.preprocessing import minmax_scale\n",
    "from sklearn.metrics import accuracy_score\n",
    "import os\n",
    "\n",
    "\n",
    "def set_nan_to_zero(a):\n",
    "    where_are_NaNs = np.isnan(a)\n",
    "    a[where_are_NaNs] = 0\n",
    "    return a\n",
    "\n",
    "def TSC_data_loader(dataset_path,dataset_name):\n",
    "    Train_dataset = np.loadtxt(\n",
    "        dataset_path + '/' + dataset_name + '/' + dataset_name + '_TRAIN.tsv')\n",
    "    Test_dataset = np.loadtxt(\n",
    "        dataset_path + '/' + dataset_name + '/' + dataset_name + '_TEST.tsv')\n",
    "    Train_dataset = Train_dataset.astype(np.float32)\n",
    "    Test_dataset = Test_dataset.astype(np.float32)\n",
    "\n",
    "    X_train = Train_dataset[:, 1:]\n",
    "    y_train = Train_dataset[:, 0:1]\n",
    "\n",
    "    X_test = Test_dataset[:, 1:]\n",
    "    y_test = Test_dataset[:, 0:1]\n",
    "    le = preprocessing.LabelEncoder()\n",
    "    le.fit(np.squeeze(y_train, axis=1))\n",
    "    y_train = le.transform(np.squeeze(y_train, axis=1))\n",
    "    y_test = le.transform(np.squeeze(y_test, axis=1))\n",
    "    return set_nan_to_zero(X_train), y_train, set_nan_to_zero(X_test), y_test\n",
    "\n",
    "def get_big_matrix(query_list, data_list):\n",
    "    Result_matrix = np.zeros((query_list.shape[0],data_list.shape[0]))\n",
    "    for i in range(query_list.shape[0]):\n",
    "        query = query_list[i]\n",
    "        for j in range(data_list.shape[0]):\n",
    "            data = data_list[j]\n",
    "            loc, dist = _ucrdtw.ucrdtw(data, query, 0.05, True)\n",
    "            Result_matrix[i,j] = dist\n",
    "    return Result_matrix\n",
    "\n",
    "def predict_with_matrix(matrix,label):\n",
    "    min_index = np.argmin(matrix, axis=-1)\n",
    "    \n",
    "    result = label[min_index]\n",
    "    return result\n",
    "\n",
    "\n",
    "dataset_path = dirname(\"./Example_Datasets/UCRArchive_2018/\")\n",
    "distance_matrix_log_folder =dirname(\"./distance_matrix_of_DTW/\")\n",
    "\n",
    "if not(os.path.exists(distance_matrix_log_folder)):\n",
    "    os.mkdir(distance_matrix_log_folder)\n",
    "\n",
    "\n",
    "for dataset_name in dataset_name_list_128:\n",
    "    result_dict = {}\n",
    "    X_train, y_train, X_test, y_test = TSC_data_loader(dataset_path, dataset_name)\n",
    "    \n",
    "    train_test_result_matrix = get_big_matrix(X_test, X_train)\n",
    "    \n",
    "    train_train_result_matrix = get_big_matrix(X_train, X_train)\n",
    "\n",
    "    result_dict['train_train'] = train_train_result_matrix\n",
    "    result_dict['test_train'] = train_test_result_matrix\n",
    "    \n",
    "    save_log_path = os.path.join(distance_matrix_log_folder, dataset_name)\n",
    "    \n",
    "    if not(os.path.exists(save_log_path)):\n",
    "        os.mkdir(save_log_path)\n",
    "    save_log_path = save_log_path +'/'+ dataset_name+'.npy'\n",
    "    np.save(save_log_path, result_dict)\n",
    "    print(dataset_name,'finish')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(100, 100)\n",
      "(100, 100)\n"
     ]
    }
   ],
   "source": [
    "dataset_name = 'ACSF1'\n",
    "result = np.load(distance_matrix_log_folder+'/'+dataset_name+'/'+dataset_name+'.npy')\n",
    "print(result.item().get('train_train').shape)\n",
    "print(result.item().get('test_train').shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(100,)\n",
      "0.64\n"
     ]
    }
   ],
   "source": [
    "X_train, y_train, X_test, y_test = TSC_data_loader(dataset_path, dataset_name)\n",
    "test_train_matrix = result.item().get('test_train')\n",
    "y_predict = y_train[np.argmin(test_train_matrix,-1)]\n",
    "print(y_predict.shape)\n",
    "acc = accuracy_score(y_predict,y_test)\n",
    "print(acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ACSF1 0.64\n",
      "Adiac 0.6061381074168798\n",
      "AllGestureWiimoteX 0.7271428571428571\n",
      "AllGestureWiimoteY 0.7442857142857143\n",
      "AllGestureWiimoteZ 0.6428571428571429\n",
      "ArrowHead 0.7314285714285714\n",
      "Beef 0.6666666666666666\n",
      "BeetleFly 0.7\n",
      "BirdChicken 0.7\n",
      "BME 0.98\n",
      "Car 0.7333333333333333\n",
      "CBF 0.98\n",
      "Chinatown 0.9565217391304348\n",
      "ChlorineConcentration 0.6484375\n",
      "CinCECGTorso 0.8260869565217391\n",
      "Coffee 1.0\n",
      "Computers 0.52\n",
      "CricketX 0.7564102564102564\n",
      "CricketY 0.7512820512820513\n",
      "CricketZ 0.7461538461538462\n",
      "Crop 0.6904166666666667\n",
      "DiatomSizeReduction 0.9640522875816994\n",
      "DistalPhalanxOutlineAgeGroup 0.7697841726618705\n",
      "DistalPhalanxOutlineCorrect 0.7101449275362319\n",
      "DistalPhalanxTW 0.6115107913669064\n",
      "DodgerLoopDay 0.5125\n",
      "DodgerLoopGame 0.8768115942028986\n",
      "DodgerLoopWeekend 0.9710144927536232\n",
      "Earthquakes 0.6906474820143885\n",
      "ECG200 0.9\n",
      "ECG5000 0.9242222222222222\n",
      "ECGFiveDays 0.8164924506387921\n",
      "ElectricDevices 0.6248216833095578\n",
      "EOGHorizontalSignal 0.4972375690607735\n",
      "EOGVerticalSignal 0.5248618784530387\n",
      "EthanolLevel 0.274\n",
      "FaceAll 0.8011834319526627\n",
      "FaceFour 0.875\n",
      "FacesUCR 0.9219512195121952\n",
      "FiftyWords 0.756043956043956\n",
      "Fish 0.8457142857142858\n",
      "FordA 0.5848484848484848\n",
      "FordB 0.6074074074074074\n",
      "FreezerRegularTrain 0.9049122807017543\n",
      "FreezerSmallTrain 0.7568421052631579\n",
      "Fungi 0.9354838709677419\n",
      "GestureMidAirD1 0.6384615384615384\n",
      "GestureMidAirD2 0.6\n",
      "GestureMidAirD3 0.36923076923076925\n",
      "GesturePebbleZ1 0.813953488372093\n",
      "GesturePebbleZ2 0.7848101265822784\n",
      "GunPoint 0.9733333333333334\n",
      "GunPointAgeSpan 0.9588607594936709\n",
      "GunPointMaleVersusFemale 0.9936708860759493\n",
      "GunPointOldVersusYoung 0.9650793650793651\n",
      "Ham 0.49523809523809526\n",
      "HandOutlines 0.8810810810810811\n",
      "Haptics 0.42857142857142855\n",
      "Herring 0.53125\n",
      "HouseTwenty 0.8403361344537815\n",
      "InlineSkate 0.4163636363636364\n",
      "InsectEPGRegularTrain 0.7429718875502008\n",
      "InsectEPGSmallTrain 0.7188755020080321\n",
      "InsectWingbeatSound 0.545959595959596\n",
      "ItalyPowerDemand 0.9543245869776482\n",
      "LargeKitchenAppliances 0.704\n",
      "Lightning2 0.8524590163934426\n",
      "Lightning7 0.7123287671232876\n",
      "Mallat 0.9364605543710022\n",
      "Meat 0.9333333333333333\n",
      "MedicalImages 0.7368421052631579\n",
      "MelbournePedestrian 0.8114285714285714\n",
      "MiddlePhalanxOutlineAgeGroup 0.487012987012987\n",
      "MiddlePhalanxOutlineCorrect 0.7044673539518901\n",
      "MiddlePhalanxTW 0.5064935064935064\n",
      "MixedShapesRegularTrain 0.9014432989690722\n",
      "MixedShapesSmallTrain 0.8412371134020619\n",
      "MoteStrain 0.8642172523961661\n",
      "NonInvasiveFetalECGThorax1 0.7964376590330788\n",
      "NonInvasiveFetalECGThorax2 0.8641221374045801\n",
      "OliveOil 0.8333333333333334\n",
      "OSULeaf 0.5785123966942148\n",
      "PhalangesOutlinesCorrect 0.7249417249417249\n",
      "Phoneme 0.19251054852320676\n",
      "PickupGestureWiimoteZ 0.7\n",
      "PigAirwayPressure 0.09615384615384616\n",
      "PigArtPressure 0.25\n",
      "PigCVP 0.11538461538461539\n",
      "PLAID 0.8063314711359404\n",
      "Plane 1.0\n",
      "PowerCons 0.9111111111111111\n",
      "ProximalPhalanxOutlineAgeGroup 0.8048780487804879\n",
      "ProximalPhalanxOutlineCorrect 0.7835051546391752\n",
      "ProximalPhalanxTW 0.7560975609756098\n",
      "RefrigerationDevices 0.456\n",
      "Rock 0.66\n",
      "ScreenType 0.38133333333333336\n",
      "SemgHandGenderCh2 0.8283333333333334\n",
      "SemgHandMovementCh2 0.6022222222222222\n",
      "SemgHandSubjectCh2 0.78\n",
      "ShakeGestureWiimoteZ 0.84\n",
      "ShapeletSim 0.6833333333333333\n",
      "ShapesAll 0.8066666666666666\n",
      "SmallKitchenAppliances 0.6453333333333333\n",
      "SmoothSubspace 0.9066666666666666\n",
      "SonyAIBORobotSurface1 0.7337770382695508\n",
      "SonyAIBORobotSurface2 0.8436516264428122\n",
      "StarLightCurves 0.8977659057795047\n",
      "Strawberry 0.9378378378378378\n",
      "SwedishLeaf 0.8176\n",
      "Symbols 0.9346733668341709\n",
      "SyntheticControl 0.9866666666666667\n",
      "ToeSegmentation1 0.7587719298245614\n",
      "ToeSegmentation2 0.9076923076923077\n",
      "Trace 0.99\n",
      "TwoLeadECG 0.8683055311676909\n",
      "TwoPatterns 0.9985\n",
      "UMD 0.9652777777777778\n",
      "UWaveGestureLibraryAll 0.9667783361250698\n",
      "UWaveGestureLibraryX 0.7780569514237856\n",
      "UWaveGestureLibraryY 0.7018425460636516\n",
      "UWaveGestureLibraryZ 0.6747627024008933\n",
      "Wafer 0.9948085658663206\n",
      "Wine 0.5740740740740741\n",
      "WordSynonyms 0.7304075235109718\n",
      "Worms 0.5584415584415584\n",
      "WormsTwoClass 0.6103896103896104\n",
      "Yoga 0.8446666666666667\n"
     ]
    }
   ],
   "source": [
    "for dataset_name in dataset_name_list_128:\n",
    "    X_train, y_train, X_test, y_test = TSC_data_loader(dataset_path, dataset_name)\n",
    "    result = np.load(distance_matrix_log_folder+'/'+dataset_name+'/'+dataset_name+'.npy')\n",
    "    test_train_matrix = result.item().get('test_train')\n",
    "    y_predict = y_train[np.argmin(test_train_matrix,-1)]\n",
    "\n",
    "    acc = accuracy_score(y_predict,y_test)\n",
    "    print(dataset_name,acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
