from alignment import H4ArgumentParser, ModelArguments, DataArguments, RDPOConfig, get_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset, DatasetDict, concatenate_datasets, load_from_disk
from multiprocess import set_start_method
import torch
from trl.trainer.utils import pad_to_length
from huggingface_hub import Repository
import os
import numpy as np
import random
import time
import argparse

###############
# Load datasets
###############
# parser = H4ArgumentParser((ModelArguments, DataArguments, RDPOConfig))
# model_args, data_args, training_args = parser.parse()


parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str)
parser.add_argument("--output", type=str)
parser.add_argument("--generation", type=int, default=4)

args = parser.parse_args()

dataset_dir = args.dataset
output_dir = args.output
n_generation = args.generation

ref_model = AutoModelForCausalLM.from_pretrained(
    "alignment-handbook/zephyr-7b-sft-full",  torch_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(
    "alignment-handbook/zephyr-7b-sft-full")

is_encoder_decoder = ref_model.config.is_encoder_decoder


def get_pi(input_list, probs, num_return_sequences):
    input_array = np.array(input_list).reshape(-1, num_return_sequences)
    probs_array = np.array(probs).reshape(-1, num_return_sequences)

    # print(probs_array)

    min_elements = []
    random_elements = []
    min_index_list = []
    random_list = []

    for i in range(input_array.shape[0]):
        row_answers = input_array[i]
        row_probs = probs_array[i]

        min_index = np.argmin(row_probs)

        min_index_list.append(min_index)

        min_elements.append(row_answers[min_index])

        remaining_indices = list(range(len(row_answers)))
        remaining_indices.remove(min_index)
        random_index = random.choice(remaining_indices)
        random_elements.append(row_answers[random_index])

        random_list.append(random_index)

    # print(min_index_list)
    # print(random_list)
    return min_elements, random_elements


def get_batch_logps(
    logits: torch.FloatTensor,
    labels: torch.LongTensor,
    average_log_prob: bool = False,
    label_pad_token_id: int = tokenizer.pad_token_id,
    is_encoder_decoder: bool = False,
) -> torch.FloatTensor:
    """Compute the log probabilities of the given labels under the given logits.

    Args:
        logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
        labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
        average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
        label_pad_token_id: The label pad token id.
        is_encoder_decoder: Whether the model is an encoder-decoder model.

    Returns:
        A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
    """
    if logits.shape[:-1] != labels.shape:
        raise ValueError(
            "Logits (batch and sequence length dim) and labels must have the same shape.")

    if not is_encoder_decoder:
        labels = labels[:, 1:].clone()
        logits = logits[:, :-1, :]
    loss_mask = labels != label_pad_token_id

    # dummy token; we'll ignore the losses on these tokens later
    labels[labels == label_pad_token_id] = 0

    per_token_logps = torch.gather(
        logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

    if average_log_prob:
        return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
    else:
        return (per_token_logps * loss_mask).sum(-1)


def gpu_computation(batch, rank):  # , model, tokenizer):
    prompts = []
    for example in batch["chosen"]:
        prompt_messages = example[:-1]
        # Prepend a system message if the first message is not a system message
        if example[0]["role"] != "system":
            prompt_messages.insert(0, {"role": "system", "content": ""})
        prompts.append(tokenizer.apply_chat_template(
            prompt_messages, tokenize=False))
    inputs = []
    responses = []
    batch_size = len(batch["prompt"])
    for i in range(batch_size):
        for j in range(n_generation):  # resp0 到 resp3
            text = prompts[i] + batch[f"resp{j}"][i]
            assert isinstance(text, str)
            inputs.append(text)
            responses.append(batch[f"resp{j}"][i])

    device = f"cuda:{(rank or 0) % torch.cuda.device_count()}"
    # model.to(device)
    ref_model.to(device)
    # print(f"device: {device}")
    tokens = tokenizer(inputs,
                       padding=True,
                       max_length=1024,
                       truncation=True,
                       return_tensors="pt").to(device)
    with torch.no_grad():

        ref_output = ref_model(input_ids=tokens["input_ids"],
                               attention_mask=tokens["attention_mask"])
        logits = ref_output.logits
        batch_logps = get_batch_logps(logits=logits,
                                      labels=tokens["input_ids"],
                                      average_log_prob=True,
                                      label_pad_token_id=tokenizer.pad_token_id,
                                      is_encoder_decoder=is_encoder_decoder,
                                      )

        # response_a, response_b = get_pi(
        #     responses, batch_logps.cpu(), n_generation)
        batch_logps = batch_logps.cpu()
        probs_array = np.array(batch_logps).reshape(-1, n_generation)
        # batch["minpi"] = response_a
        # batch["random"] = response_b
        # print(probs_array)
        for i in range(n_generation):
            batch[f"logpiref{i}"] = probs_array[:, i]

    del ref_output
    torch.cuda.empty_cache()

    return batch


if __name__ == "__main__":
    set_start_method("spawn")

    # dataset_list = []
    # for i in range(4):
    #     mini_dataset = load_dataset(dataset_dir+f"_mini_{i}", split="train_prefs",
    #                                 download_mode="force_redownload", ignore_verifications=True)
    #     dataset_list.append(mini_dataset)
    # train_dataset = concatenate_datasets(dataset_list)

    # train_dataset = load_dataset(dataset_dir, split="train_prefs",
    #                              download_mode="force_redownload", ignore_verifications=True)
    train_dataset = load_from_disk(dataset_dir)

    new_train_dataset = train_dataset.map(
        gpu_computation,
        batched=True,
        #  fn_kwargs={"model": model, "tokenizer": tokenizer},
        batch_size=8,
        with_rank=True,
        num_proc=torch.cuda.device_count(),  # one process per GPU
    )

    # set_name = "YYYYYYibo/ultrafeedback_binarized_phase_2_minpi"
    # new_train_dataset.push_to_hub(
    #     output_dir, split="train_prefs", private=False)
    new_train_dataset.save_to_disk("./"+output_dir)
