import os
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from utils.config import argparser
from utils.util import CustomModelCheckpoint
from task.audio_cls_mmac import AudioClassification
from task.video_align import VideoAlignment
from task.video_align_bbox import ObjectVideoAlignment
from task.supervised_phase_cls import PhaseClassification


def main():
    if args.task == 'align':
        task = VideoAlignment(args)
    elif 'align_bbox' in args.task:
        task = ObjectVideoAlignment(args)
    elif args.task == 'phase_cls':
        task = PhaseClassification(args)
    else:
        raise NotImplementedError

    custom_checkpoint_callback = CustomModelCheckpoint(
        every_n_epochs=args.save_every,
        filename="{epoch}",
        save_top_k=-1,
    )

    trainer = Trainer(
        gpus=args.num_gpus,
        accelerator="gpu",
        # strategy="ddp",
        callbacks=custom_checkpoint_callback,
        max_epochs=args.epochs,
        num_sanity_val_steps=0,
        # limit_train_batches=1, #debug
        default_root_dir=args.output_dir,
        fast_dev_run=args.fast_dev_run,
    )

    trainer.fit(task)


if __name__ == '__main__':
    args = argparser.parse_args()
    args.merge_all = True
    main()
