{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from sktime.utils.data_io import load_from_arff_to_dataframe\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "DATA_PATH = 'data/'\n",
    "\n",
    "X_train_orig, y_train_orig = load_from_arff_to_dataframe(\n",
    "    os.path.join(DATA_PATH, \"EigenWorms_TRAIN.arff\")\n",
    ")\n",
    "X_test_orig, y_test_orig = load_from_arff_to_dataframe(\n",
    "    os.path.join(DATA_PATH, \"EigenWorms_TEST.arff\")\n",
    ")\n",
    "X_all = pd.concat((X_train_orig, X_test_orig))\n",
    "y_all = np.concatenate((y_train_orig, y_test_orig))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(128, 6)\n",
      "(131, 6)\n",
      "(259, 6)\n"
     ]
    }
   ],
   "source": [
    "print(X_train_orig.shape)\n",
    "print(X_test_orig.shape)\n",
    "print(X_all.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=0.15, random_state=42)\n",
    "X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.15, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(187, 6)\n",
      "(187,)\n",
      "(33, 6)\n",
      "(33,)\n",
      "(39, 6)\n",
      "(39,)\n"
     ]
    }
   ],
   "source": [
    "print(X_train.shape)\n",
    "print(y_train.shape)\n",
    "print(X_val.shape)\n",
    "print(y_val.shape)\n",
    "print(X_test.shape)\n",
    "print(y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['2' '1' '1' '2' '1' '4' '3' '4' '5' '3' '3' '1' '5' '2' '1' '4' '1' '1'\n",
      " '4' '2' '1' '1' '3' '1' '1' '1' '1' '4' '4' '1' '1' '1' '2' '2' '1' '4'\n",
      " '2' '5' '3']\n"
     ]
    }
   ],
   "source": [
    "print(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _to_numpy(X):\n",
    "    return np.stack([np.stack(x) for x in X.to_numpy()]).swapaxes(-1, -2)\n",
    "\n",
    "np.save(os.path.join(DATA_PATH, \"trainx.npy\"), _to_numpy(X_train))\n",
    "np.save(os.path.join(DATA_PATH, \"trainy.npy\"), y_train.astype(int)-1)\n",
    "np.save(os.path.join(DATA_PATH, \"validx.npy\"), _to_numpy(X_val))\n",
    "np.save(os.path.join(DATA_PATH, \"validy.npy\"), y_val.astype(int)-1)\n",
    "np.save(os.path.join(DATA_PATH, \"testx.npy\"), _to_numpy(X_test))\n",
    "np.save(os.path.join(DATA_PATH, \"testy.npy\"), y_test.astype(int)-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(187, 17984, 6)\n",
      "float64\n",
      "(187,)\n",
      "int64\n",
      "(33, 17984, 6)\n",
      "float64\n",
      "(33,)\n",
      "int64\n",
      "(39, 17984, 6)\n",
      "float64\n",
      "(39,)\n",
      "int64\n"
     ]
    }
   ],
   "source": [
    "for file in ['trainx', 'trainy', 'validx', 'validy', 'testx', 'testy']:\n",
    "    a = np.load(f\"data/{file}.npy\")\n",
    "    print(a.shape)\n",
    "    print(a.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 0, 0, 1, 0, 3, 2, 3, 4, 2, 2, 0, 4, 1, 0, 3, 0, 0, 3, 1, 0, 0,\n",
       "       2, 0, 0, 0, 0, 3, 3, 0, 0, 0, 1, 1, 0, 3, 1, 4, 2])"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
