import numpy as np
import random
from tqdm import tqdm
from scipy.stats import entropy
import helper
from functools import partial



def _llm_selector_one_iter(s, num_test_samples, num_metrics, num_models, eps1, eps2):
    np.random.seed(s); random.seed(s)

    judge_scores      = helper.JUDGE_SCORES
    weak_judge_scores = helper.WEAK_JUDGE_SCORES

    unlabeled = random.sample(range(judge_scores.shape[1]), num_test_samples)
    labeled   = []
    all_models_wr = (np.concatenate((helper.JUDGE_SCORES[:,unlabeled],np.zeros((1,num_test_samples))), axis=0).mean(axis=1) + 1) / 2  # convert to win rate

    posterior = np.ones(num_models + 1) / (num_models + 1)
    result    = []

    metric_w = np.ones(num_metrics) / num_metrics

    for i in range(num_test_samples):
        entropies_est = 0.0
        for m in range(num_metrics):
            game_zero = np.concatenate(
                (weak_judge_scores[:, unlabeled, m],
                 np.zeros((1, num_test_samples - i))),
                axis=0
            )
            posterior_est = posterior[:, None] * eps1**(game_zero < 0) \
                                             * eps2**(game_zero == 0) \
                                             * (1 - eps1 - eps2)**(game_zero > 0)
            posterior_est /= posterior_est.sum(axis=0)
            entropies_est += metric_w[m] * entropy(posterior_est, axis=0)

        min_idx = unlabeled[entropies_est.argmin()]
        unlabeled.remove(min_idx)
        labeled.append(min_idx)

        game_zero = np.concatenate((judge_scores[:, min_idx], np.zeros(1)))
        posterior *= eps1**(game_zero < 0) * eps2**(game_zero == 0) \
                   * (1 - eps1 - eps2)**(game_zero > 0)
        posterior /= posterior.sum()

        scores = np.concatenate(
            (judge_scores[:, labeled], np.zeros((1, i + 1))), axis=0
        ).mean(axis=1)
        chosen_model = np.random.choice(np.arange(num_models+1)[scores == scores.max()])
        chosen_model_wr = all_models_wr[chosen_model]

        result.append([chosen_model, chosen_model_wr])

    return result


def llm_selector(num_iterations, num_test_samples, num_models, num_metrics, best_model, all_epsilons, use_pseudo_annotations=False):
    out = {}

    for eps1, eps2 in all_epsilons:
        # ===== parallel section =====
        worker = partial(_llm_selector_one_iter,
                         num_test_samples=num_test_samples,
                         num_metrics=num_metrics,
                         num_models=num_models,
                         eps1=eps1, eps2=eps2)
        results = helper._run_parallel(
            num_iterations,
            worker,
            desc=f"llm_selector eps=({eps1},{eps2})",
        )                   # shape: (num_iterations, num_test_samples, 2)

        out[(eps1, eps2)] = results

        if use_pseudo_annotations:
            acc = (results == best_model[:, None]).mean(axis=0)  # post-process
            print(f"LLM Selector (eps=({eps1},{eps2})) {acc.mean():.3f}")

    return out