from alignment import H4ArgumentParser, ModelArguments, DataArguments, RDPOConfig, get_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset, DatasetDict, concatenate_datasets, load_from_disk
from multiprocess import set_start_method
import torch
from trl.trainer.utils import pad_to_length
from huggingface_hub import Repository, snapshot_download
import numpy as np
import random
from vllm import LLM, SamplingParams
import argparse
import time
from accelerate import Accelerator
import os

parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str)
parser.add_argument("--model", type=str)
parser.add_argument("--output", type=str)
parser.add_argument("--part", type=int)
parser.add_argument("--total", type=int)
parser.add_argument("--generation", type=int)

args = parser.parse_args()

dataset_dir = args.dataset
model_dir = args.model
output_dir = args.output
n_part = args.part

n_generation = args.generation

total_part = args.total


home_directory = os.path.expanduser("~")
cache_path = os.path.join(home_directory, ".cache/vllm")

model = LLM(model_dir, download_dir=cache_path)

tokenizer = AutoTokenizer.from_pretrained(
    "alignment-handbook/zephyr-7b-sft-full")


def generate_response_vllm(model, tokenizer, dataset):
    with torch.inference_mode():
        sampling_params = SamplingParams(
            n=n_generation,
            best_of=n_generation,
            temperature=0.7,
            top_p=1.0,
            max_tokens=2048,
            stop=tokenizer.eos_token,
            skip_special_tokens=True,
        )
        chosen_messages = dataset['chosen']
        chat_prompts = []
        iter = 0
        for chosen_message in chosen_messages:
            iter += 1
            prompt_message = chosen_message[:-1]
            # prompt_messages = [{"role": "system", "content": ""}, {"role": "user", "content": p}]
            new_prompt = tokenizer.apply_chat_template(
                prompt_message, tokenize=False, add_generation_prompt=True)
            chat_prompts.append(new_prompt)
        responses = model.generate(chat_prompts, sampling_params)

        resp_list = [[response.outputs[i].text.strip()
                      for response in responses] for i in range(n_generation)]
    for i in range(n_generation):
        dataset = dataset.add_column(f"resp{i}", resp_list[i])
    return dataset


if __name__ == "__main__":
    existing_train_dataset = load_from_disk(dataset_dir)
    interval = len(existing_train_dataset)//total_part
    start = interval*n_part
    end = interval*(n_part+1) if n_part != total_part - \
        1 else len(existing_train_dataset)
    existing_train_dataset = existing_train_dataset.select(range(start, end))
    new_dataset = generate_response_vllm(
        model, tokenizer, existing_train_dataset)
    new_dataset.save_to_disk("./"+output_dir+f"_mini_{n_part}")
