{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf0c67e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from utils.constant import PROBLEM_FEATURES\n",
    "from mtticc.TICC_solver import MTTICC\n",
    "from prefixspan import PrefixSpan\n",
    "import random\n",
    "\n",
    "import math\n",
    "from utils.importance_sampling5_Spring20ONLY import Importance_Sampling\n",
    "\n",
    "from tensorflow.keras.models import load_model\n",
    "import tensorflow as tf\n",
    "\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from math import nan\n",
    "import pickle\n",
    "import csv\n",
    "\n",
    "# load data\n",
    "file_name = './raw_data/data.csv'\n",
    "data_path = '{}.csv'.format(file_name)\n",
    "\n",
    "feature_list = PROBLEM_FEATURES\n",
    "student_state = raw_data[feature_list].values\n",
    "print('finish loading data')\n",
    "# set time range\n",
    "test_data = np.load('test.pkl', allow_pickle=True)\n",
    "\n",
    "raw_data = np.load('train.pkl', allow_pickle=True)\n",
    "user_list = list(raw_data['userID'].unique())\n",
    "\n",
    "returns = np.load('student_final_grade.pkl', allow_pickle=True)\n",
    "\n",
    "gamma = 0.9\n",
    "horizon = 12\n",
    "pattern_list = [[2,1,6], [7, 4], [7, 0], [7, 9], [7, 8], [7, 3], [4], [9], [8], [1, 8]]\n",
    "potential_policy_returns = pd.read_pickle('processed_data/est_target_return_from_h_beh_random.pkl')\n",
    "\n",
    "# Calculate accumulated reward for each user\n",
    "def calculate_accumulated_reward(group):\n",
    "    # Assign time steps dynamically\n",
    "    group['time_step'] = range(1, len(group) + 1)\n",
    "    # Calculate accumulated reward\n",
    "    return sum(gamma**(t - 1) * rew for t, rew in zip(group['time_step'], group['inferred_rew']))\n",
    "\n",
    "# Create the dictionary\n",
    "accumulated_rewards_train = raw_data.groupby('userID').apply(calculate_accumulated_reward).to_dict()\n",
    "accumulated_rewards_test = test_data.groupby('userID').apply(calculate_accumulated_reward).to_dict()\n",
    "\n",
    "def getClusterSeq(raw_data):\n",
    "    db = []\n",
    "    user_list = raw_data['userID'].unique()\n",
    "    for user in user_list:\n",
    "        user_cluster_seq = raw_data.loc[raw_data['userID'] == user]['cluster'].tolist()\n",
    "        db.append(user_cluster_seq)\n",
    "    return db\n",
    "\n",
    "def find_sub_list(sl,l):\n",
    "    sll=len(sl)\n",
    "    for ind in (i for i,e in enumerate(l) if e==sl[0]):\n",
    "        if l[ind:ind+sll]==sl:\n",
    "            return ind,ind+sll\n",
    "    return -1,-1\n",
    "\n",
    "def if_sub_list(sl,l):\n",
    "    sll=len(sl)\n",
    "    for ind in (i for i,e in enumerate(l) if e==sl[0]):\n",
    "        if l[ind:ind+sll]==sl:\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "# after getting cluster indices, check patterns\n",
    "cluster_seqs = getClusterSeq(raw_data)\n",
    "ps = PrefixSpan(cluster_seqs)\n",
    "\n",
    "# get frequency\n",
    "freq_thres = 0.99\n",
    "freq_ps = ps.frequent(int(freq_thres*len(user_list))) # contain freq info\n",
    "freq_ps_list = [el[1] for el in freq_ps] # pattern only\n",
    "print(freq_ps_list)\n",
    "\n",
    "# get frequency in low performing students\n",
    "thres = np.mean([returns[k] for k in returns.keys()])\n",
    "low_perf = [i for i in user_list if returns[i] <= thres]\n",
    "high_perf = [i for i in user_list if returns[i] > thres]\n",
    "\n",
    "freq_ps_low = ps.frequent(int(freq_thres*len(low_perf))) # contain freq info\n",
    "freq_ps_low_list = [el[1] for el in freq_ps_low] # pattern only\n",
    "print(freq_ps_low_list)\n",
    "\n",
    "# [2,1,6], [7, 4], [7, 0], [7, 9], [7, 8], [7, 3], [4], [9], [8], [1, 8]\n",
    "\n",
    "# sample initial states\n",
    "policies = ['expert', 'FHRL', 'SOCHRL']\n",
    "\n",
    "init_policy = 'expert'\n",
    "init_dataset = test_data # use real test data\n",
    "\n",
    "\n",
    "# use RAND initial data\n",
    "cond_path = 'Summary.csv' # save studentId, conditionID pairs\n",
    "cond_data =  pd.read_csv(cond_path)\n",
    "cond_dict = {'FHRL':20142 , 'SOCHRL':20144 , 'expert':20135}\n",
    "\n",
    "seg = []\n",
    "\n",
    "# get userid\n",
    "conditionID = cond_dict[init_policy]\n",
    "curr_cond_data = cond_data.loc[cond_data['Cond'] == conditionID]\n",
    "cond_user_list = list(curr_cond_data['userID'].unique())\n",
    "cond_user_list = [i+201000 for i in cond_user_list]\n",
    "print(\"policy info: {}, {}\".format(str(init_policy), len(cond_user_list)))\n",
    "\n",
    "initial_data = test_data.loc[test_data['userID'].isin(cond_user_list)] \n",
    "\n",
    "target_policies_threshold =  pd.read_pickle('processed_data/target_policies_threshold_beh_random.pkl')\n",
    "\n",
    "target_policies_threshold = np.nan_to_num(target_policies_threshold, nan=1000.) # for nan pattern, no swtich\n",
    "\n",
    "def earliest_switch_step(pattern_list, ids_cluter_seq):\n",
    "    \n",
    "    earliest_step = 100\n",
    "    target_p = -1\n",
    "    \n",
    "    for p, pattern in enumerate(pattern_list):\n",
    "        begin, end = find_sub_list(pattern, ids_cluter_seq)\n",
    "        if begin != -1:\n",
    "            if end < earliest_step:\n",
    "                earliest_step = end\n",
    "                target_p = p\n",
    "                \n",
    "    return earliest_step, target_p\n",
    "\n",
    "def calculate_partial_accumulated_reward(rewards, begin, end):\n",
    "    # Assign time steps dynamically\n",
    "    time_steps = range(begin, end + 1)\n",
    "    # Calculate accumulated reward\n",
    "    return sum(gamma**t * rew for t, rew in zip(time_steps, rewards))\n",
    "\n",
    "# load augmented data\n",
    "FHRL_sim = pd.read_csv('./augmented_data/FHRL.csv') \n",
    "SOCHRL_sim = pd.read_csv('./augmented_data/SOCHRL.csv') \n",
    "expert_sim = pd.read_csv('./augmented_data/expert.csv') \n",
    "\n",
    "# Compute the closest user\n",
    "def find_closest_user(dataframe, feature_list, h):\n",
    "    closest_user_id = None\n",
    "    min_distance = float('inf')\n",
    "    \n",
    "    # Iterate through each user\n",
    "    for user_id, group in dataframe.groupby('userID'):\n",
    "        # Extract the hth row features\n",
    "        \n",
    "#         print(h)\n",
    "        \n",
    "        row_features = group.iloc[h][PROBLEM_FEATURES].values\n",
    "        \n",
    "        # Compute Euclidean distance\n",
    "        distance = np.linalg.norm(row_features - feature_list)\n",
    "        \n",
    "        # Update closest user\n",
    "        if distance < min_distance:\n",
    "            min_distance = distance\n",
    "            closest_user_id = user_id\n",
    "    \n",
    "    # Return the sub-DataFrame of the closest user\n",
    "    return dataframe[dataframe['userID'] == closest_user_id]\n",
    "\n",
    "# Compute the closest user and estimate next state probability\n",
    "def find_closest_user_allow_multiple(dataframe, feature_list, h):\n",
    "    closest_user_id = None\n",
    "    neighbors = []\n",
    "    \n",
    "    min_distance = float('inf')\n",
    "    \n",
    "    # Iterate through each user\n",
    "    for user_id, group in dataframe.groupby('userID'):\n",
    "        # Extract the hth row features\n",
    "        \n",
    "        row_features = group.iloc[h][PROBLEM_FEATURES].values\n",
    "        \n",
    "        # Compute Euclidean distance\n",
    "        distance = np.linalg.norm(row_features - feature_list)\n",
    "        \n",
    "        # Update closest user\n",
    "        if distance < min_distance:\n",
    "            min_distance = distance\n",
    "            closest_user_id = user_id\n",
    "    \n",
    "    for user_id, group in dataframe.groupby('userID'):\n",
    "        row_features = group.iloc[h][PROBLEM_FEATURES].values\n",
    "        \n",
    "        # Compute Euclidean distance\n",
    "        distance = np.linalg.norm(row_features - feature_list)\n",
    "        \n",
    "        if distance == min_distance:\n",
    "            neighbors.append(user_id)\n",
    "            \n",
    "\n",
    "    dataset = dataframe[dataframe['userID'].isin(neighbors)].reset_index(drop=True)\n",
    "    \n",
    "    # Return the sub-DataFrame of the closest user\n",
    "    return dataset\n",
    "\n",
    "# data without any other policies\n",
    "test_beh_user = list(set(list(test_data['userID'].unique())) - set(cond_user_list))\n",
    "test_beh_data = test_data.loc[test_data['userID'].isin(test_beh_user)]\n",
    "\n",
    "# Run online testing\n",
    "# begin from real-world test data\n",
    "\n",
    "# test_dataset = init_dataset\n",
    "# test_dataset = initial_data # best result ever\n",
    "test_dataset = test_beh_data\n",
    "\n",
    "test_id = test_dataset['userID'].unique()\n",
    "\n",
    "\n",
    "cluster_seqs_test = getClusterSeq(test_dataset)\n",
    "\n",
    "test_returns = []\n",
    "\n",
    "for id_idx, ids in enumerate(test_id):\n",
    "    \n",
    "    # get discrete representation\n",
    "    ids_cluter_seq = cluster_seqs_test[id_idx]\n",
    "    ids_rewards = test_dataset.loc[test_dataset['userID'] == ids]['inferred_rew'].tolist()\n",
    "    ids_returns = calculate_partial_accumulated_reward(ids_rewards, 0, horizon)\n",
    "    \n",
    "    \n",
    "    ids_df = test_dataset.loc[test_dataset['userID'] == ids].reset_index(drop=True)\n",
    "\n",
    "    # if in critical pattern space\n",
    "    # check accumulated rewards by switching from beh policy to best possible target policy\n",
    "    h, g = earliest_switch_step(pattern_list, ids_cluter_seq) # h: ealiest step for switch, g: index of pattern\n",
    "    \n",
    "    can_switch_flag = True\n",
    "    # need to swtich\n",
    "    while h < horizon and can_switch_flag: \n",
    "        \n",
    "        # threshold returns\n",
    "        policy_return_threshold = target_policies_threshold[g]\n",
    "        real_value_0_to_h = calculate_partial_accumulated_reward(ids_rewards, 0, h-1)\n",
    "        max_est_value_h_to_H = np.max(potential_policy_returns[g, h, :])\n",
    "        \n",
    "        max_idx_target_policy = np.argmax(potential_policy_returns[g, h, :])\n",
    "        \n",
    "        max_est_value_0_to_H = real_value_0_to_h + max_est_value_h_to_H # if switch at h\n",
    "        print('est', max_est_value_0_to_H)\n",
    "        \n",
    "        # if better than threshold, switch\n",
    "        # get the subtrajectory from the learnt simulator\n",
    "        if max_est_value_0_to_H > policy_return_threshold:\n",
    "            \n",
    "            print('user {} switch to policy {} at step {}'.format(ids, max_idx_target_policy, h))\n",
    "            \n",
    "            policy_name = policies[max_idx_target_policy]\n",
    "            \n",
    "            \n",
    "            lif policy_name == 'FHRL':\n",
    "                swtich_dataset = FHRL_sim\n",
    "            elif policy_name == 'SOCHRL':\n",
    "                swtich_dataset = SOCHRL_sim\n",
    "            elif policy_name == 'expert':\n",
    "                swtich_dataset = expert_sim\n",
    "                \n",
    "            current_state_feature = ids_df.iloc[h][PROBLEM_FEATURES].values\n",
    "            \n",
    "            closest_user_df = find_closest_user_allow_multiple(swtich_dataset, current_state_feature, h)\n",
    "            \n",
    "            checkpoint_new_rewards = find_closest_user(swtich_dataset, current_state_feature, h)['inferred_rew'].tolist()\n",
    "            print('true', calculate_partial_accumulated_reward(ids_rewards[:h] + checkpoint_new_rewards[h:], 0, horizon))\n",
    "            \n",
    "            # check what if switch from next state\n",
    "            est_value_switch_from_next = []\n",
    "            for candidate in closest_user_df['userID'].unique():\n",
    "                new_rewards_cand = closest_user_df.loc[closest_user_df['userID'] == candidate]['inferred_rew'].tolist()\n",
    "                ids_rewards_cand = ids_rewards[:h+1] + new_rewards_cand[h+1:]\n",
    "                ids_value_cand = calculate_partial_accumulated_reward(ids_rewards_cand, 0, horizon)\n",
    "                est_value_switch_from_next.append(ids_value_cand)\n",
    "            \n",
    "            est_value_0_to_H_look_ahead = sum(est_value_switch_from_next)/len(est_value_switch_from_next)\n",
    "            \n",
    "            if max_est_value_0_to_H >= est_value_0_to_H_look_ahead:\n",
    "                # still switch at h\n",
    "                current_closest_user_df = find_closest_user(swtich_dataset, current_state_feature, h)\n",
    "                new_rewards = current_closest_user_df['inferred_rew'].tolist()\n",
    "            \n",
    "                ids_rewards = ids_rewards[:h] + new_rewards[h:]\n",
    "                can_switch_flag = False\n",
    "            else:\n",
    "                # operate beh policy and keep checking from next step\n",
    "                h = h + 1\n",
    "                print('wait once')\n",
    "        \n",
    "        else:\n",
    "            # no need to switch\n",
    "            can_switch_flag = False\n",
    "            \n",
    "            \n",
    "        # calculate new accumulated rewards\n",
    "        ids_returns = calculate_partial_accumulated_reward(ids_rewards, 0, horizon)\n",
    "        print('final',ids_returns)\n",
    "\n",
    "    # save\n",
    "    test_returns.append(ids_returns)   "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
