import numpy as np
import matplotlib.pyplot as plt
import gurobipy as gp
from scipy.optimize import fsolve

def feasibility_check(s, k, F, phi):
    feasibility = np.ones((s, k), dtype=int)
    for a in range(s):
        for i in range(k):
            if not np.all(F[a, i, :] >= phi):
                feasibility[a, i] = 0
    return feasibility

def optimal_solution(s, X_mean, feasibility):
    opt_solution = np.zeros(s, dtype=int)
    for a in range(s):
        if np.all(feasibility[a] == 0):
            opt_solution[a] = -1
        else:
            feasible_means = np.where(feasibility[a] == 1, X_mean[a], -np.inf)
            opt_solution[a] = np.argmax(feasible_means)
    return opt_solution

def sub_solutions(s, opt_solution, X_mean, feasibility):

    D1 = [set() for _ in range(s)]
    D2 = [set() for _ in range(s)]
    D3 = [set() for _ in range(s)]

    for a in range(s):

        suboptimal_indices = np.where((feasibility[a] == 1) & (np.arange(X_mean.shape[1]) != opt_solution[a]))[0]
        D1[a] = set(suboptimal_indices)

        infeasible_greater_indices = np.where((feasibility[a] == 0) & (X_mean[a] > X_mean[a, opt_solution[a]]))[0]
        D2[a] = set(infeasible_greater_indices)

        infeasible_smaller_indices = np.where((feasibility[a] == 0) & (X_mean[a] <= X_mean[a, opt_solution[a]]))[0]
        D3[a] = set(infeasible_smaller_indices)

    return D1, D2, D3


def kl_divergence(p, q):
    return p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))

def beta(t, delta):
    threshold = np.log((np.log(t)+1)/delta)
    return threshold

def indicator(a, b):
    if a <= b:
        return 1
    else:
        return 0

def is_complete_square(n):
    p = np.floor(np.sqrt(n))
    return p*p == n

def glrt(s, k, m, dist, X_mean, F, opt_solution, X_variance, feasibility, phi, omega):
    D1, D2, D3 = sub_solutions(s, opt_solution, X_mean, feasibility)
    if np.any(opt_solution == -1):
        return 0

    if dist == "Bernoulli":
        X_mean = np.clip(X_mean, 0.001, 0.999)
    F = np.clip(F, 0.001, 0.999)

    KL_table = np.array([kl_divergence(F[:, :, j], phi[j]) for j in range(m)]).transpose(1, 2, 0)

    V1, V2, V3, V4 =  np.full(s, np.inf), np.full(s, np.inf), np.full(s, np.inf), np.full(s, np.inf)

    for a in range(s):

        V1[a] = np.min(omega[a, opt_solution[a], 0:m] * KL_table[a, opt_solution[a], :])

        re = []
        for i in D1[a]:
            if dist == "Bernoulli":
                lambda_ = (omega[a, i, m] * X_mean[a, i] + omega[a, opt_solution[a], m] * X_mean[
                    a, opt_solution[a]]) / \
                          (omega[a, i, m] + omega[a, opt_solution[a], m])
                re.append(omega[a, i, m] * kl_divergence(X_mean[a, i], lambda_) +
                          omega[a, opt_solution[a], m] * kl_divergence(X_mean[a, opt_solution[a]], lambda_))
            elif dist == "Gaussian":
                re.append((X_mean[a, i] - X_mean[a, opt_solution[a]])**2/
                          (2 * (X_variance[a, i] / omega[a, i, m] + X_variance[a, opt_solution[a]] / omega[
                              a, opt_solution[a], m])))

        V2[a] = min(re) if len(re)>0 else np.inf

        re = [sum(omega[a, i, j] * KL_table[a, i, j] for j in range(m) if F[a, i, j] < phi[j]) for i in D2[a]]
        V3[a] = min(re) if len(re)>0 else np.inf

        re = []
        for i in D3[a]:
            if dist == "Bernoulli":
                lambda_ = (omega[a, i, m] * X_mean[a, i] + omega[a, opt_solution[a], m] * X_mean[
                    a, opt_solution[a]]) / \
                          (omega[a, i, m] + omega[a, opt_solution[a], m])
                temp = omega[a, i, m] * kl_divergence(X_mean[a, i], lambda_) + \
                       omega[a, opt_solution[a], m] * kl_divergence(X_mean[a, opt_solution[a]], lambda_)
            elif dist == "Gaussian":
                temp = (X_mean[a, i] - X_mean[a, opt_solution[a]]) ** 2 / \
                       (2 * (X_variance[a, i] / omega[a, i, m] + X_variance[a, opt_solution[a]] / omega[
                           a, opt_solution[a], m]))

            temp += sum(omega[a, i, j] * KL_table[a, i, j] for j in range(m) if F[a, i, j] < phi[j])
            re.append(temp)
        V4[a] = min(re) if len(re)>0 else np.inf

    return np.min([np.min(V1), np.min(V2), np.min(V3), np.min(V4)])

def undersampled(s, k, m, n0, total_sample, alternative_count):
    threshold1 = np.sqrt(total_sample) - (s * k * (m + 1)) / 2
    undersample = alternative_count < threshold1
    return undersample

def track(optimal_ratio, alternative_count):

    total_sample = np.sum(alternative_count)
    diff = total_sample * optimal_ratio - alternative_count

    max_diff_indices = np.unravel_index(np.argmax(diff, axis=None), diff.shape)
    next_task, next_arm, next_cons = max_diff_indices
    return next_task, next_arm, next_cons

def update_ratio(alternative_count):
    total_count = np.sum(alternative_count)
    if total_count > 0:
        ratio = alternative_count / total_count
    else:
        ratio = np.zeros_like(alternative_count, dtype=float)
    return ratio

def ratio_figure(ratio_hist):
    plt.figure(figsize=(12, 6))
    _, s, k, m = ratio_hist.shape
    t = ratio_hist.shape[0]
    markers = ['o', 'v', '^', '<', '>', 's', 'p', '*', 'h']
    colors = plt.cm.tab20.colors
    linestyles = ['-', '--', '-.', ':']
    for a in range(s):
        for i in range(k):
            for j in range(m):
                marker = markers[(a * k * m + i * m + j) % len(markers)]
                color = colors[(a * k * m + i * m + j) % len(colors)]
                linestyle = linestyles[(a * k * m + i * m + j) % len(linestyles)]
                plt.plot(range(0, t, 1000), ratio_hist[::1000, a, i, j],
                         label=f'$\\omega^{{{a+1}}}_{{{i+1}{j+1}}}$',
                         marker=marker, color=color, linestyle=linestyle)

    plt.title('Change of Allocation Ratio Over Replications')
    plt.xlabel('Iteration')
    plt.ylabel('Allocation Ratio')
    plt.legend(loc='upper left', bbox_to_anchor=(1.15, 1))
    plt.grid(True, linestyle='--')
    plt.savefig('../Figure/allocation_ratio.pdf', dpi=300, bbox_inches='tight')
    plt.show()

####################################ESR sampling rule #################################################
def equations(vars, dist, s, k, m, DM1, DM2, phi, F, X_mean, X_variance, opt_solution):

    omega = np.reshape(vars, (s, k, m+1))

    if dist == "Bernoulli":
        X_mean = np.clip(X_mean, 0.001, 0.999)
    F = np.clip(F, 0.001, 0.999)
    KL_table = np.array([kl_divergence(F[:, :, j], phi[j]) for j in range(m)]).transpose(1, 2, 0)

    eq_list = []
    used_omega_indices = set()

    equal_term = []

    for a in range(s):
        opt_idx = opt_solution[a]
        temp_term = 0
        if dist == "Gaussian":
            for i in DM1[a]:
                equal_term.append((X_mean[a, i] - X_mean[a, opt_idx])**2 / (
                            2*(X_variance[a, i] / omega[a, i, m] + X_variance[a, opt_idx] / omega[a, opt_idx, m])))
                used_omega_indices.update({(a, opt_idx, m), (a, i, m)})
                temp_term += (omega[a,i,m]**2 / X_variance[a,i])
            eq_list.append(omega[a,opt_solution[a],m]/np.sqrt(X_variance[a,opt_solution[a]])-np.sqrt(temp_term))
        elif dist == "Bernoulli":
            for i in DM1[a]:
                mu_bar = (omega[a, opt_solution[a], m] * X_mean[a, opt_solution[a]] + omega[a, i, m] * X_mean[
                    a, i]) / (omega[a, opt_solution[a], m] + omega[a, i, m])
                equal_term.append(omega[a, i, m] * kl_divergence(X_mean[a,i], mu_bar) + omega[
                a, opt_solution[a], m] * kl_divergence(X_mean[a,opt_solution[a]], mu_bar))
                used_omega_indices.update({(a, opt_idx, m), (a, i, m)})
                temp_term += (kl_divergence(X_mean[a,opt_solution[a]], mu_bar) / kl_divergence(X_mean[a,i], mu_bar))
            eq_list.append(temp_term - 1)
        for j in range(m):
            equal_term.append(omega[a, opt_idx, j] * KL_table[a, opt_idx, j])
            used_omega_indices.add((a, opt_idx, j))

        for i in DM2[a]:
            jh = np.argmax(KL_table[a, i])
            equal_term.append(omega[a, i, jh] * KL_table[a, i, jh])
            used_omega_indices.add((a, i, jh))

    for i in range(1, len(equal_term)):
        eq_list.append(equal_term[0] - equal_term[i])

    unused_indices = {(a, i, j) for a in range(s) for i in range(k) for j in range(m + 1)} - used_omega_indices
    for a, i, j in unused_indices:
        eq_list.append(omega[a, i, j])

    eq_list.append(np.sum(omega) - 1)
    return eq_list

def optimality_check(dist, s, m, D3, phi, F, X_mean, X_variance, opt_solution, ratio):
    flag = True
    F = np.clip(F, 0.001, 0.999)

    KL_table = np.array([kl_divergence(F[:, :, j], phi[j]) for j in range(m)]).transpose(1, 2, 0)

    opt_check = [[] for _ in range(s)]
    for a in range(s):
        opt_idx = opt_solution[a]
        opt_mean = X_mean[a, opt_idx]
        opt_variance = X_variance[a, opt_idx]
        opt_ratio = ratio[a, opt_idx, m]

        for i in D3[a]:
            mean_i = X_mean[a, i]
            variance_i = X_variance[a, i]
            ratio_i = ratio[a, i, m]
            ratio_sum = ratio_i + opt_ratio

            lambda_ = (ratio_i * mean_i + opt_ratio * opt_mean) / ratio_sum

            if dist == "Gaussian":
                variance_sum = variance_i / ratio_i + opt_variance / opt_ratio
                th1 = (mean_i - lambda_)**2 / (2 * variance_sum)
            elif dist == "Bernoulli":
                th1 = kl_divergence(mean_i, lambda_)

            th2 = KL_table[a, i].max()

            if th1 <= th2:
                opt_check[a].append(0)
                flag = False
            else:
                opt_check[a].append(1)

    return opt_check, flag

def esr_optratio(dist, s, k, m, feasibility, phi, F, X_mean, X_variance, opt_solution, alternative_count):
    t = np.sum(alternative_count)
    tol_var = s*k*(m + 1)

    D1, D2, D3 = sub_solutions(s, opt_solution, X_mean, feasibility)
    if np.any(opt_solution == -1) or any(not subset for subset in D1):
        opt_ratio = np.full((s, k, m+1), 1.0 / tol_var)
    else:
        initial_guess = np.ones((s, k, m + 1)) * (1 / (s * k * (m + 1)))
        DM1 = [D1[i].union(D3[i]) for i in range(s)]

        ratio = np.reshape(fsolve(equations, initial_guess, args = (dist, s, k, m, DM1, D2, phi, F, X_mean, X_variance, opt_solution)),(s, k, m + 1))

        opt_check, flag = optimality_check(dist, s, m, D3, phi, F, X_mean, X_variance, opt_solution, ratio)
        if flag:
            opt_ratio = ratio
        else:
            for a in range(s):
                for idx, val in enumerate(D3[a]):
                    if opt_check[a][idx] == 0:
                        D2[a].add(val)
                    else:
                        D1[a].add(val)
            ratio = np.reshape(fsolve(equations, initial_guess, args = (dist, s, k, m, D1, D2, phi, F, X_mean, X_variance, opt_solution)),(s, k, m + 1))
            opt_ratio = ratio
    return opt_ratio
#######################################################################################################

#################################ASR sampling rule ####################################################

def asr_equations(s, k, m, feasibility, phi, F, X_mean, X_variance, opt_solution, alternative_count):
    t = np.sum(alternative_count)
    tol_var = s*k*(m + 1)

    D1, D2, D3 = sub_solutions(s, opt_solution, X_mean, feasibility)
    D2 = [D2[i].union(D3[i]) for i in range(s)] #For simplicity, union the D2 and D3
    if np.any(opt_solution == -1) or any(not subset for subset in D1):
        optimal_ratio = np.full((s, k, m + 1), 1.0 / tol_var)
    else:
        F = np.clip(F, 0.001, 0.999)
        KL_table = np.array([kl_divergence(F[:, :, j], phi[j]) for j in range(m)]).transpose(1, 2, 0)

        Ber_variance = F * (1 - F)
        Ber_gap = (F - phi) ** 2
        Ber_score = Ber_variance / Ber_gap

        temp_gap = X_mean[np.arange(s), opt_solution][:, np.newaxis] - X_mean
        Obj_gap = temp_gap ** 2
        Obj_score = X_variance / Obj_gap

        for a in range(s):
            Ber_score[a, np.array(list(D1[a])), :] = 0
            if len(list(D2[a])) > 0:
                Obj_score[a, np.array(list(D2[a]))] = 0
                temp_gap[a, np.array(list(D2[a]))] = np.inf

            for i in D2[a]:
                jh = np.argmax(KL_table[a, i])
                Ber_score[a, i, [j for j in range(m) if j != jh]] = 0

        second_min_elements = np.sort(temp_gap, axis=1)[:, 1]

        result_gap = temp_gap.copy()
        result_gap[np.arange(s), opt_solution] = second_min_elements

        Obj_score[np.arange(s), opt_solution] = X_variance[np.arange(s), opt_solution] / result_gap[
            np.arange(s), opt_solution] ** 2

        omega = np.concatenate((Ber_score, Obj_score[:, :, np.newaxis]), axis=2)

        optimal_ratio = omega / np.sum(omega)

    return optimal_ratio
#######################################################################################################


###########################################FWS sampling rule###########################################
def compute_f_df_standard_bai(dist, s, k, m, tol_var, opt_solution, D1, D2, D3, phi, track_ratio, curr_ratio, X_mean, X_variance, F, r):
    if dist == "Bernoulli":
        X_mean = np.clip(X_mean, 0.001, 0.999)
    F = np.clip(F, 0.001, 0.999)

    KL_table = np.array([kl_divergence(F[:, :, j], phi[j]) for j in range(m)]).transpose(1, 2, 0)

    df = np.zeros((s, k, m+1, tol_var))
    f = np.full((s, k, m+1), np.inf)
    for a in range(s):
        for j in range(m):
            df[a, opt_solution[a], j, a*k*(m+1)+opt_solution[a]*(m+1)+j] = KL_table[a, opt_solution[a], j]
            f[a, opt_solution[a], j] = curr_ratio[a, opt_solution[a], j] * KL_table[a, opt_solution[a], j]

        for i in D1[a]:
            if dist == "Gaussian":
                term = (X_variance[a, opt_solution[a]] / curr_ratio[a, opt_solution[a], m] + X_variance[a, i] / curr_ratio[
                    a, i, m]) ** 2
                df[a, i, m, a * k * (m + 1) + i * (m + 1) + m] = ((X_mean[a, opt_solution[a]] - X_mean[a, i]) ** 2 / 2) * (
                            (X_variance[a, i] / curr_ratio[a, i, m] ** 2) / term)
                df[a, i, m, a * k * (m + 1) + opt_solution[a] * (m + 1) + m] = ((X_mean[a, opt_solution[a]] - X_mean[a, i]) ** 2 / 2) * (
                            (X_variance[a, opt_solution[a]] / curr_ratio[a, opt_solution[a], m] ** 2) / term)
            elif dist == "Bernoulli":
                mu_bar = (curr_ratio[a, opt_solution[a], m] * X_mean[a, opt_solution[a]] + curr_ratio[a, i, m] * X_mean[
                    a, i]) / (curr_ratio[a, opt_solution[a], m] + curr_ratio[a, i, m])
                df[a, i, m, a * k * (m + 1) + i * (m + 1) + m] = kl_divergence(X_mean[a,i], mu_bar)
                df[a, i, m, a * k * (m + 1) + opt_solution[a] * (m + 1) + m] = kl_divergence(X_mean[a, opt_solution[a]], mu_bar)
            f[a, i, m] = curr_ratio[a, i, m] * df[a, i, m, a * k * (m + 1) + i * (m + 1) + m] + curr_ratio[
                a, opt_solution[a], m] * df[a, i, m, a * k * (m + 1) + opt_solution[a] * (m + 1) + m]
        for i in D2[a]:
            jh = np.argmax(KL_table[a, i])
            df[a, i, jh, a * k * (m + 1) + i * (m + 1) + jh] = KL_table[a, i, jh]
            f[a, i, jh] = curr_ratio[a, i, jh] * KL_table[a, i, jh]

        for i in D3[a]:
            jh = np.argmax(KL_table[a, i])
            df[a, i, jh, a * k * (m + 1) + i * (m + 1) + jh] = KL_table[a, i, jh]
            f[a, i, jh] = curr_ratio[a, i, jh] * KL_table[a, i, jh]
            if dist == "Gaussian":
                term = (X_variance[a, opt_solution[a]] / curr_ratio[a, opt_solution[a], m] + X_variance[a, i] / curr_ratio[
                    a, i, m]) ** 2
                df[a, i, m, a * k * (m + 1) + i * (m + 1) + m] = ((X_mean[a, opt_solution[a]] - X_mean[a, i]) ** 2 / 2) * (
                            (X_variance[a, i] / curr_ratio[a, i, m] ** 2) / term)

                df[a, i, m, a * k * (m + 1) + opt_solution[a] * (m + 1) + m] = ((X_mean[a, opt_solution[a]] - X_mean[a, i]) ** 2 / 2) * (
                            (X_variance[a, opt_solution[a]] / curr_ratio[a, opt_solution[a], m] ** 2) / term)
            elif dist == "Bernoulli":
                mu_bar = (curr_ratio[a, opt_solution[a], m] * X_mean[a, opt_solution[a]] + curr_ratio[a, i, m] * X_mean[
                    a, i]) / (curr_ratio[a, opt_solution[a], m] + curr_ratio[a, i, m])
                df[a, i, m, a * k * (m + 1) + i * (m + 1) + m] = kl_divergence(X_mean[a,i], mu_bar)
                df[a, i, m, a * k * (m + 1) + opt_solution[a] * (m + 1) + m] = kl_divergence(X_mean[a, opt_solution[a]], mu_bar)
            f[a, i, m] = curr_ratio[a, i, m] * df[a, i, m, a * k * (m + 1) + i * (m + 1) + m] + curr_ratio[
                a, opt_solution[a], m] * df[a, i, m, a * k * (m + 1) + opt_solution[a] * (m + 1) + m]

    fmin = np.min(f)

    if r > np.finfo(np.float64).eps:
        indices = np.argwhere(f < fmin + r)
        fidx = [tuple(idx) for idx in indices]
    elif np.abs(r) < np.finfo(np.float64).eps:
        fidx = [tuple(np.argmin(f))]
    return f, df, fidx

def solveZeroSumGame(M_payoff, K, n_row, default):
    try:
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(K, lb=0, name="x")
        w = m.addVar(name="w")

        for j in range(n_row):
            m.addConstr(sum(M_payoff[j][k] * x[k] for k in range(K)) >= w)

        m.addConstr(sum(x[i] for i in range(K)) == 1)

        m.setObjective(w, gp.GRB.MAXIMIZE)
        m.optimize()

        f_success = m.status
        z = [x[i].x for i in range(K)]
    except (AttributeError, gp.GurobiError) as e:
        print(f"Error encountered: {e}. Returning default direction.")
        z = default
    return np.array(z)

def fws_ratio(dist, s, k, m, n0, feasibility, phi, F, X_mean, X_variance, opt_solution, alternative_count, curr_ratio, track_ratio):
    t = np.sum(alternative_count)
    tol_var = s*k*(m + 1)
    r = t**(-9/10)/tol_var

    D1, D2, D3 = sub_solutions(s, opt_solution, X_mean, feasibility)
    if np.any(opt_solution == -1) or any(not subset for subset in D1) \
            or is_complete_square(np.floor(t / tol_var)):
        direction = np.full((s, k, m+1), 1.0 / tol_var)
    else:
        f, df, fidx = compute_f_df_standard_bai(dist, s, k, m, tol_var, opt_solution, D1, D2, D3, phi, track_ratio, curr_ratio, X_mean, X_variance, F, r)
        if len(fidx) == 1:
            a, i, j = fidx[0]
            challenger_idx = np.argmax(df[a, i, j, : ])
            direction = np.zeros((s, k, m + 1))
            if a*k*(m+1)+opt_solution[a]*(m+1)+j == challenger_idx:
                direction[a, opt_solution[a],j] = 1
            else:
                direction[a, i, j] = 1
        else:
            flat_track_ratio = track_ratio.reshape(-1)
            flat_df = df.reshape(s*k*(m+1), tol_var)
            Sigma = np.eye(tol_var) - flat_track_ratio
            A = np.zeros((len(fidx), tol_var))
            for (e, idx) in enumerate(fidx):
                a, i, j = idx
                for l in range(tol_var):
                    A[e, l] = Sigma[l].reshape(-1, 1).T @ flat_df[a*k*(m+1)+i*(m+1)+j, :].reshape(-1, 1)
            direction = solveZeroSumGame(A, tol_var, len(fidx), np.full((s, k, m+1), 1.0 / tol_var)).reshape(s, k, m+1)
    track_ratio = ((t-1)/t) * track_ratio + (1/t) * direction
    return track_ratio
#######################################################################################################

#########################################Sequential sampling rule########################################
def seqsample(dist, s, k, m, feasibility, phi, F, X_mean, X_variance, opt_solution, alternative_count):
    t = np.sum(alternative_count)
    tol_var = s*k*(m + 1)

    D1, D2, D3 = sub_solutions(s, opt_solution, X_mean, feasibility)
    D2 = [D2[i].union(D3[i]) for i in range(s)] #For simplicity, union the D2 and D3 (an approximate solution)
    if np.any(opt_solution == -1) or any(not subset for subset in D1):
        min_indices = np.unravel_index(np.argmin(alternative_count, axis=None), alternative_count.shape)
        next_task, next_arm, next_cons = min_indices
    else:
        if dist == "Bernoulli":
            X_mean = np.clip(X_mean, 0.001, 0.999)
        F = np.clip(F, 0.001, 0.999)

        KL_table = np.array([kl_divergence(F[:, :, j], phi[j]) for j in range(m)]).transpose(1, 2, 0)

        curr_ratio = alternative_count / np.sum(alternative_count)
        constrain_score = curr_ratio[:, :, :m] * KL_table

        temp_gap = X_mean[np.arange(s), opt_solution][:, np.newaxis] - X_mean
        Obj_score = np.zeros((s, k))
        scaled_var = X_variance / curr_ratio[:, :, m]

        for a in range(s):
            for i in range(k):
                if dist == "Gaussian":
                    Obj_score[a, i] = temp_gap[a, i] ** 2 / (2 * (scaled_var[a, i] + scaled_var[a, opt_solution[a]]))
                elif dist == "Bernoulli":
                    mu_bar = (curr_ratio[a, opt_solution[a], m] * X_mean[a, opt_solution[a]] + curr_ratio[a, i, m] * X_mean[
                        a, i]) / (curr_ratio[a, opt_solution[a], m] + curr_ratio[a, i, m])
                    Obj_score[a, i] = curr_ratio[a, i, m] * kl_divergence(X_mean[a,i], mu_bar) + curr_ratio[
                    a, opt_solution[a], m] * kl_divergence(X_mean[a,opt_solution[a]], mu_bar)

        for a in range(s):
            constrain_score[a, np.array(list(D1[a])), :] = np.inf
            Obj_score[a, opt_solution[a]] = np.inf
            if len(list(D2[a])) > 0:
                Obj_score[a, np.array(list(D2[a]))] = np.inf

            for i in D2[a]:
                jh = np.argmax(KL_table[a, i])
                constrain_score[a, i, [j for j in range(m) if j != jh]] = np.inf

        score = np.concatenate((constrain_score, np.expand_dims(Obj_score,axis = -1)), axis=-1)
        next_task, next_arm, next_cons = np.unravel_index(np.argmin(score, axis=None), score.shape)

        if next_arm in D1[next_task] and next_cons == m:
            if dist == "Gaussian":
                factors = curr_ratio[next_task, :, m]**2 / X_variance[next_task]
                f_1 = factors[opt_solution[next_task]]
                f_2 = sum(factors[i] for i in D1[next_task])
                next_arm = opt_solution[next_task] if f_1 < f_2 else next_arm
            elif dist == "Bernoulli":
                initial_guess = np.ones((s, k, m + 1)) * (1 / (s * k * (m + 1)))
                ratio = np.reshape(fsolve(equations, initial_guess,
                                          args=(dist, s, k, m, D1, D2, phi, F, X_mean, X_variance, opt_solution)),
                                   (s, k, m + 1))
                direction = ratio[next_task, opt_solution[next_task], m] / (ratio[next_task, opt_solution[next_task], m] + ratio[next_task, next_arm, m])
                curr = curr_ratio[next_task, opt_solution[next_task], m] / (curr_ratio[next_task, opt_solution[next_task], m] + curr_ratio[next_task, next_arm, m])

                next_arm = opt_solution[next_task] if curr < direction else next_arm
    return next_task, next_arm, next_cons
#######################################################################################################