import datetime
import argparse, importlib
from pytorch_lightning import seed_everything

import torch
import torch.distributed as dist

def setup_dist(local_rank):
    if dist.is_initialized():
        return
    torch.cuda.set_device(local_rank)
    torch.distributed.init_process_group('nccl', init_method='env://')


def get_dist_info():
    if dist.is_available():
        initialized = dist.is_initialized()
    else:
        initialized = False
    if initialized:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1
    return rank, world_size


if __name__ == '__main__':
    now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    parser = argparse.ArgumentParser()
    parser.add_argument("--module", type=str, help="module name", default="inference")
    parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0)
    args, unknown = parser.parse_known_args()
    inference_api = importlib.import_module(args.module, package=None)

    inference_parser = inference_api.get_parser()
    inference_args, unknown = inference_parser.parse_known_args()

    seed_everything(inference_args.seed)
    setup_dist(args.local_rank)
    torch.backends.cudnn.benchmark = True
    rank, gpu_num = get_dist_info()

    print("@CoLVDM Inference [rank%d]: %s"%(rank, now))
    inference_api.run_inference(inference_args, gpu_num, rank)