import numpy as np

def feasibility_check(M, K, g, b):

    feasibility = np.ones((M, K), dtype=int)
    for j in range(M):
        for i in range(K):
            if g[j, i] > b:
                feasibility[j, i] = 0
    return feasibility

def optimal_solution(M, f, feasibility):

    opt_solution = np.zeros(M, dtype=int)
    for j in range(M):
        if np.all(feasibility[j, :] == 0):
            opt_solution[j] = -1
        else:
            feasible_means = np.where(feasibility[j, :] == 1, f[j,:], -np.inf)
            opt_solution[j] = np.argmax(feasible_means)
    return opt_solution

def solutions_class(M, opt_solution, f, feasibility):

    E1 = [set() for _ in range(M)]
    E2 = [set() for _ in range(M)]

    for j in range(M):
        suboptimal_indices = np.where((feasibility[j] == 1) & (np.arange(f.shape[1]) != opt_solution[j]))[0]
        E1[j] = set(suboptimal_indices)

        infeasible_greater_indices = np.where((feasibility[j] == 0) & (f[j] > f[j, opt_solution[j]]))[0]
        infeasible_smaller_indices = np.where((feasibility[j] == 0) & (f[j] <= f[j, opt_solution[j]]))[0]
        E2[j] = set(infeasible_greater_indices) | set(infeasible_smaller_indices) | {opt_solution[j]}

    return E1, E2

def index_choice(lambda_):

    eta = 0.05
    valid_mask = lambda_ > eta
    valid_indices = np.flatnonzero(valid_mask)
    h = np.random.choice(valid_indices)
    return h

def gradient_func(lambda_, K, M, d, b, f, g, opt_solution, feasibility, phi, Z_mat, Temp, design_var):

    E1, E2 = solutions_class(M, opt_solution, f, feasibility)
    S_score = np.zeros((M, K, d))

    for j in range(M):
        for i in range(K):
            if i in E1[j]:
                Z_mat[j, i] = (phi[j, i] - phi[j, opt_solution[j]]).T @ Temp
            for h in range(d):
                if i in E1[j]:
                    S_score[j, i, h] = (design_var[h] * Z_mat[j, i, h]**2) / (f[j, i] - f[j, opt_solution[j]])**2
                else:
                    S_score[j, i, h] = (design_var[h] * Z_mat[j, i, h]**2) / (b - g[j, i])**2

    gradient = np.zeros(K*M)
    lambda_reshaped = lambda_.reshape(M, K)
    sqrt_denominator = np.sum(S_score * lambda_reshaped[..., np.newaxis], axis=(0, 1))

    for i in range(K):
        for j in range(M):
            idx = j * K + i
            gradient[idx] = -0.5 * np.sum(S_score[j, i, :] / sqrt_denominator)
    return gradient, S_score

def compute_descent_direction(lambda_, gradient, h):

    n = len(lambda_)
    best_d = None
    best_value = np.inf
    grad_T_d = None
    step_max = None
    best_grad_T_d = None

    for i in range(n):
        if i == h:
            continue

        for direction_type in [1, -1]:
            if direction_type == 1:
                d = np.zeros(n)
                d[i] = 1
                d[h] = -1
            else:
                if lambda_[i] <= 0:
                    continue
                d = np.zeros(n)
                d[h] = 1
                d[i] = -1
            grad_T_d = np.dot(gradient, d)
            if grad_T_d >= 0:
                s_max = 0
            else:
                if direction_type == 1:
                    s_max = lambda_[h]
                else:
                    s_max = lambda_[i]

            current_value = s_max * grad_T_d
            if current_value < best_value:
                best_value = current_value
                best_d = d.copy()
                step_max = s_max
                best_grad_T_d = grad_T_d

    return best_d, step_max, best_grad_T_d, best_value

def descent_condition(t, grad_T_d, best_value):

    kappa = 0.01
    if grad_T_d < max(-kappa, -(np.log(t)/t)**(1/4)) and best_value < max(-kappa, -(np.log(t)/t)**(1/2)):
        return True
    else:
        return False

def compute_objective(lambda_, S_score):

    M, K, d = S_score.shape
    lambda_mat = lambda_.reshape(M, K, 1)
    weighted_sum = np.sum(lambda_mat * S_score, axis=(0, 1))
    sqrt_terms = np.sqrt(np.maximum(weighted_sum, 1e-10))
    objective = -np.sum(sqrt_terms)
    omega = sqrt_terms / np.maximum(np.sum(sqrt_terms), 1e-10)
    omega = omega
    return objective, omega

def dual_sampling_ratio(t, lambda_, M, K, d, b, f, g, opt_solution, feasibility, phi, Z_mat, Temp, design_var, flag):

    h = index_choice(lambda_)
    gradient, S_score = gradient_func(lambda_, K, M, d, b, f, g, opt_solution, feasibility, phi, Z_mat, Temp, design_var)
    if flag:
        lambda_ = np.ones(K * M) / (K * M)
    else:
        best_d, s_max, grad_T_d, best_value = compute_descent_direction(lambda_, gradient, h)
        if descent_condition(t, grad_T_d, best_value):
            step_size = 0.01
            lambda_ = lambda_ + step_size * best_d
        else:
            lambda_ = lambda_
    _, ratio = compute_objective(lambda_, S_score)

    return lambda_, ratio

def kl_divergence(p, q):
    return p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))

def beta(t, delta):
    threshold = np.log((np.log(t)+1)/delta)
    return threshold

def indicator(a, b):
    if a <= b:
        return 1
    else:
        return 0

def is_complete_square(n):
    p = np.floor(np.sqrt(n))
    return p*p == n

def glrt(M, K, b, f, g, opt_solution, feasibility, reg_variance, diff_reg_variance):

    E1, E2 = solutions_class(M, opt_solution, f, feasibility)
    if np.any(opt_solution == -1):
        return 0

    V = np.zeros((M, K))
    for j in range(M):
        for i in range(K):
            if i in E1[j]:
                V[j, i] = (f[j, i] - f[j, opt_solution[j]])**2 / (2 * diff_reg_variance[j, i])
            else:
                V[j, i] =  ((b - g[j, i])**2 / (2 * reg_variance[j, i]))

    return np.min(V)

def undersampled(d, total_sample, alternative_count, design_indices):

    threshold1 = np.sqrt(total_sample) - d / 2
    indices = np.array(design_indices)
    values = alternative_count[indices[:, 0], indices[:, 1]]
    mask = values < threshold1
    undersampled_indices = indices[mask].tolist()

    return undersampled_indices


def seq_ofsr_sampling(EXP, K, M, f, alternative_count, phi, Temp, design_indices):
    opt_solution = np.argmax(f, axis=1)
    Score = np.zeros((M, K))
    m_indices, k_indices = zip(*design_indices)
    design_var = EXP.variance[m_indices, k_indices]
    design_count = alternative_count[m_indices, k_indices]
    variance = design_var / design_count
    Sigma = np.diag(variance)

    for j in range(M):
        for i in range(K):
            if i != opt_solution[j]:
                fevec = (phi[j, i] - phi[j, opt_solution[j]])
                opt_gap = (f[j, i] - f[j, opt_solution[j]])**2
                Score[j, i] = fevec.T @ Temp @ Sigma @ Temp.T @ fevec/ opt_gap
            else:
                Score[j, i] = -np.inf

    subScore = Score[m_indices, k_indices]
    min_ele = np.max(subScore)
    subScore[np.isinf(subScore)] = min_ele
    max_val = np.max(subScore)
    max_indices = np.where(subScore == max_val)[0]
    random_max_index = np.random.choice(max_indices)
    next_context, next_design = design_indices[random_max_index]
    return next_context, next_design

def afosr_sampling(EXP, K, M, b, f, g, alternative_count, phi, Temp, design_indices):
    opt_solution = np.argmax(f, axis=1)
    Score = np.zeros((M, K))
    m_indices, k_indices = zip(*design_indices)
    design_var = EXP.variance[m_indices, k_indices]
    design_count = alternative_count[m_indices, k_indices]
    variance = design_var / design_count
    Sigma = np.diag(variance)

    for j in range(M):
        for i in range(K):
            if i != opt_solution[j]:
                fevec = phi[j, i]
                fea_gap = (b - g[j, i])**2
                Score[j, i] = fevec.T @ Temp @ Sigma @ Temp.T @ fevec/ fea_gap
            else:
                Score[j, i] = -np.inf

    subScore = Score[m_indices, k_indices]
    min_ele = np.max(subScore)
    subScore[np.isinf(subScore)] = min_ele
    max_val = np.max(subScore)
    max_indices = np.where(subScore == max_val)[0]
    random_max_index = np.random.choice(max_indices)
    next_context, next_design = design_indices[random_max_index]
    return next_context, next_design