{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1aef4c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np;\n",
    "import scipy\n",
    "import random\n",
    "import numpy.random as ra;\n",
    "import numpy.linalg as la;\n",
    "import matplotlib.pyplot as plt\n",
    "import sklearn\n",
    "from sklearn import preprocessing\n",
    "from scipy.stats import bernoulli\n",
    "import kld as kl\n",
    "import theta_1_set_gen as big_theta\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e020518d",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(20)\n",
    "np.random.seed(21)\n",
    "\n",
    "T=5000 #Length of the time horizon \n",
    "epsilon=0.2 #epsilon parameter \n",
    "itr=5000#total number of iterations (number of Monte Carlo runs)\n",
    "\n",
    "var=.5 #variance of the Gaussian noise \n",
    "path_len=5 #length of  path connected nodes\n",
    "\n",
    "d_list=[i for i in range(10,30,5)] #list of lenghts of line graph. Graph length = 10,15,20,25\n",
    "tau_list=[i for i in range(10,50,10)] #list of change points. Change point = 10,20,30,40\n",
    "\n",
    "#----------list containing candidate post change parameter (unknown to the algorithms, but common to all algorithms) \n",
    "#----------for various length of the graph------------------------------------------------------------------------\n",
    "\n",
    "theta_1_list=np.load(\"theta_1_list.npy\",allow_pickle=True).tolist()\n",
    "\n",
    "\n",
    "\n",
    "#list to store stopping time of algorithm \n",
    "   \n",
    "l_stop_oracle_d=[]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3de84005",
   "metadata": {},
   "outputs": [],
   "source": [
    "#_______________________________________Bandit Loop for  Oracle_______________________________________________________________\n",
    "\n",
    "beta=50 #the choice of beta for which the false alarm of Oracle change detector is less than 1%\n",
    "start=time.time()\n",
    "print(\"Oracle\")\n",
    "\n",
    "\n",
    "for (i,d) in enumerate(d_list):\n",
    "    print(\"d=\",d)\n",
    "    theta_1=theta_1_list[i]\n",
    "    theta_1=theta_1.reshape(theta_1.shape[0],1)\n",
    "    \n",
    "    #post change parameter set, generated by big_theta module.\n",
    "\n",
    "    theta_1_set=big_theta.theta_1_set(d,path_len)\n",
    "    \n",
    "    #All post change parameters are normalised to 1.\n",
    "\n",
    "    theta_1_set=sklearn.preprocessing.normalize(theta_1_set,norm='l2',axis=0)\n",
    "    \n",
    "    #diffused action set\n",
    "    Action_set=theta_1_set\n",
    "    \n",
    "    #Total number of actions \n",
    "    K=Action_set.shape[1]\n",
    "    \n",
    "    #WLOG, for synthetic experiments, we set pre-change parameter to zero. \n",
    "\n",
    "    theta_not=np.zeros((d,1))\n",
    "    \n",
    "    #List to store stopping time of an algorithm for a fixed change point and fixed length of the graph.\n",
    "    l_stop=[]\n",
    "    \n",
    "    for (idx,tau) in enumerate(tau_list):\n",
    "        print(\"tau=\",tau)\n",
    "        \n",
    "#----------------------------- Start of Monte Carlo run----------------------------\n",
    "        for j in range(itr):\n",
    "            b1=0      \n",
    "            V=0\n",
    "            flag=0 #flag variable. If flag=1, then we stop and report change\n",
    "            \n",
    "        #________________________________________Start of Time Horizon_________________________________________________________\n",
    "\n",
    "\n",
    "            for t in range(1,T+1):\n",
    "                \n",
    "                #Condition to check whehter the change has occurred. \n",
    "                #If there is no change has been detected by the algorithm, then we manually stop at time horizon. \n",
    "                \n",
    "                if flag==1 or t==T:\n",
    "                    change=t\n",
    "                    l_stop.append(change)\n",
    "                    break\n",
    "                    \n",
    "                #At time step 1, play a random action to get initial observation\n",
    "                \n",
    "                if (t<=1):\n",
    "                    \n",
    "                    #Random action at time step 1\n",
    "                    A1_idx=random.randrange(0,K,1)\n",
    "                    A=Action_set[:,A1_idx]\n",
    "                    #arm.append(A1_idx)\n",
    "                    A=A.reshape(A.shape[0],)\n",
    "                    \n",
    "                    #observation at time step 1\n",
    "                    X=np.random.normal(0,var,1)\n",
    "                    \n",
    "                    \n",
    "\n",
    "                        #calculation of Q^{(1)} statistics -- V is Q^{(1)} -- first sample update\n",
    "                    b1=X*A\n",
    "                    temp4=np.dot((theta_1_set).T,A)\n",
    "                    temp6=temp4**(2)                    \n",
    "                    temp3=(2*np.dot((theta_1_set).T,b1))-temp4**(2)\n",
    "                    V=temp3\n",
    "\n",
    "                    \n",
    "\n",
    "\n",
    "\n",
    "                else:\n",
    "                    \n",
    "                    #Calculation of KL Divergence for different set of actions \n",
    "                    \n",
    "                    mu1=np.dot(Action_set.T,theta_not) \n",
    "                    mu2=np.dot(Action_set.T,theta_1) #Note: We use theta_1 here becuase oracle knows the knowledge \n",
    "                                                     #      of true post change parameter.\n",
    "                    a_list=kl.kld(mu2,mu1,var**2)\n",
    "                    \n",
    "                    #Play the action that is more informative, that is, action for which KL divergence is maximum\n",
    "                    \n",
    "                    play_idx=np.argmax(a_list)\n",
    "                    A=Action_set[:,play_idx]\n",
    "                    A=A.reshape(A.shape[0],)\n",
    "\n",
    "                                #Get an Observation\n",
    "\n",
    "                    if t<tau:\n",
    "                        X=np.random.normal(0,var,1)\n",
    "                    else:\n",
    "                        X=np.random.normal((np.dot(A,theta_1)),var,1)\n",
    "\n",
    "                    \n",
    "\n",
    "                            #Calculation of Q^{(1)} statistics -- V is Q^{(1)} -- One sample update\n",
    "\n",
    "#Calculating g(X_t|A_t):\n",
    "\n",
    "                    b1=X*A\n",
    "                    temp4=np.matmul((theta_1_set).T,A)\n",
    "                    temp6=temp4**(2)\n",
    "                    temp3=(2*np.dot((theta_1_set).T,b1))-temp4**(2)\n",
    "                \n",
    "#Recursive update of Q^{(1)}:\n",
    "\n",
    "\n",
    "                    V=np.maximum(0,(V+temp3))\n",
    "\n",
    "# Stopping Criteria. If criteria is met, then set flag variable to 1. Otherwise, continue and update theta_hat. \n",
    "# theta_hat is the estimate of post change paramaeter at time step t.\n",
    "\n",
    "                    if np.max(V)>=beta:\n",
    "                        flag=1\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "#Store the stopping times for each length of the graph and change point. \n",
    "#We then save stopping time as .npy file and use the .npy file in a separate .ipynb for visualisation.\n",
    "\n",
    "    l_stop_oracle_d.append(l_stop)\n",
    "\n",
    "                           \n",
    "np.save(\"l_stop_oracle_d.npy\",l_stop_oracle_d)\n",
    "end=time.time()\n",
    "print(\"Time Taken by Oracle=\",end-start)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
