{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import math\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import matplotlib as mpl\n",
    "import os\n",
    "import gc\n",
    "import pandas as pd\n",
    "import csv\n",
    "from numpy import *\n",
    "from datetime import date\n",
    "import time\n",
    "import builtins\n",
    "from sax import sax_tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 50 #[42,43,50]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1050 146 130\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "import os\n",
    "\n",
    "# Step 1: Load the datasets\n",
    "x_train = np.load(\"x_train.npy\", allow_pickle=True)\n",
    "y_train = np.load(\"y_train.npy\", allow_pickle=True)\n",
    "x_test = np.load(\"x_test.npy\", allow_pickle=True)\n",
    "y_test = np.load(\"y_test.npy\", allow_pickle=True)\n",
    "\n",
    "# Step 2: Combine them\n",
    "X = np.concatenate([x_train, x_test], axis=0)\n",
    "y = np.concatenate([y_train, y_test], axis=0)\n",
    "\n",
    "\n",
    "\n",
    "# Step 3: First split: Train+Valid and Test\n",
    "x_train_valid, x_test, y_train_valid, y_test = train_test_split(\n",
    "    X, y, test_size=0.11, random_state=seed, stratify=y\n",
    ")\n",
    "\n",
    "# Step 4: Second split: Train and Valid\n",
    "x_train, x_valid, y_train, y_valid = train_test_split(\n",
    "    x_train_valid, y_train_valid, test_size=0.11, random_state=seed, stratify=y_train_valid\n",
    ")\n",
    "# (0.25 * 0.8 = 0.2 --> 60% train, 20% valid, 20% test)\n",
    "\n",
    "# Step 5: Save them\n",
    "np.save(\"x_train.npy\", x_train)\n",
    "np.save(\"y_train.npy\", y_train)\n",
    "np.save(\"x_valid.npy\", x_valid)\n",
    "np.save(\"y_valid.npy\", y_valid)\n",
    "np.save(\"x_test.npy\", x_test)\n",
    "np.save(\"y_test.npy\", y_test)\n",
    "\n",
    "print(len(x_train), len(x_test), len(x_valid))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "category = 70\n",
    "word_len = 1\n",
    "\n",
    "def convert_to_sax(input_x):\n",
    "    x_sax = np.zeros(input_x.shape)\n",
    "    for i in range(len(x_sax)):\n",
    "        start = 0\n",
    "        for j in range(input_x.shape[-1]):\n",
    "            temp = sax_tokenizer(input_x[i,:,j].tolist(),alphabet_size=category, word_length=word_len) #+ start\n",
    "            x_sax[i,:,j] =  np.array(temp) + start\n",
    "            start += category\n",
    "        if i%100 == 0:\n",
    "            print(f'Done with {i}')        \n",
    "    return x_sax      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done with 0\n",
      "Done with 100\n",
      "Done with 200\n",
      "Done with 300\n",
      "Done with 400\n",
      "Done with 500\n",
      "Done with 600\n",
      "Done with 700\n",
      "Done with 800\n",
      "Done with 900\n",
      "Done with 1000\n",
      "Done with 0\n",
      "Done with 100\n",
      "Done with 0\n",
      "Done with 100\n"
     ]
    }
   ],
   "source": [
    "idx = 1\n",
    "x_train = np.load(f'./x_train.npy', allow_pickle=True)\n",
    "x_valid = np.load(f'./x_valid.npy', allow_pickle=True)\n",
    "x_test = np.load(f'./x_test.npy', allow_pickle=True)\n",
    "\n",
    "sax_train = convert_to_sax(x_train)\n",
    "sax_valid = convert_to_sax(x_valid)\n",
    "sax_test = convert_to_sax(x_test)\n",
    "\n",
    "assert x_train.shape[0]==sax_train.shape[0]\n",
    "assert x_valid.shape[0]==sax_valid.shape[0]\n",
    "assert x_test.shape[0]==sax_test.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1050, 500, 1) (130, 500, 1) (146, 500, 1)\n"
     ]
    }
   ],
   "source": [
    "print(sax_train.shape, sax_valid.shape, sax_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,\n",
       "        13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,\n",
       "        26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38.,\n",
       "        39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51.,\n",
       "        52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64.,\n",
       "        65., 66., 67., 68., 69.]),\n",
       " 70)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(sax_train), len(np.unique(sax_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def onehotencoding(x_sax):\n",
    "    x_sax = x_sax.astype(int)\n",
    "    x_sax_ohe = np.zeros((x_sax.shape[0], x_sax.shape[1], category*x_sax.shape[-1]))\n",
    "\n",
    "    for i in range(len(x_sax_ohe)):\n",
    "        for j in range(x_sax_ohe.shape[1]): \n",
    "            idx = x_sax[i,j,:].tolist()\n",
    "            x_sax_ohe[i,j,idx] = 1\n",
    "    \n",
    "    return x_sax_ohe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "sax_train_ohe = onehotencoding(sax_train)\n",
    "sax_valid_ohe = onehotencoding(sax_valid)\n",
    "sax_test_ohe = onehotencoding(sax_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[4.]\n",
      " [2.]] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n"
     ]
    }
   ],
   "source": [
    "print(sax_train[100,0:2,:], sax_train_ohe[100,1,:])\n",
    "# print(sax_valid[100,0:10,:], sax_valid_ohe[100,2,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('./sax_train', sax_train_ohe)\n",
    "np.save('./sax_valid', sax_valid_ohe)\n",
    "np.save('./sax_test', sax_test_ohe)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
