from alignment import H4ArgumentParser, ModelArguments, DataArguments, RDPOConfig, get_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset, DatasetDict, concatenate_datasets, load_from_disk
from multiprocess import set_start_method
import torch
from trl.trainer.utils import pad_to_length
from huggingface_hub import Repository, snapshot_download
import numpy as np
import random
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
import argparse
import time
from accelerate import Accelerator
import os

parser = argparse.ArgumentParser()

# parser.add_argument("--total_part", type=int)
parser.add_argument("--dataset", type=str)
parser.add_argument("--model", type=str)
parser.add_argument("--output", type=str)
# parser.add_argument("--output_left", type=str)
# parser.add_argument("--gpu", type=int)
parser.add_argument("--part", type=int)
parser.add_argument("--total", type=int)


args = parser.parse_args()

dataset_dir = args.dataset
model_dir = args.model
output_dir = args.output
# output_left = args.output_left
n_part = args.part
total_part = args.total
# print(str(args.gpu))
# os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
lora_path = None
if model_dir != 'original':
    lora_path = model_dir

n_generation = 4


###############
# Load datasets
###############
# parser = H4ArgumentParser((ModelArguments, DataArguments, RDPOConfig))
# model_args, data_args, training_args = parser.parse()
# print(model_args)
# print(data_args)
# print(training_args)

# model = AutoModelForCausalLM.from_pretrained(
#     model_args.model_name_or_path,  torch_dtype=torch.bfloat16)
# model = LLM("alignment-handbook/zephyr-7b-sft-full",
#             enable_lora=True, max_lora_rank=64, swap_space=4, tensor_parallel_size=torch.cuda.device_count(),
#             trust_remote_code=True, dtype="auto")
home_directory = os.path.expanduser("~")
cache_path = os.path.join(home_directory, ".cache/vllm")
model = LLM("alignment-handbook/zephyr-7b-sft-full",
            enable_lora=True, max_lora_rank=64, download_dir=cache_path)

# ref_model = AutoModelForCausalLM.from_pretrained(
#     "alignment-handbook/zephyr-7b-sft-full",  torch_dtype=torch.bfloat16)
tokenizer = None
if model_dir != 'original':
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
else:
    tokenizer = AutoTokenizer.from_pretrained(
        "alignment-handbook/zephyr-7b-sft-full")


def generate_response_vllm(model, tokenizer, dataset):
    with torch.inference_mode():
        sampling_params = SamplingParams(
            n=n_generation,
            best_of=n_generation,
            temperature=0.7,
            top_p=1.0,
            max_tokens=2048,
            stop=tokenizer.eos_token,
            skip_special_tokens=True,
        )
        chosen_messages = dataset['chosen']
        chat_prompts = []
        iter = 0
        for chosen_message in chosen_messages:
            iter += 1
            prompt_message = chosen_message[:-1]
            # prompt_messages = [{"role": "system", "content": ""}, {"role": "user", "content": p}]
            new_prompt = tokenizer.apply_chat_template(
                prompt_message, tokenize=False, add_generation_prompt=True)
            chat_prompts.append(new_prompt)
        responses = None
        if model_dir != 'original':
            responses = model.generate(chat_prompts, sampling_params, lora_request=LoRARequest(
                "sql_adapter", 1, lora_path))
        else:
            responses = model.generate(chat_prompts, sampling_params)

        resp_list = [[response.outputs[i].text.strip()
                      for response in responses] for i in range(n_generation)]
    for i in range(n_generation):
        dataset = dataset.add_column(f"resp{i}", resp_list[i])
    return dataset


if __name__ == "__main__":
    #  """
    # set_start_method("spawn")

    # existing_train_dataset = load_dataset(dataset_dir, split="train_prefs",
    #                                       download_mode="force_redownload", ignore_verifications=True)
    existing_train_dataset = load_from_disk(dataset_dir)
    interval = len(existing_train_dataset)//total_part
    start = interval*n_part
    end = interval*(n_part+1) if n_part != total_part - \
        1 else len(existing_train_dataset)
    existing_train_dataset = existing_train_dataset.select(range(start, end))
    dataset_list = []
    counter = 0
    i = 0
    while True:
        try:
            start = i*100
            end = min((i+1)*100, len(existing_train_dataset))
            if start >= end:
                break
            mini_dataset = existing_train_dataset.select(range(start, end))
            new_mini_dataset = generate_response_vllm(
                model, tokenizer, mini_dataset)
            dataset_list.append(new_mini_dataset)
            # counter += 1
            # if counter >= 5:
            #     break
        except Exception as e:
            print(e)
        i += 1

    new_dataset = concatenate_datasets(dataset_list)
    new_dataset.save_to_disk("./"+output_dir+f"_mini_{n_part}")
    # new_dataset = generate_response_vllm(model, tokenizer, train_dataset)
    # new_dataset.push_to_hub(
    #     output_dir+f"_mini_{n_part}", split="train_prefs", private=False)

    # left_dataset.push_to_hub(output_left, split="train_prefs", private=False)
