from alignment import H4ArgumentParser, ModelArguments, DataArguments, RDPOConfig, get_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset, DatasetDict, concatenate_datasets, load_from_disk
from multiprocess import set_start_method
import torch
from trl.trainer.utils import pad_to_length
from huggingface_hub import Repository, snapshot_download
import numpy as np
import random
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
import argparse
import time
from accelerate import Accelerator
import os

parser = argparse.ArgumentParser()

# parser.add_argument("--total_part", type=int)
parser.add_argument("--dataset", type=str)
parser.add_argument("--model", type=str)
parser.add_argument("--output", type=str)
# parser.add_argument("--output_left", type=str)
# parser.add_argument("--gpu", type=int)
parser.add_argument("--part", type=int)
parser.add_argument("--total", type=int)
parser.add_argument("--name", type=int)


args = parser.parse_args()

dataset_dir = args.dataset
model_dir = args.model
output_dir = args.output
n_part = args.part
total_part = args.total
n_name = args.name
lora_path = None
if model_dir != 'original':
    lora_path = model_dir

n_generation = 1

model = LLM("alignment-handbook/zephyr-7b-sft-full",
            enable_lora=True, max_lora_rank=64, download_dir="/home/fkq3712/.cache/vllm")

tokenizer = None
if model_dir != 'original':
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
else:
    tokenizer = AutoTokenizer.from_pretrained(
        "alignment-handbook/zephyr-7b-sft-full")


def generate_response_vllm(model, tokenizer, dataset):
    with torch.inference_mode():
        sampling_params = SamplingParams(
            n=n_generation,
            best_of=n_generation,
            temperature=0.7,
            top_p=1.0,
            max_tokens=2048,
            stop=tokenizer.eos_token,
            skip_special_tokens=True,
        )
        chosen_messages = dataset['chosen']
        chat_prompts = []
        iter = 0
        for chosen_message in chosen_messages:
            iter += 1
            prompt_message = chosen_message[:-1]
            # prompt_messages = [{"role": "system", "content": ""}, {"role": "user", "content": p}]
            new_prompt = tokenizer.apply_chat_template(
                prompt_message, tokenize=False, add_generation_prompt=True)
            chat_prompts.append(new_prompt)
        responses = None
        if model_dir != 'original':
            responses = model.generate(chat_prompts, sampling_params, lora_request=LoRARequest(
                "sql_adapter", 1, lora_path))
        else:
            responses = model.generate(chat_prompts, sampling_params)

        resp_list = [response.outputs[0].text.strip()
                     for response in responses]
        dataset = dataset.add_column(f"resp{n_name}", resp_list)
    return dataset


if __name__ == "__main__":

    existing_train_dataset = load_from_disk(dataset_dir)
    interval = len(existing_train_dataset)//total_part
    start = interval*n_part
    end = interval*(n_part+1) if n_part != total_part - \
        1 else len(existing_train_dataset)
    existing_train_dataset = existing_train_dataset.select(range(start, end))
    dataset_list = []
    counter = 0
    i = 0
    while True:
        try:
            start = i*100
            end = min((i+1)*100, len(existing_train_dataset))
            if start >= end:
                break
            mini_dataset = existing_train_dataset.select(range(start, end))
            new_mini_dataset = generate_response_vllm(
                model, tokenizer, mini_dataset)
            dataset_list.append(new_mini_dataset)
            # counter += 1
            # if counter >= 5:
            #     break
        except Exception as e:
            print(e)
        i += 1

    new_dataset = concatenate_datasets(dataset_list)
    new_dataset.save_to_disk("./"+output_dir+f"_mini_{n_part}")
    # new_dataset = generate_response_vllm(model, tokenizer, train_dataset)
    # new_dataset.push_to_hub(
    #     output_dir+f"_mini_{n_part}", split="train_prefs", private=False)

    # left_dataset.push_to_hub(output_left, split="train_prefs", private=False)
