from alignment import H4ArgumentParser, ModelArguments, DataArguments, DPOConfig, get_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset, DatasetDict, concatenate_datasets
from multiprocess import set_start_method
import torch
from trl.trainer.utils import pad_to_length

###############
# Load datasets
###############
parser = H4ArgumentParser((ModelArguments, DataArguments, DPOConfig))
model_args, data_args, training_args = parser.parse()
try:
    existing_train_dataset = load_dataset("ShenaoZ/compare_response", split="train_prefs",
                                          download_mode="force_redownload", ignore_verifications=True)
    existing_rows = existing_train_dataset.num_rows
    str_list = data_args.dataset_splits[0].split(':')
    remain_split = str_list[0] + str(existing_rows) + ':' + str_list[1]
    original_train_dataset = load_dataset(data_args.dataset_mixer["updated"], split=remain_split)
except:
    print('Existing dataset not detected!')
    original_train_dataset = load_dataset(data_args.dataset_mixer["updated"], split=data_args.dataset_splits[0])

ref_model = AutoModelForCausalLM.from_pretrained("ShenaoZ/zephyr-7b-dpo-full")
tokenizer = AutoTokenizer.from_pretrained("ShenaoZ/zephyr-7b-dpo-full")
def gpu_computation(batch, rank):#, ref_model, tokenizer):
    # Your big GPU call goes here, for example:
    prompts = []
    for example in batch["chosen"]:
        prompt_messages = example[:-1]
        # Prepend a system message if the first message is not a system message
        if example[0]["role"] != "system":
            prompt_messages.insert(0, {"role": "system", "content": ""})
        prompts.append(tokenizer.apply_chat_template(prompt_messages, tokenize=False))

    device = f"cuda:{(rank or 0) % torch.cuda.device_count()}"
    ref_model.to(device)
    prompt_token = tokenizer(prompts, padding=True, return_tensors="pt").to(device)
    with torch.no_grad():
        responses = ref_model.generate(input_ids=prompt_token["input_ids"],
                                           attention_mask=prompt_token["attention_mask"],
                                           max_length=training_args.max_length,
                                           do_sample=True,
                                           pad_token_id=tokenizer.pad_token_id
                                           )
        responses = pad_to_length(responses, training_args.max_length, tokenizer.pad_token_id)
        batch["opt_reference_response"] = tokenizer.batch_decode(responses, skip_special_tokens=True)
    return batch

if __name__ == "__main__":
    #  """
    set_start_method("spawn")
    new_train_dataset = original_train_dataset.map(
        gpu_computation,
        batched=True,
      #  fn_kwargs={"ref_model": ref_model, "tokenizer": tokenizer},
        batch_size=4,
        with_rank=True,
        num_proc=torch.cuda.device_count(),  # one process per GPU
    )
    try:
        new_train_dataset = concatenate_datasets([existing_train_dataset, new_train_dataset])
    except:
        pass
    new_train_dataset.push_to_hub("ShenaoZ/compare_response", split="train_prefs", private=False)