from modeling.config_unitge import TrainingArguments, ModelArguments, DataArguments
import transformers
import torch
from modeling.unigte_trainer import MyTrainer
from utils.acce_utils import init_seed


def main():
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    init_seed(training_args.seed)

    trainer = MyTrainer(
        model_args=model_args,
        data_args=data_args,
        training_args=training_args,
    )
    if not training_args.inference:
        trainer.train(save_directory=training_args.output_dir)
    else:
        trainer.evaluation()


    trainer.accelerator.end_training()

if __name__ == '__main__':
    main()