from alignment import H4ArgumentParser, ModelArguments, DataArguments, DPOConfig
from transformers import AutoTokenizer
import datasets
datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory='.': True
from datasets import load_dataset
import torch
from vllm import LLM, SamplingParams

@torch.no_grad()
def generate_response_vllm(model, tokenizer, dataset):
    with torch.inference_mode():
        sampling_params = SamplingParams(temperature=0, top_p=1.0, max_tokens=1024, stop=tokenizer.eos_token, skip_special_tokens=True)
        chosen_messages = dataset['chosen']
        chat_prompts = []
        for chosen_message in chosen_messages:
            prompt_message = chosen_message[:-1]
            # prompt_messages = [{"role": "system", "content": ""}, {"role": "user", "content": p}]
            chat_prompts.append(tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True))
        responses = model.generate(chat_prompts, sampling_params)
        responses = [response.outputs[0].text.strip() for response in responses]
    dataset = dataset.add_column("reference_response", responses)
    return dataset

if __name__ == "__main__":
    parser = H4ArgumentParser((ModelArguments, DataArguments, DPOConfig))
    model_args, data_args, training_args = parser.parse()
    if type(data_args.dataset_mixer) == str:
        data_args.dataset_mixer = eval(data_args.dataset_mixer)

    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    try:
        ref_model = LLM(model=model_args.model_name_or_path, tokenizer=model_args.model_name_or_path,
                        gpu_memory_utilization=0.8, swap_space=1, tensor_parallel_size=torch.cuda.device_count(),
                        trust_remote_code=True, dtype="auto")

        updated_dataset_name, iter_str = data_args.dataset_mixer["updated"].split("_iter")
        original_train_dataset = load_dataset(data_args.dataset_mixer["original"], split=data_args.dataset_splits[0])
                                          #    download_mode="force_redownload", ignore_verifications=True)
        new_train_dataset = generate_response_vllm(ref_model, tokenizer, original_train_dataset)
        new_train_dataset.push_to_hub(updated_dataset_name, private=False, split="train_prefs"+iter_str)

        original_test_dataset = load_dataset(data_args.dataset_mixer["original"], split=data_args.dataset_splits[1])
                                         #    download_mode="force_redownload", ignore_verifications=True)
        new_test_dataset = generate_response_vllm(ref_model, tokenizer, original_test_dataset)
        new_test_dataset.push_to_hub(updated_dataset_name, private=False, split="test_prefs"+iter_str)
    except Exception as e:
        print(e)
