import sys
import logging

from LLMProxy.option import (
    parse_training_args
)
from LLMProxy.trainer import Trainer, DDPTrainer, AccelerateTrainer, DeepSpeedTrainer
from LLMProxy.distributed_utils import distributed_init


logging.basicConfig(
    level=logging.INFO, 
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    stream=sys.stdout,
)
logger = logging.getLogger("Training")


def main():
    dist_args, model_args, train_args = parse_training_args()
    distributed_init(dist_args)

    trainer = DeepSpeedTrainer(
        train_args=train_args,
        model_args=model_args,
        dist_args=dist_args,
    )

    logger.info("Start Training")
    trainer.train()
    logger.info("Done")
    

if __name__ == '__main__':
    main()