import numpy as np
import matplotlib.pyplot as plt

def undersampled(K, total_sample, alternative_count):
    threshold = np.sqrt(total_sample) - K / 2
    mask = alternative_count < threshold
    undersampled_indices = np.where(mask)[0].tolist()

    return undersampled_indices

def beta(t, delta, rho):
    threshold = np.log((np.log(t)+1)/((1 - rho) **3 * delta))
    return threshold

def glrt(k, mean_estimation, variance, rho, sampling_ratio):
    opt_solution = np.argmax(mean_estimation)
    V = np.zeros(k)
    V[opt_solution] = np.inf
    for i in range(k):
        mean_diff = (mean_estimation[i] - mean_estimation[opt_solution]) ** 2
        r = np.sqrt(variance[i] / variance[opt_solution])
        if i == opt_solution:
            continue
        elif sampling_ratio[opt_solution] > sampling_ratio[i]:
            var_term = 2 * ((variance[i] * (1 - rho**2)) / sampling_ratio[i] + (variance[opt_solution] * (rho*r - 1)**2) / sampling_ratio[opt_solution])
            V[i] = mean_diff/var_term
        else:
            var_term = 2 * ((variance[i] * (rho/r -1) ** 2) / sampling_ratio[i] + (
                        variance[opt_solution] * (1 - rho ** 2)) / sampling_ratio[opt_solution])
            V[i] = mean_diff/var_term
    return np.min(V)

def update_ratio(alternative_count):
    total_count = np.sum(alternative_count)
    if total_count > 0:
        ratio = alternative_count / total_count
    else:
        ratio = np.zeros_like(alternative_count, dtype=float)
    return ratio

def ratio_figure(ratio_hist, policy):
    plt.rcParams.update({
        'font.size': 20,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 20,
        'ytick.labelsize': 20,
        'legend.fontsize': 20,
        'font.family': 'DejaVu Sans'
    })

    fig, ax = plt.subplots(figsize=(10, 6), facecolor='white')
    ax.set_facecolor('white')

    _, k = ratio_hist.shape
    t = ratio_hist.shape[0]

    linestyles = ['-', '--', '-.']
    colors = ["#d62728", "#9467bd", "#1f77b4"] 

    for i in range(k):
        ax.plot(
            range(0, t, 30000),
            ratio_hist[::30000, i],
            label=fr'$\omega_{i + 1}$',
            marker='s',
            markersize=10,
            color=colors[i],
            linestyle=linestyles[i],
            linewidth=2.0
        )

    ax.axhline(y=1 / 3, color='black', linestyle='--', linewidth=1.2, label="OPT")

    ax.set_title("Change of Sampling Ratio Over Time")
    ax.set_xlabel("Number of Samples")
    ax.set_ylabel("Sampling Ratio")

    ax.minorticks_on()
    ax.set_axisbelow(True)
    ax.grid(which='major', color='gray', linestyle='-', linewidth=0.6, alpha=0.6)
    ax.grid(which='minor', color='lightgray', linestyle='--', linewidth=0.4, alpha=0.4)

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(1.0)

    ax.legend(frameon=False, loc='lower right')

    plt.tight_layout()

    plt.savefig(f"./Figure/sampling_ratio_{policy}.pdf", dpi=1200, bbox_inches='tight')
    plt.show()

