
from maskrcnn_benchmark.utils.env import setup_environment  # noqa F401 isort:skip

import argparse
import os
import json
import torch
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.data import make_data_loader
from maskrcnn_benchmark.engine.inference import inference
from maskrcnn_benchmark.engine.inference import inference_AL
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info
from maskrcnn_benchmark.utils.comm import synchronize, get_rank
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir
from tqdm import tqdm
import numpy as np


try:
    from apex import amp
except ImportError:
    raise ImportError('Use APEX for mixed precision via apex.amp')


import numpy as np
from maskrcnn_benchmark.config import cfg
import torch
import math


def get_IBM(select_images, del_num):
    img_match = torch.load('./features/image_match.pth')
    img_match = torch.tensor([item.cpu().detach().numpy() for item in img_match])
    img_match = img_match.numpy()
    img_match = img_match[select_images]
    img_match = np.reshape(np.sum(img_match ** 2, axis=1), (img_match.shape[0], 1)) + np.sum(img_match ** 2,axis=1) - 2 * img_match.dot(img_match.T)

    lenth = img_match.shape[0]
    img_match = img_match.flatten()
    b = np.argsort(img_match)
    b_x = ((b // lenth)[lenth:])[::2]
    b_y = ((b % lenth)[lenth:])[::2].tolist()
    b_y_uni = list(set(b_y))
    b_y_uni.sort(key=b_y.index)

    return b_y_uni[del_num:]

def get_RPG():
    if (int(cfg.AL) == 2):
        get_AL = cfg.AL_each_epoch
    else:
        get_AL = int(cfg.AL) - 1
    label_image = np.load('image_index_' + str(get_AL) + '.npy').tolist()
    full_image = np.load('image_index_train_full.npy').tolist()
    gt_classes = np.load('gt_classes.npy',allow_pickle=True)
    rel = np.load('./gt_relationships.npy', allow_pickle=True)

    label_image_index = [full_image.index(x) for x in label_image]

    relations = []
    statistic_rel = np.zeros((151,151))
    for i in range(len(label_image_index)):
        rel_ = rel[label_image_index[i]]
        class_ = gt_classes[label_image_index[i]]
        for j in range(len(rel_)):
            relations.append([class_[rel_[j][0]],class_[rel_[j][1]]])
            statistic_rel[class_[rel_[j][0]],class_[rel_[j][1]]] = 1

    RPGraph = []
    for i in range(1,151):
        i_obj = [val for val in relations if val[0] == i]
        sub_of_i_obj = [val[1] for val in i_obj]
        cluster_i_all = [val for val in relations if val[0] in sub_of_i_obj]
        RPGraph_ = np.zeros(151)
        for j in range(len(cluster_i_all)):
            RPGraph_[cluster_i_all[j][1]] += 1
        if(max(RPGraph_)!=0):
             RPGraph_ = np.minimum(1,RPGraph_)
        RPGraph.append(RPGraph_)

    RPGraph = np.maximum(np.array(RPGraph)[:,1:],statistic_rel[1:,1:])
    return RPGraph


def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Inference")
    parser.add_argument(
        "--config-file",
        default="./configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://"
        )
        synchronize()

    cfg.merge_from_file('./Scene-Graph-Benchmark/configs/e2e_relation_X_101_32_8_FPN_1x.yaml')

    cfg.merge_from_list(args.opts)
    cfg.freeze()

    save_dir = ""
    logger = setup_logger("maskrcnn_benchmark", save_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(cfg)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    model = build_detection_model(cfg)
    model.to(cfg.MODEL.DEVICE)

    # Initialize mixed-precision if necessary
    use_mixed_precision = cfg.DTYPE == 'float16'
    amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE)

    output_dir = cfg.OUTPUT_DIR
    checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
    _ = checkpointer.load(cfg.MODEL.WEIGHT)

    iou_types = ("bbox",)
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm",)
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints",)
    if cfg.MODEL.RELATION_ON:
        iou_types = iou_types + ("relations", )
    if cfg.MODEL.ATTRIBUTE_ON:
        iou_types = iou_types + ("attributes", )
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    # if cfg.OUTPUT_DIR:
    #     for idx, dataset_name in enumerate(dataset_names):
    #         output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
    #         mkdir(output_folder)
    #         output_folders[idx] = output_folder

    data_loaders_val = make_data_loader(cfg, mode="train", is_distributed=distributed)

    for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val):
        pre_uncertainty = inference_AL(
                                cfg,
                                model,
                                data_loader_val,
                                dataset_name=dataset_name,
                                iou_types=iou_types,
                                box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
                                device=cfg.MODEL.DEVICE,
                                expected_results=cfg.TEST.EXPECTED_RESULTS,
                                expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
                                output_folder=output_folder,
                            )

        image_uncertainty = []
        if (cfg.RPG):
            RPGraph = get_RPG()
        for i in tqdm(range(len(pre_uncertainty))):
            pre_uncer = pre_uncertainty[i].get_field('pred_uncertainty').detach().cpu().numpy()
            pred_labels = pre_uncertainty[i].get_field('pred_labels').detach().cpu().numpy()
            rel_pair_idxs = pre_uncertainty[i].get_field('rel_pair_idxs').detach().cpu().numpy()
            if (cfg.RPG):
                rel_ = np.array([pred_labels[rel_pair_idxs[:, 0].tolist()].tolist(),
                                 pred_labels[rel_pair_idxs[:, 1].tolist()].tolist()])
                p_RPGraph = RPGraph[(rel_[0]-1).tolist(),(rel_[1]-1).tolist()]
                pre_uncer = p_RPGraph * pre_uncer
            image_uncertainty.append(np.sum(pre_uncer))
        np.save('image_uncertainty' + str(int(cfg.AL) - 1) + '.npy', np.array(image_uncertainty))

        if (int(cfg.AL) == 2):
            get_AL = cfg.AL_each_epoch
        else:
            get_AL = int(cfg.AL) - 1
        image_index = np.load('image_index_'+ str(get_AL) + '.npy').tolist()
        image_index_train_full = np.load('image_index_train_full.npy').tolist()

        remain_image = [val for val in image_index_train_full if val not in image_index]
        remain_image_index = [image_index_train_full.index(x) for x in remain_image]
        remain_u = np.array(image_uncertainty)[np.array(remain_image_index)]

        slect_image_num = int((cfg.AL_each_epoch) * len(image_index_train_full))
        remove_image_num = int(math.pow(int(cfg.AL) * 0.01, 4) * 1000)

        select_index = np.argsort(-remain_u)[0:(slect_image_num + remove_image_num)]
        select_index = get_IBM(select_index, remove_image_num)

        select_data = np.array(remain_image)[select_index]
        select_data = sorted(image_index + select_data.tolist())
        np.save('image_index_' + str(int(cfg.AL)) + '.npy', select_data)


if __name__ == "__main__":
    main()