import numpy as np
import random
from tqdm import tqdm
from scipy.stats import entropy
from scipy.special import softmax
import helper
from functools import partial

# Fast implementation of Bradley-Terry softmax
# No need to fit the parameters since we only need their softmax -> Can be estimated using win/loss ratios
def bt_softmax_fast(wins, losses, draws, alpha=1e-8, beta=1e-8):
    if len(wins.shape) == 1:
        wins = wins[None, :]
        losses = losses[None, :]
        draws = draws[None, :]

    T = wins[0][0] + draws[0][0] + losses[0][0] # number of annotated queries. It would be ideal to check if this is the same for all models
    B = wins.shape[0] # batch size
    wins_adj = np.concatenate((2*wins + draws,np.ones((B,1))*T), axis=1)
    losses_adj = np.concatenate((2*losses + draws,np.ones((B,1))*T), axis=1)
    ratio = (wins_adj + alpha) / (losses_adj + beta)
    return ratio / ratio.sum(axis=1, keepdims=True)

def _bradley_terry_one_iter(s, num_test_samples, num_metrics, num_models):
    np.random.seed(s); random.seed(s)

    judge_scores      = helper.JUDGE_SCORES
    weak_judge_scores = helper.WEAK_JUDGE_SCORES

    unlabeled = random.sample(range(judge_scores.shape[1]), num_test_samples)
    labeled   = []
    all_models_wr = (np.concatenate((helper.JUDGE_SCORES[:,unlabeled],np.zeros((1,num_test_samples))), axis=0).mean(axis=1) + 1) / 2  # convert to win rate
    result    = []

    metric_w = np.ones(num_metrics) / num_metrics

    observed_wins = np.zeros(num_models)
    observed_losses = np.zeros(num_models)
    observed_draws = np.zeros(num_models)

    for i in range(num_test_samples):
        entropies_est = 0.0
        for m in range(num_metrics):
            results_est = weak_judge_scores[:, unlabeled, m]
            wins_est = observed_wins[None, :] + (results_est > 0).T
            losses_est = observed_losses[None, :] + (results_est < 0).T
            draws_est = observed_draws[None, :] + (results_est == 0).T
            posterior_est = bt_softmax_fast(wins_est, losses_est, draws_est)
            entropies_est += metric_w[m] * entropy(posterior_est, axis=1)

        min_idx = unlabeled[entropies_est.argmin()]
        unlabeled.remove(min_idx)
        labeled.append(min_idx)

        results_judge = judge_scores[:, min_idx]
        observed_wins += (results_judge > 0)
        observed_losses += (results_judge < 0)
        observed_draws += (results_judge == 0)

        scores = np.concatenate(
            (judge_scores[:, labeled], np.zeros((1, i + 1))), axis=0
        ).mean(axis=1)
        
        chosen_model = np.random.choice(np.arange(num_models+1)[scores == scores.max()])
        chosen_model_wr = all_models_wr[chosen_model]

        result.append([chosen_model, chosen_model_wr])

    return result


def bradley_terry(num_iterations, num_test_samples, num_models, num_metrics, best_model, use_pseudo_annotations=False):
    worker = partial(_bradley_terry_one_iter,
                     num_test_samples=num_test_samples,
                     num_metrics=num_metrics,
                     num_models=num_models)
    results = helper._run_parallel(
        num_iterations,
        worker,
        desc="bradley_terry",
    )                   # shape: (num_iterations, num_test_samples, 2)
    
    return results
