import sys
import tempfile
import time
from collections import ChainMap
from tqdm import tqdm

import numpy as np

import torch

from yolox.utils import gather, is_main_process, postprocess, synchronize, time_synchronized

import argparse
import os
import random
import warnings

import torch
import torch.backends.cudnn as cudnn
from torch.nn.parallel import DistributedDataParallel as DDP

from yolox.core import launch
from yolox.exp import get_exp
from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger
from yolox.data.datasets import VOC_CLASSES

from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.utils.class_weight import compute_sample_weight

num_classes=21

def iou(box1, box2):
    ixmin = np.maximum(box1[:,0], box2[:,0])
    iymin = np.maximum(box1[:,1], box2[:,1])
    ixmax = np.minimum(box1[:,2], box2[:,2])
    iymax = np.minimum(box1[:,3], box2[:,3])
    iw = np.maximum(ixmax - ixmin, 0.)
    ih = np.maximum(iymax - iymin, 0.)
    inters = iw * ih
    uni = ((box2[:,2] - box2[:,0]) * (box2[:,3] - box2[:,1]) +
           (box1[:,2] - box1[:,0]) *
           (box1[:,3] - box1[:,1]) - inters)
    return inters / uni

def pred_to_DCM(path, soft=False, w='balanced'):
    conf_p, conf_t, conf_score, ovs = torch.load(path)
    conf_p=conf_p.cpu().long()
    conf_t=conf_t.cpu().long()
    ovs=ovs.cpu()

    iou_mat=[[[] for u in range(num_classes)] for i in range(num_classes)]
    for idx, t in enumerate(conf_t):
        iou_mat[t][conf_p[idx]].append(ovs[idx])
    bins=10
    for i in range(num_classes):
        for u in range(num_classes):
            if len(iou_mat[i][u])>0:
                hists, bins=np.histogram(np.array(iou_mat[i][u]), bins=bins, range=(0,1))
                hists = hists.astype(float)
                #print(len(iou_mat[i][u]), hists)
                hists/=np.max(hists)
                iou_mat[i][u]=(hists, bins)
            else:
                iou_mat[i][u]=None

    conf_p_soft=conf_score.cpu()[(torch.arange(0,len(conf_p)), conf_p)]
    if soft:
        sw = compute_sample_weight(class_weight=w, y=conf_t) * conf_p_soft.numpy()
    else:
        sw = compute_sample_weight(class_weight=w, y=conf_t)
    return confusion_matrix(conf_t, conf_p, sample_weight=sw), iou_mat

def pcm(cm):
    for y,line in enumerate(cm):
        line_sum=np.sum(line)
        for x,col in enumerate(line):
            cm[y,x]=col/line_sum
    return cm

def plot_DCM(confusion, iou=None, title='', title_size=20, cmap=plt.cm.Greens, iou_color='#9d16e5'):
    plt.figure(dpi=300)
    plt.imshow(confusion, cmap=cmap)
    indices = range(len(confusion))
    cls_names=['bg']+list(VOC_CLASSES)
    plt.xticks(indices, cls_names, rotation=90, fontsize=5)
    plt.yticks(indices, cls_names, fontsize=5)

    plt.colorbar()

    #plt.xlabel('predict')
    #plt.ylabel('real')
    plt.title(title, fontdict = {'fontsize' : title_size})

    # plt.rcParams两行是用于解决标签不能显示汉字的问题
    #plt.rcParams['font.sans-serif'] = ['SimHei']
    #plt.rcParams['axes.unicode_minus'] = False

    # 显示数据

    for first_index in range(len(confusion)):  # 第几行
        for second_index in range(len(confusion[first_index])):  # 第几列
            plt.text(first_index-0.4, second_index, '{:.2f}'.format(confusion[second_index][first_index]), fontsize=2)
    if iou is not None:
        for first_index in range(len(confusion)):  # 第几行
            for second_index in range(len(confusion[first_index])):  # 第几列
                if iou[second_index][first_index] is not None:
                    grid=iou[second_index][first_index]
                    gx=grid[1][:-1]-0.5+first_index
                    gy=(1-grid[0])-0.5+second_index
                    plt.plot(gx,gy,color=iou_color, linewidth=0.2)

#生成混淆矩阵原始数据
def evaluate_prediction(evaluator, data_dict, statistics):
    if not is_main_process():
        return 0, 0, None

    print("Evaluate in main process...")

    inference_time = statistics[0].item()
    nms_time = statistics[1].item()
    n_samples = statistics[2].item()

    a_infer_time = 1000 * inference_time / (n_samples * evaluator.dataloader.batch_size)
    a_nms_time = 1000 * nms_time / (n_samples * evaluator.dataloader.batch_size)

    time_info = ", ".join(
        [
            "Average {} time: {:.2f} ms".format(k, v)
            for k, v in zip(
            ["forward", "NMS", "inference"],
            [a_infer_time, a_nms_time, (a_infer_time + a_nms_time)],
        )
        ]
    )

    info = time_info + "\n"

    all_boxes = [
        [[] for _ in range(evaluator.num_images)] for _ in range(evaluator.num_classes)
    ]
    for img_num in range(evaluator.num_images):
        bboxes, cls, scores = data_dict[img_num]
        if bboxes is None:
            for j in range(evaluator.num_classes):
                all_boxes[j][img_num] = np.empty([0, 4+evaluator.num_classes], dtype=np.float32)
            continue
        for j in range(evaluator.num_classes):
            mask_c = cls == j
            if sum(mask_c) == 0:
                all_boxes[j][img_num] = np.empty([0, 4+evaluator.num_classes], dtype=np.float32)
                continue

            c_dets = torch.cat((bboxes, scores), dim=1)
            all_boxes[j][img_num] = c_dets[mask_c].numpy()

        sys.stdout.write(
            "im_eval: {:d}/{:d} \r".format(img_num + 1, evaluator.num_images)
        )
        sys.stdout.flush()

    return evaluator.dataloader.dataset.evaluate_DCM(all_boxes)

def get_DCM(
        evaluator,
        model,
        distributed=False,
        half=False,
        trt_file=None,
        decoder=None,
        test_size=None,
    ):
    tensor_type = torch.cuda.HalfTensor if half else torch.cuda.FloatTensor
    model = model.eval()
    if half:
        model = model.half()
    ids = []
    data_list = {}
    progress_bar = tqdm if is_main_process() else iter

    inference_time = 0
    nms_time = 0
    n_samples = max(len(evaluator.dataloader) - 1, 1)

    if trt_file is not None:
        from torch2trt import TRTModule

        model_trt = TRTModule()
        model_trt.load_state_dict(torch.load(trt_file))

        x = torch.ones(1, 3, test_size[0], test_size[1]).cuda()
        model(x)
        model = model_trt

    model.eval()
    for cur_iter, (imgs, target, info_imgs, ids) in enumerate(progress_bar(evaluator.dataloader)):
        with torch.no_grad():
            imgs = imgs.type(tensor_type)
            target = target.type(tensor_type)

            # skip the the last iters since batchsize might be not enough for batch inference
            is_time_record = cur_iter < len(evaluator.dataloader) - 1
            if is_time_record:
                start = time.time()

            outputs = model(imgs)

            if decoder is not None:
                outputs = decoder(outputs, dtype=outputs.type())

            if is_time_record:
                infer_end = time_synchronized()
                inference_time += infer_end - start

            outputs = postprocess(outputs, evaluator.num_classes, evaluator.confthre, evaluator.nmsthre, all_score=True)

            if is_time_record:
                nms_end = time_synchronized()
                nms_time += nms_end - infer_end

        data_list.update(evaluator.convert_to_voc_format(outputs, info_imgs, ids, all_score=True))

    statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
    if distributed:
        data_list = gather(data_list, dst=0)
        data_list = ChainMap(*data_list)
        torch.distributed.reduce(statistics, dst=0)

    return evaluate_prediction(evaluator, data_list, statistics)

def make_parser():
    def str2bool(v):
        return v.lower() in ("yes", "true", "t", "1")

    parser = argparse.ArgumentParser("YOLOX Eval")
    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
    parser.add_argument("-n", "--name", type=str, default=None, help="model name")

    # distributed
    parser.add_argument(
        "--dist-backend", default="nccl", type=str, help="distributed backend"
    )
    parser.add_argument(
        "--dist-url",
        default=None,
        type=str,
        help="url used to set up distributed training",
    )
    parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
    parser.add_argument(
        "-d", "--devices", default=None, type=int, help="device for training"
    )
    parser.add_argument(
        "--num_machines", default=1, type=int, help="num of node for training"
    )
    parser.add_argument(
        "--machine_rank", default=0, type=int, help="node rank for multi-node training"
    )
    parser.add_argument(
        "-f",
        "--exp_file",
        default=None,
        type=str,
        help="pls input your expriment description file",
    )
    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
    parser.add_argument("--conf", default=None, type=float, help="test conf")
    parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
    parser.add_argument("--tsize", default=None, type=int, help="test img size")
    parser.add_argument("--seed", default=None, type=int, help="eval seed")
    parser.add_argument(
        "--fp16",
        dest="fp16",
        default=False,
        action="store_true",
        help="Adopting mix precision evaluating.",
    )
    parser.add_argument(
        "--fuse",
        dest="fuse",
        default=False,
        action="store_true",
        help="Fuse conv and bn for testing.",
    )
    parser.add_argument(
        "--trt",
        dest="trt",
        default=False,
        action="store_true",
        help="Using TensorRT model for testing.",
    )
    parser.add_argument(
        "--legacy",
        dest="legacy",
        default=False,
        action="store_true",
        help="To be compatible with older versions",
    )
    parser.add_argument(
        "--test",
        dest="test",
        default=False,
        action="store_true",
        help="Evaluating on test-dev set.",
    )
    parser.add_argument(
        "--speed",
        dest="speed",
        default=False,
        action="store_true",
        help="speed test only.",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    #DCM
    parser.add_argument(
        "--DCM_dir",
        help="DCM save directory",
        default='DCM/',
        type=str,
    )

    return parser

if __name__ == '__main__':
    args = make_parser().parse_args()
    exp = get_exp(args.exp_file, args.name)
    exp.merge(args.opts)

    os.makedirs(args.DCM_dir, exist_ok=True)

    num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
    assert num_gpu <= torch.cuda.device_count()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn(
            "You have chosen to seed testing. This will turn on the CUDNN deterministic setting, "
        )

    is_distributed = num_gpu > 1

    # set environment variables for distributed training
    configure_nccl()
    cudnn.benchmark = True

    rank = get_local_rank()

    if args.conf is not None:
        exp.test_conf = args.conf
    if args.nms is not None:
        exp.nmsthre = args.nms
    if args.tsize is not None:
        exp.test_size = (args.tsize, args.tsize)

    model = exp.get_model()

    evaluator = exp.get_evaluator(args.batch_size, is_distributed, args.test, args.legacy)

    torch.cuda.set_device(rank)
    model.cuda(rank)
    model.eval()

    if not args.speed and not args.trt:
        ckpt_file = args.ckpt
        print("loading checkpoint from {}".format(ckpt_file))
        loc = "cuda:{}".format(rank)
        ckpt = torch.load(ckpt_file, map_location=loc)
        model.load_state_dict(ckpt["model"])
        print("loaded checkpoint done.")

    if is_distributed:
        model = DDP(model, device_ids=[rank])

    if args.fuse:
        print("Fusing model...")
        model = fuse_model(model)

    trt_file = None
    decoder = None

    # start evaluate
    conf_data = get_DCM(evaluator, model, is_distributed, args.fp16, trt_file, decoder, exp.test_size)

    torch.save(conf_data, os.path.join(args.DCM_dir, 'DCM_YOLOX.pth'))
    print('ok')