from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset, DatasetDict, concatenate_datasets
from multiprocess import set_start_method
import torch
from trl.trainer.utils import pad_to_length
from huggingface_hub import Repository, snapshot_download
import numpy as np
import random
from vllm import LLM, SamplingParams
import argparse
import time
from accelerate import Accelerator
import os

parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str)
parser.add_argument("--output", type=str)
parser.add_argument("--part", type=int)

args = parser.parse_args()

dataset_dir = args.dataset
output_dir = args.output
n_part = args.part
n_generation = 2


model = LLM(model="alignment-handbook/zephyr-7b-sft-full",
            tokenizer="alignment-handbook/zephyr-7b-sft-full",
            gpu_memory_utilization=0.7)

# ref_model = AutoModelForCausalLM.from_pretrained(
#     "alignment-handbook/zephyr-7b-sft-full",  torch_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(
    "alignment-handbook/zephyr-7b-sft-full")


def generate_response_vllm(model, tokenizer, dataset):
    with torch.inference_mode():
        sampling_params = SamplingParams(
            n=n_generation,
            best_of=n_generation,
            temperature=1.0,
            top_p=1.0,
            max_tokens=1024,
            stop=tokenizer.eos_token,
            skip_special_tokens=True,
        )
        chosen_messages = dataset['chosen']
        chat_prompts = []
        iter = 0
        for chosen_message in chosen_messages:
            iter += 1
            prompt_message = chosen_message[:-1]
            # prompt_messages = [{"role": "system", "content": ""}, {"role": "user", "content": p}]
            new_prompt = tokenizer.apply_chat_template(
                prompt_message, tokenize=False, add_generation_prompt=True)
            chat_prompts.append(new_prompt)

        responses = model.generate(chat_prompts, sampling_params)

        resp_list = [[response.outputs[i].text.strip()
                      for response in responses] for i in range(n_generation)]
    for i in range(n_generation):
        dataset = dataset.add_column(f"resp{i}", resp_list[i])
    return dataset


if __name__ == "__main__":
    #  """
    # set_start_method("spawn")

    existing_train_dataset = load_dataset(dataset_dir, split="train_prefs",
                                          download_mode="force_redownload", ignore_verifications=True)
    interval = len(existing_train_dataset)//4
    start = interval*n_part
    end = interval*(n_part+1) if n_part != 3 else len(existing_train_dataset)
    existing_train_dataset = existing_train_dataset.select(range(start, end))
    dataset_list = []
    counter = 0
    i = 0
    while True:
        try:
            start = i*100
            end = min((i+1)*100, len(existing_train_dataset))
            if start >= end:
                break
            mini_dataset = existing_train_dataset.select(range(start, end))
            new_mini_dataset = generate_response_vllm(
                model, tokenizer, mini_dataset)
            dataset_list.append(new_mini_dataset)
            # counter += 1
            # if counter >= 5:
            #     break
        except Exception as e:
            print(e)
        i += 1

    new_dataset = concatenate_datasets(dataset_list)
    # new_dataset = generate_response_vllm(model, tokenizer, train_dataset)
    new_dataset.push_to_hub(
        output_dir+f"_mini_{n_part}", split="train_prefs", private=False)
    # left_dataset.push_to_hub(output_left, split="train_prefs", private=False)
