# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Train and eval functions used in main.py
"""
import math
import os.path

import cv2
import numpy as np
import sys
from typing import Iterable
import matplotlib.pyplot as plt

import torch
import torch.distributed as dist

import util.misc as utils
from util.metric import metric



@torch.no_grad()
def test(model, criterion, postprocessors, data_loader, device, args):
    model.eval()
    criterion.eval()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    output_list = []
    target_list = []
    cls_acc = []
    cls_recall = []

    visualization = False
    if visualization:
        out_dir = os.path.dirname(args.resume).replace("out", "rank_cmp")  # show/show-fm/show-norecall/showPre/rank_cmp
        if not os.path.exists(out_dir + "/"):
            os.makedirs(out_dir + "/")
    i = 0
    for samples in metric_logger.log_every(data_loader, 10, header):
        imgs = samples['img'].to(device)
        label_cls = samples['label_cls'].to(device)
        label_lines = samples['label_lines'].to(device)
        score_cls = samples['score_cls'].to(device)
        score_lines = samples['score_lines'].to(device)
        label_junction_map = [m.to(device) for m in samples['label_junction_map']]
        label_line_map = [m.to(device) for m in samples['label_line_map']]

        outputs = model(imgs)
        cls_cmp = np.equal(outputs['pred_cls'].cpu().numpy() > 0, label_cls.cpu().numpy() > 0)
        cls_acc.append(np.bitwise_and(cls_cmp, outputs['pred_cls'].cpu().numpy() > 0).sum() /
                       (outputs['pred_cls'].cpu().numpy() > 0).sum())
        cls_recall.append(np.bitwise_and(cls_cmp, label_cls.cpu().numpy() > 0).sum() /
                          (label_cls.cpu().numpy() > 0).sum())

        loss_dict = criterion(outputs, label_cls, label_lines, label_junction_map, label_line_map)
        weight_dict = criterion.weight_dict

        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {f'{k}_unscaled': v
                                      for k, v in loss_dict_reduced.items()}
        loss_dict_reduced_scaled = {k: v * weight_dict[k]
                                    for k, v in loss_dict_reduced.items() if k in weight_dict}

        metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
                             **loss_dict_reduced_scaled,
                             **loss_dict_reduced_unscaled)
        output, target = postprocessors(outputs, samples)
        output_list += output
        target_list += target


    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)

    result = metric(output_list, target_list)
    print(result)

    print("cls_acc:", np.mean(np.array(cls_acc)))
    print("cls_recall:", np.mean(np.array(cls_recall)))

    return 0

