# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
r"""
Basic training script for PyTorch
"""

# Set up custom environment before nearly anything else is imported
# NOTE: this should be the first import (no not reorder)
from maskrcnn_benchmark.utils.env import setup_environment  # noqa F401 isort:skip

import argparse
import os

import numpy as np
import torch
from maskrcnn_benchmark.config import cfg, try_to_find
from maskrcnn_benchmark.data import make_data_loader
from maskrcnn_benchmark.solver import make_lr_scheduler
from maskrcnn_benchmark.solver import make_optimizer
from maskrcnn_benchmark.engine.inference import inference
from maskrcnn_benchmark.engine.trainer import do_train
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info
from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process, all_gather
from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.metric_logger import (MetricLogger, TensorboardLogger)
from maskrcnn_benchmark.utils.miscellaneous import mkdir, save_config
import random
from maskrcnn_benchmark.utils.amp import autocast, GradScaler

from pathlib import Path
from tqdm import tqdm
from collections import defaultdict

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# os.environ["TOKENIZERS_PARALLELISM"] = "false"


def tuning_highlevel_override(cfg,):
    if cfg.SOLVER.TUNING_HIGHLEVEL_OVERRIDE == "vision_query":
        cfg.MODEL.BACKBONE.FREEZE = True
        cfg.MODEL.FPN.FREEZE = True
        cfg.MODEL.RPN.FREEZE = True if not cfg.VISION_QUERY.QUERY_FUSION else False
        cfg.MODEL.LINEAR_PROB = False
        cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER = False
        cfg.MODEL.LANGUAGE_BACKBONE.FREEZE = False
        cfg.MODEL.DYHEAD.USE_CHECKPOINT = False # Disable checkpoint
        cfg.VISION_QUERY.ENABLED = True
    if cfg.SOLVER.TUNING_HIGHLEVEL_OVERRIDE == "vs_with_txt_enc":
        cfg.MODEL.BACKBONE.FREEZE = True
        cfg.MODEL.FPN.FREEZE = True
        cfg.MODEL.RPN.FREEZE = True if not cfg.VISION_QUERY.QUERY_FUSION else False
        cfg.MODEL.LINEAR_PROB = False
        cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER = False
        cfg.MODEL.LANGUAGE_BACKBONE.FREEZE = False
        cfg.MODEL.DYHEAD.USE_CHECKPOINT = False # Disable checkpoint
        cfg.VISION_QUERY.ENABLED = True


def train(cfg, local_rank, distributed, use_tensorboard=False, resume=False):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.GROUNDINGDINO.enabled:
        if cfg.SOLVER.TUNING_HIGHLEVEL_OVERRIDE == "vision_query":
            for key, p in model.named_parameters():
                if not ('pre_select' in key or 'qv_layer' in key):
                    p.requires_grad = False
    else:
        if cfg.SOLVER.TUNING_HIGHLEVEL_OVERRIDE == "vision_query":
            if model.language_backbone is not None:
                for key, p in model.language_backbone.named_parameters():
                    if not ('pre_select' in key or 'qv_layer' in key):
                        p.requires_grad = False
            if cfg.VISION_QUERY.QUERY_FUSION:
                if model.rpn is not None:
                    for key, p in model.rpn.named_parameters():
                        if not ('pre_select' in key or 'qv_layer' in key):
                            p.requires_grad = False

    if cfg.MODEL.BACKBONE.RESET_BN:
        for name, param in model.named_buffers():
            if 'running_mean' in name:
                torch.nn.init.constant_(param, 0)
            if 'running_var' in name:
                torch.nn.init.constant_(param, 1)

    if cfg.SOLVER.GRAD_CLIP > 0:
        clip_value = cfg.SOLVER.GRAD_CLIP
        for p in filter(lambda p: p.grad is not None, model.parameters()):
            p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value))

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=0  # <TODO> Sample data from resume is disabled, due to the conflict with max_epoch
    )

    if cfg.TEST.DURING_TRAINING or cfg.SOLVER.USE_AUTOSTEP:
        data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed)
        data_loaders_val = data_loaders_val[0]
    else:
        data_loaders_val = None

    if cfg.GROUNDINGDINO.enabled:
        pass
    else:
        if cfg.MODEL.BACKBONE.FREEZE:
            for p in model.backbone.body.parameters():
                p.requires_grad = False

        if cfg.MODEL.LANGUAGE_BACKBONE.FREEZE:
            print("LANGUAGE_BACKBONE FROZEN.")
            for p in model.language_backbone.body.parameters():
                p.requires_grad = False

        if cfg.MODEL.FPN.FREEZE:
            for p in model.backbone.fpn.parameters():
                p.requires_grad = False
        if cfg.MODEL.RPN.FREEZE:
            for p in model.rpn.parameters():
                p.requires_grad = False
    
    # if cfg.SOLVER.PROMPT_PROBING_LEVEL != -1:
    #     if cfg.SOLVER.PROMPT_PROBING_LEVEL == 1:
    #         for p in model.parameters():
    #             p.requires_grad = False

    #         for p in model.language_backbone.body.parameters():
    #             p.requires_grad = True

    #         for name, p in model.named_parameters():
    #             if p.requires_grad:
    #                 print(name, " : Not Frozen")
    #             else:
    #                 print(name, " : Frozen")
    #     else:
    #         assert(0)

    for n, p in model.named_parameters():
        if 'lora_' in n:
            p.requires_grad = True
            print("set {} learnable".format(n))
    
    for n, p in model.named_parameters():
        if 'vision_query_linear' in n:
            p.requires_grad = True
            print("set {} learnable".format(n))
    
    for n, p in model.named_parameters():
        if 'vision_query_gate' in n:
            p.requires_grad = True
            print("set {} learnable".format(n))
    
    for n, p in model.named_parameters():
        if ('vlfuse_visual' in n) or ('vlfuse_lang' in n):
            p.requires_grad = True
            print("set {} learnable".format(n))

    optimizer = make_optimizer(cfg, model)
    print('Making scheduler')
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        print('Distributing model')
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            broadcast_buffers=cfg.MODEL.BACKBONE.USE_BN,
            find_unused_parameters=cfg.SOLVER.FIND_UNUSED_PARAMETERS
        )
        print('Done')

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    print('Making checkpointer')
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    if resume and cfg.OUTPUT_DIR != "OUTPUT":
        if not os.path.exists(cfg.OUTPUT_DIR):
            load_weight=cfg.MODEL.WEIGHT
        else:
            checkpoint_list=[name for name in os.listdir(cfg.OUTPUT_DIR) if name.endswith('.pth') and 'final' not in name and 'resume' not in name]
            if len(checkpoint_list)==0:
                load_weight=cfg.MODEL.WEIGHT
                resume=False
            else:
                max_bits=len(checkpoint_list[0].split('.')[0].split('_')[-1])
                iter_list=[int(name.split('.')[0].split('_')[-1]) for name in checkpoint_list]
                max_iter=str(max(iter_list)).zfill(max_bits)
                resume_weight_name='model_'+max_iter+'.pth'
                load_weight=str(Path(cfg.OUTPUT_DIR, resume_weight_name))
    else:
        load_weight=cfg.MODEL.WEIGHT

    print('Loading checkpoint')
    extra_checkpoint_data = checkpointer.load(try_to_find(load_weight))
    arguments.update(extra_checkpoint_data)

    # enable resume
    data_loader.batch_sampler.start_iter = arguments["iteration"] + 1 if resume else 0


    # data_loader = make_data_loader(
    #     cfg,
    #     is_train=True,
    #     is_distributed=distributed,
    #     start_iter=arguments["iteration"] + 1 if resume else 0  # <TODO> Sample data from resume is disabled, due to the conflict with max_epoch
    # )

    # if cfg.TEST.DURING_TRAINING or cfg.SOLVER.USE_AUTOSTEP:
    #     data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed)
    #     data_loaders_val = data_loaders_val[0]
    # else:
    #     data_loaders_val = None
    
    if cfg.DATASETS.FEW_SHOT:
        arguments["dataset_ids"] = data_loader.dataset.ids

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    if use_tensorboard:
        meters = TensorboardLogger(
            log_dir=cfg.OUTPUT_DIR,
            start_iter=arguments["iteration"],
            delimiter="  "
        )
    else:
        meters = MetricLogger(delimiter="  ")
    
    if is_main_process():
        for name, p in model.named_parameters():
            if p.requires_grad:
                print(name, " : Not Frozen")
            else:
                print(name, " : Frozen")

    do_train(
        cfg,
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        data_loaders_val,
        meters,
    )

    return model

def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__

    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop("force", False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print

def extract_query(cfg):
    if cfg.DATASETS.FEW_SHOT:
        assert cfg.DATASETS.FEW_SHOT == cfg.VISION_QUERY.MAX_QUERY_NUMBER, 'To extract the right query instances, set VISION_QUERY.MAX_QUERY_NUMBER = DATASETS.FEW_SHOT.'
    # if cfg.num_gpus > 1:
    #     max_query_number = cfg.VISION_QUERY.MAX_QUERY_NUMBER
    #     cfg.defrost()
    #     cfg.VISION_QUERY.MAX_QUERY_NUMBER = int(cfg.VISION_QUERY.MAX_QUERY_NUMBER/cfg.num_gpus)
    #     cfg.freeze()

    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)


    checkpointer = DetectronCheckpointer(
        cfg, model
    )
    checkpointer.load(try_to_find(cfg.MODEL.WEIGHT))

    data_loader = make_data_loader(
        cfg,
        is_train=False,
        is_cache=True,
        is_distributed= cfg.num_gpus > 1,
    )
    assert isinstance(data_loader, list) and len(data_loader)==1
    data_loader=data_loader[0]

    # if cfg.VISION_QUERY.CUSTOM_DATA_IDS is not None:
    #     data_loader.dataset.ids = cfg.VISION_QUERY.CUSTOM_DATA_IDS

    if cfg.num_gpus > 1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[cfg.local_rank], output_device=cfg.local_rank,
            broadcast_buffers=cfg.MODEL.BACKBONE.USE_BN,
            find_unused_parameters=cfg.SOLVER.FIND_UNUSED_PARAMETERS
        )

    query_images=defaultdict(list)
    _iterator = tqdm(data_loader)
    # _iterator = data_loader # for debug
    model.eval()
    for i, batch in enumerate(_iterator):
        images, targets, *_ = batch
        if cfg.num_gpus > 1:
            query_images = model.module.extract_query(images.to(device), targets, query_images)
        else:
            query_images = model.extract_query(images.to(device), targets, query_images)
    
    if cfg.num_gpus > 1:
        ## not stable when using all_gather, easy to OOM.
        # all_query_images = all_gather(query_images)
        # if is_main_process():
        #     accumulated_query_images = defaultdict(list)
        #     for r, query_images_dict in enumerate(all_query_images):
        #         print('accumulating results: {}/{}'.format(r, len(all_query_images)))
        #         for label, feat in query_images_dict.items():
        #             num_queries=len(accumulated_query_images[label])
        #             if num_queries >= cfg.VISION_QUERY.MAX_QUERY_NUMBER:
        #                 continue
        #             if num_queries==0:
        #                 accumulated_query_images[label] = feat.to(device)
        #             else:
        #                 accumulated_query_images[label] = torch.cat([accumulated_query_images[label].to(device), feat.to(device)])

        #     save_name = 'MODEL/{}_query_{}_pool{}_{}{}_multi-node.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME)
        #     print('saving to ', save_name)
        #     torch.save(accumulated_query_images, save_name)
        if cfg.VISION_QUERY.QUERY_BANK_SAVE_PATH != '':
            raise NotImplementedError
        global_rank = get_rank()
        save_name = 'MODEL/{}_query_{}_pool{}_{}{}_rank{}.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME, global_rank)
        print('saving to ', save_name)
        torch.save(query_images, save_name)
    else:
        if cfg.VISION_QUERY.QUERY_BANK_SAVE_PATH != '':
            save_name = cfg.VISION_QUERY.QUERY_BANK_SAVE_PATH
        else:
            save_name = 'MODEL/{}_query_{}_pool{}_{}{}.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME)
        print('saving to ', save_name)
        torch.save(query_images, save_name)
    # if cfg.num_gpus > 1:
    #     # 
    #     world_size = torch.distributed.dist.get_world_size()
    #     if is_main_process():
    #         query_images_list = []
    #         for r in range(world_size):
    #             saved_path = 'MODEL/{}_query_{}_pool{}_{}{}_rank{}.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME, r)
    #             query_images_list.append(torch.load(saved_path, map_location='cpu'))
            
    #         for s in query_images_list

    
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )

    parser.add_argument("--use-tensorboard",
                        dest="use_tensorboard",
                        help="Use tensorboardX logger (Requires tensorboardX installed)",
                        action="store_true",
                        default=False
                        )

    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    parser.add_argument("--save_original_config", action="store_true")
    parser.add_argument("--disable_output_distributed", action="store_true")
    parser.add_argument("--override_output_dir", default=None)
    parser.add_argument("--custom_shot_and_epoch_and_general_copy", default=None, type=str)
    parser.add_argument("--resume", action="store_true", default=False)
    parser.add_argument("--extract_query", action="store_true", default=False)
    parser.add_argument(
        "--task_config",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "--additional_model_config",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )



    args = parser.parse_args()

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1

    if args.distributed:
        import datetime
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="gloo", init_method="env://",
            timeout=datetime.timedelta(0, 7200)
        )
    
    if args.disable_output_distributed:
        setup_for_distributed(args.local_rank <= 0)

    cfg.local_rank = args.local_rank
    cfg.num_gpus = num_gpus

    cfg.merge_from_file(args.config_file)
    if args.task_config:
        cfg.merge_from_file(args.task_config)
    if args.additional_model_config:
        cfg.merge_from_file(args.additional_model_config)
    cfg.merge_from_list(args.opts)
    # specify output dir for models
    if args.override_output_dir:
        cfg.OUTPUT_DIR = args.override_output_dir
    tuning_highlevel_override(cfg)
    cfg.freeze()

    seed = cfg.SOLVER.SEED + args.local_rank
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
    logger.info(args)
    logger.info("Using {} GPUs".format(num_gpus))

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    output_config_path = os.path.join(cfg.OUTPUT_DIR, 'config.yml')
    logger.info("Saving config into: {}".format(output_config_path))
    # save overloaded model config in the output directory
    if args.save_original_config:
        import shutil
        shutil.copy(args.config_file, os.path.join(cfg.OUTPUT_DIR, 'config_original.yml'))
    
    save_config(cfg, output_config_path)

    if args.extract_query:
        extract_query(cfg)
    else:
        # if cfg.DATASETS.FEW_SHOT and cfg.VISION_QUERY.ENABLED:
        #     max_query_number = int(os.path.basename(cfg.VISION_QUERY.QUERY_BANK_PATH).split('_')[-2])
        #     assert cfg.DATASETS.FEW_SHOT == max_query_number, 'You should first extract corresponding few-shot query instances.'
        #     assert max_query_number >= cfg.VISION_QUERY.NUM_QUERY_PER_CLASS
        model = train(cfg=cfg,
                    local_rank=args.local_rank,
                    distributed=args.distributed,
                    use_tensorboard=args.use_tensorboard, 
                    resume=args.resume)


if __name__ == "__main__":
    main()
