import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.linear_model import BayesianRidge
from scipy.stats import norm

NON_STATIONARY_LEARN = 0
BOOTSTRAP_SIZE = 30
N_BOOTSTRAPS = 5
N_BOOTSTRAPS_NST = 5

'''
This code partially uses some functions from
https://github.com/asonabend/ESRL
'''


################################################
################# Sampling MDPs ################
################################################

# Sample K MDPs simultaneously 
def sampleK_MDPs(S_card, A_space, H_tk, K, pi_sa=None, sepsis=False):
    if sepsis:
        prior_par = {'m0': 0, 'lamb0': 1, 'alpha0': 5.01, 'gamma0': 1}
    else:
        prior_par = {'m0': 0, 'lamb0': 1, 'alpha0': 1.01, 'gamma0': 1}

    all_states = [float(i) for i in range(S_card)]
    P_sas_dict, R_sa_dict = {k: {} for k in range(K)}, {k: {} for k in range(K)}
    for a in A_space:
        for s in all_states:
            H_sa = None
            dirich_alpha = [1 / S_card] * S_card  # reset state-action probabilities to 1/alpha
            if H_tk is not None:
                indx = np.where(np.product(np.array([s, a]) == H_tk[:, :2], axis=1) == 1)[0]
                if indx.shape[0] > 0:
                    H_sa = H_tk[indx, :]
                    # update dirichlet's alphas with current counts
                    for nxt_s in np.unique(H_sa[:, 3])[1:]:  # all next states (remove -inf)
                        dirich_alpha[int(nxt_s)] += sum(
                            nxt_s == H_sa[:, 3])  # add counts for that (s,a) pair to the Dirich. parameter
            # sample means for the reward distributions        
            K_Rs = normal_gamma_sample(m0=prior_par['m0'], lamb0=prior_par['lamb0'], alpha0=prior_par['alpha0'],
                                       gamma0=prior_par['gamma0'], H_sa=H_sa, K=K)
            # Draw random vector of probabilities from Dir. posterior for the transition distribution:
            K_P_sas = np.random.dirichlet(dirich_alpha, K)
            for k in range(K):
                R_sa_dict[k][(s, a)] = K_Rs[k]
                P_sas_dict[k][(s, a)] = K_P_sas[k]
    return R_sa_dict, P_sas_dict


def sampleK_MDPs_defer(S_card, A_space, H_tk, K, reward_vec=None, P_sa=None, t=None, debug=False, use_true_tx=False,
                       nst=False, env="discrete_toy"):
    prior_par = {'m0': -0.01, 'lamb0': 1, 'alpha0': 1.01, 'gamma0': 1}  # Riverswim
    all_states = [float(i) for i in range(S_card)]
    P_sas_dict, R_sa_dict = {k: {} for k in range(K)}, {k: {} for k in range(K)}
    if reward_vec is None:
        reward_vec = np.unique(H_tk[:, 2])
    dir_alpha_dict = {}
    for a in A_space:
        for s in all_states:
            H_sa = None
            dirich_alpha = [1 / S_card] * S_card  # reset state-action probabilities to 1/alpha
            dir_alpha_dict[(s, a)] = {}
            dir_alpha_dict[(s, a)]['prior'] = dirich_alpha
            if H_tk is not None:
                indx = np.where(np.product(np.array([s, a]) == H_tk[:, :2], axis=1) == 1)[0]
                if indx.shape[0] > 0:
                    # #Check if we should use informative prior:
                    H_sa = H_tk[indx, :]
                    # update dirichlet's alphas with current counts
                    for nxt_s in np.unique(H_sa[:, 3]):  # all next states (remove -inf)
                        dirich_alpha[int(nxt_s)] += sum(
                            nxt_s == H_sa[:, 3])  # add counts for that (s,a) pair to the Dirich. parameter
            # sample means for the reward distributions
            if env == "diabetes":
                K_Rs = normal_gamma_sample(m0=prior_par['m0'], lamb0=prior_par['lamb0'], alpha0=prior_par['alpha0'],
                                           gamma0=prior_par['gamma0'], H_sa=H_sa, K=K)

            else:
                K_Rs, _ = categorical_reward(len(reward_vec), reward_vec, alpha0=None, H_sa=H_sa, K=K, env=env)

            dir_alpha_dict[(s, a)]['posterior'] = dirich_alpha

            if not use_true_tx:
                # Draw random vector of probabilities from Dir. posterior for the transition distribution:
                K_P_sas = np.random.dirichlet(dirich_alpha, K)
            for k in range(K):
                R_sa_dict[k][(s, a)] = K_Rs[k]
                if use_true_tx and not nst:
                    P_sas_dict[k][(s, a)] = P_sa[t][int(a), int(s), :]
                elif use_true_tx and nst:
                    P_sas_dict[k][(s, a)] = P_sa[t][int(a), int(s), :]
                else:
                    P_sas_dict[k][(s, a)] = K_P_sas[k]

            if debug and P_sa is not None:
                fig, (ax1, ax2) = plt.subplots(2, 1)
                ax1.bar(x=all_states, height=P_sa[a, int(s)])
                plt.xlabel("States")
                plt.ylabel(["Empirical Transition: " + str(s) + ',' + str(a)])
                ax2.bar(x=all_states, height=dirich_alpha)
                plt.xlabel("States")
                plt.ylabel(["Posterior probability: " + str(s) + ',' + str(a)])
                plt.savefig("./plots/hist_" + str(s) + '_' + str(a) + '.pdf')
                plt.close(fig)
    return R_sa_dict, P_sas_dict, dir_alpha_dict


def sampleK_MDPs_defer_cont(A_space, H_tk, K, P_sa=None, t=None, debug=False, p_fit_dict=None, use_true_tx=False,
                            nst=False):
    P_sas_dict, R_sa_dict = {k: {} for k in range(K)}, {k: {} for k in range(K)}
    prior_par = {'m0': 0, 'lamb0': 1, 'alpha0': 1.01, 'gamma0': 1}  # random walk continuous
    if p_fit_dict is None:
        p_fit_dict = {}
        if nst:
            idx = np.where(H_tk[:, -1] == t)[0]
        else:
            idx = range(H_tk.shape[0])
        for a in A_space:
            if H_tk is not None:
                indx = np.where(a == H_tk[:, 1])[0]
                indx = np.intersect1d(indx, idx)
                if indx.shape[0] > 0:
                    H_sa = H_tk[indx, :]
                    x_plus = H_sa[:, 3]
                    x = H_sa[:, 0]
                    reg = BayesianRidge()
                    reg.fit(x.reshape(-1, 1), x_plus)
                else:
                    reg = BayesianRidge(fit_intercept=False)
                    reg.coef_ = np.array([0.0])
                    reg.intercept_ = 0.
                    reg.sigma_ = np.eye(1)
                    reg.alpha_ = 1.
                p_fit_dict[a] = reg

    # Sample K MDPs from the Bayesian Fit - with uncertainty in reward
    reward_vec = np.unique(H_tk[:, 2])
    for a in A_space:
        if H_tk is not None:
            indx = np.where(a == H_tk[:, 1])[0]
            if indx.shape[0] > 0:
                all_states = H_tk[:, 0]
                mu_dict = {}
                H_sa = H_tk[indx, :]
                for s in all_states:
                    mu, std = p_fit_dict[a].predict(np.array([s]).reshape(1, -1), return_std=True)
                    mu_dict[s] = mu
                    K_Rs, r_params = categorical_reward(len(reward_vec), reward_vec, alpha0=None, H_sa=H_sa, K=K)
                    for k in range(K):
                        if use_true_tx:
                            _, mu, std = P_sa(s, a, tt=t)
                            P_sas_dict[k][(s, a)] = {'mu': mu, 'std': std}
                        else:
                            P_sas_dict[k][(s, a)] = {'mu': mu, 'std': std}
                        R_sa_dict[k][(s, a)] = K_Rs[k]

                if debug and P_sa is not None:
                    fig, (ax1, ax2) = plt.subplots(2, 1)
                    if nst:
                        ax1.plot(all_states, [P_sa(s, a, tt=t)[1] for s in all_states])
                    else:
                        ax1.plot(all_states, [P_sa(s, a, tt=0) for s in all_states])
                    plt.xlabel("States")
                    plt.ylabel(["Empirical Transition: " + ',' + str(a)])
                    ax2.plot(all_states, np.asarray(list(mu_dict.values())), marker='.')
                    plt.xlabel("States")
                    plt.ylabel(["Posterior mean: " + ',' + str(a)])
                    plt.savefig("./plots/hist" + '_' + str(a) + '.pdf')
                    plt.close(fig)
    return R_sa_dict, P_sas_dict, p_fit_dict


# sample a posterior distributon for the parameters on the state action pair (s,a)
def normal_gamma_sample(m0, lamb0, alpha0, gamma0, H_sa=None, K=1):
    m, lamb, alpha, gamma = m0, lamb0, alpha0, gamma0
    if H_sa is not None:
        n_sa = H_sa.shape[0]
        r_bar = np.mean(H_sa[:, 2])
        r_sq_bar = np.mean(H_sa[:, 2] ** 2)
        m = (lamb0 * m0 + n_sa * r_bar) / (lamb0 + n_sa)
        lamb = lamb0 + n_sa
        alpha = alpha0 + n_sa / 2
        gamma = gamma0 + 0.5 * np.sum((H_sa[:, 2] - r_bar) ** 2) + (n_sa * lamb0 * (r_bar - m0) ** 2) / (
                2 * (lamb0 + n_sa))
    tautau = np.random.gamma(alpha, gamma, K)
    sigma = np.sqrt(1 / (lamb * tautau))
    mu_sa = np.random.normal(loc=m, scale=sigma, size=K)
    return mu_sa


def categorical_reward(n, reward_vec, alpha0=None, H_sa=None, K=1, env="discrete_toy"):
    if alpha0 is not None:
        alpha = [alpha0] * len(np.unique(reward_vec))
    else:
        alpha = np.array([1 / len(np.unique(reward_vec))] * len(np.unique(reward_vec)))
        if env == "discrete_toy":
            alpha[np.where(reward_vec < 0)[0]] = 1e-3
    if H_sa is not None:
        for i, rr in enumerate(np.unique(H_sa[:, 2])):
            alpha[np.where(reward_vec == rr)[0][0]] += np.sum([int(x) for x in rr == H_sa[:, 2]])

    reward_prob = np.array(alpha) / sum(np.array(alpha))
    r_vec = np.random.choice(reward_vec, size=K, replace=True, p=reward_prob)
    return r_vec, reward_prob


# Computes the null probability
def P_H0_MV(s, t, a_behavior, a_mu, Mk_R_sa, Mk_P_sas, kset, V_st, visited_states, pi_tsa, S_space, A_space):
    Qs, i = np.zeros((len(kset), len(A_space))), 0
    for k in kset:
        # compute Q values for current state of interest            
        R_sa, P_sas = Mk_R_sa[k], Mk_P_sas[k]
        Qs[i, :] = [(R_sa[(s, a)] + sum([P_sas[(s, a)][int(nxt_s)] * V_st[k][(nxt_s, t + 1)] for nxt_s in S_space])) for
                    a in A_space]
        i += 1
    return np.mean(Qs[:, a_mu] < Qs[:, a_behavior]), Qs


def P_H0_MV_deferral(s, t, Mk_R_sa, Mk_P_sas, kset, V_st, S_space,
                     A_space, one_step=False):
    Qs, i = np.zeros((len(kset), len(A_space))), 0
    for k in kset:
        # compute Q values for current state of interest
        R_sa, P_sas = Mk_R_sa[k], Mk_P_sas[k]

        Qs[i, :] = [(R_sa[(s, a)] + (1 - one_step) * sum(
            [P_sas[(s, a)][int(nxt_s)] * V_st[k][(nxt_s, t + 1)] for nxt_s in S_space])) for
                    a in A_space]
        i += 1

    return Qs


def P_H0_MV_deferral_multistepeval(s, t, Mk_R_sa, Mk_P_sas, kset, V_st, S_space,
                                   A_space, one_step=False):
    Qs, Qs_1step, i = np.zeros((len(kset), len(A_space))), np.zeros((len(kset), len(A_space))), 0
    r_vec = np.zeros((len(kset), len(A_space)))
    for k in kset:
        # compute Q values for current state of interest
        R_sa, P_sas = Mk_R_sa[k], Mk_P_sas[k]
        r_vec[k, :] = [R_sa[(s, a)] for a in A_space]
        Qs[i, :] = [(R_sa[(s, a)] + sum(
            [P_sas[(s, a)][int(nxt_s)] * V_st[k][(nxt_s, t + 1)] for nxt_s in S_space])) for
                    a in A_space]

        Qs_1step[i, :] = [(R_sa[(s, a)]) for a in A_space]

    return Qs, Qs_1step


def P_H0_MV_deferral_cont(s, t, Mk_R_sa, Mk_P_sas, kset, V_st, S_space,
                          A_space):
    Qs, i = np.zeros((len(kset), len(A_space))), 0
    Qs_1st = np.zeros((len(kset), len(A_space)))
    for k in kset:
        # compute Q values for current state of interest
        R_sa, P_sas = Mk_R_sa[k], Mk_P_sas[k]

        Qs[i, :] = [(R_sa[(s, a)] + sum(
            [(norm.pdf((nxt_s - P_sas[(s, a)]['mu']) / P_sas[(s, a)]['std'])) / P_sas[(s, a)]['std'] *
             V_st[k][(nxt_s, t + 1)] for nxt_s in S_space])) for a in A_space]
        Qs_1st[i, :] = [(R_sa[(s, a)]) for a in A_space]
        i += 1

    return Qs, Qs_1st


# Stationary dynamics+Stochastic policy + Discrete states (can handle true dynamics):
def SLTD_stochastic(H_T, alpha, tau, K_no, pi_st, target_policy, S_space, A_space, tx_mat,
                    defer_cost=0.0, true_tx=None, learn_dynamics=True,
                    defer_method="mean", non_stationary_policy=False, non_stationary=False, one_step=False,
                    env="discrete_toy"):
    # Generate sets for estimating Q and testing H0
    K_ls = list(range(K_no))
    Qs_st = {}
    Qs_all = {}
    g_s = {(s, tt): 0.0 for s in S_space for tt in range(tau - 1, -1, -1)}
    g_s_count = {(s, tt): 0.0 for s in S_space for tt in range(tau - 1, -1, -1)}

    n_b = N_BOOTSTRAPS_NST
    dirich_alpha_dict = {}
    for b in range(n_b):  # no of trials
        V_st = {k: {(s, tau): 0 for s in S_space} for k in range(K_no)}
        V_target = {k: {(s, tau): 0 for s in S_space} for k in range(K_no)}

        a_defer_all = {}
        P_sas_dict = {}

        # Samples K MDPs from posterior s
        if learn_dynamics:
            # print("sltd_ stationary learn")
            reward_vec = np.unique(H_T[:, 2])
            Mk_R_sa, Mk_P_sas, dirich_alpha_s_a = sampleK_MDPs_defer(S_card=len(S_space), A_space=A_space,
                                                                     reward_vec=reward_vec,
                                                                     H_tk=H_T, K=K_no, P_sa=tx_mat, t=None,
                                                                     use_true_tx=False,
                                                                     nst=non_stationary, env=env)
        else:
            # check that true_tx is not None
            if true_tx is None:
                raise ValueError("Provide true dynamics model")
            else:
                # Here argument to P_sa is treated as the true dynamics to be used
                Mk_R_sa, Mk_P_sas, dirich_alpha_s_a = sampleK_MDPs_defer(S_card=len(S_space), A_space=A_space, H_tk=H_T,
                                                                         K=K_no,
                                                                         P_sa=true_tx, t=0,
                                                                         use_true_tx=not learn_dynamics,
                                                                         nst=non_stationary, env=env)

        for t in tqdm(range(tau - 1, -1, -1)):
            # in this case these estimates don't depend on time, so we copy
            dirich_alpha_dict[t] = dirich_alpha_s_a
            P_sas_dict[t] = Mk_P_sas
            for s in S_space:
                # Compute P(H_0|s,d,H_T)
                Qs = P_H0_MV_deferral(s, t, Mk_R_sa=Mk_R_sa, Mk_P_sas=Mk_P_sas, kset=K_ls, V_st=V_st,
                                      S_space=S_space,
                                      A_space=A_space, one_step=one_step)
                Qs_st[(s, t)] = Qs
                n_trial = 5
                Qs_defer = np.zeros(n_trial * K_no)
                Qs_policy = np.zeros(n_trial * K_no)

                if non_stationary_policy:
                    tp = target_policy[t]
                    cp = pi_st[t]
                else:
                    tp = target_policy
                    cp = pi_st

                if len(A_space) == 2:
                    # print(tp)
                    a_policy_vec = np.random.binomial(1, p=tp[int(s)], size=n_trial)
                    a_defer_vec = np.random.binomial(1, p=cp[int(s)], size=n_trial)
                else:
                    a_policy_vec = np.random.choice(A_space, p=tp[int(s)], size=n_trial)
                    a_defer_vec = np.random.choice(A_space, p=cp[int(s)], size=n_trial)

                for nn in range(n_trial):
                    Qs_policy[(nn * K_no): ((nn + 1) * K_no)] = Qs[:, a_policy_vec[nn]]
                    Qs_defer[(nn * K_no): ((nn + 1) * K_no)] = \
                        Qs[:, a_defer_vec[nn]] - defer_cost
                if defer_method == "hypothesis":
                    P_0 = np.mean(Qs_policy < Qs_defer)
                    ind = 1 - int(P_0 < alpha)
                elif defer_method == "mean":
                    P_0 = int(np.mean(Qs_defer) > np.mean(Qs_policy))
                else:
                    raise ValueError(" Defer method %s is not implemented!" % defer_method)

                g_s[(s, t)] += P_0
                g_s_count[(s, t)] += 1

                for k in range(K_no):
                    # Compute policy based on P-value rule
                    # Compute value function based on chosen policy - defer
                    R_sa, P_sas = Mk_R_sa[k], Mk_P_sas[k]
                    # if ind:
                    epsilon_vec = np.random.binomial(n=1, p=g_s[(s, t)] / (g_s_count[(s, t)] + 1e-7), size=n_trial)
                    # epsilon_vec = 1.0
                    if len(A_space) == 2:
                        a_defer_vec = epsilon_vec * np.random.binomial(n=1, p=cp[int(s)], size=n_trial) + \
                                      (1 - epsilon_vec) * np.random.binomial(n=1, p=tp[int(s)], size=n_trial)
                    else:
                        a_defer_vec = epsilon_vec * np.random.choice(a=A_space, p=cp[int(s)], size=n_trial) + \
                                      (1 - epsilon_vec) * np.random.choice(a=A_space, p=tp[int(s)], size=n_trial)
                    V_st[k][(s, t)] = np.mean(
                        [(R_sa[(s, a)] + sum(
                            [P_sas[(s, a)][int(nxt_s)] * V_st[k][(nxt_s, t + 1)] for nxt_s in S_space]))
                         for a in a_defer_vec])

                    a_defer_all[(s, t)] = a_defer_vec
                    # value corresponding to not deferring
                    if len(A_space) == 2:
                        a_behavior_vec = np.random.binomial(n=1, p=tp[int(s)], size=5)
                    else:
                        a_behavior_vec = np.random.choice(a=A_space, p=tp[int(s)], size=5)
                    V_target[k][(s, t)] = np.mean(
                        [(R_sa[(s, a)] + sum(
                            [P_sas[(s, a)][int(nxt_s)] * V_target[k][(nxt_s, t + 1)] for nxt_s in S_space]))
                         for a in a_behavior_vec])

        V_defer = 0
        V_tar = 0
        for s in S_space:
            V_defer_s = np.mean([V_st[k][(s, 0)] for k in range(K_no)])
            V_tar_s = np.mean([V_target[k][(s, 0)] for k in range(K_no)])
            V_defer += V_defer_s
            V_tar += V_tar_s
        V_defer = V_defer / len(S_space)
        V_tar = V_tar / len(S_space)
        # print("V_defer", V_defer / len(S_space), "V_tar", V_tar / len(S_space))
    Qs_all[b] = {'V_defer': V_defer / len(S_space), 'V_tar': V_tar / len(S_space)}
    for k in g_s.keys():
        g_s[k] = g_s[k] / g_s_count[k]
    return g_s, Qs_all, dirich_alpha_dict, a_defer_all, P_sas_dict


# Stationary dynamics+Stochastic policy + Continuous states (can handle true dynamics):
def SLTD_stochastic_cont(H_T, alpha, tau, K_no, clinician_policy, target_policy, A_space, tx_mat, n_eps,
                         defer_cost=0.0, defer_method="mean", learn_dynamics=True, non_stationary_policy=True):
    # Generate sets for estimating Q and testing H0
    K_ls = list(range(K_no))
    # Samples K MDPs from posterior s
    p_fit = None
    if learn_dynamics:
        Mk_R_sa, Mk_P_sas, p_fit = sampleK_MDPs_defer_cont(A_space=A_space, H_tk=H_T, K=K_no, P_sa=tx_mat,
                                                           p_fit_dict=p_fit, debug=True,
                                                           use_true_tx=False, nst=True)
    else:
        # check that true_tx is not None
        if tx_mat is None:
            raise ValueError("Provide true dynamics model")
        else:
            # Here argument to P_sa is treated as the true dynamics to be used
            Mk_R_sa, Mk_P_sas, p_fit = sampleK_MDPs_defer_cont(A_space=A_space, H_tk=H_T, K=K_no, P_sa=tx_mat,
                                                               p_fit_dict=p_fit, debug=True,
                                                               use_true_tx=True, nst=True)
    # initialize value Vtau(S) and policy function dictionaries:
    g_s_all = {}
    Qs_all = {}
    model = {}
    S_space = H_T[:, 0]  # , replace=True, size=20)
    g_s = {s: 0 for s in S_space}
    g_s_count = {s: 0 for s in S_space}
    n_b = N_BOOTSTRAPS
    for b in range(n_b):  # no of bootstraps
        V_st = {k: {(s, tau - 1): 0 for s in S_space} for k in range(K_no)}
        V_target = {k: {(s, tau - 1): 0 for s in S_space} for k in range(K_no)}

        Qs_st = {}
        S_t_space0 = np.array([])
        S_t_spaceu = np.array([])
        defer_samples = []
        for t in tqdm(range(tau - 2, 0, -1)):
            S_t_space = np.random.choice(S_space[(tau - 1) * np.array((range(n_eps - 1))) + t - 1], size=BOOTSTRAP_SIZE)
            if t == tau - 2:
                S_tp1_space = S_t_space
            else:
                S_tp1_space = np.random.choice(S_t_space0, size=BOOTSTRAP_SIZE)

            for s in S_t_space:
                epsilon = 0.00
                Qs, _ = P_H0_MV_deferral_cont(s, t, Mk_R_sa=Mk_R_sa, Mk_P_sas=Mk_P_sas, kset=K_ls, V_st=V_st,
                                              S_space=S_tp1_space,
                                              A_space=A_space)
                Qs_st[(s, t)] = Qs
                n_trial = 10
                Qs_defer = np.zeros(n_trial * K_no)
                Qs_policy = np.zeros(n_trial * K_no)

                tp = target_policy
                cp = clinician_policy

                if non_stationary_policy:
                    if len(A_space) == 2:
                        a_policy_vec = np.random.binomial(1, p=tp(s, t), size=n_trial)
                        a_defer_vec = np.random.binomial(1, p=cp(s, t), size=n_trial)
                    else:
                        a_policy_vec = np.random.choice(A_space, p=tp(s, t), size=n_trial)
                        a_defer_vec = np.random.choice(A_space, p=cp(s, t), size=n_trial)
                else:
                    if len(A_space) == 2:
                        a_policy_vec = np.random.binomial(1, p=tp(s), size=n_trial)
                        a_defer_vec = np.random.binomial(1, p=cp(s), size=n_trial)
                    else:
                        a_policy_vec = np.random.choice(A_space, p=tp(s), size=n_trial)
                        a_defer_vec = np.random.choice(A_space, p=cp(s), size=n_trial)

                for nn in range(n_trial):
                    Qs_policy[(nn * K_no): ((nn + 1) * K_no)] = Qs[:, a_policy_vec[nn]]
                    Qs_defer[(nn * K_no): ((nn + 1) * K_no)] = \
                        Qs[:, a_defer_vec[nn]] - defer_cost

                if defer_method == "hypothesis":
                    P_0 = np.mean(Qs_policy < Qs_defer)
                    ind = 1 - int(P_0 < alpha)
                elif defer_method == "mean":
                    ind = int(np.mean(Qs_defer) < np.mean(Qs_policy))
                else:
                    raise ValueError(" Defer method %s is not implemented!" % defer_method)

                g_s[s] += ind
                g_s_count[s] += 1

                defer_samples.append([s, ind])

                for k in range(K_no):
                    R_sa, P_sas = Mk_R_sa[k], Mk_P_sas[k]
                    if s in g_s.keys():
                        epsilon_vec = np.random.binomial(1, p=g_s[s] / g_s_count[s], size=n_trial)
                    else:
                        epsilon_vec = 0.0
                    if non_stationary_policy:
                        if len(A_space) == 2:
                            a_policy_vec = epsilon_vec * np.random.binomial(n=1, p=cp(s, t), size=n_trial) + \
                                           (1 - epsilon_vec) * np.random.binomial(n=1, p=tp(s, t), size=n_trial)
                        else:
                            a_policy_vec = epsilon_vec * np.random.choice(a=A_space, p=cp(s, t), size=n_trial) + \
                                           (1 - epsilon_vec) * np.random.choice(a=A_space, p=tp(s, t), size=n_trial)
                    else:
                        if len(A_space) == 2:
                            a_policy_vec = epsilon_vec * np.random.binomial(n=1, p=cp(s), size=n_trial) + \
                                           (1 - epsilon_vec) * np.random.binomial(n=1, p=tp(s), size=n_trial)
                        else:
                            a_policy_vec = epsilon_vec * np.random.choice(a=A_space, p=cp(s), size=n_trial) + \
                                           (1 - epsilon_vec) * np.random.choice(a=A_space, p=tp(s), size=n_trial)
                    V_st[k][(s, t)] = np.mean(
                        [(R_sa[(s, a)] + sum(
                            [(norm.pdf((nxt_s - P_sas[(s, a)]['mu']) / P_sas[(s, a)]['std'])) / P_sas[(s, a)][
                                'std'] *
                             V_st[k][(nxt_s, t + 1)] for nxt_s in S_tp1_space]))
                         for a in a_policy_vec])

                    # value corresponding to not deferring
                    if non_stationary_policy:
                        a_behavior_vec = np.random.binomial(1, target_policy(s, t), size=5)
                    else:
                        a_behavior_vec = np.random.binomial(1, target_policy(s), size=5)
                    V_target[k][(s, t)] = np.mean(
                        [(R_sa[(s, a)] + sum(
                            [(norm.pdf((nxt_s - P_sas[(s, a)]['mu']) / P_sas[(s, a)]['std'])) / P_sas[(s, a)]['std'] *
                             V_target[k][(nxt_s, t + 1)] for nxt_s in S_tp1_space]))
                         for a in a_behavior_vec])

                    if not NON_STATIONARY_LEARN:
                        t_vec = np.setdiff1d(list(range(0, tau - 1)), t)
                        for tt in t_vec:
                            V_st[k][(s, tt)] = V_st[k][(s, t)]
                            V_target[k][(s, tt)] = V_target[k][(s, t)]

            if NON_STATIONARY_LEARN:
                S_t_space0 = S_t_space
            else:
                S_t_space0 = np.union1d(S_t_space0, S_t_space)
            S_t_spaceu = np.union1d(S_t_spaceu, S_t_space)

        V_defer = 0
        V_tar = 0
        for s in S_t_space0:
            V_defer_s = np.mean([V_st[k][(s, 1)] for k in range(K_no)])
            V_tar_s = np.mean([V_target[k][(s, 1)] for k in range(K_no)])
            V_defer += V_defer_s
            V_tar += V_tar_s
        V_defer = V_defer / np.sum(np.array(list(g_s_count.values())))
        V_tar = V_tar / np.sum(np.array(list(g_s_count.values())))
        print("V_defer", V_defer, "V_tar", V_tar)
        for ss in g_s.keys():
            g_s[ss] = g_s[ss] / (g_s_count[ss] + 1)
        g_s_all[b] = g_s
        Qs_all[b] = {"V_defer": V_defer, "V_tar": V_tar}
        model[b] = None
    return g_s_all, Qs_all, model


def update_values(g_s, g_s_count, V_st, cp, tp, R_sa, P_sas, s, t, k, A_space, S_space, a_policy_vec, a_defer_vec,
                  n_trial=5):
    # if (s, t) in g_s.keys():
    # epsilon_vec = np.random.binomial(n=1, p=g_s[(s, t)] / g_s_count[(s, t)], size=n_trial)
    epsilon_vec = np.random.binomial(n=1, p=g_s / g_s_count, size=n_trial)
    # else:
    #    epsilon_vec = 0.0

    if len(A_space) == 2:
        a_policy_vec = epsilon_vec * np.random.binomial(n=1, p=cp[int(s)], size=n_trial) + \
                       (1 - epsilon_vec) * np.random.binomial(n=1, p=tp[int(s)], size=n_trial)
    else:
        a_policy_vec = epsilon_vec * np.random.choice(a=A_space, p=cp[int(s)], size=n_trial) + \
                       (1 - epsilon_vec) * np.random.choice(a=A_space, p=tp[int(s)], size=n_trial)

    V = np.mean(
        [(R_sa[(s, a)] + sum(
            [P_sas[(s, a)][int(nxt_s)] * V_st[k][(nxt_s, t + 1)] for nxt_s in S_space]))
         for a in a_policy_vec])
    return V


# Non-stationary dynamics+Stochastic policy + Discrete states (can handle true dynamics):
def SLTD_stochastic_nst_multistep(H_T, alpha, tau, K_no, pi_st, target_policy, S_space, A_space, tx_mat,
                                  defer_cost=0.0, true_tx=None, learn_dynamics=True,
                                  defer_method="mean", non_stationary_policy=False, non_stationary=True,
                                  one_step=False, env="discrete_toy"):
    # Generate sets for estimating Q and testing H0
    K_ls = list(range(K_no))
    # initialize value Vtau(S) and policy function dictionaries:
    Qs_st = {}
    Qs_st_1step = {}
    Qs_all = {}
    Qs_all_1step = {}
    g_s = {(s, tt): 0.0 for s in S_space for tt in range(tau - 1, -1, -1)}
    g_s_count = {(s, tt): 0.0 for s in S_space for tt in range(tau - 1, -1, -1)}

    g_s_1st = {(s, tt): 0.0 for s in S_space for tt in range(tau - 1, -1, -1)}
    g_s_count_1st = {(s, tt): 0.0 for s in S_space for tt in range(tau - 1, -1, -1)}
    # g_s = {s: 0.5 for s in S_space}
    n_b = N_BOOTSTRAPS_NST
    dirich_alpha_dict = {}
    for b in range(n_b):  # no of trials
        V_st = {k: {(s, tau): 0 for s in S_space} for k in range(K_no)}
        V_defer = {k: {(s, tau): 0 for s in S_space} for k in range(K_no)}
        V_target = {k: {(s, tau): 0 for s in S_space} for k in range(K_no)}

        V_st_1st = {k: {(s, tau): 0 for s in S_space} for k in range(K_no)}
        V_defer_1st = {k: {(s, tau): 0 for s in S_space} for k in range(K_no)}
        V_target_1st = {k: {(s, tau): 0 for s in S_space} for k in range(K_no)}
        a_defer_all = {}
        P_sas_dict = {}
        for t in tqdm(range(tau - 1, -1, -1)):
            # Samples K MDPs from posterior s
            if learn_dynamics:
                idx = np.where(H_T[:, -1] == t)[0]
                reward_vec = np.unique(H_T[:, 2])
                Mk_R_sa, Mk_P_sas, dirich_alpha_s_a = sampleK_MDPs_defer(S_card=len(S_space), A_space=A_space,
                                                                         reward_vec=reward_vec,
                                                                         H_tk=H_T[idx], K=K_no, P_sa=tx_mat, t=t,
                                                                         use_true_tx=False,
                                                                         nst=True, env=env)

            else:
                # check that true_tx is not None
                if true_tx is None:
                    raise ValueError("Provide true dynamics model")
                else:
                    # Here argument to P_sa is treated as the true dynamics to be used
                    Mk_R_sa, Mk_P_sas, dirich_alpha_s_a = sampleK_MDPs_defer(S_card=len(S_space), A_space=A_space,
                                                                             H_tk=H_T, K=K_no,
                                                                             P_sa=true_tx, t=t,
                                                                             use_true_tx=not learn_dynamics,
                                                                             nst=non_stationary)

            dirich_alpha_dict[t] = dirich_alpha_s_a
            P_sas_dict[t] = Mk_P_sas
            for s in S_space:
                Qs, Qs_1step = P_H0_MV_deferral_multistepeval(s, t, Mk_R_sa=Mk_R_sa, Mk_P_sas=Mk_P_sas, kset=K_ls,
                                                              V_st=V_st,
                                                              S_space=S_space,
                                                              A_space=A_space, one_step=one_step)

                Qs_st[(s, t)] = Qs
                Qs_st_1step[(s, t)] = Qs_1step
                n_trial = 15
                Qs_defer = np.zeros(n_trial * K_no)
                Qs_policy = np.zeros(n_trial * K_no)

                Qs_defer_1st = np.zeros(n_trial * K_no)
                Qs_policy_1st = np.zeros(n_trial * K_no)

                if non_stationary_policy:
                    tp = target_policy[t]
                    cp = pi_st[t]
                else:
                    tp = target_policy
                    cp = pi_st

                if len(A_space) == 2:
                    a_policy_vec = np.random.binomial(1, p=tp[int(s)], size=n_trial)
                    a_defer_vec = np.random.binomial(1, p=cp[int(s)], size=n_trial)
                else:
                    a_policy_vec = np.random.choice(A_space, p=tp[int(s)], size=n_trial)
                    a_defer_vec = np.random.choice(A_space, p=cp[int(s)], size=n_trial)

                for nn in range(n_trial):
                    Qs_policy[(nn * K_no): ((nn + 1) * K_no)] = Qs[:, int(a_policy_vec[nn])]
                    Qs_defer[(nn * K_no): ((nn + 1) * K_no)] = \
                        Qs[:, a_defer_vec[nn]] - defer_cost * (tau - 1 - t)

                    Qs_policy_1st[(nn * K_no): ((nn + 1) * K_no)] = Qs_1step[:, a_policy_vec[nn]]
                    Qs_defer_1st[(nn * K_no): ((nn + 1) * K_no)] = Qs_1step[:, a_defer_vec[nn]] - defer_cost

                if defer_method == "hypothesis":
                    P_0 = np.mean(Qs_policy < Qs_defer)
                    ind = 1 - int(P_0 < alpha)

                    P_0_1st = np.mean(Qs_policy_1st < Qs_defer_1st)
                    ind_1st = 1 - int(P_0_1st < alpha)

                    # if 1 < t < 6 and 1:
                    # print("t:", t, "state:", s, 'sltd-seq:', np.std(Qs_policy), Qs_defer, Qs_policy_1st, Qs_defer_1st)

                elif defer_method == "mean":
                    print('here')
                    P_0 = int(np.mean(Qs_policy) < np.mean(Qs_defer))
                    P_0_1st = int(np.mean(Qs_policy_1st) < np.mean(Qs_defer_1st))
                else:
                    raise ValueError(" Defer method %s is not implemented!" % defer_method)

                g_s[(s, t)] += P_0
                g_s_count[(s, t)] += 1

                g_s_1st[(s, t)] += P_0_1st
                g_s_count_1st[(s, t)] += 1

                for k in range(K_no):
                    # Compute policy based on P-value rule
                    # Compute value function based on chosen policy - defer
                    R_sa, P_sas = Mk_R_sa[k], Mk_P_sas[k]
                    # if ind:
                    #    V_st = update_values(g_s, g_s_count, V_st, cp, tp, R_sa, P_sas, s, t, k, A_space, S_space,
                    #                         n_trial=n_trial)
                    # else:
                    V_st[k][(s, t)] = update_values(g_s[(s, t)] / (g_s_count[(s, t)] + 1e-7), 1, V_st, cp, tp, R_sa,
                                                    P_sas, s, t,
                                                    k, A_space,
                                                    S_space, a_policy_vec, a_defer_vec,
                                                    n_trial=n_trial)

                    V_st_1st[k][(s, t)] = update_values(g_s_1st[(s, t)] / (g_s_count_1st[(s, t)] + 1e-7), 1, V_st_1st,
                                                        cp, tp,
                                                        R_sa, P_sas, s, t, k,
                                                        A_space, S_space, a_policy_vec, a_defer_vec,
                                                        n_trial=n_trial)

                    # value corresponding to not deferring
                    a_defer_all[(s, t)] = a_defer_vec
                    if len(A_space) == 2:
                        a_behavior_vec = np.random.binomial(n=1, p=tp[int(s)], size=n_trial)
                    else:
                        a_behavior_vec = np.random.choice(a=A_space, p=tp[int(s)], size=n_trial)

                    V_target[k][(s, t)] = np.mean(
                        [(R_sa[(s, a)] + sum(
                            [P_sas[(s, a)][int(nxt_s)] * V_target[k][(nxt_s, t + 1)] for nxt_s in S_space]))
                         for a in a_policy_vec])

                    V_target_1st[k][(s, t)] = np.mean(
                        [(R_sa[(s, a)] + sum(
                            [P_sas[(s, a)][int(nxt_s)] * V_target_1st[k][(nxt_s, t + 1)] for nxt_s in S_space]))
                         for a in a_policy_vec])

        V_defer = 0
        V_tar = 0
        for s in S_space:
            V_defer_s = np.mean([V_st[k][(s, 0)] for k in range(K_no)])
            V_tar_s = np.mean([V_target[k][(s, 0)] for k in range(K_no)])
            V_defer += V_defer_s
            V_tar += V_tar_s
        V_defer = V_defer / len(S_space)
        V_tar = V_tar / len(S_space)
        print("V_defer", V_defer, "V_tar", V_tar)

        V_defer_1st = 0
        V_tar_1st = 0
        for s in S_space:
            V_defer_s = np.mean([V_st_1st[k][(s, 0)] for k in range(K_no)])
            V_tar_s = np.mean([V_target_1st[k][(s, 0)] for k in range(K_no)])
            V_defer_1st += V_defer_s
            V_tar_1st += V_tar_s
        V_defer_1st = V_defer_1st / len(S_space)
        V_tar_1st = V_tar_1st / len(S_space)
        print("V_defer", V_defer_1st, "V_tar", V_tar_1st)

    Qs_all[b] = {'V_defer': V_defer, 'V_tar': V_tar}
    Qs_all_1step[b] = {'V_defer': V_defer_1st, 'V_tar': V_tar_1st}
    for k in g_s.keys():
        g_s[k] = g_s[k] / g_s_count[k]
        g_s_1st[k] = g_s_1st[k] / g_s_count_1st[k]

    return g_s, Qs_all, g_s_1st, Qs_all_1step, dirich_alpha_dict, a_defer_all, P_sas_dict


# Non-stationary dynamics+Stochastic policy + Continuous states (can handle true dynamics):
def SLTD_stochastic_cont_nst(H_T, alpha, tau, K_no, clinician_policy, target_policy, A_space, tx_mat, n_eps,
                             defer_cost=0.0, defer_method="mean", learn_dynamics=True, non_stationary_policy=False):
    # Generate sets for estimating Q and testing H0
    K_ls = list(range(K_no))
    # Samples K MDPs from posterior s
    # initialize value Vtau(S) and policy function dictionaries:
    g_s_all = {}
    Qs_all = {}
    model = {}
    S_space = H_T[:, 0]  # , replace=True, size=20)

    g_s = {(s, tt): 0.0 for s in S_space for tt in range(tau - 1, -1, -1)}
    g_s_count = {(s, tt): 0.0 for s in S_space for tt in range(tau - 1, -1, -1)}
    g_s_1st = {(s, tt): 0.0 for s in S_space for tt in range(tau - 1, -1, -1)}
    g_s_count_1st = {(s, tt): 0.0 for s in S_space for tt in range(tau - 1, -1, -1)}

    n_b = N_BOOTSTRAPS_NST
    V_defer = 0
    V_tar = 0
    for b in range(n_b):  # no of bootstraps
        V_st = {k: {(s, tau - 1): 0 for s in S_space} for k in range(K_no)}
        V_st_1st = {k: {(s, tau - 1): 0 for s in S_space} for k in range(K_no)}
        V_target = {k: {(s, tau - 1): 0 for s in S_space} for k in range(K_no)}

        Qs_st = {}
        Qs_st_1st = {}
        S_t_space0 = np.array([])
        defer_samples = []
        P_sas_dict = {}
        Reg_dict = {}
        for t in tqdm(range(tau - 2, 0, -1)):
            p_fit = None
            if learn_dynamics:
                Mk_R_sa, Mk_P_sas, p_fit = sampleK_MDPs_defer_cont(A_space=A_space, H_tk=H_T, K=K_no, P_sa=tx_mat,
                                                                   p_fit_dict=p_fit, t=t, debug=True,
                                                                   use_true_tx=False, nst=True)
            else:
                # check that true_tx is not None
                if tx_mat is None:
                    raise ValueError("Provide true dynamics model")
                else:
                    # Here argument to P_sa is treated as the true dynamics to be used
                    Mk_R_sa, Mk_P_sas, p_fit = sampleK_MDPs_defer_cont(A_space=A_space, H_tk=H_T, K=K_no, P_sa=tx_mat,
                                                                       p_fit_dict=p_fit, t=t, debug=True,
                                                                       use_true_tx=True, nst=True)
            Reg_dict[t] = p_fit
            P_sas_dict[t] = Mk_P_sas
            S_t_space = np.random.choice(S_space[(tau - 1) * np.array((range(n_eps - 1))) + t - 1], size=BOOTSTRAP_SIZE)
            if t == tau - 2:
                S_tp1_space = S_t_space
            else:
                S_tp1_space = np.random.choice(S_t_space0, size=BOOTSTRAP_SIZE)

            for s in S_t_space:
                epsilon = 0.00
                Qs, Qs_1st = P_H0_MV_deferral_cont(s, t, Mk_R_sa=Mk_R_sa, Mk_P_sas=Mk_P_sas, kset=K_ls, V_st=V_st,
                                                   S_space=S_tp1_space,
                                                   A_space=A_space)
                Qs_st[(s, t)] = Qs
                Qs_st_1st[(s, t)] = Qs_1st
                n_trial = 10
                Qs_defer = np.zeros(n_trial * K_no)
                Qs_policy = np.zeros(n_trial * K_no)

                Qs_defer_1st = np.zeros(n_trial * K_no)
                Qs_policy_1st = np.zeros(n_trial * K_no)

                tp = target_policy
                cp = clinician_policy

                # print(tp[int(s)], t)
                if non_stationary_policy:
                    if len(A_space) == 2:
                        a_policy_vec = np.random.binomial(1, p=tp(s, t), size=n_trial)
                        a_defer_vec = np.random.binomial(1, p=cp(s, t), size=n_trial)
                    else:
                        a_policy_vec = np.random.choice(A_space, p=tp(s, t), size=n_trial)
                        a_defer_vec = np.random.choice(A_space, p=cp(s, t), size=n_trial)
                else:
                    if len(A_space) == 2:
                        a_policy_vec = np.random.binomial(1, p=tp(s), size=n_trial)
                        a_defer_vec = np.random.binomial(1, p=cp(s), size=n_trial)
                    else:
                        a_policy_vec = np.random.choice(A_space, p=tp(s), size=n_trial)
                        a_defer_vec = np.random.choice(A_space, p=cp(s), size=n_trial)

                for nn in range(n_trial):
                    Qs_policy[(nn * K_no): ((nn + 1) * K_no)] = Qs[:, a_policy_vec[nn]]
                    Qs_defer[(nn * K_no): ((nn + 1) * K_no)] = \
                        Qs[:, a_defer_vec[nn]] - defer_cost

                    Qs_policy_1st[(nn * K_no): ((nn + 1) * K_no)] = Qs_1st[:, a_policy_vec[nn]]
                    Qs_defer_1st[(nn * K_no): ((nn + 1) * K_no)] = Qs_1st[:, a_defer_vec[nn]] - defer_cost

                if defer_method == "hypothesis":
                    P_0 = np.mean(Qs_policy < Qs_defer)
                    ind = 1 - int(P_0 < alpha)

                    P_0_1st = np.mean(Qs_policy_1st < Qs_defer_1st)
                    ind_1st = 1 - int(P_0_1st < alpha)
                elif defer_method == "mean":
                    ind = int(np.mean(Qs_defer) < np.mean(Qs_policy))
                else:
                    raise ValueError(" Defer method %s is not implemented!" % defer_method)

                g_s[(s, t)] += P_0
                g_s_count[(s, t)] += 1

                g_s_1st[(s, t)] += P_0_1st
                g_s_count_1st[(s, t)] += 1

                defer_samples.append([s, ind])

                for k in range(K_no):
                    R_sa, P_sas = Mk_R_sa[k], Mk_P_sas[k]
                    if (s, t) in g_s.keys():
                        epsilon_vec = np.random.binomial(n=1, p=g_s[(s, t)] / g_s_count[(s, t)], size=n_trial)
                    else:
                        epsilon_vec = 0.0
                    if non_stationary_policy:
                        if len(A_space) == 2:
                            a_policy_vec = epsilon_vec * np.random.binomial(n=1, p=cp(s, t), size=n_trial) + \
                                           (1 - epsilon_vec) * np.random.binomial(n=1, p=tp(s, t), size=n_trial)
                        else:
                            a_policy_vec = epsilon_vec * np.random.choice(a=A_space, p=cp(s, t), size=n_trial) + \
                                           (1 - epsilon_vec) * np.random.choice(a=A_space, p=tp(s, t), size=n_trial)
                    else:
                        if len(A_space) == 2:
                            a_policy_vec = epsilon_vec * np.random.binomial(n=1, p=cp(s), size=n_trial) + \
                                           (1 - epsilon_vec) * np.random.binomial(n=1, p=tp(s), size=n_trial)
                        else:
                            a_policy_vec = epsilon_vec * np.random.choice(a=A_space, p=cp(s), size=n_trial) + \
                                           (1 - epsilon_vec) * np.random.choice(a=A_space, p=tp(s), size=n_trial)

                    V_st[k][(s, t)] = np.mean(
                        [(R_sa[(s, a)] + sum(
                            [(norm.pdf((nxt_s - P_sas[(s, a)]['mu']) / P_sas[(s, a)]['std'])) / P_sas[(s, a)][
                                'std'] *
                             V_st[k][(nxt_s, t + 1)] for nxt_s in S_tp1_space]))
                         for a in a_policy_vec])

                    if (s, t) in g_s_1st.keys():
                        epsilon_vec = np.random.binomial(n=1, p=g_s_1st[(s, t)] / g_s_count_1st[(s, t)], size=n_trial)
                    else:
                        epsilon_vec = 0.0
                    if non_stationary_policy:
                        if len(A_space) == 2:
                            a_policy_vec = epsilon_vec * np.random.binomial(n=1, p=cp(s, t), size=n_trial) + \
                                           (1 - epsilon_vec) * np.random.binomial(n=1, p=tp(s, t), size=n_trial)
                        else:
                            a_policy_vec = epsilon_vec * np.random.choice(a=A_space, p=cp(s, t), size=n_trial) + \
                                           (1 - epsilon_vec) * np.random.choice(a=A_space, p=tp(s, t), size=n_trial)
                    else:
                        if len(A_space) == 2:
                            a_policy_vec = epsilon_vec * np.random.binomial(n=1, p=cp(s), size=n_trial) + \
                                           (1 - epsilon_vec) * np.random.binomial(n=1, p=tp(s), size=n_trial)
                        else:
                            a_policy_vec = epsilon_vec * np.random.choice(a=A_space, p=cp(s), size=n_trial) + \
                                           (1 - epsilon_vec) * np.random.choice(a=A_space, p=tp(s), size=n_trial)
                    V_st_1st[k][(s, t)] = np.mean(
                        [(R_sa[(s, a)] + sum(
                            [(norm.pdf((nxt_s - P_sas[(s, a)]['mu']) / P_sas[(s, a)]['std'])) / P_sas[(s, a)][
                                'std'] *
                             V_st[k][(nxt_s, t + 1)] for nxt_s in S_tp1_space]))
                         for a in a_policy_vec])

                    # value corresponding to not deferring
                    if non_stationary_policy:
                        a_behavior_vec = np.random.binomial(1, target_policy(s, t), size=5)
                    else:
                        a_behavior_vec = np.random.binomial(1, target_policy(s), size=5)
                    V_target[k][(s, t)] = np.mean(
                        [(R_sa[(s, a)] + sum(
                            [(norm.pdf((nxt_s - P_sas[(s, a)]['mu']) / P_sas[(s, a)]['std'])) / P_sas[(s, a)]['std'] *
                             V_target[k][(nxt_s, t + 1)] for nxt_s in S_tp1_space]))
                         for a in a_behavior_vec])

            # if NON_STATIONARY_LEARN:
            S_t_space0 = S_t_space

        V_defer_s = {s: 0 for s in S_space}
        V_tar_s = {s: 0 for s in S_space}
        for s in S_t_space0:
            V_defer_s[s] = np.mean([V_st[k][(s, 1)] for k in range(K_no)])
            V_tar_s[s] = np.mean([V_target[k][(s, 1)] for k in range(K_no)])
            # print('s', s, g_s[s], 'V defer', V_defer_s, 'V target', V_tar_s)
        V_defer = 0
        V_tar = 0
        V_defer = np.mean(list(V_defer_s.values()))  # )
        V_tar = np.mean(list(V_tar_s.values()))  # np.sum(np.array(list(g_s_count.values())))
        print("V_defer", V_defer, "V_tar", V_tar)
        for ss in g_s.keys():
            if g_s_count[ss] != 0:
                g_s[ss] = g_s[ss] / (g_s_count[ss])
        g_s_all[b] = g_s
        Qs_all[b] = {"V_defer": V_defer, "V_tar": V_tar}
        model[b] = None
    return g_s, Qs_all, g_s_1st, P_sas_dict, Reg_dict
