{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import math\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.autograd import Variable\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import preprocessing\n",
    "from sklearn.metrics import r2_score\n",
    "import random\n",
    "import matplotlib as mpl\n",
    "import os\n",
    "import gc\n",
    "import pandas as pd\n",
    "import csv\n",
    "from numpy import *\n",
    "\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from datetime import date\n",
    "from sax import sax_tokenizer\n",
    "# from generate_property import output_property"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of input_x (10000, 500, 6)\n",
      "Size of output_y (10000,)\n"
     ]
    }
   ],
   "source": [
    "data_dir = '../generating_raw_data/data_small/'\n",
    "total_files = 10000\n",
    "max_len = 500\n",
    "num_f = 6\n",
    "\n",
    "input_x = np.zeros((total_files,max_len,num_f))\n",
    "output_y = np.zeros((total_files,))\n",
    "\n",
    "for i in range(total_files):\n",
    "    temp_x = np.load(f'{data_dir}sample_{i}.npy', allow_pickle=True)\n",
    "    temp_y = np.load(f'{data_dir}target_{i}.npy', allow_pickle=True)\n",
    "       \n",
    "    input_x[i,...] = temp_x\n",
    "    output_y[i,...] = temp_y\n",
    "\n",
    "\n",
    "print('Size of input_x', input_x.shape)\n",
    "print('Size of output_y', output_y.shape)\n",
    "\n",
    "# input_x = input_x[0:100,...]\n",
    "# output_y = output_y[0:100,]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**SAX: Symbolic Representation**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# '''Write a function to map the sax representation to the actual sequence length'''\n",
    "# def decode(sax_rep, original_len,word_len):\n",
    "#     decode_seq = np.zeros((sax_rep.shape[0], original_len))\n",
    "#     l = sax_rep.shape[0]\n",
    "#     count = 0\n",
    "#     while count < l:\n",
    "#         print('use this', sax_rep[0:3,count])\n",
    "#         decode_seq[:,count:count+word_len] = sax_rep[:,count]\n",
    "#         print('Result', decode_seq[0:3,count:count+word_len])\n",
    "#         count += word_len\n",
    "#         print(aaa)\n",
    "    \n",
    "#     return decode_seq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done with 0\n",
      "Done with 500\n",
      "Done with 1000\n",
      "Done with 1500\n",
      "Done with 2000\n",
      "Done with 2500\n",
      "Done with 3000\n",
      "Done with 3500\n",
      "Done with 4000\n",
      "Done with 4500\n",
      "Done with 5000\n",
      "Done with 5500\n",
      "Done with 6000\n",
      "Done with 6500\n",
      "Done with 7000\n",
      "Done with 7500\n",
      "Done with 8000\n",
      "Done with 8500\n",
      "Done with 9000\n",
      "Done with 9500\n"
     ]
    }
   ],
   "source": [
    "category = 2\n",
    "np.save('./num_category', category)\n",
    "word_len = 1\n",
    "x_sax = np.zeros(input_x.shape)\n",
    "\n",
    "for i in range(len(x_sax)):\n",
    "    start = 0\n",
    "    for j in range(input_x.shape[-1]):\n",
    "        temp = sax_tokenizer(input_x[i,:,j].tolist(),alphabet_size=category, word_length=word_len) #+ start\n",
    "        x_sax[i,:,j] =  np.array(temp) + start\n",
    "        # print(temp)\n",
    "        start += category\n",
    "    if i%500 == 0:\n",
    "        print(f'Done with {i}')\n",
    "\n",
    "\n",
    "# x_sax_decoded = decode(x_sax, input_x.shape[1],word_len)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_sax = x_sax.astype(int)\n",
    "x_sax_ohe = np.zeros((input_x.shape[0], input_x.shape[1], category*input_x.shape[-1]))\n",
    "\n",
    "for i in range(len(x_sax_ohe)):\n",
    "    for j in range(x_sax_ohe.shape[1]): \n",
    "        idx = x_sax[i,j,:].tolist()\n",
    "        x_sax_ohe[i,j,idx] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 1  2  5  7  9 11]\n",
      "[0. 1. 1. 0. 0. 1. 0. 1. 0. 1. 0. 1.]\n"
     ]
    }
   ],
   "source": [
    "print(x_sax[50,50,:])\n",
    "print(x_sax_ohe[50,50,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_length =np.ones(output_y.shape)*max_len"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train (5000,)\n",
      "Test (500,)\n",
      "Valid (500,)\n",
      "(5000, 500, 6) (5000,) (5000,)\n",
      "2503.0 242.0 244.0\n"
     ]
    }
   ],
   "source": [
    "seed = 50 ## [10,50,70]\n",
    "all_ex = np.arange(input_x.shape[0])\n",
    "X, x_test, _, _ = train_test_split( all_ex, all_ex, test_size=0.05,random_state=seed) ## [10,50,70]\n",
    "x_train, x_valid, _, _ = train_test_split( X, X, test_size=0.05263,random_state=seed)\n",
    "_, x_train, _, _ = train_test_split( x_train, x_train, test_size=0.55555,random_state=50)\n",
    "\n",
    "print('Train',x_train.shape)\n",
    "print('Test' ,x_test.shape)\n",
    "print('Valid',x_valid.shape)\n",
    "\n",
    "print(input_x[x_train].shape, seq_length[x_train].shape, output_y[x_train].shape)\n",
    "\n",
    "print(np.sum(output_y[x_train]),np.sum(output_y[x_valid]),np.sum(output_y[x_test]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('./x_train', input_x[x_train])\n",
    "np.save(f'./sax_train_{category}', x_sax_ohe[x_train])\n",
    "np.save('./len_train', seq_length[x_train])\n",
    "np.save('./y_train', output_y[x_train])\n",
    "\n",
    "np.save('./x_valid', input_x[x_valid])\n",
    "np.save(f'./sax_valid_{category}', x_sax_ohe[x_valid])\n",
    "np.save('./len_valid', seq_length[x_valid])\n",
    "np.save('./y_valid', output_y[x_valid])\n",
    "\n",
    "np.save('./x_test', input_x[x_test])\n",
    "np.save(f'./sax_test_{category}', x_sax_ohe[x_test])\n",
    "np.save('./len_test', seq_length[x_test])\n",
    "np.save('./y_test', output_y[x_test])\n",
    "np.save('./test_idx', x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[9102 7868 4176 4161 8770 8919 3884 1590  617 1562 1043 3494 2884 1505\n",
      "  102 1912 7294 2464 7070 6468 6651 5498 8821 3777 4913 3614 6770 9912\n",
      " 7044 6637    4  944 7503 2069 1441 8933 2354 4995 9612 9880 6529 3534\n",
      " 1072 1261 8070 7138 5859 5919 5802  158 7666 5540 8291 2216 9625 9801\n",
      "  329 7729 1924  434 8470 3396 2089 9736 5522 3437 2027  801 3410 6251\n",
      " 5706 1249 1218 9428 2095 4020 7187 8219 9682 5430 2949  846 9408 1433\n",
      " 3761  708 7879 5316 7389 8898 7533  428 6367  829 3703 5238 9353 6264\n",
      " 5809 6423   12 2319 6944 5699 7940 6133 4639 1888 7493 2674 2261 8825\n",
      "  956 1489 8718 3737 4162 4244 3188 8421 1980 8371 8022  438 3471 6334\n",
      " 6487  513 2415 8883 1514 2567 7264 8942 7753 9051 1615 2229 6299 8255\n",
      " 3873 1681 3416 7481 6149  842 5019 4622 1296  914 9500 2937 9918 5096\n",
      " 5778 2628 1832 3201  203 1345  253 2413  542 1823 1944 6961 8625 7574\n",
      " 4079 7751 6575 6048 8049 6747 9897 9182 3200 3426 5380 8202  453  864\n",
      "  246 4609 9426 1226  976    2 8564 7926 8335 4334 3731 3715 4364 5094\n",
      "   61 1576 8026 1712 9637 6719 4977 3949 2934 4653 4923 5345 2890 7393\n",
      "  343 1804 9623 7195 1018 6093  624 4585 9072 7807 9590 7324 9613 1854\n",
      " 1577 4684 4600  625 3720 8993 8903 2136 6341 7509 3272 6609 8179 1233\n",
      " 1808 7710 3167  915 1863 6436 5063 8573 8602 6373 2430 6691 8568 9777\n",
      " 5573 1128 4078 9638 7974 8835 4277 3340 2144 5030 5209 6443 2939 4171\n",
      " 9703 1620 6236 4447 6178 6202 4989 3789 3252 2086 8298 7058 2366 9992\n",
      " 4491 1338 6176 1880 5502 5366 4956  713 5357 2203 9045 8080 2513 3311\n",
      " 2862 2759 9788 4694 5915  156 3338 2606 6131 2077 9653 6617 7564  733\n",
      " 3433 3629 5354 8110 8761 1393  794   79 6668 7354 6038 7599 1265 8854\n",
      " 9284 1696 6867 9166 9237 1323  341 1782  851 7071 2245 1475 9377 8991\n",
      " 7377 5623 4824 3209 9133 9242 9217 1983 9829 3153 2955 8301 3661 4863\n",
      " 7307 7555 5091 3431  529 1308 4249 4100 5184 4075 4124 7502 2234 2672\n",
      " 7091 1971 7618 5667 9687 3584  883 7683 2554  251 4937  245 6587 3705\n",
      " 9040 2637 3010 6489 2803 9652 2898 4123 7342 7410 4285 1735  493  537\n",
      " 3429 2865 7010  177 3077  366 1639 1642 7159 3556 2385 8278 9355 6745\n",
      " 1095  152 3878 6070 2268 1106 2078 4406 7007 4983  727 2065 7511 9907\n",
      " 7878 1124 7801 7338 8896 7175 4473 7033  820 8587 3263 7862 9192 9586\n",
      " 9066 8528 9676 9896 3329 6531 8489  831 2568 9416 6885 4947 1602 8624\n",
      " 5792 3752 1781 9318 4348   67 8726 9334 9510 6991 3186 9755 7331 9956\n",
      " 6204 3319 3355 2586   99 5605 9823 2130 5073 2617 6398 5180  369 1046\n",
      " 4480 2125 9527 8410 4412 5949 1166 6510  644 9086 4748 8157 4951  784\n",
      " 3834 6669 9410 1390 5324 6150 5920 4830 5940 2233]\n"
     ]
    }
   ],
   "source": [
    "print(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(500, 500, 12)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_sax_ohe[x_test].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
