import numpy as np
import random
from environment import Environment
import concurrent.futures
import time
from get_api_mu import answer_multiple_choice_questions
from C2MAB_V import C2MAB_V
from C2MAB_V_direct import C2MAB_V_direct

def parallel_C2MAB_V(L, K, T, C, log_ind, CB_coefficient, LCB_coefficient, epsilon_num):
    np.random.seed(log_ind)
    random.seed(log_ind)
    questions = []   # from SciQ
    mu = answer_multiple_choice_questions(questions)
    cost = []

    env = Environment(L, C, mu, cost)
    Disjunctive_problem = C2MAB_V(K, env, T, CB_coefficient, log_ind, LCB_coefficient)
    reward_t, violation_t, starttime = Disjunctive_problem.run()
    runtime = time.time() - starttime

    np.savez('C2MAB_V_AWC_' + str(K) + '_' + str(L) + '_' + str(C) + '_seed' + str(log_ind) + '_times' + str(T) + '_UCB' + str(CB_coefficient) + '_LCB' + str(LCB_coefficient), violation_t, reward_t, runtime)


def parallel_C2MAB_V_direct(L, K, T, C, log_ind, CB_coefficient, LCB_coefficient, epsilon_num):
    np.random.seed(log_ind)
    random.seed(log_ind)
    questions = []  # from SciQ
    mu = answer_multiple_choice_questions(questions)
    cost = []

    env = Environment(L, C, mu, cost)
    C2MAB_V_direct_alg = C2MAB_V_direct(K, env, T, CB_coefficient, log_ind, LCB_coefficient)
    reward_t, violation_t, starttime = C2MAB_V_direct_alg.run()
    runtime = time.time() - starttime

    np.savez('C2MAB_V_direct_AWC_' + str(K) + '_' + str(L) + '_' + str(C) + '_seed' + str(log_ind) + '_times' + str(T) + '_UCB' + str(CB_coefficient) + '_LCB' + str(LCB_coefficient), violation_t, reward_t, runtime)


if __name__ == "__main__":
    n_low = 0
    n_trials = 10
    choose_seed = [503, 507, 600, 601, 602, 603, 604, 607, 608, 609]
    m = 10000
    with concurrent.futures.ProcessPoolExecutor(max_workers=6) as executor:
        for i in range(n_low, n_low + n_trials):
            executor.submit(parallel_C2MAB_V, L=9, K=4, T=m, C=0.45, log_ind=choose_seed[i], CB_coefficient=0.3,LCB_coefficient=0.01, epsilon_num=8)
            executor.submit(parallel_C2MAB_V, L=9, K=4, T=m, C=0.45, log_ind=choose_seed[i], CB_coefficient=0.3, LCB_coefficient=0.05, epsilon_num=8)
            executor.submit(parallel_C2MAB_V, L=9, K=4, T=m, C=0.45, log_ind=choose_seed[i], CB_coefficient=1,LCB_coefficient=0.01, epsilon_num=8)
            executor.submit(parallel_C2MAB_V, L=9, K=4, T=m, C=0.45, log_ind=choose_seed[i], CB_coefficient=1, LCB_coefficient=0.05, epsilon_num=8)
            executor.submit(parallel_C2MAB_V_direct, L=9, K=4, T=m, C=0.45, log_ind=choose_seed[i], CB_coefficient=0.3,LCB_coefficient=0.05,epsilon_num=8)