from alignment import H4ArgumentParser, ModelArguments, DataArguments, RDPOConfig, get_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset, DatasetDict, concatenate_datasets
from multiprocess import set_start_method
import torch
from trl.trainer.utils import pad_to_length
multi_part = 3
length_per_part = 20000

###############
# Load datasets
###############
parser = H4ArgumentParser((ModelArguments, DataArguments, RDPOConfig))
model_args, data_args, training_args = parser.parse()

if __name__ == "__main__":
    #  """
    original_train_datasets = []
   

'''
    try:
        load_dataset(data_args.dataset_mixer["updated"], split="train_prefs", 
                                              download_mode="force_redownload",   ignore_verifications=True)
        raise("No need to preprocess again!")
    except:
        print(f"The training dataset is not found on Hub. Start preprocessing.")      
'''
if True:
    existing_train_dataset = None
    print("Combining all the parts.")
    for index in range(multi_part):        
        if index == 0:
            existing_train_dataset = load_dataset(data_args.dataset_mixer["updated"]+ f"_part{index}", split="train_prefs",
                                                download_mode="force_redownload", ignore_verifications=True)
        else:
            new_train_dataset = load_dataset(data_args.dataset_mixer["updated"]+ f"_part{index}", split="train_prefs",
                                                download_mode="force_redownload", ignore_verifications=True)
            existing_train_dataset = concatenate_datasets([existing_train_dataset, new_train_dataset])
    print(existing_train_dataset.num_rows)
    existing_train_dataset.push_to_hub(data_args.dataset_mixer["updated"], split="train_prefs", private=False)
    