"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

from __future__ import print_function

import json
import pathlib
import time

import numpy as np
import os
import torch
from torch import distributed as dist
from torch.nn import functional as F

from examples.torch.common.example_logger import logger
from nncf.torch.utils import is_main_process, is_dist_avail_and_initialized, get_world_size


def str2bool(v):
    return v.lower() in ("yes", "true", "t", "1")


class Timer:
    """A simple timer."""

    def __init__(self):
        self.total_time = 0.
        self.calls = 0
        self.start_time = 0.
        self.diff = 0.
        self.average_time = 0.

    def tic(self):
        # using time.time instead of time.clock because time time.clock
        # does not normalize for multithreading
        self.start_time = time.time()

    def toc(self, average=True):
        self.diff = time.time() - self.start_time
        self.total_time += self.diff
        self.calls += 1
        self.average_time = self.total_time / self.calls
        if average:
            return self.average_time
        return self.diff

class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.val = None
        self.avg = None
        self.sum = None
        self.count = None
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def evaluate_detections(box_list, dataset, use_07=True):
    cachedir = os.path.join('cache', 'annotations_cache')
    aps = []
    # The PASCAL VOC metric changed in 2010
    use_07_metric = use_07
    logger.info('VOC07 metric? {}'.format('Yes' if use_07_metric else 'No'))
    for cls_ind, cls in enumerate(dataset.classes):  # for each class
        class_boxes = box_list[box_list[:, 1] == cls_ind + 1]
        ap, _, _ = voc_eval(  # calculate rec, prec, ap
            class_boxes, dataset, cls, cachedir,
            ovthresh=0.5, use_07_metric=use_07_metric)
        aps += [ap]
        logger.info('AP for {} = {:.4f}'.format(cls, ap))
    mAp = np.mean(aps)
    logger.info('Mean AP = {:.4f}'.format(mAp))
    return mAp


def voc_ap(rec, prec, use_07_metric=True):
    """ ap = voc_ap(rec, prec, [use_07_metric])
    Compute VOC AP given precision and recall.
    If use_07_metric is true, uses the
    VOC 07 11 point method (default:False).
    """
    if use_07_metric:
        # 11 point metric
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(rec >= t) == 0:
                p = 0
            else:
                p = np.max(prec[rec >= t])
            ap = ap + p / 11.
    else:
        # correct AP calculation
        # first append sentinel values at the end
        mrec = np.concatenate(([0.], rec, [1.]))
        mpre = np.concatenate(([0.], prec, [0.]))

        # compute the precision envelope
        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

        # to calculate area under PR curve, look for points
        # where X axis (recall) changes value
        i = np.where(mrec[1:] != mrec[:-1])[0]

        # and sum (\Delta recall) * prec
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap


def voc_eval(class_detections,
             dataset,
             classname,
             cachedir,
             ovthresh=0.5,
             use_07_metric=True):
    # cachedir caches the annotations in a pickle file
    # first load gt
    gt, imagenames = load_detection_annotations(cachedir, dataset)
    image_bboxes, npos = extract_gt_bboxes(classname, dataset, gt, imagenames)

    image_names = dataset.get_img_names()
    image_ids = [image_names[int(i)] for i in class_detections[:, 0]]
    BB = class_detections[:, 3:]
    confidence = class_detections[:, 2]

    # sort by confidence
    sorted_ind = np.argsort(confidence)[::-1]
    BB = BB[sorted_ind, :]
    image_ids = [image_ids[x] for x in sorted_ind]

    # go down dets and mark TPs and FPs
    nd = len(image_ids)
    tp = np.zeros(nd)
    fp = np.zeros(nd)
    for d in range(nd):
        R = image_bboxes[image_ids[d]]
        bb = BB[d, :].astype(float)
        matched_iou = -np.inf
        BBGT = R['bbox'].astype(float)
        if BBGT.size > 0:
            # compute overlaps
            # intersection
            matched_ind, matched_iou = match_bbox(BBGT, bb)

        if matched_iou > ovthresh:
            if not R['difficult'][matched_ind]:
                if not R['det'][matched_ind]:
                    tp[d] = 1.
                    R['det'][matched_ind] = 1
                else:
                    fp[d] = 1.
        else:
            fp[d] = 1.

    return compute_detection_metrics(fp, tp, npos, use_07_metric)


def extract_gt_bboxes(classname, dataset, gt, imagenames):
    # extract gt objects for this class
    class_gt = {}
    npos = 0
    for imagename in imagenames:
        img_gt_objects_for_class = [rec for rec in gt[imagename] if dataset.classes[rec['label_idx']] == classname]
        bbox = np.asarray([x['bbox'] for x in img_gt_objects_for_class])
        difficult = []
        for x in img_gt_objects_for_class:
            if 'difficult' in x:
                difficult.append(x['difficult'])
            else:
                difficult.append(False)
        difficult = np.array(difficult).astype(np.bool)
        det = [False] * len(img_gt_objects_for_class)
        npos = npos + sum(~difficult)
        class_gt[imagename] = {
            'bbox': bbox,
            'difficult': difficult,
            'det': det
        }
    return class_gt, npos


def load_detection_annotations(cachedir, dataset):
    cachefile = os.path.join(cachedir, 'annots_{}.json'.format(dataset.name))
    imagenames = dataset.get_img_names()
    if is_main_process():
        if not os.path.isfile(cachefile):
            # load annots
            gt = {}
            for i, imagename in enumerate(imagenames):
                _, gt[imagename] = dataset.pull_anno(i)

                if i % 100 == 0:
                    logger.info('Reading annotation for {:d}/{:d}'.format(
                        i + 1, len(imagenames)))
            # save
            logger.info('Saving cached annotations to {:s}'.format(cachefile))
            pathlib.Path(cachedir).mkdir(parents=True, exist_ok=True)
            with open(cachefile, 'w', encoding='utf8') as f:
                json.dump(gt, f)
    if is_dist_avail_and_initialized():
        dist.barrier()
    with open(cachefile, 'r', encoding='utf8') as f:
        gt = json.load(f)
    return gt, imagenames


def compute_detection_metrics(fp, tp, n_positives, use_07_metric=True):
    # compute precision recall
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)
    rec = tp / float(n_positives)
    # avoid divide by zero in case the first detection matches a difficult
    # ground truth
    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
    ap = voc_ap(rec, prec, use_07_metric)
    return ap, prec, rec


def match_bbox(gt_boxes, bbox):
    ixmin = np.maximum(gt_boxes[:, 0], bbox[0])
    iymin = np.maximum(gt_boxes[:, 1], bbox[1])
    ixmax = np.minimum(gt_boxes[:, 2], bbox[2])
    iymax = np.minimum(gt_boxes[:, 3], bbox[3])
    iw = np.maximum(ixmax - ixmin, 0.)
    ih = np.maximum(iymax - iymin, 0.)
    inters = iw * ih
    uni = ((bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) +
           (gt_boxes[:, 2] - gt_boxes[:, 0]) *
           (gt_boxes[:, 3] - gt_boxes[:, 1]) - inters)
    overlaps = inters / uni
    matched_ind = np.argmax(overlaps)
    matched_iou = np.max(overlaps)
    return matched_ind, matched_iou


def gather_detections(all_detections, samples_per_rank):
    world_size = get_world_size()
    result = [torch.zeros(samples_per_rank, *all_detections.shape[1:]).cuda() for _ in range(world_size)]
    all_detections = F.pad(all_detections, [0, 0, 0, 0, 0, samples_per_rank - all_detections.size(0)])
    dist.all_gather(result, all_detections.cuda())
    return torch.cat(result)


def convert_detections(all_detection):
    """Returns (num_dets x 7) array with detections (img_ing, label, conf, x0, y0, x1, y1)"""
    all_boxes = []
    for img_ind, dets in enumerate(all_detection):
        # remove predictions with zero confidence
        mask = dets[:, 2].gt(0.).expand(7, dets.size(0)).t()
        dets = torch.masked_select(dets, mask).view(-1, 7).cpu()

        if dets.size() == (0,):
            continue

        boxes = np.c_[np.full(dets.size(0), img_ind), dets[:, 1:].numpy()]
        all_boxes.append(boxes)
    return np.vstack(all_boxes)


def predict_detections(data_loader, device, net):
    num_batches = len(data_loader)
    all_detections = []
    timer = Timer()
    for batch_ind, (ims, _gts, hs, ws) in enumerate(data_loader):
        x = ims.to(device)
        hs = x.new_tensor(hs).view(-1, 1)
        ws = x.new_tensor(ws).view(-1, 1)

        timer.tic()
        batch_detections = net(x)
        top_k = batch_detections.size(2)
        batch_detections = batch_detections.view(-1, top_k, 7)
        detect_time = timer.toc(average=False)

        batch_detections[..., 3] *= ws
        batch_detections[..., 5] *= ws
        batch_detections[..., 4] *= hs
        batch_detections[..., 6] *= hs

        all_detections.append(batch_detections.cpu())
        logger.info('Detect for batch: {:d}/{:d} {:.3f}s'.format(batch_ind + 1, num_batches, detect_time))
    if all_detections:
        return torch.cat(all_detections)
    return None  # No predictions

def eval_net_loss(data_loader, device, net, criterion):
    batch_loss_l = AverageMeter()
    batch_loss_c = AverageMeter()
    batch_loss = AverageMeter()
    t_elapsed = AverageMeter()

    num_batches = len(data_loader)

    # Assume 10 lines of reporting
    print_freq = num_batches//10
    print_freq = 1 if print_freq == 0 else print_freq

    # all_detections = []
    timer = Timer()
    for batch_ind, (ims, _gts, _, _) in enumerate(data_loader):
        images = ims.to(device)
        targets = [anno.requires_grad_(False).to(device) for anno in _gts]

        # forward
        out = net(images)
        loss_l, loss_c = criterion(out, targets)
        loss = loss_l + loss_c

        batch_loss_l.update(loss_l.item(), images.size(0))
        batch_loss_c.update(loss_c.item(), images.size(0))
        batch_loss.update(loss.item(), images.size(0))

        timer.tic()
        t_elapsed.update(timer.toc(average=False))

        if batch_ind % print_freq == 0:
            logger.info('Loss_inference: [{}/{}] || Time: {elapsed.val:.4f}s ({elapsed.avg:.4f}s)'
                        ' || Conf Loss: {conf_loss.val:.3f} ({conf_loss.avg:.3f})'
                        ' || Loc Loss: {loc_loss.val:.3f} ({loc_loss.avg:.3f})'
                        ' || Model Loss: {model_loss.val:.3f} ({model_loss.avg:.3f})'.format(
                            batch_ind, num_batches, elapsed=t_elapsed, conf_loss=batch_loss_c,
                            loc_loss=batch_loss_l, model_loss=batch_loss))

    model_loss = batch_loss_l.avg + batch_loss_c.avg
    return model_loss

def test_net(net, device, data_loader, distributed=False, loss_inference=False, criterion=None):
    """Test a Fast R-CNN network on an image database."""

    # Put BN layers into evaluation mode
    bn_modules_training_flags = {}

    def put_bn_modules_in_eval_mode(module):
        if isinstance(module, torch.nn.BatchNorm2d):
            bn_modules_training_flags[module] = module.training
            module.eval()

    def restore_bn_module_mode(module):
        if isinstance(module, torch.nn.BatchNorm2d):
            module.training = bn_modules_training_flags[module]

    net.apply(put_bn_modules_in_eval_mode)

    if loss_inference is True:
        logger.info("Testing... loss function will be evaluated instead of detection mAP")
        if distributed:
            raise NotImplementedError
        if criterion is None:
            raise ValueError("Missing loss inference function (criterion)")
        output = eval_net_loss(data_loader, device, net, criterion)
        net.apply(restore_bn_module_mode)
        return output

    logger.info("Testing...")
    num_images = len(data_loader.dataset)
    batch_detections = predict_detections(data_loader, device, net)
    if distributed:
        batch_detections = gather_detections(batch_detections, data_loader.sampler.samples_per_rank)
    batch_detections = batch_detections[:num_images]
    all_boxes = convert_detections(batch_detections)

    logger.info('Evaluating detections')
    output = evaluate_detections(all_boxes, data_loader.dataset)
    net.apply(restore_bn_module_mode)
    return output
