import numpy as np
def sample_cbic(K, gap, d, sigma, lam, epsilon, c, hetero=False):
    # determine the number of samples for CBIC and epsilon-CBIC
    if hetero == False:
        c_1 = 0.1 * lam
    else:
        for i in range(K):
            if i == 0:
                c_1 = 0.1 * lam[i]
            else:
                c_1 = np.append(c_1, 0.1 * lam[i])
        c_1 = np.max(c_1)
    c_2 = 4
    c_3 = (c_1 * c_2 + c_2 * np.power(sigma, 2) * d)/(np.power(sigma, 2) * d)
    C = np.power(c, 2) / c_3
    
    if hetero == False:
        # same variance for all arms
        phi_0 = 0.5
        if epsilon != 0:
            n_bic = np.power(sigma,2) * d * np.power(K,3)/ (C * phi_0 * np.power(gap+epsilon,2))
        else:
            n_bic = np.power(sigma,2) * d * np.power(K,3)/ (C * phi_0 * np.power(gap,2))
        
        n_bic = int(np.ceil(n_bic))
        epoch_start = int(np.ceil(np.log2(n_bic))) + 2
        return n_bic, epoch_start
    else:
        # with different variance for each arm
        phi_0 = lam 
        for i in range(K):
            if epsilon != 0:
                n_bic = np.power(sigma,2) * d * np.power(K,3)/ (C * phi_0[i] * np.power(gap+epsilon,2))
            else:
                n_bic = np.power(sigma,2) * d * np.power(K,3)/ (C * phi_0[i] * np.power(gap,2))
            n_bic = int(np.ceil(n_bic))
            if i == 0:
                n_bic_all = n_bic
            else:
                n_bic_all = np.append(n_bic_all, n_bic)
        epoch_start = int(np.ceil(np.log2(np.max(n_bic_all)))) + 2
        return np.max(n_bic_all), epoch_start


