{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "running at: FiftyWords\n",
      "train data shape (450, 270)\n",
      "train label shape (450,)\n",
      "test data shape (455, 270)\n",
      "test label shape (455,)\n",
      "unique train label [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n",
      " 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47\n",
      " 48 49]\n",
      "unique test label [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n",
      " 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47\n",
      " 48 49]\n",
      "code is running on  cuda:0\n",
      "epoch = 49 lr =  0.001\n",
      "train_acc=\t 0.9288888888888889 \t test_acc=\t 0.7186813186813187 \t loss=\t 2.0506303310394287\n",
      "log saved at:\n",
      "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
      "epoch = 99 lr =  0.001\n",
      "train_acc=\t 0.9977777777777778 \t test_acc=\t 0.7758241758241758 \t loss=\t 1.0059599876403809\n",
      "log saved at:\n",
      "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
      "epoch = 149 lr =  0.001\n",
      "train_acc=\t 0.9977777777777778 \t test_acc=\t 0.7626373626373626 \t loss=\t 0.9938879013061523\n",
      "log saved at:\n",
      "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
      "epoch = 199 lr =  0.0005\n",
      "train_acc=\t 1.0 \t test_acc=\t 0.789010989010989 \t loss=\t 0.062162842601537704\n",
      "log saved at:\n",
      "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
      "epoch = 249 lr =  0.00025\n",
      "train_acc=\t 1.0 \t test_acc=\t 0.8153846153846154 \t loss=\t 0.06960023194551468\n",
      "log saved at:\n",
      "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
      "epoch = 299 lr =  0.00025\n",
      "train_acc=\t 1.0 \t test_acc=\t 0.8087912087912088 \t loss=\t 0.3699897527694702\n",
      "log saved at:\n",
      "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
      "epoch = 349 lr =  0.000125\n",
      "train_acc=\t 1.0 \t test_acc=\t 0.8 \t loss=\t 0.35558873414993286\n",
      "log saved at:\n",
      "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
      "epoch = 399 lr =  0.0001\n",
      "train_acc=\t 1.0 \t test_acc=\t 0.8065934065934066 \t loss=\t 0.018871668726205826\n",
      "log saved at:\n",
      "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
      "epoch = 449 lr =  0.0001\n",
      "train_acc=\t 1.0 \t test_acc=\t 0.8021978021978022 \t loss=\t 0.04294228553771973\n",
      "log saved at:\n",
      "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
      "epoch = 499 lr =  0.0001\n",
      "train_acc=\t 1.0 \t test_acc=\t 0.8065934065934066 \t loss=\t 0.07842594385147095\n",
      "log saved at:\n",
      "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
      "correct: [ 3 11 12 22  3 12 26  0 21  0  6  8 42 12  7  2  1  5 23  8  5  0 14  2\n",
      "  7 10 15  9  1  1  4  1  2  3  0  0  5  2  0  1 32  0  5  8  3 41 29 13\n",
      " 28 47 39 40  7 13 44 11 15  0  7  6 33  0  0  2 28 25 30  1  5  0  0  4\n",
      " 41  0  4 21 23  0 27 16  2 30  3  0  4  3  4  0 29  2  0 18  0  1  9  9\n",
      "  1 16 33  0  2  4  4  4  1 24  7 40 38  1 23 17  8  5  0  9 18  1  3  4\n",
      " 20  2  2  4  1  0  1  1  5  1 23 14  2 48  8 21 14 10 23 16  5  3  0  0\n",
      " 20  1 36 11  3 40 17 34  1  0  2  5  1  9  1 30  4  4 15 44  2  1  8 45\n",
      "  6  0  4 15  0 40 13  6  5 10 21  2 25  6  5 37  1  3 10  9  8  6 48 16\n",
      " 48 27  3  1 24 22  5  2  6  0 19  0 49  3  1  2  3 13 28 41  0 10 37 39\n",
      " 15 45 15  2 12 19 11 31 10 16 31  2  7  0 37  5 39  5  9 34  5  1 48  1\n",
      "  0  6  2  1  0  1 22 19  3  3  2 17 28 32  2 27 14  7  1  8 13  9  2 10\n",
      "  4  5  6  4  8 38 24  8 41 10  7 26 44 47 46 38  2 45 36  6  1  1 30 13\n",
      " 46 32 12  3 40 24 11  3 35 11  1  7  0 18  0  1  1 26  3  3  2 19 25  0\n",
      " 28 22 21  0  1 43  4  3  0 10  0  4  0  0 16  2  4 26  0  6  8 38  0  5\n",
      "  3 30 24 49 15 11 30  1  0  1  3  8 41 33  0  9 14  3 13 24 16 12 40 14\n",
      "  0 11  3  4  3  5  3 11  3  9  0  0  0 10  1 10 18  4 26  7  7  4  3  2\n",
      "  3 12  0 21 17  6  0 26  3 23 20 13  9  0 13 14 24 19  4 49  2 20 42  4\n",
      "  2  5  0  7  3  4 19 36 33 37 14 17 25  4  0  1  3  0  1  6  6 35 39  6\n",
      " 49 46  0 13  9 29  1  8  0  1 33  3  6 20  0  7 42  1 10  3 24 14 15]\n",
      "predict: [ 3. 11. 12. 22.  3. 12. 31.  0. 21.  0. 11.  8.  6. 12.  7.  1.  1.  5.\n",
      " 23.  8.  5.  0. 14.  2.  7.  3. 15.  9.  1.  1.  4.  1.  2.  3.  0.  0.\n",
      "  5.  2.  0.  1. 32.  0.  5.  8.  3. 41. 29. 12. 28. 15.  7. 32.  7. 12.\n",
      " 13. 11. 15.  0.  7.  2. 17.  0.  0.  2. 28.  2. 30.  1.  5.  0.  0.  4.\n",
      " 41.  0.  4. 21. 23.  0.  5. 16.  2. 12.  3.  0.  4.  3.  4.  0. 29.  2.\n",
      "  0. 18.  0.  1.  9. 21.  1.  3. 17.  0.  2.  4.  4.  4.  1. 29.  7.  1.\n",
      " 38.  1. 23.  9.  8. 15.  0.  0. 18.  1.  3.  4. 20.  2.  2.  4.  1.  0.\n",
      "  1.  1.  5.  1. 23. 14.  2. 37.  1.  5. 14. 10.  2. 16.  5.  3.  0.  0.\n",
      " 11.  1. 36. 11.  3. 43. 17.  4.  1.  0.  2.  5.  1.  9.  1. 30.  4.  4.\n",
      " 15. 44.  2.  1.  8. 45.  6. 32.  4. 15.  0. 35. 13.  6.  5.  0. 28.  2.\n",
      " 25.  6.  5. 37.  1.  3. 10.  9.  8.  6.  2. 11. 36. 27.  3.  1. 24. 22.\n",
      "  5.  2.  6.  0. 34.  0.  7.  3.  1.  2. 32. 13. 28. 24.  0. 10. 37. 19.\n",
      "  2. 45. 15.  2. 45. 19. 11. 31.  3.  3. 31.  2.  7.  0. 37.  5. 42.  5.\n",
      "  9. 18.  5.  1. 48.  1.  0.  6. 37.  1.  0.  1. 22. 19.  1.  3.  2. 17.\n",
      "  2. 32.  2.  5. 14.  7.  1.  8. 32.  9.  2. 10.  4.  5.  6.  4.  8. 38.\n",
      " 24.  8. 22.  3.  7. 26. 44.  7. 46. 38.  2. 45. 36.  6.  1.  1. 12. 13.\n",
      " 46. 32. 12.  3. 43. 24. 11.  3. 35. 11.  4.  7.  0. 18.  0.  1.  1. 31.\n",
      "  3.  3.  2. 19. 25.  0. 28. 22. 21.  0.  1.  0.  4.  3.  0. 10.  0.  4.\n",
      "  0.  0. 16.  2.  4. 26.  0.  6.  8. 41.  9.  5.  1. 45. 24.  1. 15. 11.\n",
      " 21.  1.  0.  1.  3.  8. 41. 33.  0. 21. 14. 39. 12. 24.  4. 12.  6. 14.\n",
      "  0. 11. 14.  4.  3.  5.  3. 11.  3.  9.  0.  0.  0. 10.  1.  3. 18.  4.\n",
      " 26.  7.  7.  4.  3.  2.  3. 12.  0.  1. 25.  6.  0. 26.  3. 23. 20. 12.\n",
      " 21.  0. 17. 14. 29. 19.  4. 39.  2.  9. 42.  4.  2.  5.  0.  7.  3.  4.\n",
      " 19. 36. 33. 37. 14. 17. 25.  4.  0.  1.  3.  0.  1.  6.  6. 35. 11.  6.\n",
      " 39. 46.  0. 13.  9. 28.  1.  8.  0.  1. 33.  3.  6. 20.  0.  7. 28.  1.\n",
      "  3.  3. 24. 14. 15.]\n",
      "0.8065934065934066\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from os.path import dirname\n",
    "from utils.dataloader.TSC_data_loader import TSC_data_loader\n",
    "from Classifiers.OS_CNN.OS_CNN_easy_use import OS_CNN_easy_use\n",
    "from sklearn.metrics import accuracy_score\n",
    "import numpy as np\n",
    "\n",
    "Result_log_folder = './Example_Results_of_OS_CNN/OS_CNN_result_iter_0/'\n",
    "dataset_path = dirname(\"./Example_Datasets/UCRArchive_2018/\")\n",
    "\n",
    "dataset_name_list = [\n",
    "\"FiftyWords\"\n",
    "]\n",
    "\n",
    "for dataset_name in dataset_name_list:\n",
    "    print('running at:', dataset_name)   \n",
    "\n",
    "    # load data,\n",
    "    X_train, y_train, X_test, y_test = TSC_data_loader(dataset_path, dataset_name)\n",
    "    print('train data shape', X_train.shape)\n",
    "    print('train label shape',y_train.shape)\n",
    "    print('test data shape',X_test.shape)\n",
    "    print('test label shape',y_test.shape)\n",
    "    print('unique train label',np.unique(y_train))\n",
    "    print('unique test label',np.unique(y_test))\n",
    "\n",
    "    # creat model and log save place,\n",
    "\n",
    "    model = OS_CNN_easy_use(\n",
    "        Result_log_folder = Result_log_folder, # the Result_log_folder,\n",
    "        dataset_name = dataset_name,           # dataset_name for creat log under Result_log_folder,\n",
    "        device = \"cuda:0\",                # Gpu \n",
    "        max_epoch = 500,                        # In our expirement the number is 2000 for keep it same with FCN for the example dataset 500 will be enough,\n",
    "        paramenter_number_of_layer_list = [8*128, 5*128*256 + 2*256*128],\n",
    "        batch_size = 16,\n",
    "        lr = 0.001\n",
    "        )\n",
    "\n",
    "    model.fit(X_train, y_train, X_test, y_test)\n",
    "\n",
    "    y_predict = model.predict(X_test)\n",
    "\n",
    "    print('correct:',y_test)\n",
    "    print('predict:',y_predict)\n",
    "    acc = accuracy_score(y_predict, y_test)\n",
    "    print(acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
