from datasets import load_dataset, DatasetDict, concatenate_datasets, load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
import torch
import os
import numpy as np
import random
import time
import argparse

###############
# Load datasets
###############
# parser = H4ArgumentParser((ModelArguments, DataArguments, RDPOConfig))
# model_args, data_args, training_args = parser.parse()

parser = argparse.ArgumentParser()
tokenizer = AutoTokenizer.from_pretrained(
    "alignment-handbook/zephyr-7b-sft-full")

parser.add_argument("--dataset", type=str)
parser.add_argument("--output", type=str)
parser.add_argument("--impsam", action='store_true')
parser.add_argument("--selectmax", action='store_true')
parser.add_argument("--generation", type=int, default=4)

args = parser.parse_args()

dataset_dir = args.dataset
output_dir = args.output
impsam = args.impsam
selectmax = args.selectmax

n_generation = args.generation


def select_min_and_max_elements(input_list, probs, num_return_sequences):
    input_array = np.array(input_list).reshape(-1, num_return_sequences)
    probs_array = np.array(probs).reshape(-1, num_return_sequences)

    # print(probs_array)

    min_elements = []
    max_elements = []
    min_index_list = []
    max_index_list = []

    for i in range(input_array.shape[0]):
        row_answers = input_array[i]
        row_probs = probs_array[i]

        min_index = np.argmin(row_probs)

        min_index_list.append(min_index)

        min_elements.append(row_answers[min_index])

        max_index = np.argmax(row_probs)

        max_index_list.append(max_index)

        max_elements.append(row_answers[max_index])

    return min_elements, max_elements


def select_min_and_random_elements(input_list, probs, num_return_sequences):
    input_array = np.array(input_list).reshape(-1, num_return_sequences)
    probs_array = np.array(probs).reshape(-1, num_return_sequences)

    # print(probs_array)

    min_elements = []
    random_elements = []
    max_elements = []
    min_index_list = []
    max_index_list = []
    random_list = []

    for i in range(input_array.shape[0]):
        row_answers = input_array[i]
        row_probs = probs_array[i]

        min_index = np.argmin(row_probs)

        min_index_list.append(min_index)

        min_elements.append(row_answers[min_index])

        remaining_indices = list(range(len(row_answers)))
        remaining_indices.remove(min_index)
        random_index = random.choice(remaining_indices)
        random_elements.append(row_answers[random_index])

        random_list.append(random_index)

    # print(min_index_list)
    # print(random_list)
    return min_elements, random_elements


def select_response(batch):  # , model, tokenizer):
    prompts = []
    for example in batch["chosen"]:
        prompt_messages = example[:-1]
        # Prepend a system message if the first message is not a system message
        if example[0]["role"] != "system":
            prompt_messages.insert(0, {"role": "system", "content": ""})
        prompts.append(tokenizer.apply_chat_template(
            prompt_messages, tokenize=False))
    inputs = []
    responses = []
    batch_logps = []
    batch_size = len(batch["prompt"])
    for i in range(batch_size):
        for j in range(n_generation):  # resp0 - resp3
            text = prompts[i] + batch[f"resp{j}"][i]
            # assert isinstance(text, str)
            inputs.append(text)
            responses.append(batch[f"resp{j}"][i])
            if impsam:
                batch_logps.append(
                    batch[f"logpiref{j}"][i] - batch[f"logpi{j}"][i])
            else:
                batch_logps.append(
                    batch[f"logpiref{j}"][i])

    if selectmax:
        response_a, response_b = select_min_and_max_elements(
            responses, batch_logps, n_generation)
    else:
        response_a, response_b = select_min_and_random_elements(
            responses, batch_logps, n_generation)

    # batch["reference_response"] = batch["resp0"]

    column_to_remove = []
    for i in range(n_generation):
        column_to_remove.append(f"resp{i}")
        if impsam:
            column_to_remove.append(f"logpi{i}")
        column_to_remove.append(f"logpiref{i}")

    batch["resp0"] = response_a
    batch["resp1"] = response_b

    return batch


if __name__ == "__main__":

    # train_dataset = load_dataset(dataset_dir, split="train_prefs",
    #                              download_mode="force_redownload", ignore_verifications=True)
    train_dataset = load_from_disk(dataset_dir)

    new_train_dataset = train_dataset.map(
        select_response, batched=True, batch_size=8)

    new_train_dataset.save_to_disk("./"+output_dir)
    # new_train_dataset.push_to_hub(
    #     output_dir, split="train_prefs", private=False)
