import numpy as np
import random
from tqdm import tqdm
import helper
from functools import partial


def _random_one_iter(s, num_test_samples, num_models):
    np.random.seed(s)
    random.seed(s)
    result = [] # pairs of chosen models id and their win rates
    unlabeled = random.sample(list(range(helper.JUDGE_SCORES.shape[1])), num_test_samples)
    labeled = []
    all_models_wr = (np.concatenate((helper.JUDGE_SCORES[:,unlabeled],np.zeros((1,num_test_samples))), axis=0).mean(axis=1) + 1) / 2  # convert to win rate

    for i in range(num_test_samples):
        idx = random.sample(unlabeled, k=1)[0]

        unlabeled.remove(idx)
        labeled.append(idx)
        
        scores = np.concatenate((helper.JUDGE_SCORES[:,labeled],np.zeros((1,i+1))), axis=0).mean(axis=1)

        chosen_model = np.random.choice(np.arange(num_models+1)[scores == scores.max()])
        chosen_model_wr = all_models_wr[chosen_model]

        result.append([chosen_model, chosen_model_wr])

    return result
    

def random_method(num_iterations, num_test_samples, num_models):
    
    worker = partial(_random_one_iter,
                     num_test_samples=num_test_samples,
                     num_models=num_models)
    
    random_results = helper._run_parallel(
        num_iterations,
        worker,
        desc="random_method",
    )
    random_results = np.array(random_results)
    best_model     = random_results[:, -1, 0]

    return random_results, best_model