{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import math\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import matplotlib as mpl\n",
    "import os\n",
    "import gc\n",
    "import pandas as pd\n",
    "import csv\n",
    "from numpy import *\n",
    "from datetime import date\n",
    "import time\n",
    "import builtins\n",
    "from sax import sax_tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "category = 25\n",
    "word_len = 1\n",
    "\n",
    "def convert_to_sax(input_x):\n",
    "    x_sax = np.zeros(input_x.shape)\n",
    "    for i in range(len(x_sax)):\n",
    "        start = 0\n",
    "        for j in range(input_x.shape[-1]):\n",
    "            temp = sax_tokenizer(input_x[i,:,j].tolist(),alphabet_size=category, word_length=word_len) #+ start\n",
    "            x_sax[i,:,j] =  np.array(temp) + start\n",
    "            start += category\n",
    "        if i%100 == 0:\n",
    "            print(f'Done with {i}')        \n",
    "    return x_sax      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done with 0\n",
      "Done with 100\n",
      "Done with 200\n",
      "Done with 300\n",
      "Done with 400\n",
      "Done with 500\n",
      "Done with 600\n",
      "Done with 700\n",
      "Done with 800\n",
      "Done with 900\n",
      "Done with 1000\n",
      "Done with 1100\n",
      "Done with 1200\n",
      "Done with 1300\n",
      "Done with 1400\n",
      "Done with 1500\n",
      "Done with 1600\n",
      "Done with 1700\n",
      "Done with 1800\n",
      "Done with 1900\n",
      "Done with 0\n",
      "Done with 100\n",
      "Done with 200\n",
      "Done with 300\n",
      "Done with 400\n",
      "Done with 0\n",
      "Done with 100\n",
      "Done with 200\n",
      "Done with 300\n",
      "Done with 400\n"
     ]
    }
   ],
   "source": [
    "idx = 1\n",
    "x_train = np.load(f'./x_train.npy', allow_pickle=True)\n",
    "x_valid = np.load(f'./x_valid.npy', allow_pickle=True)\n",
    "x_test = np.load(f'./x_test.npy', allow_pickle=True)\n",
    "\n",
    "sax_train = convert_to_sax(x_train)\n",
    "sax_valid = convert_to_sax(x_valid)\n",
    "sax_test = convert_to_sax(x_test)\n",
    "\n",
    "assert x_train.shape[0]==sax_train.shape[0]\n",
    "assert x_valid.shape[0]==sax_valid.shape[0]\n",
    "assert x_test.shape[0]==sax_test.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2000, 3000, 1) (500, 3000, 1) (500, 3000, 1)\n"
     ]
    }
   ],
   "source": [
    "print(sax_train.shape, sax_valid.shape, sax_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def onehotencoding(x_sax):\n",
    "    x_sax = x_sax.astype(int)\n",
    "    x_sax_ohe = np.zeros((x_sax.shape[0], x_sax.shape[1], category*x_sax.shape[-1]))\n",
    "\n",
    "    for i in range(len(x_sax_ohe)):\n",
    "        for j in range(x_sax_ohe.shape[1]): \n",
    "            idx = x_sax[i,j,:].tolist()\n",
    "            x_sax_ohe[i,j,idx] = 1\n",
    "    \n",
    "    return x_sax_ohe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sax_train_ohe = onehotencoding(sax_train)\n",
    "sax_valid_ohe = onehotencoding(sax_valid)\n",
    "sax_test_ohe = onehotencoding(sax_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 8.]\n",
      " [ 9.]\n",
      " [ 7.]\n",
      " [ 7.]\n",
      " [11.]\n",
      " [14.]\n",
      " [17.]\n",
      " [19.]\n",
      " [20.]\n",
      " [19.]] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0.]\n",
      "[[ 1.]\n",
      " [ 1.]\n",
      " [ 1.]\n",
      " [ 2.]\n",
      " [ 2.]\n",
      " [ 8.]\n",
      " [ 8.]\n",
      " [10.]\n",
      " [11.]\n",
      " [14.]] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0.]\n"
     ]
    }
   ],
   "source": [
    "print(sax_train[100,0:10,:], sax_train_ohe[100,2,:])\n",
    "print(sax_valid[100,0:10,:], sax_valid_ohe[100,2,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('./sax_train', sax_train_ohe)\n",
    "np.save('./sax_valid', sax_valid_ohe)\n",
    "np.save('./sax_test', sax_test_ohe)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
