from alignment import H4ArgumentParser, ModelArguments, DataArguments, RDPOConfig, get_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset, DatasetDict, concatenate_datasets
from multiprocess import set_start_method
import torch
from trl.trainer.utils import pad_to_length
from huggingface_hub import Repository
import os
import numpy as np
import random
import argparse
import llm_blender


###############
# Load datasets
###############
# parser = H4ArgumentParser((ModelArguments, DataArguments, RDPOConfig))
# model_args, data_args, training_args = parser.parse()
# print(model_args)
# print(data_args)
# print(training_args)
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str)
parser.add_argument("--model", type=str)
parser.add_argument("--ref_model", type=str)
parser.add_argument("--output", type=str)
parser.add_argument("--ratio", type=float)
parser.add_argument("--generation_num", type=int)

args = parser.parse_args()

dataset_dir = args.dataset
model_dir = args.model
ref_model_dir = args.ref_model
ratio = args.ratio
output_dir = args.output
n_generation = args.generation_num


# blender = llm_blender.Blender()
# blender.loadranker("llm-blender/PairRM")

model = AutoModelForCausalLM.from_pretrained(
    model_dir,  torch_dtype=torch.bfloat16)
ref_model = AutoModelForCausalLM.from_pretrained(
    ref_model_dir,  torch_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(model_dir)

is_encoder_decoder = model.config.is_encoder_decoder


def select_min_and_random_elements(input_list, probs, num_return_sequences):
    input_array = np.array(input_list).reshape(-1, num_return_sequences)
    probs_array = np.array(probs).reshape(-1, num_return_sequences)

    # print(probs_array)

    min_elements = []
    random_elements = []
    min_index_list = []
    random_list = []

    for i in range(input_array.shape[0]):
        row_answers = input_array[i]
        row_probs = probs_array[i]

        min_index = np.argmin(row_probs)

        min_index_list.append(min_index)

        min_elements.append(row_answers[min_index])

        remaining_indices = list(range(len(row_answers)))
        remaining_indices.remove(min_index)
        random_index = random.choice(remaining_indices)
        random_elements.append(row_answers[random_index])

        random_list.append(random_index)

    # print(min_index_list)
    # print(random_list)
    return min_elements, random_elements


def get_batch_logps(
    logits: torch.FloatTensor,
    labels: torch.LongTensor,
    average_log_prob: bool = False,
    label_pad_token_id: int = tokenizer.pad_token_id,
    is_encoder_decoder: bool = False,
) -> torch.FloatTensor:
    """Compute the log probabilities of the given labels under the given logits.

    Args:
        logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
        labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
        average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
        label_pad_token_id: The label pad token id.
        is_encoder_decoder: Whether the model is an encoder-decoder model.

    Returns:
        A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
    """
    if logits.shape[:-1] != labels.shape:
        raise ValueError(
            "Logits (batch and sequence length dim) and labels must have the same shape.")

    if not is_encoder_decoder:
        labels = labels[:, 1:].clone()
        logits = logits[:, :-1, :]
    loss_mask = labels != label_pad_token_id

    # dummy token; we'll ignore the losses on these tokens later
    labels[labels == label_pad_token_id] = 0

    per_token_logps = torch.gather(
        logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

    if average_log_prob:
        return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
    else:
        return (per_token_logps * loss_mask).sum(-1)


def gpu_computation(batch, rank):  # , model, tokenizer):
    # Your big GPU call goes here, for example:
    # print(batch.keys())
    # batch_size = len(batch)
    prompts = []
    for example in batch["chosen"]:
        prompt_messages = example[:-1]
        # Prepend a system message if the first message is not a system message
        if example[0]["role"] != "system":
            prompt_messages.insert(0, {"role": "system", "content": ""})
        prompts.append(tokenizer.apply_chat_template(
            prompt_messages, tokenize=False))
    # print("?"*80)
    # print(torch.cuda.device_count())
    # print(rank)
    # print("?"*80)
    device = f"cuda:{(rank or 0) % torch.cuda.device_count()}"
    model.to(device)
    ref_model.to(device)
    # print(f"device: {device}")
    prompt_token = tokenizer(prompts,
                             padding=True,
                             max_length=1024,
                             truncation=True,
                             return_tensors="pt").to(device)
    with torch.inference_mode():
        # print(f"device:{rank} is generating")
        responses = model.generate(input_ids=prompt_token["input_ids"],
                                   attention_mask=prompt_token["attention_mask"],
                                   max_new_tokens=1024,
                                   do_sample=True,
                                   num_return_sequences=n_generation,
                                   pad_token_id=tokenizer.pad_token_id,
                                   )
        attention_mask = (responses != tokenizer.pad_token_id).float()

        ref_output = ref_model(input_ids=responses,
                               attention_mask=attention_mask)
        # print("="*80)
        # print(responses)
        # print(attention_mask)
        # print("="*80)
        logits = ref_output.logits
        batch_logps = get_batch_logps(logits=logits,
                                      labels=responses,
                                      average_log_prob=False,
                                      label_pad_token_id=tokenizer.pad_token_id,
                                      is_encoder_decoder=is_encoder_decoder,
                                      )
        # print(batch_logps)
        # print(responses)
        # print(responses.shape)
        responses_str = tokenizer.batch_decode(
            responses, skip_special_tokens=True)
        # print(responses_str)

        response_a, response_b = select_min_and_random_elements(
            responses_str, batch_logps.cpu(), n_generation)
        # print(response_a[1])
        # print(response_b[1])
        batch["chosen"] = response_a
        batch["rejected"] = response_b

        # response_a_only = [resp.split("<|assistant|>\n")[-1]
        #                    for resp in response_a]
        # response_b_only = [resp.split("<|assistant|>\n")[-1]
        #                    for resp in response_b]
        # candidates_texts = [[response_a_only[idx]] + [response_b_only[idx]]
        #                     for idx in range(len(response_a))]
        # rank = blender.rank(prompts, candidates_texts, return_scores=False)

        # # print(rank)

        # chosen_indices = np.argmin(rank, axis=1)
        # rejected_indices = np.argmax(rank, axis=1)
        # chosen_texts = np.array(candidates_texts)[np.arange(
        #     len(candidates_texts)), chosen_indices]
        # rejected_texts = np.array(candidates_texts)[np.arange(
        #     len(candidates_texts)), rejected_indices]

        # winner = [response_a[i] if rank[i][0] == 1 else response_b[i]
        #           for i in range(len(response_a))]
        # loser = [response_a[i] if rank[i][1] == 1 else response_b[i]
        #          for i in range(len(response_a))]
        # batch["chosen"] = winner
        # batch["rejected"] = loser

        # chosen_responses_dict = np.array(
        #     [{"content": res, "role": "assistant"} for res in chosen_texts])
        # rejected_responses_dict = np.array(
        #     [{"content": res, "role": "assistant"} for res in rejected_texts])
        # chosen_np = np.array(dataset['chosen'])
        # reject_np = np.array(dataset['rejected'])
        # update_chosen_column = np.column_stack(
        #     (chosen_np[:, 0], chosen_responses_dict))  # -1 for Gemma
        # update_reject_column = np.column_stack(
        #     (reject_np[:, 0], rejected_responses_dict))
        # print("="*80)
        # print(winner)
        # print("-"*80)
        # print(loser)
        # print("="*80)

    del responses, ref_output
    torch.cuda.empty_cache()
    # print(f"device:{rank} finished")

    return batch


if __name__ == "__main__":
    #  """
    set_start_method("spawn")

    # multi_part = 3
    # for index in range(multi_part):
    #     if index == 0:
    #         existing_train_dataset = load_dataset(data_args.dataset_mixer["updated"] + f"_part{index}", split="train_prefs",
    #                                               download_mode="force_redownload", ignore_verifications=True)
    #     else:
    #         new_train_dataset = load_dataset(data_args.dataset_mixer["updated"] + f"_part{index}", split="train_prefs",
    #                                          download_mode="force_redownload", ignore_verifications=True)
    #         existing_train_dataset = concatenate_datasets(
    #             [existing_train_dataset, new_train_dataset])

    existing_train_dataset = load_dataset(dataset_dir, split="train_prefs",
                                          download_mode="force_redownload", ignore_verifications=True)
    train_dataset = existing_train_dataset.select(
        range(int(len(existing_train_dataset) * ratio)))
    new_dataset = train_dataset.map(
        gpu_computation,
        batched=True,
        #  fn_kwargs={"model": model, "tokenizer": tokenizer},
        batch_size=2,
        with_rank=True,
        num_proc=torch.cuda.device_count(),  # one process per GPU
    )
    new_dataset.push_to_hub(output_dir, split="train_prefs", private=False)

    # for i in range(6):

    #     start = 10000*i
    #     end = 10000*(i+1)

    #     train_dataset = existing_train_dataset.select(range(start, end, 1))
    #     new_train_dataset = train_dataset.map(
    #         gpu_computation,
    #         batched=True,
    #         #  fn_kwargs={"model": model, "tokenizer": tokenizer},
    #         batch_size=2,
    #         with_rank=True,
    #         num_proc=torch.cuda.device_count(),  # one process per GPU
    #     )
    #     print(new_train_dataset[3])

    #     set_name = "YYYYYYibo/ultrafeedback_binarized_with_response_full_labeled_part_" + \
    #         str(i)
    #     train_dataset.push_to_hub(
    #         set_name, split="train_prefs", private=False)
    # try:
    #     new_train_dataset.push_to_hub(
    #         set_name, split="train_prefs", private=False)
    #     new_train_dataset.save_to_disk("./datasets/"+set_name)
    # except:
    #     print("network error. save to disk.")
    #     new_train_dataset.save_to_disk("./datasets/"+set_name)

    # num_samples = len(existing_train_dataset)
    # subset_size = num_samples // 5
    # subsets = []

    # for i in range(5):
    #     start_index = i * subset_size
    #     end_index = (i + 1) * subset_size if i < 4 else num_samples
    #     subset = existing_train_dataset.select(
    #         list(range(start_index, end_index)))
    #     subsets.append(subset)

    # for i, subset in enumerate(subsets):
    #     subset_name = f"subset_{i+1}"
    #     # repo.push_to_hub(subset, organization="YYYYYYibo", name=subset_name)
    #     repo = Repository(
    #         "YYYYYYibo/ultrafeedback_binarized_with_response_full_"+subset_name)
    #     subset.push_to_hub(repo)
    # Repository.push_to_hub(
    #     subset, "YYYYYYibo/ultrafeedback_binarized_with_response_full_split_10/"+subset_name)

    # try:
    #     existing_train_dataset = load_dataset(data_args.dataset_mixer["updated"], split="train_prefs",
    #                                           download_mode="force_redownload",   ignore_verifications=True)
    #     raise ("No need to preprocess again!")
    # except:
    #     print(f"The training dataset is not found on Hub. Start preprocessing.")
    # for index in range(multi_part):
    #     try:
    #         existing_train_dataset = load_dataset(data_args.dataset_mixer["updated"] + f"_part{index}", split="train_prefs",
    #                                               download_mode="force_redownload", ignore_verifications=True)
    #         existing_rows = existing_train_dataset.num_rows
    #         str_list = data_args.dataset_splits[0].split(':')
    #         remain_split = str_list[0] + str(existing_rows) + ':' + str_list[1]
    #     except:
    #         print(
    #             f"The part{index} training dataset is not found on Hub. Start preprocessing the part {index}.")
    #         str_list = data_args.dataset_splits[0].split(':')
    #         if index == multi_part - 1:
    #             if len(str_list) > 1:
    #                 remain_split = str_list[0] + \
    #                     f"{index * length_per_part}" + ':' + str_list[1]
    #             else:
    #                 remain_split = str_list[0] + \
    #                     f"[{index * length_per_part}:]"
    #         else:
    #             if len(str_list) > 1:
    #                 remain_split = str_list[0] + \
    #                     f"{index * length_per_part}" + ':' + str_list[1]
    #             else:
    #                 remain_split = str_list[0] + \
    #                     f"[{index * length_per_part}:{(index + 1) * length_per_part}]"
    #         original_train_dataset = load_dataset(
    #             data_args.dataset_mixer["original"], split=remain_split)
    #         original_train_dataset = original_train_dataset.select(range(10))
    # new_train_dataset = original_train_dataset.map(
    #     gpu_computation,
    #     batched=True,
    #     #  fn_kwargs={"model": model, "tokenizer": tokenizer},
    #     batch_size=4,
    #     with_rank=True,
    #     num_proc=torch.cuda.device_count(),  # one process per GPU
    # )
    # print(new_train_dataset[2])
    # # new_train_dataset.push_to_hub(
    # #     data_args.dataset_mixer["updated"] + f"_part{index}_small", split="train_prefs", private=False)
    # existing_train_dataset = None
    # print("Combining all the parts.")
    # for index in range(multi_part):
    #     if index == 0:
    #         existing_train_dataset = load_dataset(data_args.dataset_mixer["updated"] + f"_part{index}", split="train_prefs",
    #                                               download_mode="force_redownload", ignore_verifications=True)
    #     else:
    #         new_train_dataset = load_dataset(data_args.dataset_mixer["updated"] + f"_part{index}", split="train_prefs",
    #                                          download_mode="force_redownload", ignore_verifications=True)
    #         existing_train_dataset = concatenate_datasets(
    #             [existing_train_dataset, new_train_dataset])
    # small_train_dataset = existing_train_dataset.select(range(5000))
