{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import math\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import matplotlib as mpl\n",
    "import os\n",
    "import gc\n",
    "import pandas as pd\n",
    "import csv\n",
    "from numpy import *\n",
    "from datetime import date\n",
    "import time\n",
    "import builtins\n",
    "from sax import sax_tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_mfcc = np.load('./x_mfcc.npy', allow_pickle=True)\n",
    "x_raw = np.zeros(x_mfcc.shape)\n",
    "y = np.load('./y.npy', allow_pickle=True)\n",
    "\n",
    "x_raw = x_raw.transpose(0,2,1)\n",
    "x_mfcc = x_mfcc.transpose(0,2,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(280, 501, 40) (280, 501, 40) (280,)\n",
      "(120, 501, 40) (120, 501, 40) (120,)\n",
      "(120, 501, 40) (120, 501, 40) (120,)\n"
     ]
    }
   ],
   "source": [
    "# Assuming x_raw, x_mfcc, and y are already defined numpy arrays\n",
    "\n",
    "# First split: Train (70%) and Temp (30%)\n",
    "\n",
    "## seeds = [42, 117]\n",
    "seed = 42\n",
    "x_raw_train, x_raw_temp, x_mfcc_train, x_mfcc_temp, y_train, y_temp = train_test_split(\n",
    "    x_raw, x_mfcc, y, test_size=0.30, random_state=seed, stratify=y\n",
    ")\n",
    "\n",
    "# Second split: Validation (15%) and Test (15%) from Temp (30%)\n",
    "x_raw_val, x_raw_test, x_mfcc_val, x_mfcc_test, y_val, y_test = train_test_split(\n",
    "    x_raw_temp, x_mfcc_temp, y_temp, test_size=0.50, random_state=seed, stratify=y_temp\n",
    ")\n",
    "\n",
    "\n",
    "x_raw_val = np.concatenate([x_raw_val, x_raw_test], axis=0)\n",
    "x_mfcc_val = np.concatenate([x_mfcc_val, x_mfcc_test], axis=0)\n",
    "y_val = np.concatenate([y_val, y_test], axis=0)\n",
    "\n",
    "x_raw_test = x_raw_val\n",
    "x_mfcc_test = x_mfcc_val\n",
    "y_test = y_val\n",
    "\n",
    "print(x_raw_train.shape, x_mfcc_train.shape, y_train.shape)\n",
    "print(x_raw_val.shape, x_mfcc_val.shape, y_val.shape)\n",
    "print(x_raw_test.shape, x_mfcc_test.shape, y_test.shape)\n",
    "\n",
    "\n",
    "\n",
    "np.save('x_train.npy', x_mfcc_train)\n",
    "np.save('x_valid.npy', x_mfcc_val)\n",
    "np.save('x_test.npy', x_mfcc_test)\n",
    "\n",
    "np.save('y_train.npy', y_train)\n",
    "np.save('y_valid.npy', y_val)\n",
    "np.save('y_test.npy', y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "category = 10\n",
    "word_len = 1\n",
    "\n",
    "def convert_to_sax(input_x):\n",
    "    x_sax = np.zeros(input_x.shape)\n",
    "    for i in range(len(x_sax)):\n",
    "        start = 0\n",
    "        for j in range(input_x.shape[-1]):\n",
    "            temp = sax_tokenizer(input_x[i,:,j].tolist(),alphabet_size=category, word_length=word_len) #+ start\n",
    "            x_sax[i,:,j] =  np.array(temp) + start\n",
    "            start += category\n",
    "        if i%100 == 0:\n",
    "            print(f'Done with {i}')        \n",
    "    return x_sax      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done with 0\n",
      "Done with 100\n",
      "Done with 200\n",
      "Done with 0\n",
      "Done with 100\n",
      "Done with 0\n",
      "Done with 100\n"
     ]
    }
   ],
   "source": [
    "idx = 1\n",
    "x_train = np.load(f'./x_train.npy', allow_pickle=True)\n",
    "x_valid = np.load(f'./x_valid.npy', allow_pickle=True)\n",
    "x_test = np.load(f'./x_test.npy', allow_pickle=True)\n",
    "\n",
    "# x_train = np.load(f'./x_raw_train_{idx}.npy', allow_pickle=True)\n",
    "# x_valid = np.load(f'./x_raw_valid_{idx}.npy', allow_pickle=True)\n",
    "# x_test = np.load(f'./x_raw_test_{idx}.npy', allow_pickle=True)\n",
    "\n",
    "sax_train = convert_to_sax(x_train)\n",
    "sax_valid = convert_to_sax(x_valid)\n",
    "sax_test = convert_to_sax(x_test)\n",
    "\n",
    "assert x_train.shape[0]==sax_train.shape[0]\n",
    "assert x_valid.shape[0]==sax_valid.shape[0]\n",
    "assert x_test.shape[0]==sax_test.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(280, 501, 40) (120, 501, 40) (120, 501, 40)\n"
     ]
    }
   ],
   "source": [
    "print(sax_train.shape, sax_valid.shape, sax_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,\n",
       "         11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,\n",
       "         22.,  23.,  24.,  25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,\n",
       "         33.,  34.,  35.,  36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,\n",
       "         44.,  45.,  46.,  47.,  48.,  49.,  50.,  51.,  52.,  53.,  54.,\n",
       "         55.,  56.,  57.,  58.,  59.,  60.,  61.,  62.,  63.,  64.,  65.,\n",
       "         66.,  67.,  68.,  69.,  70.,  71.,  72.,  73.,  74.,  75.,  76.,\n",
       "         77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,  86.,  87.,\n",
       "         88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.,\n",
       "         99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109.,\n",
       "        110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120.,\n",
       "        121., 122., 123., 124., 125., 126., 127., 128., 129., 130., 131.,\n",
       "        132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142.,\n",
       "        143., 144., 145., 146., 147., 148., 149., 150., 151., 152., 153.,\n",
       "        154., 155., 156., 157., 158., 159., 160., 161., 162., 163., 164.,\n",
       "        165., 166., 167., 168., 169., 170., 171., 172., 173., 174., 175.,\n",
       "        176., 177., 178., 179., 180., 181., 182., 183., 184., 185., 186.,\n",
       "        187., 188., 189., 190., 191., 192., 193., 194., 195., 196., 197.,\n",
       "        198., 199., 200., 201., 202., 203., 204., 205., 206., 207., 208.,\n",
       "        209., 210., 211., 212., 213., 214., 215., 216., 217., 218., 219.,\n",
       "        220., 221., 222., 223., 224., 225., 226., 227., 228., 229., 230.,\n",
       "        231., 232., 233., 234., 235., 236., 237., 238., 239., 240., 241.,\n",
       "        242., 243., 244., 245., 246., 247., 248., 249., 250., 251., 252.,\n",
       "        253., 254., 255., 256., 257., 258., 259., 260., 261., 262., 263.,\n",
       "        264., 265., 266., 267., 268., 269., 270., 271., 272., 273., 274.,\n",
       "        275., 276., 277., 278., 279., 280., 281., 282., 283., 284., 285.,\n",
       "        286., 287., 288., 289., 290., 291., 292., 293., 294., 295., 296.,\n",
       "        297., 298., 299., 300., 301., 302., 303., 304., 305., 306., 307.,\n",
       "        308., 309., 310., 311., 312., 313., 314., 315., 316., 317., 318.,\n",
       "        319., 320., 321., 322., 323., 324., 325., 326., 327., 328., 329.,\n",
       "        330., 331., 332., 333., 334., 335., 336., 337., 338., 339., 340.,\n",
       "        341., 342., 343., 344., 345., 346., 347., 348., 349., 350., 351.,\n",
       "        352., 353., 354., 355., 356., 357., 358., 359., 360., 361., 362.,\n",
       "        363., 364., 365., 366., 367., 368., 369., 370., 371., 372., 373.,\n",
       "        374., 375., 376., 377., 378., 379., 380., 381., 382., 383., 384.,\n",
       "        385., 386., 387., 388., 389., 390., 391., 392., 393., 394., 395.,\n",
       "        396., 397., 398., 399.]),\n",
       " 400)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(sax_train), len(np.unique(sax_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def onehotencoding(x_sax):\n",
    "    x_sax = x_sax.astype(int)\n",
    "    x_sax_ohe = np.zeros((x_sax.shape[0], x_sax.shape[1], category*x_sax.shape[-1]))\n",
    "\n",
    "    for i in range(len(x_sax_ohe)):\n",
    "        for j in range(x_sax_ohe.shape[1]): \n",
    "            idx = x_sax[i,j,:].tolist()\n",
    "            x_sax_ohe[i,j,idx] = 1\n",
    "    \n",
    "    return x_sax_ohe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "sax_train_ohe = onehotencoding(sax_train)\n",
    "sax_valid_ohe = onehotencoding(sax_valid)\n",
    "sax_test_ohe = onehotencoding(sax_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[  6.  14.  23.  31.  45.  53.  67.  77.  86.  95. 107. 119. 124. 136.\n",
      "  149. 156. 162. 175. 186. 195. 203. 210. 220. 232. 240. 255. 266. 271.\n",
      "  283. 290. 308. 313. 329. 338. 344. 357. 364. 377. 380. 396.]\n",
      " [  7.  18.  22.  32.  48.  57.  68.  79.  87.  95. 109. 117. 123. 136.\n",
      "  146. 150. 160. 177. 187. 193. 200. 210. 221. 233. 243. 253. 266. 278.\n",
      "  284. 297. 308. 317. 323. 331. 342. 350. 364. 373. 380. 399.]] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.\n",
      " 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.\n",
      " 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n"
     ]
    }
   ],
   "source": [
    "print(sax_train[100,0:2,:], sax_train_ohe[100,1,:])\n",
    "# print(sax_valid[100,0:10,:], sax_valid_ohe[100,2,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('./sax_train', sax_train_ohe)\n",
    "np.save('./sax_valid', sax_valid_ohe)\n",
    "np.save('./sax_test', sax_test_ohe)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
