{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "OS_CNN_Colab_demo",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cTNTsbnijSQb"
      },
      "source": [
        "## Note\n",
        "Make sure that your runtime type is 'Python 3 with GPU acceleration'. To do so, go to Edit > Notebook settings > Hardware Accelerator > Select \"GPU\"."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x6uGPx5hjEwd",
        "outputId": "c3b09602-7ab6-4c71-9aa7-582022627ea2",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "import torch\n",
        "print(\"Are we using GPU?: \", torch.cuda.is_available())\n",
        "print(\"the GPU name is: \",torch.cuda.get_device_name(0))"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Are we using GPU?:  True\n",
            "the GPU name is:  Tesla T4\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "S05hVXGJjtWK",
        "outputId": "88b52827-ddec-4a27-e405-87e45d1adf1d",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "# get code from github\n",
        "!git clone https://github.com/Wensi-Tang/OS-CNN.git"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "fatal: destination path 'OS-CNN' already exists and is not an empty directory.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hbPs3iONkBQK",
        "outputId": "21bd0c0d-1674-44ff-8237-5523c31fcabe",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "#change path to OS-CNN\n",
        "%cd /content/OS-CNN"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/content/OS-CNN\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WzaEj3Eskll8",
        "outputId": "2f7ba852-7f0a-488e-dbb0-f2152c1eeb0d",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "import os\n",
        "from os.path import dirname\n",
        "from utils.dataloader.TSC_data_loader import TSC_data_loader\n",
        "from Classifiers.OS_CNN.OS_CNN_easy_use import OS_CNN_easy_use\n",
        "from sklearn.metrics import accuracy_score\n",
        "import numpy as np\n",
        "\n",
        "Result_log_folder = './Example_Results_of_OS_CNN/OS_CNN_result_iter_0/'\n",
        "dataset_path = dirname(\"./Example_Datasets/UCRArchive_2018/\")\n",
        "\n",
        "\n",
        "dataset_name_list = [\n",
        "\"FiftyWords\"\n",
        "]\n",
        "\n",
        "for dataset_name in dataset_name_list:\n",
        "    print('running at:', dataset_name)   \n",
        "\n",
        "    # load data,\n",
        "    X_train, y_train, X_test, y_test = TSC_data_loader(dataset_path, dataset_name)\n",
        "    print('train data shape', X_train.shape)\n",
        "    print('train label shape',y_train.shape)\n",
        "    print('test data shape',X_test.shape)\n",
        "    print('test label shape',y_test.shape)\n",
        "    print('unique train label',np.unique(y_train))\n",
        "    print('unique test label',np.unique(y_test))\n",
        "\n",
        "    # creat model and log save place,\n",
        "\n",
        "    model = OS_CNN_easy_use(\n",
        "        Result_log_folder = Result_log_folder, # the Result_log_folder,\n",
        "        dataset_name = dataset_name,           # dataset_name for creat log under Result_log_folder,\n",
        "        device = \"cuda:0\",                # Gpu \n",
        "        max_epoch = 500,                        # In our expirement the number is 2000 for keep it same with FCN for the example dataset 500 will be enough,\n",
        "        paramenter_number_of_layer_list = [8*128, 5*128*256 + 2*256*128],\n",
        "        )\n",
        "\n",
        "    model.fit(X_train, y_train, X_test, y_test)\n",
        "\n",
        "    y_predict = model.predict(X_test)\n",
        "\n",
        "    print('correct:',y_test)\n",
        "    print('predict:',y_predict)\n",
        "    acc = accuracy_score(y_predict, y_test)\n",
        "    print(acc)"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "running at: FiftyWords\n",
            "train data shape (450, 270)\n",
            "train label shape (450,)\n",
            "test data shape (455, 270)\n",
            "test label shape (455,)\n",
            "unique train label [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n",
            " 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47\n",
            " 48 49]\n",
            "unique test label [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n",
            " 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47\n",
            " 48 49]\n",
            "code is running on  cuda:0\n",
            "epoch = 49 lr =  0.001\n",
            "train_acc=\t 0.9377777777777778 \t test_acc=\t 0.7516483516483516 \t loss=\t 1.692021369934082\n",
            "log saved at:\n",
            "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
            "epoch = 99 lr =  0.001\n",
            "train_acc=\t 0.9911111111111112 \t test_acc=\t 0.7934065934065934 \t loss=\t 1.2091208696365356\n",
            "log saved at:\n",
            "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
            "epoch = 149 lr =  0.0005\n",
            "train_acc=\t 0.9977777777777778 \t test_acc=\t 0.7978021978021979 \t loss=\t 3.058522939682007\n",
            "log saved at:\n",
            "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
            "epoch = 199 lr =  0.00025\n",
            "train_acc=\t 1.0 \t test_acc=\t 0.8065934065934066 \t loss=\t 5.315889358520508\n",
            "log saved at:\n",
            "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
            "epoch = 249 lr =  0.00025\n",
            "train_acc=\t 1.0 \t test_acc=\t 0.8043956043956044 \t loss=\t 0.24811102449893951\n",
            "log saved at:\n",
            "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
            "epoch = 299 lr =  0.00025\n",
            "train_acc=\t 1.0 \t test_acc=\t 0.810989010989011 \t loss=\t 0.3401278853416443\n",
            "log saved at:\n",
            "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
            "epoch = 349 lr =  0.000125\n",
            "train_acc=\t 1.0 \t test_acc=\t 0.810989010989011 \t loss=\t 0.8644722104072571\n",
            "log saved at:\n",
            "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
            "epoch = 399 lr =  0.0001\n",
            "train_acc=\t 1.0 \t test_acc=\t 0.810989010989011 \t loss=\t 0.5051943063735962\n",
            "log saved at:\n",
            "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
            "epoch = 449 lr =  0.0001\n",
            "train_acc=\t 1.0 \t test_acc=\t 0.8131868131868132 \t loss=\t 0.018573710694909096\n",
            "log saved at:\n",
            "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
            "epoch = 499 lr =  0.0001\n",
            "train_acc=\t 1.0 \t test_acc=\t 0.8197802197802198 \t loss=\t 0.016301987692713737\n",
            "log saved at:\n",
            "./Example_Results_of_OS_CNN/OS_CNN_result_iter_0/FiftyWords/FiftyWords_.txt\n",
            "correct: [ 3 11 12 22  3 12 26  0 21  0  6  8 42 12  7  2  1  5 23  8  5  0 14  2\n",
            "  7 10 15  9  1  1  4  1  2  3  0  0  5  2  0  1 32  0  5  8  3 41 29 13\n",
            " 28 47 39 40  7 13 44 11 15  0  7  6 33  0  0  2 28 25 30  1  5  0  0  4\n",
            " 41  0  4 21 23  0 27 16  2 30  3  0  4  3  4  0 29  2  0 18  0  1  9  9\n",
            "  1 16 33  0  2  4  4  4  1 24  7 40 38  1 23 17  8  5  0  9 18  1  3  4\n",
            " 20  2  2  4  1  0  1  1  5  1 23 14  2 48  8 21 14 10 23 16  5  3  0  0\n",
            " 20  1 36 11  3 40 17 34  1  0  2  5  1  9  1 30  4  4 15 44  2  1  8 45\n",
            "  6  0  4 15  0 40 13  6  5 10 21  2 25  6  5 37  1  3 10  9  8  6 48 16\n",
            " 48 27  3  1 24 22  5  2  6  0 19  0 49  3  1  2  3 13 28 41  0 10 37 39\n",
            " 15 45 15  2 12 19 11 31 10 16 31  2  7  0 37  5 39  5  9 34  5  1 48  1\n",
            "  0  6  2  1  0  1 22 19  3  3  2 17 28 32  2 27 14  7  1  8 13  9  2 10\n",
            "  4  5  6  4  8 38 24  8 41 10  7 26 44 47 46 38  2 45 36  6  1  1 30 13\n",
            " 46 32 12  3 40 24 11  3 35 11  1  7  0 18  0  1  1 26  3  3  2 19 25  0\n",
            " 28 22 21  0  1 43  4  3  0 10  0  4  0  0 16  2  4 26  0  6  8 38  0  5\n",
            "  3 30 24 49 15 11 30  1  0  1  3  8 41 33  0  9 14  3 13 24 16 12 40 14\n",
            "  0 11  3  4  3  5  3 11  3  9  0  0  0 10  1 10 18  4 26  7  7  4  3  2\n",
            "  3 12  0 21 17  6  0 26  3 23 20 13  9  0 13 14 24 19  4 49  2 20 42  4\n",
            "  2  5  0  7  3  4 19 36 33 37 14 17 25  4  0  1  3  0  1  6  6 35 39  6\n",
            " 49 46  0 13  9 29  1  8  0  1 33  3  6 20  0  7 42  1 10  3 24 14 15]\n",
            "predict: [ 3. 11. 12. 22.  3. 12. 26.  0. 21.  0. 11.  8. 42. 12.  7.  1.  1.  5.\n",
            " 23.  8.  5.  0. 14.  2.  7. 10.  0.  9.  1.  1.  4.  1.  2.  3.  0.  0.\n",
            "  5.  2.  0.  1.  3.  0.  5.  8.  3. 41. 29. 13. 28. 30.  1.  3.  7. 13.\n",
            " 29. 11. 15.  0.  7.  6. 13.  0.  0.  2. 28. 25. 30.  1.  5.  0.  0.  4.\n",
            " 27.  0.  4. 21. 23.  0. 27. 16.  2. 12.  3.  0.  4.  3.  1.  0. 29.  2.\n",
            "  0. 18.  0. 34.  9. 21.  1.  4. 33.  0.  2.  4.  4.  4.  1. 29.  7.  1.\n",
            " 38.  1. 23. 17.  8. 15.  0.  9. 18.  1.  3.  4. 20.  2.  2.  4.  1.  0.\n",
            "  1.  1.  5.  1. 23. 14.  2. 18.  1.  5. 14. 10.  2. 16.  5.  3.  0.  0.\n",
            " 11.  1. 21. 11.  3.  3. 17. 19.  1.  0.  2.  5.  1.  9.  1. 30.  4. 35.\n",
            " 15. 44.  2.  1.  8. 45.  6.  0.  4. 15.  0. 12. 13.  6.  5.  0. 21.  2.\n",
            " 36.  6.  5. 37.  1.  3. 10.  9.  8.  6. 48. 11. 48. 27.  3.  1. 24. 22.\n",
            "  5.  2.  6.  0. 34.  0. 16.  3.  1.  2.  3. 13. 28. 17.  0. 10. 37. 19.\n",
            " 15. 45. 29.  2. 29. 19. 11. 31. 10.  1. 31.  2.  7.  0. 37.  5. 39.  5.\n",
            "  9. 24.  5.  1.  2.  1.  0.  6. 23.  1.  0.  1. 22. 19. 14.  3.  2. 17.\n",
            " 39. 32.  2.  5. 14.  7.  1.  8. 32.  9.  2. 10.  4.  0.  6.  4.  8. 38.\n",
            " 24.  8. 27.  3.  7. 26. 44. 10. 46. 38.  2.  9. 36.  6.  1.  1. 12. 13.\n",
            " 46. 32. 12.  3. 42. 24. 11.  3. 35. 11. 34.  7.  0. 18.  0.  1.  1. 26.\n",
            "  3.  3.  2. 19. 25.  0. 28. 22. 21.  0.  1. 28.  4.  3.  0. 10.  0.  4.\n",
            "  0.  0. 16.  2.  4. 26.  0.  6.  8.  2.  0.  5. 47.  0. 24.  1. 15. 11.\n",
            " 21.  1.  0.  1.  3.  8. 27. 33.  0.  9. 14. 39. 12. 44.  4. 12.  7. 14.\n",
            "  0. 11. 10.  4.  3.  5.  3. 11.  3.  9.  0.  0.  0. 10. 11.  3. 18.  4.\n",
            " 26.  7.  7.  4.  3.  2.  3. 12.  0. 22.  2.  6.  0. 26.  3. 23. 20. 12.\n",
            "  9.  0. 28. 14. 29. 19.  4. 39.  2. 20. 42.  4.  2.  5.  0.  7.  3.  4.\n",
            " 19. 36. 17. 37. 14. 17. 25.  4.  0.  1.  3.  0.  1.  6.  6.  4. 11.  6.\n",
            " 39. 46.  0. 13.  9. 29.  1.  8.  0.  1. 17.  3.  6. 11.  0. 15. 42.  1.\n",
            "  3.  3. 24. 14. 15.]\n",
            "0.8197802197802198\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}