import numpy as np
import random
from tqdm import tqdm
import helper
from functools import partial



def _most_draws_one_iter(s, num_test_samples, num_metrics, num_models):
    np.random.seed(s)
    random.seed(s)
    result = []
    unlabeled = random.sample(list(range(helper.JUDGE_SCORES.shape[1])), num_test_samples)
    labeled = []
    all_models_wr = (np.concatenate((helper.JUDGE_SCORES[:,unlabeled],np.zeros((1,num_test_samples))), axis=0).mean(axis=1) + 1) / 2  # convert to win rate

    weak_judge_ensemble = (helper.WEAK_JUDGE_SCORES[:,unlabeled,:num_metrics].mean(axis=2) > 1/3).astype(int) - (helper.WEAK_JUDGE_SCORES[:,unlabeled,:num_metrics].mean(axis=2) < -1/3).astype(int)

    draw_counts = (weak_judge_ensemble==0).sum(axis=0)
    draw_counts_idx = draw_counts.argsort()

    for i in range(num_test_samples):
        labeled.append(unlabeled[draw_counts_idx[-i]])
        
        scores = np.concatenate((helper.JUDGE_SCORES[:,labeled],np.zeros((1,i+1))), axis=0).mean(axis=1)
        chosen_model = np.random.choice(np.arange(num_models+1)[scores == scores.max()])
        chosen_model_wr = all_models_wr[chosen_model]

        result.append([chosen_model, chosen_model_wr])
        
    return result  


def most_draws(num_iterations, num_test_samples, num_models, num_metrics, best_model):
    worker = partial(_most_draws_one_iter,
                     num_test_samples=num_test_samples,
                     num_metrics=num_metrics,
                     num_models=num_models)
    raw = helper._run_parallel(
        num_iterations,
        worker,
        desc="most_draws",
    )
    most_draws_result = np.array(raw)
    
    return most_draws_result