from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

import os
import torch
from video_language_critic.modules.tokenization_clip import SimpleTokenizer as ClipTokenizer
from video_language_critic.util import (
    get_args,
    set_seed_logger,
    init_device,
    freeze_model,
    init_model,
    get_val_test_dataloaders,
    do_train,
    eval_epoch,
)
from torch.distributed.elastic.multiprocessing.errors import record

torch.distributed.init_process_group(backend="nccl")

global logger


# @record
def main():
    global logger
    args = get_args()
    args = set_seed_logger(args)
    logger = args.logger
    device, n_gpu = init_device(args, args.local_rank)

    tokenizer = ClipTokenizer()

    assert args.task_type == "retrieval"
    model = init_model(args, device, n_gpu, args.local_rank)

    ## ####################################
    # freeze testing
    ## ####################################
    freeze_model(model, args)

    ## ####################################
    # dataloader loading
    ## ####################################
    test_dataloader, test_length, val_dataloader, val_length = get_val_test_dataloaders(
        args, tokenizer
    )

    if args.local_rank == 0:
        if test_length is not None:
            logger.info("***** Running test *****")
            logger.info("  Num examples = %d", test_length)
            logger.info("  Batch size = %d", args.batch_size_val)
            logger.info("  Num steps = %d", len(test_dataloader))
        if val_length is not None:
            logger.info("***** Running val *****")
            logger.info("  Num examples = %d", val_length)

    ## ####################################
    # train and eval
    ## ####################################
    if args.do_train:
        best_ckpt = do_train(
            args,
            tokenizer,
            model,
            device,
            n_gpu,
            val_dataloader,
            test_dataloader,
        )
        # Currently handled by do_train.
        # if args.local_rank == 0:
        #     dataloader = val_dataloader if args.eval_on_val else test_dataloader
        #     model = init_model(args, device, n_gpu, args.local_rank, best_ckpt)
        #     ckpt_epoch = int(os.path.basename(best_ckpt).split(".")[-1])
        #     eval_epoch(
        #         args,
        #         model,
        #         dataloader,
        #         device,
        #         n_gpu,
        #         save_eval_result=True,
        #         ckpt_epoch=ckpt_epoch,
        #     )
    elif args.do_eval:
        if args.local_rank == 0:
            dataloader = val_dataloader if args.eval_on_val else test_dataloader
            eval_epoch(args, model, dataloader, device, n_gpu, save_eval_result=True)


if __name__ == "__main__":
    main()
