import warnings
import numpy as np
from util import sub_solutions, undersampled, track, esr_optratio, asr_equations, seqsample, fws_ratio

warnings.filterwarnings("ignore")

def USR(alternative_count):
    min_indices = np.unravel_index(np.argmin(alternative_count, axis=None), alternative_count.shape)
    next_task, next_arm, next_cons = min_indices
    return next_task, next_arm, next_cons

def ESR(dist, s, k, m, n0, feasibility, phi, F, X_mean, X_variance, opt_solution, alternative_count):
    total_sample = np.sum(alternative_count)
    undersample = undersampled(s, k, m, n0, total_sample, alternative_count)
    if np.any(alternative_count < n0):
        min_indices = np.unravel_index(np.argmin(alternative_count, axis=None), alternative_count.shape)
        next_task, next_arm, next_cons = min_indices
    elif np.any(undersample):
        under_count = np.where(undersample, alternative_count, np.inf)
        min_indices = np.unravel_index(np.argmin(under_count, axis=None), under_count.shape)
        next_task, next_arm, next_cons = min_indices
    else:
        optimal_ratio = esr_optratio(dist, s, k, m, feasibility, phi, F, X_mean, X_variance, opt_solution, alternative_count)
        next_task, next_arm, next_cons = track(optimal_ratio, alternative_count)
    return next_task, next_arm, next_cons

def ASR(s, k, m, n0, feasibility, phi, F, X_mean, X_variance, opt_solution, alternative_count):
    total_sample = np.sum(alternative_count)
    undersample = undersampled(s, k, m, n0, total_sample, alternative_count)
    if np.any(alternative_count < n0):
        min_indices = np.unravel_index(np.argmin(alternative_count, axis=None), alternative_count.shape)
        next_task, next_arm, next_cons = min_indices
    elif np.any(undersample):
        under_count = np.where(undersample, alternative_count, np.inf)
        min_indices = np.unravel_index(np.argmin(under_count, axis=None), under_count.shape)
        next_task, next_arm, next_cons = min_indices
    else:
        optimal_ratio = asr_equations(s, k, m, feasibility, phi, F, X_mean, X_variance, opt_solution, alternative_count)
        next_task, next_arm, next_cons = track(optimal_ratio, alternative_count)
    return next_task, next_arm, next_cons

def SEQSR(dist, s, k, m, n0, feasibility, phi, F, X_mean, X_variance, opt_solution, alternative_count):
    total_sample = np.sum(alternative_count)
    undersample = undersampled(s, k, m, n0, total_sample, alternative_count)
    if np.any(alternative_count < n0):
        min_indices = np.unravel_index(np.argmin(alternative_count, axis=None), alternative_count.shape)
        next_task, next_arm, next_cons = min_indices
    elif np.any(undersample):
        under_count = np.where(undersample, alternative_count, np.inf)
        min_indices = np.unravel_index(np.argmin(under_count, axis=None), under_count.shape)
        next_task, next_arm, next_cons = min_indices
    else:
        next_task, next_arm, next_cons = seqsample(dist, s, k, m, feasibility, phi, F, X_mean, X_variance, opt_solution, alternative_count)
    return next_task, next_arm, next_cons

def FWSR(dist, s, k, m, n0, feasibility, phi, F, X_mean, X_variance, opt_solution, alternative_count, fws_hist):
    if np.any(alternative_count < n0):
        min_indices = np.unravel_index(np.argmin(alternative_count, axis=None), alternative_count.shape)
        next_task, next_arm, next_cons = min_indices
        track_ratio = np.full((s, k, m+1), 1.0 / (s*k*(m+1)))
    else:
        curr_ratio = alternative_count / np.sum(alternative_count)
        track_ratio = fws_ratio(dist, s, k, m, n0, feasibility, phi, F, X_mean, X_variance, opt_solution, alternative_count, curr_ratio, fws_hist)
        ratio_gap = track_ratio / curr_ratio
        max_indices = np.unravel_index(np.argmax(ratio_gap, axis=None), ratio_gap.shape)
        next_task, next_arm, next_cons = max_indices
    return next_task, next_arm, next_cons, track_ratio