from tqdm import tqdm
import llm_blender
import torch
from datasets import load_dataset, concatenate_datasets
import numpy as np
from transformers import AutoTokenizer
import datasets
import argparse

blender = llm_blender.Blender()
blender.loadranker("llm-blender/PairRM")


parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str)
# parser.add_argument("--dataset_left", type=str)
parser.add_argument("--output", type=str)
parser.add_argument("--part", type=int)
parser.add_argument("--total", type=int)

args = parser.parse_args()

dataset_dir = args.dataset
output_dir = args.output
n_part = args.part
total_part = args.total


@torch.no_grad()
def rank_responses(dataset):
    with torch.inference_mode():
        prompts = dataset["prompt"]
        response_index_list = ["chosen", "rejected", "max_pi", "random"]
        # chosen_list = [row[1]["content"] for row in dataset["minpi"]]
        # rejected_list = [row[1]["content"] for row in dataset["random"]]
        opt_list = dataset["resp0"]
        random_list = dataset["resp1"]
        ds_size = len(prompts)
        candidates_texts = [[opt_list[idx]] + [random_list[idx]]
                            for idx in range(ds_size)]
        rank = blender.rank(prompts, candidates_texts, return_scores=False)

        chosen_indices = np.argmin(rank, axis=1)
        rejected_indices = np.argmax(rank, axis=1)
        chosen_texts = np.array(candidates_texts)[np.arange(
            len(candidates_texts)), chosen_indices]
        rejected_texts = np.array(candidates_texts)[np.arange(
            len(candidates_texts)), rejected_indices]

        chosen_responses_dict = np.array(
            [{"content": res, "role": "assistant"} for res in chosen_texts])
        rejected_responses_dict = np.array(
            [{"content": res, "role": "assistant"} for res in rejected_texts])
        chosen_np = np.array(dataset['chosen'])
        reject_np = np.array(dataset['rejected'])
        update_chosen_column = np.column_stack(
            (chosen_np[:, 0], chosen_responses_dict))  # -1 for Gemma
        update_reject_column = np.column_stack(
            (reject_np[:, 0], rejected_responses_dict))

        # print(chosen_indices)

    dataset = dataset.remove_columns(["chosen", "rejected"])
    dataset = dataset.add_column("chosen", update_chosen_column.tolist())
    dataset = dataset.add_column("rejected", update_reject_column.tolist())
    return dataset


if __name__ == "__main__":

    dataset_opt = load_dataset(dataset_dir, split="train_prefs",
                               download_mode="force_redownload", ignore_verifications=True)

    interval = len(dataset_opt)//total_part
    start = interval*n_part
    end = interval*(n_part+1) if n_part != total_part-1 else len(dataset_opt)
    dataset_opt = dataset_opt.select(range(start, end))

    new_dataset = rank_responses(dataset_opt)

    column_to_remove = ["resp0", "resp1"]
    # for i in range(4):
    #     column_to_remove.append(f"resp{i}")
    #     column_to_remove.append(f"logpi{i}")
    #     column_to_remove.append(f"logpiref{i}")

    new_dataset = new_dataset.remove_columns(column_to_remove)

    new_dataset.save_to_disk("./"+output_dir+f"_mini_{n_part}")
