# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://github.com/NVlabs/prismer/blob/main/LICENSE

import torch
import torchvision.transforms as transforms


def load_expert_model(task=None):
    if task == 'depth':
        # DPT model is a standard pytorch model class
        from UniDet_eval.experts.depth.models import DPTDepthModel

        model = DPTDepthModel(path='UniDet_eval/experts/expert_weights/dpt_hybrid-midas-501f0c75.pt',
                              backbone="vitb_rn50_384",
                              non_negative=True,
                              enable_attention_hooks=False)
        transform = transforms.Compose([
            transforms.Resize([480, 480]),
            transforms.ToTensor(),
            transforms.Normalize(mean=0.5, std=0.5)]
        )

    elif task == 'seg_coco':
        # Mask2Former is wrapped in detection2,
        # the model takes input in the format of: {"image": image (BGR), "height": height, "width": width}
        import argparse
        from detectron2.engine.defaults import DefaultPredictor
        from UniDet_eval.experts.segmentation.utils import setup_cfg

        parser = argparse.ArgumentParser()
        parser.add_argument("--mode", default="client")
        parser.add_argument("--port", default=2)
        args = parser.parse_args()

        args.config_file = 'experts/segmentation/configs/coco/panoptic-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_100ep.yaml'
        args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/model_final_f07440.pkl']
        cfg = setup_cfg(args)
        model = DefaultPredictor(cfg).model
        transform = transforms.Compose([
            transforms.Resize(size=479, max_size=480)
        ])

    elif task == 'seg_ade':
        # Mask2Former is wrapped in detection2,
        # the model takes input in the format of: {"image": image (BGR), "height": height, "width": width}
        import argparse
        from detectron2.engine.defaults import DefaultPredictor
        from UniDet_eval.experts.segmentation.utils import setup_cfg

        parser = argparse.ArgumentParser()
        parser.add_argument("--mode", default="client")
        parser.add_argument("--port", default=2)
        args = parser.parse_args()

        args.config_file = 'experts/segmentation/configs/ade20k/panoptic-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_160k.yaml'
        args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/model_final_e0c58e.pkl']
        cfg = setup_cfg(args)
        model = DefaultPredictor(cfg).model
        transform = transforms.Compose([
            transforms.Resize(size=479, max_size=480)
        ])

    elif task == 'obj_detection':
        # UniDet is wrapped in detection2,
        # the model takes input in the format of: {"image": image (BGR), "height": height, "width": width}
        import argparse
        from detectron2.engine.defaults import DefaultPredictor
        from UniDet_eval.experts.obj_detection.utils import setup_cfg
        # parser = argparse.ArgumentParser()
        # parser.add_argument("--mode", default="client")
        # parser.add_argument("--port", default=2)
        # parser.add_argument("--confidence-threshold", type=float, default=0.5)
        # parser.add_argument("--outpath", type=str)
        # parser.add_argument("--complex", type=bool)

        # args = parser.parse_args()
        args = argparse.Namespace()
        args.mode = 'client'
        args.port = 2
        args.confidence_threshold = 0.5
        args.outpath = None
        args.complex = False

        args.config_file = 'UniDet_eval/experts/obj_detection/configs/Unified_learned_OCIM_R50_6x+2x.yaml' 
        args.opts = ['MODEL.WEIGHTS', 'UniDet_eval/experts/expert_weights/Unified_learned_OCIM_R50_6x+2x.pth'] 

        cfg = setup_cfg(args)
        model = DefaultPredictor(cfg).model
        transform = transforms.Compose([
            transforms.Resize(size=479, max_size=480)
        ])

    elif task == 'ocr_detection':
        from UniDet_eval.experts.ocr_detection.charnet.modeling.model import CharNet
        model = CharNet()
        model.load_state_dict(torch.load('experts/expert_weights/icdar2015_hourglass88.pth'))
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    elif task == 'normal':
        # NLL-AngMF model is a standard pytorch model class
        import argparse
        from UniDet_eval.experts.normal.models.NNET import NNET
        from UniDet_eval.experts.normal.utils import utils

        parser = argparse.ArgumentParser()
        parser.add_argument("--mode", default="client")
        parser.add_argument("--port", default=2)
        parser.add_argument('--architecture', default='BN', type=str, help='{BN, GN}')
        parser.add_argument("--pretrained", default='scannet', type=str, help="{nyu, scannet}")
        parser.add_argument('--sampling_ratio', type=float, default=0.4)
        parser.add_argument('--importance_ratio', type=float, default=0.7)
        args = parser.parse_args()
        model = NNET(args)
        model = utils.load_checkpoint('experts/expert_weights/scannet.pt', model)

        transform = transforms.Compose([
            transforms.Resize([480, 480]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    elif task == 'edge':
        # NLL-AngMF model is a standard pytorch model class
        from UniDet_eval.experts.edge.model import DexiNed
        model = DexiNed()
        model.load_state_dict(torch.load('experts/expert_weights/10_model.pth', map_location='cpu'))
        transform = transforms.Compose([
            transforms.Resize([480, 480]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[1.0, 1.0, 1.0])
        ])
    else:
        print('Task not supported')
        model = None
        transform = None

    model.eval()
    return model, transform




