# Copyright (c) 2021 Microsoft Corporation. Licensed under the MIT license.

import cv2
import os.path as op
import argparse
import json
import time

from scene_graph_benchmark.scene_parser import SceneParser
from scene_graph_benchmark.AttrRCNN import AttrRCNN
from maskrcnn_benchmark.data.transforms import build_transforms
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.config import cfg
from scene_graph_benchmark.config import sg_cfg
from maskrcnn_benchmark.data.datasets.utils.load_files import \
    config_dataset_file
from maskrcnn_benchmark.data.datasets.utils.load_files import load_labelmap_file
from maskrcnn_benchmark.utils.miscellaneous import mkdir
import torch

from tools.demo.detect_utils import detect_objects_on_single_image, detect_objects_on_batch_images
from tools.demo.visual_utils import draw_bb, draw_rel
def init_detector_model(args):
    cfg.set_new_allowed(True)
    cfg.merge_from_other_cfg(sg_cfg)
    cfg.set_new_allowed(False)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    # assert op.isfile(args.img_file), \
    #     "Image: {} does not exist".format(args.img_file)

    output_dir = cfg.OUTPUT_DIR
    mkdir(output_dir)

    if cfg.MODEL.META_ARCHITECTURE == "SceneParser":
        model = SceneParser(cfg)
    elif cfg.MODEL.META_ARCHITECTURE == "AttrRCNN":
        model = AttrRCNN(cfg)
    model.to(cfg.MODEL.DEVICE)
    model.eval()

    checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
    checkpointer.load(cfg.MODEL.WEIGHT)
    transforms = build_transforms(cfg, is_train=False)
    return model, transforms

def detect_img(model, transforms, img_file):
    cv2_img = cv2.imread(img_file)
    dets = detect_objects_on_single_image(model, transforms, cv2_img)
    return dets

def detect_batch_imgs(model, transforms, img_files):
    cv2_imgs = [cv2.imread(img_file) for img_file in img_files]
    boxes, classes, scores = detect_objects_on_batch_images(model, transforms, cv2_imgs)
    return boxes

def convert_boxes_to_alpha(boxes, mask_transform, h, w):
    region_alphas = []
    # idx = 0
    for box in boxes:
        # idx += 1
        region_mask = torch.zeros(h, w)
        # region_mask = region_mask
        region_mask[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = 1
        # save_image(region_mask.unsqueeze(0).repeat(3,1,1), f"{idx}.jpg")
        region_alpha = mask_transform(region_mask.unsqueeze(0))
        region_alpha = region_alpha.half().cuda().unsqueeze(dim=0)
        region_alphas.append(region_alpha)
    region_alphas_ = torch.cat(region_alphas)
    return region_alphas_

def main():
    parser = argparse.ArgumentParser(description="Object Detection Demo")
    parser.add_argument("--config_file", metavar="FILE",
                        help="path to config file")
    # parser.add_argument("--img_file", metavar="FILE", help="image path")
    parser.add_argument("--labelmap_file", metavar="FILE",default='VG-SGG-dicts-vgoi6-clipped.json',
                        help="labelmap file to select classes for visualizatioin")
    parser.add_argument("--save_file", required=False, type=str, default=None,
                        help="filename to save the proceed image")
    parser.add_argument("--visualize_attr", action="store_true",
                        help="visualize the object attributes")
    parser.add_argument("--visualize_relation", action="store_true",
                        help="visualize the relationships")
    parser.add_argument("opts", default=None, nargs=argparse.REMAINDER,
                        help="Modify config options using the command-line")

    args = parser.parse_args()
    cfg.set_new_allowed(True)
    cfg.merge_from_other_cfg(sg_cfg)
    cfg.set_new_allowed(False)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    # assert op.isfile(args.img_file), \
    #     "Image: {} does not exist".format(args.img_file)

    output_dir = cfg.OUTPUT_DIR
    mkdir(output_dir)

    if cfg.MODEL.META_ARCHITECTURE == "SceneParser":
        model = SceneParser(cfg)
    elif cfg.MODEL.META_ARCHITECTURE == "AttrRCNN":
        model = AttrRCNN(cfg)
    model.to(cfg.MODEL.DEVICE)
    model.eval()

    checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
    checkpointer.load(cfg.MODEL.WEIGHT)

    transforms = build_transforms(cfg, is_train=False)
    cv2_img = cv2.imread(args.img_file)
    t1 = time.time()
    dets = detect_objects_on_single_image(model, transforms, cv2_img)
    t2 = time.time()
    print("detect cost time:", t2 - t1)

    rects = [d["rect"] for d in dets]
    scores = [d["conf"] for d in dets]



if __name__ == "__main__":
    main()
