from tqdm import tqdm
import llm_blender
import torch
from datasets import load_dataset, concatenate_datasets, load_from_disk
import numpy as np
from transformers import AutoTokenizer
import datasets
import argparse
import os
import json

blender = llm_blender.Blender()
blender.loadranker("llm-blender/PairRM")


parser = argparse.ArgumentParser()

parser.add_argument("--output_1", type=str)
parser.add_argument("--output_2", type=str)

args = parser.parse_args()

op1 = args.output_1
op2 = args.output_2


@torch.no_grad()
def rank_responses(prompts, response1, response2):
    with torch.inference_mode():

        ds_size = len(prompts)
        candidates_texts = [[response1[idx]] + [response2[idx]]
                            for idx in range(ds_size)]
        rank = blender.rank(prompts, candidates_texts, return_scores=False)

        chosen_indices = np.argmin(rank, axis=1)
        rejected_indices = np.argmax(rank, axis=1)

        winrate = 1 - np.sum(chosen_indices) / ds_size

    return winrate


if __name__ == "__main__":

    output_dir_1 = "./alpaca/"+op1
    file_path_1 = os.path.join(output_dir_1, "data.json")
    with open(file_path_1, "r", encoding='utf-8') as json_file:
        result_1 = json.load(json_file)

    output_dir_2 = "./alpaca/"+op2
    file_path_2 = os.path.join(output_dir_2, "data.json")
    with open(file_path_2, "r", encoding='utf-8') as json_file:
        result_2 = json.load(json_file)

    prompts = [row["instruction"] for row in result_1]
    response1 = [row["output"] for row in result_1]
    response2 = [row["output"] for row in result_2]
    winrate = rank_responses(prompts, response1, response2)
    print(f"win rate of {op1} against {op2}:", winrate)
