#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 14 11:50:43 2019
@author: sidney-tio, lidexun
@adapted from: adityamate, killian-34
"""

import numpy as np
import pandas as pd
import time

# import pomdp

from itertools import combinations


import os
import argparse
import tqdm

from utils import (
    returnKGreatestIndices,
    generateRandomTmatrix,
    verify_T_matrix,
    epsilon_clip,
    barPlot,
    soft_max,
    prob_cal,
    soft_max_stable,
    get_topic_mapping,
    propagate_action,
    get_junyi_data,
    get_oli_data
)

from whittle import whittleIndex
from q_learning import update_q, ReplayBuffer

from generateTmatrix import generateTmatrixReal, generateStudentTmatrix

OPT_SIZE_LIMIT = 8
FULL_OBS = True
CONST = 9
NUM_SEED = 88


def takeAction(
    current_learning,
    belief,
    actions,
    T,
    random_stream,
    grouping,
    nodes_to_group,
    T_hat=None,
):
    """ """
    N = len(current_learning)
    actions = propagate_action(actions, grouping, nodes_to_group)

    ###### Get next adhrence (ground truth)
    # Use the ground truth T here
    next_learning = np.zeros(current_learning.shape)
    for i in range(N):
        current_state = int(current_learning[i])

        next_state = random_stream.binomial(1, T[i, int(actions[i]), current_state, 1])

        next_learning[i] = next_state

    ##### Update belief vector
    # Remember to use T_hat here
    for i in range(N):
        if FULL_OBS:
            belief[i] = current_learning[i]
        # belief = Prob(A)*Prob(A-->A) + Prob(NA)*Prob(NA-->A)
        else:
            if int(actions[i]) == 0:
                belief[i] = belief[i] * T_hat[i][0][1][1] + (1 - belief[i]) * (
                    T_hat[i][0][0][1]
                )

            elif int(actions[i]) == 1:
                belief[i] = current_learning[i] * T_hat[i][1][1][1] + (
                    1 - current_learning[i]
                ) * (T_hat[i][1][0][1])
                #   This relies on the assumption that on being called at least yesterday's
                #   learning is perfectly observable. If not replace current_learning[i] by belief[i]

    ##### Record observation
    observations = np.zeros(N)
    for i in range(N):
        if FULL_OBS:
            observations[i] = current_learning[i]
        else:
            if actions[i] == 0:
                observations[i] = None
            else:
                observations[i] = current_learning[i]

    return next_learning, belief, observations


def getActions(
    N,
    k,
    Q_tables,
    V_tables,
    epsilon_t,
    time_steps=None,
    belief=None,
    T_hat=None,
    policy_option=0,
    current_node=None,
    policy_graph_dict=None,
    days_since_called=None,
    days_since_semi_active=None,
    last_observed_state=None,
    w=None,
    w_new=None,
    newWhittle=True,
    learning_oracle=None,
    days_remaining=None,
    current_t=None,
    observations=None,
    learning=None,
    T=None,
    verbose=False,
):
    """

    0: never call
    1: Call all patients everyday
    2: Randomly pick k patients to call
    5: whittle
    19: Q-learning
    """
    if policy_option == 0:
        ################## Nobody
        return np.zeros(N)

    elif policy_option == 1:
        ################## Everybody
        return np.ones(N)

    elif policy_option == 2:
        ################## Random
        # Randomly pick k arms out N arms
        random_call_indices = np.random.choice(N, size=k, replace=False)
        return np.array([1 if i in random_call_indices else 0 for i in range(N)])

    elif policy_option == 3:
        ################## Myopic policy
        actions = np.zeros(N)
        myopic_rewards = np.zeros(N)
        for i in range(N):
            b = belief[i]  # Patient is adhering today with probability b

            b_next_nocall = b * (T_hat[i][0][1][1]) + (1 - b) * (T_hat[i][0][0][1])

            b_next_call = b * (T_hat[i][1][1][1]) + (1 - b) * (T_hat[i][1][0][1])

            myopic_rewards[i] = b_next_call - b_next_nocall
            # Myopic reward can be thought of as: Prob(A)*1 + Prob(NA)*0 = b

        # Pick the k greatest values from the array myopic_rewards
        subjects_to_intervene = returnKGreatestIndices(myopic_rewards, k)

        for patient in subjects_to_intervene:
            actions[patient] = 1

        return actions

    elif policy_option == 5:
        ################## Whittle Index policy
        # Initialize if inputs not given
        if days_since_called.any() == None:
            days_since_called = np.zeros(
                N
            )  # Initialize to 0 days since last called (means nothing much)

        if last_observed_state.any() == None:
            last_observed_state = np.zeros(
                N
            )  # Initialize to all patients found adhering last

        actions = np.zeros(N)

        whittle_indices = [
            w[patient][int(last_observed_state[patient])][
                int(days_since_called[patient])
            ]
            for patient in range(N)
        ]

        subjects_to_intervene = returnKGreatestIndices(whittle_indices, k)

        for subject in subjects_to_intervene:
            actions[subject] = 1

        return actions

    elif policy_option == 19 or policy_option == 20:
        # whittleIndex base Q learning

        # Initialize if inputs not given
        if days_since_called.any() == None:
            days_since_called = np.zeros(
                N
            )  # Initialize to 0 days since last called (means nothing much)

        if last_observed_state.any() == None:
            last_observed_state = np.ones(
                N
            )  # Initialize to all patients found adhering last

        if np.random.random() < epsilon_t:
            random_call_indices = np.random.choice(N, size=k, replace=False)
            return np.array([1 if i in random_call_indices else 0 for i in range(N)])
        else:
            whittle_indices = np.ones(N)
            actions = np.zeros(N)
            for j in range(N):
                # period = days_since_called[j].astype(int)
                state = last_observed_state[j].astype(int)
                # get the whittleindex based Q values Lambda(X(t))
                if policy_option == 19:
                    # whittle_indices[j] = (
                    #     Q_tables[j][time_steps][state][1]
                    #     - Q_tables[j][time_steps][state][0]
                    # )
                    whittle_indices[j] = (
                        Q_tables[j][state][1]
                        - Q_tables[j][state][0]
                    )
                elif policy_option == 20:
                    group = nodes_to_group[j]
                    members = []
                    for g in group:
                        members.extend(grouping[g])
                    members = [member for member in members if member != j]
                    member_states = last_observed_state[members].astype(int)
                    whittle_indices[j] = (
                        Q_tables[j][state][1]
                        - Q_tables[j][state][0]
                        + np.sum(
                            Q_tables[members, member_states, 2]
                            - Q_tables[members, member_states, 0]
                        )
                    )
                    # whittle_indices[j] = (
                    #     Q_tables[j][time_steps][state][1]
                    #     - Q_tables[j][time_steps][state][0]
                    #     + np.sum(
                    #         Q_tables[members, time_steps, member_states, 2]
                    #         - Q_tables[members, time_steps, member_states, 0]
                    #     )
                    # )
            subjects_to_intervene = returnKGreatestIndices(whittle_indices, k)
            for subject in subjects_to_intervene:
                actions[subject] = 1

        return actions

    elif policy_option == 7:
        ################## Myopic policy
        actions = np.zeros(N)
        myopic_rewards = np.zeros(N)
        for i in range(N):
            state = last_observed_state[i].astype(int)

            myopic_rewards[i] = T_hat[i][1][state][1] - T_hat[i][0][state][1]
            # Myopic reward can be thought of as: Prob(A)*1 + Prob(NA)*0 = b

        # Pick the k greatest values from the array myopic_rewards
        prob = soft_max(myopic_rewards)
        new_prob = prob_cal(prob, k)
        new_prob = new_prob / np.sum(new_prob)
        # print('prob of being choosed:', prob)
        random_call_indices = np.random.choice(N, size=k, replace=False, p=new_prob)
        actions = np.array([1 if i in random_call_indices else 0 for i in range(N)])

        return actions
    elif policy_option == 8:
        ################## Myopic policy
        actions = np.zeros(N)
        myopic_rewards = np.zeros(N)
        for i in range(N):
            state = last_observed_state[i].astype(int)

            myopic_rewards[i] = T_hat[i][1][state][1] - T_hat[i][0][state][1]
            # Myopic reward can be thought of as: Prob(A)*1 + Prob(NA)*0 = b

        # Pick the k greatest values from the array myopic_rewards
        patients_to_call = returnKGreatestIndices(myopic_rewards, k)

        for patient in patients_to_call:
            actions[patient] = 1

        return actions


def learnTmatrixFromObservations(observations, actions, random_stream):
    """
    observations and actions are L+1 and L-sized matrices with:
        Observations: [o0, o1,...oL] with each entry being 0=NA;  1=A
        Actions:      [a1, a2,...aL] with each entry being 0=NoCall; 1=Called
    """
    T = np.zeros((2, 2, 2))
    p_pass_01, p_pass_11, p_act_01, p_act_11 = sorted(random_stream.uniform(size=4))
    l = len(actions)
    vals, counts = np.unique(
        list(zip(observations[:l], actions, observations[1:])),
        axis=0,
        return_counts=True,
    )

    freq = np.zeros((2, 2, 2))

    for i, item in enumerate(vals):
        freq[int(item[0]), int(item[1]), int(item[2])] = counts[i]

    if (freq[0, 0, 0] + freq[0, 0, 1]) > 0:
        p_pass_01 = freq[0, 0, 1] / (freq[0, 0, 0] + freq[0, 0, 1])

    if (freq[1, 0, 0] + freq[1, 0, 1]) > 0:
        p_pass_11 = freq[1, 0, 1] / (freq[1, 0, 0] + freq[1, 0, 1])

    if (freq[0, 1, 0] + freq[0, 1, 1]) > 0:
        p_act_01 = freq[0, 1, 1] / (freq[0, 1, 0] + freq[0, 1, 1])

    if (freq[1, 1, 0] + freq[1, 1, 1]) > 0:
        p_act_11 = freq[1, 1, 1] / (freq[1, 1, 0] + freq[1, 1, 1])

    T[0] = np.array([[1 - p_pass_01, p_pass_01], [1 - p_pass_11, p_pass_11]])
    T[1] = np.array([[1 - p_act_01, p_act_01], [1 - p_act_11, p_act_11]])

    return T


def update_counts(
    learning,
    actions,
    last_called,
    current_round,
    counts,
    buffer_length=0,
    get_last_call_transition_flag=False,
):
    if buffer_length == 0:
        buffer_length = 100000000

    # Buffer is how much patient "remembers" which doesn't include today.
    # so add 1 to the buffer_length to make code cleaner below, i.e. adding
    # 1 makes the buffer include today.
    buffer_length += 1

    patients_called = [i for i, a in enumerate(actions) if a == 1]

    for i in patients_called:
        info_packet = learning[i, last_called[i] : current_round + 1].astype(int)

        curr = None
        prev = None

        # if it doesn't fit in the buffer cut it, but conditionally add the t1
        # remaining adds will be to t0
        if info_packet.shape[0] > buffer_length:
            if get_last_call_transition_flag:
                prev = info_packet[0]
                curr = info_packet[1]
                counts[i, 1, prev, curr] += 1

            info_packet = info_packet[-buffer_length:]
            prev = info_packet[0]
            curr = info_packet[1]
            counts[i, 0, prev, curr] += 1
            prev = curr

        # Else first add will be to t1
        else:
            prev = info_packet[0]
            curr = info_packet[1]
            counts[i, 1, prev, curr] += 1
            prev = curr

        # The rest is about T0
        for j in range(2, len(info_packet)):
            curr = info_packet[j]
            counts[i, 0, prev, curr] += 1
            prev = curr

        # record that we called this patient
        last_called[i] = current_round


def thompson_sampling(N, priors, counts, random_stream):
    T_hat = np.zeros((N, 2, 2, 2))
    for i in range(N):
        for j in range(T_hat.shape[1]):
            for k in range(T_hat.shape[2]):
                params = priors[i, j, k, :] + counts[i, j, k, :]
                T_hat[i, j, k, :] = random_stream.dirichlet(params)
    return T_hat


def thompson_sampling_constrained(N, priors, counts, random_stream):
    T_hat = np.zeros((N, 2, 2, 2))
    for i in range(N):
        # While sampled T_hat is not valid or has not been sampled yet...
        while not verify_T_matrix(T_hat[i]) or T_hat[i].sum() == 0:
            for j in range(T_hat.shape[1]):
                for k in range(T_hat.shape[2]):
                    params = priors[i, j, k, :] + counts[i, j, k, :]
                    T_hat[i, j, k, :] = random_stream.dirichlet(params)
    return T_hat


def simulate_learning(
    N,
    L,
    T,
    k,
    policy_option,
    Q_tables,
    V_tables,
    grouping,
    nodes_to_group,
    episode,
    max_ep,
    gamma=0.95,
    alpha=0.1,
    start_node=None,
    policy_graph_dict=None,
    obs_space=None,
    action_logs={},
    cum_learning=None,
    new_whittle=True,
    online=True,
    seedbase=None,
    savestring="trial",
    epsilon=0.0,
    learning_mode=False,
    world_random_seed=None,
    learning_random_seed=None,
    verbose=False,
    buffer_length=0,
    get_last_call_transition_flag=False,
    file_root=None,
    replay_buffer=None
):
    learning_random_stream = np.random.RandomState()
    if learning_mode > 0:
        learning_random_stream.seed(learning_random_seed)

    world_random_stream = np.random.RandomState()
    world_random_stream.seed(world_random_seed)

    T_hat = None
    if learning_mode == 2:
        T_hat = generateRandomTmatrix(N, random_stream=learning_random_stream)
    priors = np.ones((N, 2, 2, 2))
    counts = np.zeros((N, 2, 2, 2))
    last_called = np.zeros(N).astype(int)

    actions_record = np.zeros((N, L - 1))

    learning = np.zeros((N, L))
    if action_logs is not None:
        action_logs[policy_option] = []

    belief = np.ones(N)

    current_node = None

    w = None
    w_new = None

    if (policy_option == 5) and (not online) and (not learning_mode):
        # Pre-compute only if policy is 5 or 16 AND it is neither online nor learning case.
        print("Pre-computing whittle index for offline, no-learning mode")
        # Pre-compute whittle index for patients
        w = np.zeros((N, 2, L))
        w_new = np.zeros(
            (N, 2, L)
        )  # right now, w_new does not get used in takeAction() even though it's passed in.
        for patient in range(N):
            if policy_option == 5:
                w[patient, 1, :], w[patient, 0, :] = whittleIndex(T[patient], L=L)

    # Keep track of days since called and last observed state
    days_since_called = np.zeros(N)
    days_since_semi_active = np.zeros(N)  # Initialize to 0 days since called
    last_observed_state = np.zeros(N)

    #######  Run simulation #######
    print("Running simulation w/ policy: %s" % policy_option)
    # make array of nan to initialize observations
    observations = np.full(N, np.nan)
    learning_modes = [
        "no_learning",
        "Thompson sampling",
        "e-greedy",
        "Constrained TS",
        "Naive Mean",
    ]
    print("Learning mode:", learning_modes[learning_mode])

    epsilon_schedule = [epsilon] * (L - 1)  # always explore with epsilon
    if epsilon == 0.0:  # else anneal epsilon from 1 to 0.0.
        # power = np.log(y)/np.log(1-x)
        # where y = desired epsilon when we are x% of the way through treatment
        # so if we want epsilon to be 0.25 by the time we are 25% of the way through treatment
        # we get: np.log(0.25)/np.log(1-0.25) = 4.818841679306418
        power = 4.818841679306418
        epsilon_schedule = np.linspace(1, 0.00, L) ** power
        # Note that we never have epsilon 0 since we never access the last element.

    for t in tqdm.tqdm(range(1, L)):
        """
        Learning T_hat from simulation so far
        """
        epsilon_t = epsilon * max((0.9 - ((ep * L + t)/(max_ep * L))),0)
        days_remaining = L - t
        if learning_mode == 0:
            T_hat = T
        elif learning_mode == 1:
            # Thompson sampling
            T_hat = thompson_sampling(
                N, priors, counts, random_stream=learning_random_stream
            )
        elif learning_mode == 2 and t > 2:
            # Epsilon-Greedy
            for student, action in enumerate(
                actions
            ):  # Note that actions here is still the previous day's action record
                if action == 1:
                    T_hat[student] = learnTmatrixFromObservations(
                        learning[student, : t - 1],
                        actions_record[student, 1 : (t - 1)],
                        random_stream=learning_random_stream,
                    )

        elif learning_mode == 3:
            # Thompson sampling
            T_hat = thompson_sampling_constrained(
                N, priors, counts, random_stream=learning_random_stream
            )

        elif learning_mode == 4:
            # Naive mean
            pass

        EPSILON_CLIP = 0.0005
        T_hat = epsilon_clip(T_hat, EPSILON_CLIP)

        if online or learning_mode:
            # If neither online nor learning, then just work with pre-computed whittle indices, w and w_new.
            w = np.zeros((N, 2, L))
            w_new = np.zeros((N, 2, L))

            if policy_option == 5:
                for patient in range(N):
                    limits = [0, 0]
                    limits[int(last_observed_state[patient])] = (
                        days_since_called[patient] + 1
                    )
                    w[patient, 1, :], w[patient, 0, :] = whittleIndex(
                        T_hat[patient], L=L, limit_a=limits[1], limit_na=limits[0]
                    )
            if policy_option == 16:
                pass

        #### Epsilon greedy part

        if learning_mode == 2 and (
            policy_option not in NON_EPSILON_POLICIES
        ):  # epsilon greedy
            if (
                learning_random_stream.binomial(1, epsilon_schedule[t]) == 0
            ):  # Exploitation
                if policy_option != 6:
                    actions = getActions(
                        N,
                        k,
                        Q_tables,
                        V_tables,
                        epsilon_t,
                        time_steps=t,
                        policy_option=policy_option,
                        belief=belief,
                        T_hat=T_hat,
                        current_node=current_node,
                        policy_graph_dict=policy_graph_dict,
                        days_since_called=days_since_called,
                        days_since_semi_active=days_since_semi_active,
                        last_observed_state=last_observed_state,
                        w=w,
                        w_new=w_new,
                        current_t=t,
                        learning_oracle=learning[:, t - 1].squeeze(),
                        days_remaining=days_remaining,
                        observations=observations,
                        learning=learning[:, t - 1],
                        T=T,
                        verbose=verbose,
                    )
                else:
                    actions, prob = getActions(
                        N,
                        k,
                        Q_tables,
                        V_tables,
                        epsilon_t,
                        time_steps=t,
                        policy_option=policy_option,
                        belief=belief,
                        T_hat=T_hat,
                        current_node=current_node,
                        policy_graph_dict=policy_graph_dict,
                        days_since_called=days_since_called,
                        days_since_semi_active=days_since_semi_active,
                        last_observed_state=last_observed_state,
                        w=w,
                        w_new=w_new,
                        current_t=t,
                        learning_oracle=learning[:, t - 1].squeeze(),
                        days_remaining=days_remaining,
                        observations=observations,
                        learning=learning[:, t - 1],
                        T=T,
                        verbose=verbose,
                    )
            else:  # Exploration
                if policy_option != 6:
                    actions = getActions(
                        N,
                        k,
                        Q_tables,
                        V_tables,
                        epsilon_t,
                        time_steps=t,
                        policy_option=2,
                        belief=belief,
                        T_hat=T_hat,
                        current_node=current_node,
                        policy_graph_dict=policy_graph_dict,
                        days_since_called=days_since_called,
                        days_since_semi_active=days_since_semi_active,
                        last_observed_state=last_observed_state,
                        w=w,
                        w_new=w_new,
                        current_t=t,
                        learning_oracle=learning[:, t - 1].squeeze(),
                        days_remaining=days_remaining,
                        observations=observations,
                        learning=learning[:, t - 1],
                        T=T,
                        verbose=verbose,
                    )
                else:
                    actions, prob = getActions(
                        N,
                        k,
                        Q_tables,
                        V_tables,
                        epsilon_t,
                        time_steps=t,
                        policy_option=2,
                        belief=belief,
                        T_hat=T_hat,
                        current_node=current_node,
                        policy_graph_dict=policy_graph_dict,
                        days_since_called=days_since_called,
                        days_since_semi_active=days_since_semi_active,
                        last_observed_state=last_observed_state,
                        w=w,
                        w_new=w_new,
                        current_t=t,
                        learning_oracle=learning[:, t - 1].squeeze(),
                        days_remaining=days_remaining,
                        observations=observations,
                        learning=learning[:, t - 1],
                        T=T,
                        verbose=verbose,
                    )

        else:  # Normal process
            if policy_option != 6:
                actions = getActions(
                    N,
                    k,
                    Q_tables,
                    V_tables,
                    epsilon_t,
                    time_steps=t,
                    policy_option=policy_option,
                    belief=belief,
                    T_hat=T_hat,
                    current_node=current_node,
                    policy_graph_dict=policy_graph_dict,
                    days_since_called=days_since_called,
                    days_since_semi_active=days_since_semi_active,
                    last_observed_state=last_observed_state,
                    w=w,
                    w_new=w_new,
                    current_t=t,
                    learning_oracle=learning[:, t - 1].squeeze(),
                    days_remaining=days_remaining,
                    observations=observations,
                    learning=learning[:, t - 1],
                    T=T,
                    verbose=verbose,
                )
            else:
                actions, prob = getActions(
                    N,
                    k,
                    Q_tables,
                    V_tables,
                    epsilon_t,
                    time_steps=t,
                    policy_option=policy_option,
                    belief=belief,
                    T_hat=T_hat,
                    current_node=current_node,
                    policy_graph_dict=policy_graph_dict,
                    days_since_called=days_since_called,
                    days_since_semi_active=days_since_semi_active,
                    last_observed_state=last_observed_state,
                    w=w,
                    w_new=w_new,
                    current_t=t,
                    learning_oracle=learning[:, t - 1].squeeze(),
                    days_remaining=days_remaining,
                    observations=observations,
                    learning=learning[:, t - 1],
                    T=T,
                    verbose=verbose,
                )

        actions_record[:, t - 1] = actions

        if action_logs is not None:
            action_logs[policy_option].append(actions.astype(int))

        learning[:, t], belief, observations = takeAction(
            learning[:, t - 1].squeeze(),
            belief,
            actions,
            T,
            random_stream=world_random_stream,
            grouping=grouping,
            nodes_to_group=nodes_to_group,
            T_hat=T_hat,
        )

        # update counts
        # get all information about a patient since the last time we called
        # the transition (last_called[i], last_called[i]+1) will be the only info we get about a T1 matrix
        # all others will update our info about T0
        update_counts(
            learning,
            actions,
            last_called,
            t,
            counts,
            buffer_length=buffer_length,
            get_last_call_transition_flag=get_last_call_transition_flag,
        )

        # Update the whittleIndex baseed Q value
        if policy_option == 19:
            update_q(
                actions,
                t,
                last_observed_state,
                Q_tables,
                learning,
                N,
                gamma=gamma,
                alpha=alpha,
                replay_buffer=replay_buffer
            )
        if policy_option == 20:
            actions = propagate_action(actions, grouping, nodes_to_group)
            update_q(
                actions,
                t,
                last_observed_state,
                Q_tables,
                learning,
                N,
                gamma=gamma,
                alpha=alpha,
                replay_buffer=replay_buffer
            )
        if policy_option == 6:
            probs = prob_cal(prob, k)
            states_new_observed = learning[:, t]
            for i in range(N):
                a = actions[i].astype(int)
                # days_since_last_called = days_since_called[i].astype(int)
                horizon = t - 1
                state_last_observed = last_observed_state[i].astype(int)
                state_new_observed = states_new_observed[i].astype(int)
                # q_target = learning[i,t] + gamma * Q_tables[i][days_since_last_called][state_new_observed].max()
                # q_predict = Q_tables[i][days_since_last_called][state_last_observed][a]
                # Q_tables[i][days_since_last_called][state_last_observed][a] += alpha * (q_target - q_predict)
                if t < L - 1:
                    Q_active = T_hat[i][1][state_last_observed][0] * (
                        state_last_observed + V_tables[i][horizon + 1][0]
                    ) + T_hat[i][1][state_last_observed][1] * (
                        state_last_observed + V_tables[i][horizon + 1][1]
                    )
                    Q_passive = T_hat[i][0][state_last_observed][0] * (
                        state_last_observed + V_tables[i][horizon + 1][0]
                    ) + T_hat[i][0][state_last_observed][1] * (
                        state_last_observed + V_tables[i][horizon + 1][1]
                    )
                    V_tables[i][horizon][state_last_observed] = (
                        probs[i] * Q_active + (1 - probs[i]) * Q_passive
                    )
                else:
                    V_tables[i][horizon][state_last_observed] = state_last_observed
        ###### Update last observed state and last called matrix:
        for i in range(N):
            if actions[i] == 1:
                days_since_called[i] = 0
                days_since_semi_active[i] += 1
                last_observed_state[i] = observations[i]
            elif actions[i] == 0:
                days_since_called[i] += 1
                days_since_semi_active[i] += 1
                if FULL_OBS:
                    last_observed_state[i] = observations[i]
            elif actions[i] == 2:
                days_since_called[i] += 1
                days_since_semi_active[i] == 0
                if FULL_OBS:
                    last_observed_state[i] = observations[i]

    if cum_learning is not None:
        cum_learning[policy_option] = np.cumsum(learning.sum(axis=0))

    return learning


if __name__ == "__main__":
    """
    0: never call    1: Call all patients everyday     2: Randomly pick k patients to call
    3: Myopic policy    4: pomdp  5: whittle    6: 2-day look ahead    7:oracle
    8: oracle_whittle   9: despot 10: yundi's whittle index
    11: naive belief (longest since taking a pill belief)
    12: naive real (longest since taking a pill ground truth)
    13: naive real, multiplied by ground truth transition probability
    """

    parser = argparse.ArgumentParser(description="Run learning simulations")
    parser.add_argument(
        "-n", "--num_subjects", default=100, type=int, help="Number of Arms"
    )
    parser.add_argument(
        "-eps", "--episode", default=80, type=int, help="Number of episodes"
    )
    parser.add_argument(
        "-k", "--n_interventions", default=1, type=float, help="Number of calls per day"
    )
    parser.add_argument(
        "-l",
        "--simulation_length",
        default=180,
        type=int,
        help="Number of days to run simulation",
    )
    parser.add_argument(
        "-N", "--num_trials", default=5, type=int, help="Number of trials to run"
    )
    parser.add_argument(
        "-d",
        "--data",
        default="simulated",
        choices=[
            "real",
            "simulated",
            "full_random",
            "unit_test",
            "myopic_fail",
            "demo",
            "uniform",
            "junyi",
            "oli"
        ],
        type=str,
        help="Method for generating transition probabilities",
    )
    parser.add_argument("-s", "--seed_base", type=int, help="Base for the random seed")
    parser.add_argument(
        "-ws",
        "--world_seed_base",
        default=None,
        type=int,
        help="Base for the random seed",
    )
    parser.add_argument(
        "-ls",
        "--learning_seed_base",
        default=None,
        type=int,
        help="Base for the random seed",
    )
    parser.add_argument(
        "-p",
        "--num_processes",
        default=4,
        type=int,
        help="Number of cores for parallelization",
    )
    parser.add_argument(
        "-f",
        "--file_root",
        default=None,
        type=str,
        help="Root dir for experiment (should be the dir containing this script)",
    )
    parser.add_argument(
        "-pc",
        "--policy",
        nargs="+",
        type=int,
        default=-1,
        help="policy to run, default is all policies",
    )
    parser.add_argument(
        "-res",
        "--results_file",
        default="answer",
        type=str,
        help="learning numpy matrix file name",
    )
    parser.add_argument(
        "-tr", "--trial_number", default=None, type=int, help="Trial number"
    )
    parser.add_argument(
        "-sv",
        "--save_string",
        default="",
        type=str,
        help="special string to include in saved file name",
    )
    parser.add_argument(
        "-badf",
        "--bad_fraction",
        default=0.4,
        type=float,
        help="fraction of non-responsive patients",
    )
    parser.add_argument(
        "-thrf_perc",
        "--threshopt_percentage",
        default=None,
        type=int,
        help="% of threshold optimal patients in data",
    )
    parser.add_argument(
        "-beta",
        "--beta",
        default=0.5,
        type=float,
        help="beta used in quick check for determining threshold optimal fraction",
    )
    parser.add_argument(
        "-epsilon",
        "--epsilon",
        default=1.0,
        type=float,
        help="espilon value for epsilon greedy",
    )
    parser.add_argument(
        "-lr",
        "--learning_option",
        default=0,
        choices=[0, 1, 2, 3, 4],
        type=int,
        help="0: No Learning (Ground truth known)\n1: Thompson Sampling\n2 Epsilon Greedy\n3 Constrained TS\n4 Naive average baseline",
    )
    parser.add_argument("-v", "--verbose", default=False, type=bool)
    parser.add_argument(
        "-o", "--online", default=0, type=int, help="0: offline, 1: online"
    )
    parser.add_argument(
        "-kp", "--k_percentage", default=None, type=int, help="100* k/N "
    )
    parser.add_argument(
        "-slurm",
        "--slurm_array_id",
        default=-1,
        type=int,
        help="Unique identifier for slurm array id/ encoding set of parameters",
    )
    parser.add_argument(
        "-sh1", "--shift1", default=0.05, type=float, help="shift 1 variable "
    )
    parser.add_argument(
        "-sh2", "--shift2", default=0.05, type=float, help="shift 2 variable "
    )
    parser.add_argument(
        "-sh3", "--shift3", default=0.05, type=float, help="shift 3 variable "
    )
    parser.add_argument(
        "-sh4", "--shift4", default=0.05, type=float, help="shift 4 variable "
    )
    parser.add_argument(
        "-bl",
        "--buffer_length",
        default=0,
        type=int,
        help="If using Thompson Sampling, max number of most recent days of learning you learn with an arm pull",
    )
    parser.add_argument(
        "-t1f",
        "--get_last_call_transition_flag",
        default=0,
        type=int,
        help="If using Thompson Sampling, whether or not you learn the T1 transition regardless of buffer_length with an arm pull",
    )
    parser.add_argument(
        "-topics",
        "--num_topics",
        default=10,
        type=int,
        help="Number of topics/groups",
    )
    parser.add_argument(
        "-user",
        "--user",
        default= '',
        type = str
    )
    parser.add_argument(
        "-buffer",
        "--buffer",
        default=False,
        type=bool,
        help="Boolean to include buffer or not. Only for Q-learning policies",
    )

    args = parser.parse_args()

    """
    POLICY NAMES ***
    0: never call    1: Call all patients everyday     2: Randomly pick k patients to call
    3: Myopic policy    4: pomdp  5: Threshold whittle    6: 2-day look ahead    7:oracle
    8: oracle_whittle   9: despot 10: yundi's whittle index 11,12,13: Lily
    14: MDP oracle     15: round robin  16: new new whittle(fast)  17: fast whittle
    """

    NON_EPSILON_POLICIES = [0, 1, 14, 15]

    if args.slurm_array_id >= 0:
        """
        Code for SLURM
        """
        """
        Changing tr: 0-49, policy: {10,14}, N:{10,20,100,200,500,1000,2000}
        """
        slurm_trial_nums = [i for i in range(50)]
        slurm_policies = [10, 14]
        slurm_N = [200, 500, 1000, 2000]
        # slurm_th_fracs=[0,10,20,30,40,50,60,70,80,90,100]

        args.trial_number = args.slurm_array_id % len(slurm_trial_nums)
        args.policy = slurm_policies[
            int(args.slurm_array_id // len(slurm_trial_nums)) % len(slurm_policies)
        ]
        args.num_patients = slurm_N[
            int(args.slurm_array_id // (len(slurm_trial_nums) * len(slurm_policies)))
            % len(slurm_N)
        ]
        # args.threshopt_percentage=slurm_th_fracs[int(args.slurm_array_id//(len(slurm_trial_nums)*len(slurm_policies)*len(slurm_N)))%len(slurm_th_fracs)]
        # args.save_string+=("_threshopt_frac_"+str(args.threshopt_percentage))

    ##### File root
    if args.file_root == ".":
        args.file_root = os.getcwd()
    ##### k
    args.n_interventions = int(args.n_interventions)
    if args.k_percentage is not None:
        args.n_interventions = int(
            (args.k_percentage / 100 * args.num_students)
        )  # This rounds down, good.
    ##### Save special name
    if args.save_string == "":
        args.save_string = str(time.ctime().replace(" ", "_").replace(":", "_"))
    else:
        args.save_string = args.save_string
    policies = args.policy
    ##### Seed = seed_base + trial_number
    if args.trial_number is not None:
        args.num_trials = 1
        add_to_seed_for_specific_trial = args.trial_number
    else:
        add_to_seed_for_specific_trial = 0
    first_seedbase = np.random.randint(0, high=100000)
    if args.seed_base is not None:
        first_seedbase = args.seed_base + add_to_seed_for_specific_trial

    first_world_seedbase = np.random.randint(0, high=100000)
    if args.world_seed_base is not None:
        first_world_seedbase = args.world_seed_base + add_to_seed_for_specific_trial

    first_learning_seedbase = np.random.randint(0, high=100000)
    if args.learning_seed_base is not None:
        first_learning_seedbase = (
            args.learning_seed_base + add_to_seed_for_specific_trial
        )

    ##### Other parameters
    N = args.num_subjects
    L = args.simulation_length
    k = args.n_interventions
    savestring = args.save_string
    N_TRIALS = args.num_trials
    LEARNING_MODE = args.learning_option
    N_TOPICS = args.num_topics

    record_policy_actions = [3, 4, 5, 6, 11, 12, 13, 7, 8, 10, 14, 15, 16, 17, 18]
    num_actions = 3
    num_states = 2

    # size_limits: run policy if N< size_limit; ALso size_limit=-1 means all N are ok. Size_limit=0 means switched off.
    size_limits = {
        0: None,
        1: None,
        2: None,
        3: None,
        4: OPT_SIZE_LIMIT,
        5: None,
        6: 4,
        7: 0,
        8: 1000,
        9: 0,
        10: None,
        11: None,
        12: None,
        13: None,
        14: None,
        15: None,
        16: None,
        17: None,
        18: None,
        19: None,
        20: None,
    }

    # policy names dict
    pname = {
        0: "nobody",
        1: "everyday",
        2: "Random",
        3: "Myopic",
        4: "optimal",
        5: "Threshold whittle",
        6: "2-day",
        7: "oracl_m",
        8: "oracle_POMDP",
        9: "despot",
        10: "yundi",
        11: "naiveBelief",
        12: "naiveReal",
        13: "naiveReal2",
        14: "oracle_MDP",
        15: "Round Robin",
        16: "New_whittle",
        17: "FastWhittle",
        18: "BuggyWhittle",
        19: "Q-learning",
        20: "EduQate",
    }

    learning = [[] for i in range(len(pname))]
    learning_matrices = [None for i in range(len(pname))]
    action_logs = {}
    cum_learning = {}
    grouping, nodes_to_group = get_topic_mapping(N, N_TOPICS)

    start = time.time()
    file_root = args.file_root if args.file_root else os.getcwd()

    for trial in range(N_TRIALS):
        seedbase = first_seedbase + trial
        np.random.seed(seed=seedbase)

        world_seed_base = first_world_seedbase + trial
        learning_seed_base = first_learning_seedbase + trial

        print("Seed is", seedbase)
        T = None
        if args.data == "real":
            np.random.seed(NUM_SEED)

            T = generateTmatrixReal(
                N,
                file_root=args.file_root,
                thresh_opt_frac=(args.threshopt_percentage) / 100.0,
                beta=args.beta,
                quick_check=False,
            )
            T = generateTmatrixReal(
                N,
                file_root=args.file_root,
                thresh_opt_frac=None,
                beta=args.beta,
                quick_check=False,
            )
        elif args.data == "simulated":
            T = generateStudentTmatrix(N)
        elif args.data == "junyi":
            T, grouping, nodes_to_group = get_junyi_data()
            N = len(T)
        elif args.data == "oli":
            T, grouping, nodes_to_group = get_oli_data()
            N = len(T)

        np.random.seed(seed=seedbase)

        for policy_option in policies:
            if args.buffer and policy_option in [19,20]:
                replay_buffer = ReplayBuffer()
            else:
                replay_buffer = None
            if policy_option == 19:
                # Q_tables = np.zeros(
                #     (N, args.simulation_length, num_states, num_actions - 1)
                # )
                Q_tables = np.zeros(
                    (N, num_states, num_actions - 1)
                )

            else:
                # Q_tables = np.zeros(
                #     (N, args.simulation_length, num_states, num_actions)
                # )
                Q_tables = np.zeros(
                    (N, num_states, num_actions)
                )
            V_tables = np.zeros((N, args.simulation_length, num_states))
            policy_start_time = time.time()
            for ep in range(args.episode):
                ############################ Policy # policy_option
                if size_limits[policy_option] == None or size_limits[policy_option] > N:
                    np.random.seed(seed=seedbase)
                    if policy_option in record_policy_actions:
                        learning_matrix = simulate_learning(
                            N,
                            L,
                            T,
                            k,
                            grouping=grouping,
                            nodes_to_group=nodes_to_group,
                            Q_tables=Q_tables,
                            V_tables=V_tables,
                            policy_option=policy_option,
                            seedbase=seedbase,
                            action_logs=action_logs,
                            cum_learning=cum_learning,
                            epsilon=args.epsilon,
                            episode = ep,
                            max_ep = args.episode,
                            learning_mode=LEARNING_MODE,
                            learning_random_seed=learning_seed_base,
                            world_random_seed=world_seed_base,
                            verbose=args.verbose,
                            online=(args.online == 1),
                            buffer_length=args.buffer_length,
                            get_last_call_transition_flag=args.get_last_call_transition_flag,
                            file_root=file_root,
                            replay_buffer=replay_buffer
                        )
                        learning_matrices[policy_option] = learning_matrix
                        # np.save(file_root+'/learning_log/rebuttal/learning_'+savestring+'_N%s_k%s_L%s_policy%s_data%s_badf%s_s%s_lr%s'%(N,k,L,policy_option,args.data,args.bad_fraction,seedbase, LEARNING_MODE), learning_matrix)
                        if (ep + 1) % 100 == 0:
                            np.save(
                                file_root
                                + "/results/learning_log/learning_"
                                + savestring
                                + "_N%s_k%s_L%s_policy%s_data%s_trial%s_s%s_lr%s_bl%s_t1f%s_topics%s_ep%s"
                                % (
                                    N,
                                    k,
                                    L,
                                    policy_option,
                                    args.data + args.user,
                                    trial,
                                    seedbase,
                                    LEARNING_MODE,
                                    args.buffer_length,
                                    args.get_last_call_transition_flag,
                                    N_TOPICS,
                                    ep,
                                ),
                                learning_matrix,
                            )
                        learning[policy_option].append(
                            np.mean(np.sum(learning_matrix, axis=1))
                        )

                    else:
                        if args.verbose:
                            print(learning_seed_base, "LRSEED\n\n\n\n\n")
                            print(world_seed_base, "LRSEED\n\n\n\n\n")
                        learning_matrix = simulate_learning(
                            N,
                            L,
                            T,
                            k,
                            grouping=grouping,
                            nodes_to_group=nodes_to_group,
                            Q_tables=Q_tables,
                            V_tables=V_tables,
                            policy_option=policy_option,
                            seedbase=seedbase,
                            learning_mode=LEARNING_MODE,
                            learning_random_seed=learning_seed_base,
                            world_random_seed=world_seed_base,
                            verbose=args.verbose,
                            online=(args.online == 1),
                            buffer_length=args.buffer_length,
                            get_last_call_transition_flag=args.get_last_call_transition_flag,
                            file_root=file_root,
                            cum_learning=cum_learning,
                            epsilon=args.epsilon,
                            episode=ep,
                            max_ep = args.episode,
                            replay_buffer=replay_buffer
                        )
                        learning_matrices[policy_option] = learning_matrix
                        # np.save(file_root+'/learning_log/rebuttal/learning_'+savestring+'_N%s_k%s_L%s_policy%s_data%s_badf%s_s%s_lr%s'%(N,k,L,policy_option,args.data,args.bad_fraction, seedbase, LEARNING_MODE), learning_matrix)
                        if (ep + 1) % 100 == 0:
                            np.save(
                                file_root
                                + "/results/learning_log/learning_"
                                + savestring
                                + "_N%s_k%s_L%s_policy%s_data%s_trial%s_s%s_lr%s_bl%s_t1f%s_topics%s_ep%s"
                                % (
                                    N,
                                    k,
                                    L,
                                    policy_option,
                                    args.data + args.user,
                                    trial,
                                    seedbase,
                                    LEARNING_MODE,
                                    args.buffer_length,
                                    args.get_last_call_transition_flag,
                                    N_TOPICS,
                                    ep,
                                ),
                                learning_matrix,
                            )
                        learning[policy_option].append(
                            np.mean(np.sum(learning_matrix, axis=1))
                        )

                else:
                    learning_matrix = np.zeros((N, L))
                    learning_matrices[policy_option] = learning_matrix
                    learning[policy_option] = np.mean(np.sum(learning_matrix, axis=1))
                policy_end_time = time.time()


            policy_run_time = policy_end_time - policy_start_time
            # np.save(file_root+'/runtime/rebuttal/runtime_'+savestring+'_N%s_k%s_L%s_policy%s_data%s_badf%s_s%s_lr%s'%(N,k,L,policy_option,args.data,args.bad_fraction,seedbase, LEARNING_MODE), policy_run_time)
            np.save(
                file_root
                + "/results/runtime/runtime_"
                + savestring
                + "_N%s_k%s_L%s_policy%s_data%s_ep%s_s%s_lr%s_bl%s_t1f%s"
                % (
                    N,
                    k,
                    L,
                    policy_option,
                    args.data,
                    trial,
                    seedbase,
                    LEARNING_MODE,
                    args.buffer_length,
                    args.get_last_call_transition_flag,
                ),
                policy_run_time,
            )
            ##### SAVE ALL RELEVANT LOGS #####

            # write out action logs
            for policy_option in action_logs.keys():
                fname = os.path.join(
                    file_root,
                    "results/action_logs/action_logs_"
                    + savestring
                    + "_N%s_k%s_L%s_data%s_ep%s_policy%s_s%s_lr%s_bl%s_t1f%s.csv"
                    % (
                        N,
                        k,
                        L,
                        args.data,
                        trial,
                        policy_option,
                        seedbase,
                        LEARNING_MODE,
                        args.buffer_length,
                        args.get_last_call_transition_flag,
                    ),
                )
                columns = list(map(str, np.arange(N)))
                df = pd.DataFrame(action_logs[policy_option], columns=columns)
                df.to_csv(fname, index=False)

            # write out cumulative learning logs
            for policy_option in cum_learning.keys():
                fname = os.path.join(
                    file_root,
                    "results/cum_learning/cum_learning_"
                    + savestring
                    + "_N%s_k%s_L%s_data%s_ep%s_policy%s_s%s_lr%s_bl%s_t1f%s.csv"
                    % (
                        N,
                        k,
                        L,
                        args.data,
                        trial,
                        policy_option,
                        seedbase,
                        LEARNING_MODE,
                        args.buffer_length,
                        args.get_last_call_transition_flag,
                    ),
                )
                columns = list(map(str, np.arange(L)))
                df = pd.DataFrame([cum_learning[policy_option]], columns=columns)
                df.to_csv(fname, index=False)

            # write out T matrix logs
            fname = os.path.join(
                file_root,
                "results/Tmatrix_logs/Tmatrix_logs_"
                + savestring
                + "_N%s_k%s_L%s_data%s_ep%s_s%s_lr%s_bl%s_t1f%s.csv"
                % (
                    N,
                    k,
                    L,
                    args.data,
                    trial,
                    seedbase,
                    LEARNING_MODE,
                    args.buffer_length,
                    args.get_last_call_transition_flag,
                ),
            )
            np.save(fname, T)

    for p in range(max(policies) + 1):
        print(pname[p], ": ", np.mean(learning[p]))

    end = time.time()
    print("Time taken: ", end - start)
