import os
import argparse

from mmengine.config import Config
from mmengine.runner import Runner

import custom_datasets  # This and the following should be loaded here because of mmseg module registration
import resclip_segmentor

import torch


def parse_args():
    parser = argparse.ArgumentParser(description='Evaluation with MMSeg')
    parser.add_argument('--config', default='')
    parser.add_argument('--backbone', default='')
    parser.add_argument('--arch', default='')
    parser.add_argument('--attn', default='')
    parser.add_argument('--std', default='')
    parser.add_argument('--work-dir', default='./work_logs/')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    # parser.add_argument('--show-dir', default='./Visual_Result/', help='directory to save visualization images')
    parser.add_argument('--show-dir', default='', help='directory to save visualization images')
    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    return args


def trigger_visualization_hook(cfg, show_dir):
    default_hooks = cfg.default_hooks
    if 'visualization' in default_hooks:
        visualization_hook = default_hooks['visualization']
        visualization_hook['draw'] = True
        visualizer = cfg.visualizer
        visualizer['save_dir'] = show_dir
    else:
        raise RuntimeError(
            'VisualizationHook must be included in default_hooks. refer to usage '
            '"visualization=dict(type=\'VisualizationHook\')"')
    # cfg.model['pamr_steps'] = 50
    # cfg.model['pamr_stride'] = [1, 2, 4, 8, 12, 24]
    return cfg


def safe_set_arg(cfg, arg, name, func=lambda x: x):
    if arg != '':
        cfg.model[name] = func(arg)


def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    cfg.launcher = args.launcher
    cfg.work_dir = args.work_dir

    safe_set_arg(cfg, args.backbone, 'clip_path')
    safe_set_arg(cfg, args.arch, 'arch')
    safe_set_arg(cfg, args.attn, 'attn_strategy')
    safe_set_arg(cfg, args.std, 'gaussian_std', float)
    if args.show_dir != '':
        trigger_visualization_hook(cfg, args.show_dir)

    runner = Runner.from_cfg(cfg)
    runner.test()

if __name__ == '__main__':
    torch.set_grad_enabled(False)
    main()
