import numpy as np
import random
from tqdm import tqdm
from scipy.stats import entropy
import helper
from functools import partial



def _confidence_one_iter(s, num_test_samples, num_metrics, num_models):
    np.random.seed(s)
    random.seed(s)
    result = []
    unlabeled = random.sample(list(range(helper.JUDGE_SCORES.shape[1])), num_test_samples)
    labeled = []
    all_models_wr = (np.concatenate((helper.JUDGE_SCORES[:,unlabeled],np.zeros((1,num_test_samples))), axis=0).mean(axis=1) + 1) / 2  # convert to win rate

    weak_judge_ensemble = (helper.WEAK_JUDGE_SCORES[:,unlabeled,:num_metrics].mean(axis=2) > 1/3).astype(int) - (helper.WEAK_JUDGE_SCORES[:,unlabeled,:num_metrics].mean(axis=2) < -1/3).astype(int)
    
    weak_predictions = np.zeros((num_test_samples,3))
    weak_predictions[:,0] = (weak_judge_ensemble == 1).sum(axis=0)
    weak_predictions[:,1] = (weak_judge_ensemble == -1).sum(axis=0)
    weak_predictions[:,2] = (weak_judge_ensemble == 0).sum(axis=0)
    weak_predictions /= weak_predictions.sum(axis=1)[:,None]

    entropies = entropy(weak_predictions, axis=1)
    entropies_idx = entropies.argsort()

    for i in range(num_test_samples):
        labeled.append(unlabeled[entropies_idx[i]])
        
        scores = np.concatenate((helper.JUDGE_SCORES[:,labeled],np.zeros((1,i+1))), axis=0).mean(axis=1)
        chosen_model = np.random.choice(np.arange(num_models+1)[scores == scores.max()])
        chosen_model_wr = all_models_wr[chosen_model]

        result.append([chosen_model, chosen_model_wr])
        
    return result


def confidence(num_iterations, num_test_samples, num_models, num_metrics, best_model):
    worker = partial(_confidence_one_iter,
                     num_test_samples=num_test_samples,
                     num_metrics=num_metrics,
                     num_models=num_models)
    raw = helper._run_parallel(
        num_iterations,
        worker,
        desc="confidence",
    )
    confidence_result = np.array(raw)
    
    return confidence_result