{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11abfc26",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from TICC_solver import MTTICC\n",
    "from pygapbide import *\n",
    "import random\n",
    "from math import nan\n",
    "import pickle\n",
    "import csv\n",
    "import os\n",
    "os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'\n",
    "\n",
    "import multiprocessing as mp\n",
    "from sklearn.datasets import make_blobs\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.metrics import silhouette_samples, silhouette_score\n",
    "\n",
    "import tensorflow as tf\n",
    "import gym\n",
    "from gym import wrappers\n",
    "from gym.envs.classic_control.pendulum import angle_normalize, PendulumEnv\n",
    "import d4rl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f227065",
   "metadata": {},
   "outputs": [],
   "source": [
    "# read data\n",
    "env = gym.make(\"pen-human-v1\")\n",
    "min_c, max_c = 10, 20 # min, max value for searching cluster number\n",
    "min_sup = 10 # min-support can be set as 1, but we keep it as 10 for fast tracing, which won't affect the result for topk selection \n",
    "\n",
    "d4rl_original_data = [i for i in d4rl.sequence_dataset(env)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2e603e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get # of clusters\n",
    "X = [r for i in range(len(d4rl_original_data)) for r in d4rl_original_data[i]['observations'].tolist()] # data\n",
    "range_n_clusters = range(min_c, max_c+1)\n",
    "sil_score = [0.]*(max_c-min_c+1)\n",
    "\n",
    "def getSilScore(n_clusters):\n",
    "    clusterer = KMeans(n_clusters=n_clusters, random_state=10)\n",
    "    cluster_labels = clusterer.fit_predict(X)\n",
    "    silhouette_avg = silhouette_score(X, cluster_labels)\n",
    "\n",
    "    return [n_clusters,silhouette_avg]\n",
    "\n",
    "pool = mp.Pool(5)\n",
    "sil_score = pool.map(getSilScore, [n_clusters for n_clusters in range_n_clusters])\n",
    "pool.close()\n",
    "pool.join()    \n",
    "    \n",
    "num_clusters = sil_score[[i[1] for i in sil_score].index(max([el[1] for el in sil_score]))][0]\n",
    "print('selected num of clusters: {} !'.format(num_clusters))\n",
    "\n",
    "# mtticc\n",
    "# need to fix: cannot be terminated on server\n",
    "\n",
    "# preprocess data as .data first\n",
    "data = []\n",
    "intervals = []\n",
    "\n",
    "for v in range(len(d4rl_original_data)):\n",
    "    \n",
    "    v_data = d4rl_original_data[v]['observations'].tolist()\n",
    "    data.append(v_data)\n",
    "    \n",
    "    v_intervals = []\n",
    "    v_intervals.append(nan)\n",
    "    v_intervals = v_intervals+[1. for i in range(len(v_data)-1)]   \n",
    "    intervals.append(v_intervals)\n",
    "    \n",
    "raw_save_path = './raw_data/data.data'\n",
    "pickle_out = open(raw_save_path,'wb')\n",
    "pickle.dump([data,intervals], pickle_out)\n",
    "pickle_out.close()\n",
    "\n",
    "# start mtticc\n",
    "mtticc = MTTICC(fixed_window=1, number_of_clusters=num_clusters, lambda_parameter=11e-3,\n",
    "                beta=1, maxIters=100, num_proc=6, input_pattern='multiple', window_pattern='fixed')\n",
    "TICC_fname = raw_save_path\n",
    "TICC_return_cv = mtticc.fit(TICC_fname)\n",
    "pred_cluster_cv = TICC_return_cv[1]\n",
    "\n",
    "print('MTTICC assignment done!')\n",
    "cluster_list = []\n",
    "for i in pred_cluster_cv:\n",
    "    cluster_list = cluster_list+i\n",
    "\n",
    "np.savetxt('./raw_data/mtticc_clusters.txt', cluster_list, fmt='%d', delimiter=',')\n",
    "\n",
    "# read cluster index by samples\n",
    "\n",
    "fname = './raw_data/mtticc_clusters.txt'\n",
    "cluster = []\n",
    "with open(fname, 'r') as fd:\n",
    "    reader = csv.reader(fd)\n",
    "    for row in reader:\n",
    "        cluster+=row\n",
    "cluster = [int(i) for i in cluster]\n",
    "\n",
    "cluster_db = []\n",
    "p = 0\n",
    "for i in range(len(d4rl_original_data)):\n",
    "    i_len = len(d4rl_original_data[i]['observations'])\n",
    "    cluster_db.append(cluster[p:(p+i_len)])\n",
    "    p += i_len\n",
    "    \n",
    "## get patterns\n",
    "g = Gapbide(cluster_db, min_sup, 0, 0) # (min-support, gap)\n",
    "raw_pattern = g.run()\n",
    "\n",
    "# here, we directly search the one which support = N, i.e., top1\n",
    "# this can be extended to iteratively searching procedure for easier blackbox usage\n",
    "for i in raw_pattern:\n",
    "    if i[1] == len(d4rl_original_data):\n",
    "        pattern = i[0]\n",
    "        break\n",
    "\n",
    "# get segments of traj for augmentation\n",
    "# data format\n",
    "# list of dictionaries:\n",
    "# List: N traj\n",
    "# dictionary: 9 elements, {\n",
    "#     'observations', NEED,\n",
    "#         'actions': An N by action dimensional array of actions (1000,6) ([[]]), NEED\n",
    "#         'rewards': An N dimensional array of rewards, NEED\n",
    "#         'next_observations': An N by observation dimensional array of observations. (1000,17), NEED\n",
    "#         'terminals': An N dimensional array of episode termination flags. This is true when episodes end due to termination conditions such as falling over., NEED\n",
    "#             }\n",
    "seg = []\n",
    "\n",
    "def find_sub_list(sl,l):\n",
    "    sll=len(sl)\n",
    "    for ind in (i for i,e in enumerate(l) if e==sl[0]):\n",
    "        if l[ind:ind+sll]==sl:\n",
    "            return ind,ind+sll\n",
    "    return -1,-1\n",
    "\n",
    "# given a pattern, find all segments from each traj\n",
    "# NOTE: for list, l[a:b] exclude b\n",
    "for user in range(len(d4rl_original_data)):\n",
    "    # initialize local segment\n",
    "    user_seg = dict.fromkeys(['observations','actions','rewards','next_observations','terminals'])\n",
    "    begin, end = -1,-1\n",
    "    \n",
    "    begin, end = find_sub_list(pattern, cluster_db[user][:-1])\n",
    "    if begin != -1:\n",
    "        user_seg['observations'] = d4rl_original_data[user]['observations'][begin:end]\n",
    "        user_seg['actions'] = d4rl_original_data[user]['actions'][begin:end,:] #2D\n",
    "        user_seg['rewards'] = d4rl_original_data[user]['rewards'][begin:end]\n",
    "        if end != len(d4rl_original_data[user]['observations']):\n",
    "            user_seg['next_observations'] = d4rl_original_data[user]['observations'][(begin+1):(end+1)]\n",
    "        else:\n",
    "            user_seg['next_observations'] = [-1]*len(d4rl_original_data[user]['observations'][0])\n",
    "        user_seg['terminals'] = np.asarray([False if i < len(pattern)-1 else True for i in range(len(pattern))])\n",
    "        \n",
    "        seg.append(user_seg)\n",
    "\n",
    "with open('./processed_data/train_pattern.npy', 'wb') as f:\n",
    "    np.save(f, seg)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
