import json
import numpy as np
import matplotlib.pyplot as plt
import helper as helper
from methods import all_methods

################ Arguments ################
data_path = "arena-hard" # options: ["arena-hard","alpacaeval","mt-bench", "flickr30k", "bingo", "medi_qa"]

num_test_samples = 400 # realization pool size
num_iterations = 100 # number of realizations

num_metrics = 10 # number of weak judges

use_pseudo_annotations = False # set True for hyperparameter search

# hyperparameter pairs (eps_loss, eps_draw) for the algorithm. should be a list of pairs
# for hyperparameter search:
# all_epsilons = [(e1/100,e2/100) for e1 in range(5,100,5) for e2 in range(e1,100-e1,5)]
# running a single epsilon pair:
all_epsilons = [(0.2,0.4)]

###########################################

############# Additional Info #############

# Both in judge and weak judge files, outcomes are represented by:
#   - Loss: -1
#   - Draw:  0
#   - Win:  +1
#
# Shape of judge_scores is (num_models, num_queries)
# Shape of weak_judge_scores is (num_models, num_queries, num_weak_judges)
#
# Results are saved at '{data_path}/results.json'

###########################################

judge_data = json.load(open(f"{data_path}/judge.json", 'r'))
judge_scores = []
model_names = []

for model_name,model_scores in judge_data.items():
    if model_scores is None:
        baseline_model = model_name
    else:
        model_names.append(model_name)
        judge_scores.append(model_scores)

judge_scores = np.array(judge_scores)


weak_judge_data = json.load(open(f"{data_path}/weak_judge.json", 'r'))
weak_judge_scores = []
for model_name,model_scores in weak_judge_data.items():
    if model_scores is None:
        pass
    else:
        weak_judge_scores.append(model_scores)

weak_judge_scores = np.array(weak_judge_scores)

helper._init_globals(judge_scores, weak_judge_scores)

print("Models:",model_names + [baseline_model])

num_models = len(model_names) # baseline model not included in `num_models`

# pseudo-annotations for epsilon search
if use_pseudo_annotations:
    judge_scores = (weak_judge_scores[:,:,:num_metrics].mean(axis=2) > 1/3).astype(int) - (weak_judge_scores[:,:,:num_metrics].mean(axis=2) < -1/3).astype(int)


def plot_results(best_model, random_results, bradley_terry_results, confidence_results, uncertainty_results, most_draws_results, all_llm_selector_results):
    # Plot identification probability
    random_identification_prob = (random_results[:,:,0] == best_model[:,None]).mean(axis=0)
    bradley_terry_identification_prob = (bradley_terry_results[:,:,0] == best_model[:,None]).mean(axis=0)
    confidence_identification_prob = (confidence_results[:,:,0] == best_model[:,None]).mean(axis=0)
    uncertainty_identification_prob = (uncertainty_results[:,:,0] == best_model[:,None]).mean(axis=0)
    most_draws_identification_prob = (most_draws_results[:,:,0] == best_model[:,None]).mean(axis=0)
    all_llm_selector_identification_prob = {
        eps: (results[:,:,0] == best_model[:,None]).mean(axis=0) for eps, results in all_llm_selector_results.items()
    }
    plt.plot(random_identification_prob, label='random')
    plt.plot(bradley_terry_identification_prob, label='bradley_terry')
    plt.plot(confidence_identification_prob, label='confidence')
    plt.plot(uncertainty_identification_prob, label='uncertainty')
    plt.plot(most_draws_identification_prob, label='most_draws')
    for eps, results in all_llm_selector_identification_prob.items():
        plt.plot(results, label=f'llm_selector (eps={eps})')
    plt.ylim((-0.01,1.01))
    plt.xlabel('budget')
    plt.ylabel('identification prob')
    plt.legend()
    plt.savefig(f"{data_path}/plot.png")
    # plt.show()

    best_model_wr = random_results[:,-1,1]  # win rate of the best model

    # Output label efficiency
    for d in [0., 0.01, 0.025, 0.05]:
        print(f"\n##### LABEL EFFICIENCY (delta={d}) #####")
        random_identification_prob_d = (random_results[:,:,1] >= best_model_wr[:,None]-d-1e-4).mean(axis=0)
        bradley_terry_identification_prob_d = (bradley_terry_results[:,:,1] >= best_model_wr[:,None]-d-1e-4).mean(axis=0)
        confidence_identification_prob_d = (confidence_results[:,:,1] >= best_model_wr[:,None]-d-1e-4).mean(axis=0)
        uncertainty_identification_prob_d = (uncertainty_results[:,:,1] >= best_model_wr[:,None]-d-1e-4).mean(axis=0)
        most_draws_identification_prob_d = (most_draws_results[:,:,1] >= best_model_wr[:,None]-d-1e-4).mean(axis=0)
        all_llm_selector_identification_prob_d = {
            eps: (results[:,:,1] >= best_model_wr[:,None]-d-1e-4).mean(axis=0) for eps, results in all_llm_selector_results.items()
        }
        print_label_efficiency(random_identification_prob_d, bradley_terry_identification_prob_d, confidence_identification_prob_d, uncertainty_identification_prob_d, most_draws_identification_prob_d, all_llm_selector_identification_prob_d)

    # Output accuracy gap
    acc_gap_eps = list(all_llm_selector_results.keys())[0] # just getting the first llm selector if there are multiple
    print_accuracy_gap(all_llm_selector_identification_prob[acc_gap_eps], best_model_wr, random_results, bradley_terry_results, confidence_results, uncertainty_results, most_draws_results, all_llm_selector_results[acc_gap_eps])

    # Save results
    results_out = {}
    for eps, results in all_llm_selector_results.items():
        results_out[f'llm_selector (eps={eps})'] = results.tolist()
    results_out['random'] = random_results.tolist()
    results_out['bradley_terry'] = bradley_terry_results.tolist()
    results_out['confidence'] = confidence_results.tolist()
    results_out['uncertainty'] = uncertainty_results.tolist()
    results_out['most_draws'] = most_draws_results.tolist()


    json.dump(results_out, open(f'{data_path}/results.json', 'w'))


def find_index_where_all_after_greater_than(arr, threshold=0.9):
    arr = np.array(arr)
    for i in range(len(arr)-1):
        if arr[i] >= threshold - 0.01 and np.all(arr[i:] >= threshold - 0.02):
            return i+1
    return len(arr)-1


def print_label_efficiency(random_identification_prob, bradley_terry_identification_prob, confidence_identification_prob, uncertainty_identification_prob, most_draws_identification_prob, all_llm_selector_identification_prob):
    print()
    print("##### LABEL EFFICIENCY #####")

    
    c = 1 # confidence threshold

    random_idx = find_index_where_all_after_greater_than(random_identification_prob, c)
    bradley_terry_idx = find_index_where_all_after_greater_than(bradley_terry_identification_prob, c)
    confidence_idx = find_index_where_all_after_greater_than(confidence_identification_prob, c)
    uncertainty_idx = find_index_where_all_after_greater_than(uncertainty_identification_prob, c)
    most_draws_idx = find_index_where_all_after_greater_than(most_draws_identification_prob, c)

    min_baseline_idx = min(random_idx, bradley_terry_idx, confidence_idx, uncertainty_idx, most_draws_idx)

    for eps, results in all_llm_selector_identification_prob.items():
        idx = find_index_where_all_after_greater_than(results, c)
        print(f"\tLLM Selector (eps={eps}): {100 * ((min_baseline_idx - idx) / max(min_baseline_idx,idx)):.2f}%")
    print()


def print_accuracy_gap(llm_selector_identification_prob, best_model_wr, random_results, bradley_terry_results, confidence_results, uncertainty_results, most_draws_results, llm_selector_results):
    print()
    print("##### 95-TH PERCENTILE WIN RATE GAP #####")

    for c in [0.7, 0.8, 0.9, 1.0]:
        # Find the index where LLM Selector reaches c identification probability
        c_idx = find_index_where_all_after_greater_than(llm_selector_identification_prob, c)
        print(f"\tConfidence = {c} (Budget = {c_idx + 1}):")
        
        # Find 95th percentile win rate gap
        random_gap = np.percentile(best_model_wr - random_results[:,c_idx, 1], 95, method='nearest')
        bradley_terry_gap = np.percentile(best_model_wr - bradley_terry_results[:,c_idx, 1], 95, method='nearest')
        confidence_gap = np.percentile(best_model_wr - confidence_results[:,c_idx, 1], 95, method='nearest')
        uncertainty_gap = np.percentile(best_model_wr - uncertainty_results[:,c_idx, 1], 95, method='nearest')
        most_draws_gap = np.percentile(best_model_wr - most_draws_results[:,c_idx, 1], 95, method='nearest')
        llm_selector_gap = np.percentile(best_model_wr - llm_selector_results[:,c_idx, 1], 95, method='nearest')
        print(f"\t\tLLM Selector: {llm_selector_gap:.3f}")
        print(f"\t\tRandom: {random_gap:.3f}")
        print(f"\t\tBradley-Terry: {bradley_terry_gap:.3f}")
        print(f"\t\tConfidence: {confidence_gap:.3f}")
        print(f"\t\tUncertainty: {uncertainty_gap:.3f}")
        print(f"\t\tMost Draws: {most_draws_gap:.3f}")
        print()

if __name__ == "__main__":
    random_results, best_model = all_methods["random"](num_iterations=num_iterations, num_test_samples=num_test_samples, num_models=num_models)
    all_llm_selector_results = all_methods["llm_selector"](num_iterations=num_iterations, num_test_samples=num_test_samples, num_models=num_models, num_metrics=num_metrics, best_model=best_model, 
                                                        all_epsilons=all_epsilons, use_pseudo_annotations=use_pseudo_annotations)
    bradley_terry_results = all_methods["bradley_terry"](num_iterations=num_iterations, num_test_samples=num_test_samples, num_models=num_models, num_metrics=num_metrics, best_model=best_model)
    confidence_results = all_methods["confidence"](num_iterations=num_iterations, num_test_samples=num_test_samples, num_models=num_models, num_metrics=num_metrics, best_model=best_model)
    uncertainty_results = all_methods["uncertainty"](num_iterations=num_iterations, num_test_samples=num_test_samples, num_models=num_models, num_metrics=num_metrics, best_model=best_model)
    most_draws_results = all_methods["most_draws"](num_iterations=num_iterations, num_test_samples=num_test_samples, num_models=num_models, num_metrics=num_metrics, best_model=best_model)

    plot_results(best_model, random_results, bradley_terry_results, confidence_results, uncertainty_results, most_draws_results, all_llm_selector_results)