import argparse
import torch
from omegaconf import OmegaConf
from generate_bbox import Inference
from distributed import synchronize
import os 
import torch.multiprocessing as multiprocessing
import datetime 
from ldm.util import bool_flag



if __name__ == "__main__":

    #multiprocessing.set_start_method('spawn')

    parser = argparse.ArgumentParser()

    parser.add_argument("--name", type=str,  default="test", help="experiment will be stored in OUTPUT_ROOT/name")
    #parser.add_argument("--seed", type=int,  default=123, help="used in sampler")
    parser.add_argument("--local-rank", type=int, default=0)
    parser.add_argument("--yaml_file", type=str,  default="configs/evaluate_karlo.yaml", help="paths to base configs.")
    parser.add_argument("--base_learning_rate", type=float,  default=5e-5, help="")
    parser.add_argument("--weight_decay", type=float,  default=0.0, help="")
    parser.add_argument("--warmup_steps", type=int,  default=2500, help="")
    parser.add_argument("--scheduler_type", type=str,  default='constant', help="cosine or constant")
    parser.add_argument("--workers", type=int,  default=4, help="")
    parser.add_argument("--official_ckpt_name", type=str,  default="sd-v1-4.ckpt", help="SD ckpt name and it is expected in DATA_ROOT, thus DATA_ROOT/official_ckpt_name must exists")
    parser.add_argument('--inpaint_mode', default=False, type=lambda x:x.lower() == "true", help="Train a GLIGEN model in inpaitning setting")
    parser.add_argument('--randomize_fg_mask', default=False, type=lambda x:x.lower() == "true", help="Only used if inpaint_mode is true. If true, 0.5 chance that fg mask will not be a box but a random mask. See code for details")
    parser.add_argument('--random_add_bg_mask', default=False, type=lambda x:x.lower() == "true", help="Only used if inpaint_mode is true. If true, 0.5 chance add arbitrary mask for the whole image. See code for details")
    parser.add_argument('--enable_ema', default=False, type=lambda x:x.lower() == "true")
    parser.add_argument("--ema_rate", type=float,  default=0.9999, help="")
    parser.add_argument("--total_iters", type=int,  default=50000, help="")
    parser.add_argument("--disable_inference_in_training", type=lambda x:x.lower() == "true",  default=False, help="Do not do inference, thus it is faster to run first a few iters. It may be useful for debugging ")
    parser.add_argument('--wandb_project_name', type=str, default="VideoDirectorGPT", help="the wandb's project name")
    parser.add_argument('--wandb_entity', type=str, help="the entity (team) of wandb's project")
    parser.add_argument("--range_start", type=int,  default=0, help="")
    parser.add_argument("--range_end", type=int,  default=-1, help="")
    parser.add_argument("--batch_size", type=int,  default=1, help="")
    parser.add_argument("--gpu_id", type=int,  default=2, help="")
    parser.add_argument("--device", type=str,  default='cuda')
    parser.add_argument("--mode", type=str,  default='single')
    args = parser.parse_args()
    assert args.scheduler_type in ['cosine', 'constant']


    """
    CUDA_VISIBLE_DEVICES=7 python main_inference_karlo.py --yaml_file=configs/evaluate_karlo_pororo_multi_scene_FIX_ImgTextEmbAfter_alpha01_ref_Sep20.yaml --range_start 5 --range_end -1 --batch_size 1 --mode multi
    """


    n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = n_gpu > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=5400)) 
        synchronize()

    config = OmegaConf.load(args.yaml_file) 

    if args.device == 'cpu':
        config.device = args.device

    config.update( vars(args) )
    n_gpu = 1
    config.total_batch_size = config.batch_size * n_gpu
    if config.inpaint_mode:
        config.model.params.inpaint_mode = True

    config.model.params.use_videoldm = config.use_videoldm

    if config.enable_fuser == False:
        config.model.params.fuser_type = None 

    config.model.params.image_size = config.new_image_size // 8

    inference = Inference(config)
    synchronize()

    if args.mode == 'single':
        inference.inference_single_scene()
    elif args.mode == 'multi':
        inference.inference_multi_scene()
    
    





