from tqdm import tqdm
import llm_blender
import torch
from datasets import load_dataset, concatenate_datasets
import numpy as np
from transformers import AutoTokenizer
import datasets
import argparse

blender = llm_blender.Blender()
blender.loadranker("llm-blender/PairRM")


parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str)
parser.add_argument("--dataset_left", type=str)
parser.add_argument("--output", type=str)
parser.add_argument("--part", type=int)

args = parser.parse_args()

dataset_dir = args.dataset
output_dir = args.output
dataset_left = args.dataset_left
n_part = args.part


@torch.no_grad()
def rank_responses(dataset):
    with torch.inference_mode():
        prompts = dataset["prompt"]
        # response_index_list = ["chosen", "rejected", "max_pi", "random"]
        chosen_list = [row[1]["content"] for row in dataset["chosen"]]
        # rejected_list = [row[1]["content"] for row in dataset["rejected"]]
        opt_list = dataset["minpi"]
        random_list = dataset["random"]
        ds_size = len(prompts)
        # candidates_texts = [[chosen_list[idx]] + [rejected_list[idx]] + [opt_list[idx]] + [random_list[idx]]
        #                     for idx in range(ds_size)]
        candidates_texts = [[chosen_list[idx]] + [opt_list[idx]] + [random_list[idx]]
                            for idx in range(ds_size)]

        rank = blender.rank(prompts, candidates_texts, return_scores=False)

        # print(rank)

        chosen_indices = np.argmin(rank, axis=1)
        rejected_indices = np.argmax(rank, axis=1)
        chosen_texts = np.array(candidates_texts)[np.arange(
            len(candidates_texts)), chosen_indices]
        rejected_texts = np.array(candidates_texts)[np.arange(
            len(candidates_texts)), rejected_indices]

        chosen_responses_dict = np.array(
            [{"content": res, "role": "assistant"} for res in chosen_texts])
        rejected_responses_dict = np.array(
            [{"content": res, "role": "assistant"} for res in rejected_texts])
        chosen_np = np.array(dataset['chosen'])
        reject_np = np.array(dataset['rejected'])
        update_chosen_column = np.column_stack(
            (chosen_np[:, 0], chosen_responses_dict))  # -1 for Gemma
        update_reject_column = np.column_stack(
            (reject_np[:, 0], rejected_responses_dict))

    dataset = dataset.remove_columns(["chosen", "rejected"])
    dataset = dataset.add_column("chosen", update_chosen_column.tolist())
    dataset = dataset.add_column("rejected", update_reject_column.tolist())
    return dataset


if __name__ == "__main__":

    dataset_opt = load_dataset(dataset_dir, split="train_prefs",
                               download_mode="force_redownload", ignore_verifications=True)

    interval = len(dataset_opt)//4
    start = interval*n_part
    end = interval*(n_part+1) if n_part != 3 else len(dataset_opt)
    dataset_opt = dataset_opt.select(range(start, end))

    new_dataset = rank_responses(dataset_opt)

    new_dataset = new_dataset.remove_columns(
        ["resp0", "resp1", "resp2", "resp3", "minpi", "random"])
    if dataset_left != "None":
        dataset_rest = load_dataset(dataset_left, split="train_prefs",
                                    download_mode="force_redownload", ignore_verifications=True)
        new_dataset = concatenate_datasets([new_dataset, dataset_rest])

    new_dataset.push_to_hub(
        output_dir+f"_mini_{n_part}", split="train_prefs", private=False)
