from alignment import H4ArgumentParser, ModelArguments, DataArguments, RDPOConfig, get_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset, DatasetDict, concatenate_datasets, load_from_disk
from multiprocess import set_start_method
import torch
from trl.trainer.utils import pad_to_length
from huggingface_hub import Repository, snapshot_download
import numpy as np
import random
from vllm import LLM, SamplingParams
import argparse
import time
from accelerate import Accelerator
import os

parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str)
parser.add_argument("--model", type=str)
parser.add_argument("--output", type=str)
parser.add_argument("--part", type=int)
parser.add_argument("--total", type=int)
parser.add_argument("--temperature", type=float)
parser.add_argument("--index", type=int)

args = parser.parse_args()

dataset_dir = args.dataset
model_dir = args.model
output_dir = args.output
n_part = args.part
n_generation = 1
total_part = args.total
gen_temp = args.temperature
gen_index = args.index

home_directory = os.path.expanduser("~")
cache_path = os.path.join(home_directory, ".cache/vllm")

model = LLM(model_dir, download_dir=cache_path)

tokenizer = AutoTokenizer.from_pretrained(model_dir)


def generate_response_vllm(model, tokenizer, dataset, temperature, index):
    with torch.inference_mode():
        sampling_params = SamplingParams(
            n=n_generation,
            best_of=n_generation,
            temperature=temperature,
            top_p=1.0,
            max_tokens=2048,
            stop=tokenizer.eos_token,
            skip_special_tokens=True,
        )
        chosen_messages = dataset['chosen']
        chat_prompts = []
        iter = 0
        for chosen_message in chosen_messages:
            iter += 1
            prompt_message = chosen_message[:-1]
            new_prompt = tokenizer.apply_chat_template(
                prompt_message, tokenize=False, add_generation_prompt=True)
            chat_prompts.append(new_prompt)
        responses = model.generate(chat_prompts, sampling_params)

        resp_list = [response.outputs[0].text.strip()for response in responses]
        dataset = dataset.add_column(f"resp{index}", resp_list)
    return dataset


if __name__ == "__main__":
    existing_train_dataset = load_from_disk(dataset_dir)
    interval = len(existing_train_dataset)//total_part
    start = interval*n_part
    end = interval*(n_part+1) if n_part != total_part - \
        1 else len(existing_train_dataset)
    existing_train_dataset = existing_train_dataset.select(range(start, end))
    new_dataset = generate_response_vllm(
        model, tokenizer, existing_train_dataset, temperature=gen_temp, index=gen_index)
    # new_dataset = generate_response_vllm(
    #     model, tokenizer, new_dataset, temperature=0.5, index=1)
    new_dataset.save_to_disk("./"+output_dir+f"_mini_{n_part}")
