# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Train and eval functions used in main.py
"""
import math
import os
import sys
from typing import Iterable

import torch

import util.misc as utils
from datasets.coco_eval import CocoEvaluator
from datasets.panoptic_eval import PanopticEvaluator
import torch
from torch.nn import functional as F
from torch.nn.modules import loss
from pathlib import Path
import json
import shutil
import pdb

from torch.utils.tensorboard import SummaryWriter

import numpy as np
from hook_functions import *
from KL_interlayer import save_npy, KL, KL_savefile
import xlwt


feats = {}


def hook(module, input, output):
    global feats
    feats = {"input": input, "output": output}
    # feats = {"input":input[0].detach().cpu(), "output":(x.detach().cpu() for x in output)}
    return


def get_params_grad(model):
    """
    get model parameters and corresponding gradients
    """
    params = []
    grads = []
    for param in model.parameters():
        if not param.requires_grad:
            continue
        params.append(param)
        grads.append(0. if param.grad is None else param.grad + 0.)
    return params, grads



def read_json(json_path):
    with open(json_path, 'r') as dic:
        dic_str = json.load(dic)
    return dic_str


def train_one_epoch(args, writer, model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, max_norm: float = 0):

    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    i = 0
    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        i += 1
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(samples)

        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {f'{k}_unscaled': v
                                      for k, v in loss_dict_reduced.items()}
        loss_dict_reduced_scaled = {k: v * weight_dict[k]
                                    for k, v in loss_dict_reduced.items() if k in weight_dict}
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()

        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])


        train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
        log_stats = {**{f'step_train_{k}': v for k, v in train_stats.items()},
                     'step': i}

        if args.output_dir and utils.is_main_process():
            output_dir = Path(args.output_dir)
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")
            for k in log_stats.keys():
                if k == 'test_coco_eval_bbox' or k == 'test_coco_eval_masks':
                    writer.add_scalar('mAP_all', log_stats[k][0], i)
                elif isinstance(log_stats[k], dict):
                    print(f'{k}:{log_stats[k]}')
                else:
                    writer.add_scalar(k, log_stats[k], i)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


class DistillationLoss(loss._Loss):
    """The KL-Divergence loss for the binary student model and real teacher output.
    output must be a pair of (model_output, real_output), both NxC tensors.
    The rows of real_output must all add up to one (probability scores);
    however, model_output must be the pre-softmax output of the network."""

    def forward(self, model_output, real_output):

        self.size_average = True

        # Target is ignored at training time. Loss is defined as KL divergence
        # between the model output and the refined labels.
        if real_output.requires_grad:
            raise ValueError("real network output should not require gradients.")

        model_output_log_prob = F.log_softmax(model_output, dim=1)
        real_output_soft = F.softmax(real_output, dim=1)
        del model_output, real_output

        # Loss is -dot(model_output_log_prob, real_output). Prepare tensors
        # for batch matrix multiplicatio
        real_output_soft = real_output_soft.unsqueeze(1)
        model_output_log_prob = model_output_log_prob.unsqueeze(2)

        # Compute the loss, and average/sum for the batch.
        cross_entropy_loss = -torch.bmm(real_output_soft, model_output_log_prob)
        if self.size_average:
            cross_entropy_loss = cross_entropy_loss.mean()
        else:
            cross_entropy_loss = cross_entropy_loss.sum()
        # Return a pair of (loss_output, model_output). Model output will be
        # used for top-1 and top-5 evaluation.
        # model_output_log_prob = model_output_log_prob.squeeze(2)
        return cross_entropy_loss


class DistillKL(torch.nn.Module):
    """Distilling the Knowledge in a Neural Network"""

    def __init__(self, T=6):
        super(DistillKL, self).__init__()
        self.T = T

    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s / self.T, dim=1)
        p_t = F.softmax(y_t / self.T, dim=1)
        loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
        return loss


def loss_kd(output, target, teacher_output, args):
    """
    Compute the knowledge-distillation (KD) loss given outputs and labels.
    "Hyperparameters": temperature and alpha
    The KL Divergence for PyTorch comparing the softmaxs of teacher and student.
    The KL Divergence expects the input tensor to be log probabilities.
    """
    alpha = args.distill_alpha
    T = args.temperature
    KD_loss = F.kl_div(F.log_softmax(output / T, dim=1), F.softmax(teacher_output / T, dim=1)) * (alpha * T * T) + \
        F.cross_entropy(output, target) * (1. - alpha)

    return KD_loss


class L1Loss(torch.nn.Module):
    def __init__(self, size_average=None, reduce=None, reduction='mean'):
        super(L1Loss, self).__init__()  # .__init__(size_average, reduce, reduction)

    def forward(self, input, target):
        return F.l1_loss(input, target, reduction='none')


class HintLoss(torch.nn.Module):
    """Fitnets: hints for thin deep nets, ICLR 2015"""

    def __init__(self):
        super(HintLoss, self).__init__()
        self.crit = torch.nn.MSELoss()  

    def forward(self, f_s, f_t):
        if isinstance(f_s, dict) and isinstance(f_t, dict):
            losses = {}
            losses_key = f_t.keys()
            for key in losses_key:
                losses[key] = self.crit(f_s[key], f_t[key])
        else:
            losses = self.crit(f_s, f_t)
        return losses


class hubber_loss(torch.nn.Module):
    def __init__(self, delta=1.35):
        super(hubber_loss, self).__init__()
        self.MSEloss = HintLoss()
        self.L1loss = L1Loss()
        self.delta = delta

    def forward(self, y_truth, y_predicted):
        y_size = y_truth.shape
        total_error = 0
        final_error = 0
        num = 0
        error_tensor = torch.zeros_like(y_truth)
        if len(y_size) == 3:
            for x in range(y_size[0]):
                for y in range(y_size[1]):
                    for z in range(y_size[2]):
                        num += 1
                        error = torch.absolute(y_predicted[x][y][z] - y_truth[x][y][z])  # .cpu().detach().numpy()
                        error_tensor[x][y][z] = error
                        total_error += error
            MAEerror = total_error / num
            self.new_delta = self.delta * MAEerror

            for x in range(y_size[0]):
                for y in range(y_size[1]):
                    for z in range(y_size[2]):
                        if error_tensor[x][y][z] < self.new_delta:
                            hubber_error = 0.5 * error_tensor[x][y][z]**2
                        else:
                            hubber_error = self.new_delta * error_tensor[x][y][z] - 0.5 * self.new_delta ** 2
                        final_error += hubber_error
        else:
            print('y_truth.shape isn\'t 3')
        total_hubber_error = final_error / num
        total_hubber_error = torch.tensor(total_hubber_error)
        return total_hubber_error


class detr_criterion_teacher(torch.nn.Module):
    def __init__(self, args, T=6, bbox_fun='L1', bxloss_style='mean',
                 KL_alpha=1, delta=1.35, distillation_aux=True):
        super(detr_criterion_teacher, self).__init__()
        # self.MSEloss = HintLoss()
        # self.L1loss = L1Loss()
        self.args = args
        self.KLloss = DistillKL(T)

        self.distillation_aux = distillation_aux
        self.bxloss_style = bxloss_style
        self.KL_alpha = KL_alpha
        self.delta = delta
        if bbox_fun == 'L1':
            self.bbox_fun = L1Loss()
        elif bbox_fun == 'MSE':
            self.bbox_fun = HintLoss()
        elif bbox_fun == 'hubber':
            self.bbox_fun = hubber_loss(self.delta)
        else:
            print("bounding box loss only has [MSEloss / L1Loss] now")

    def forward(self, f_s, f_t):
        # f_s: student model ; f_t: teacher model
        loss_dict_teacher = {}
        losses = 0
        # alpha = 0.1
        if isinstance(f_t, dict):
            for (f_s_key, f_s_value), (f_t_key, f_t_value) in zip(f_s.items(), f_t.items()):
                if f_t_key == 'pred_logits':
                    if self.args.super_category is not None and self.args.grad_norm is not None:
                        super_category_id_dict = read_json(self.args.super_category)
                        super_category_list = super_category_id_dict.get(next(iter(super_category_id_dict)))
                        w = len(super_category_list)

                        other_class_id_list = []
                        for z in range(f_t_value.shape[-1] - 1): 
                            if not z in super_category_list:
                                other_class_id_list.append(z)

                        # -------------- other_logits: max --------------
                        a = f_t_value.shape[0]
                        b = f_t_value.shape[1]
                        L = len(other_class_id_list)
                        other_pred_logits_all = torch.zeros(a, b, L)
                        for i in range(L):
                            z = other_class_id_list[i]
                            other_pred_logits_all[:, :, i] = f_t_value[:, :, z]

                        other_scores_max = torch.max(other_pred_logits_all, dim=2)[0]

                        new_out_pred_logits = torch.zeros(a, b, w + 2)
                        for i in range(w + 2):
                            # 0 - w-1: sub_class ;  w : other ;  w+1: no-object

                            if i < w:
                                sub_class_id = int(super_category_list[i])
                                new_out_pred_logits[:, :, i] = f_t_value[:, :, sub_class_id]
                            elif i == w:
                                new_out_pred_logits[:, :, i] = other_scores_max
                            else:
                                new_out_pred_logits[:, :, i] = f_t_value[:, :, -1]

                        f_t_value = new_out_pred_logits.to(self.args.device)

                    loss_logits_KL = self.KLloss(f_s_value, f_t_value)
                    loss_dict_teacher[f_t_key] = loss_logits_KL * self.KL_alpha
                    losses += loss_logits_KL * self.KL_alpha
                elif f_t_key == 'pred_boxes':

                    loss_boxes_l1 = self.bbox_fun(f_s_value, f_t_value)

                    if self.bxloss_style == 'mean':
                        loss_boxes = torch.mean(loss_boxes_l1)
                    elif self.bxloss_style == 'sum':
                        loss_boxes = torch.sum(loss_boxes_l1 * 5)
                    else:
                        print("bounding box loss style only has [mean / sum] now")

                    loss_dict_teacher[f_t_key] = loss_boxes
                    losses += loss_boxes
                elif self.distillation_aux and (f_t_key == 'aux_outputs'):
                    loss_sublist = []
                    for f_s_pred_subdict, f_t_pred_subdict in zip(f_s_value, f_t_value):
                        loss_subdict = {}
                        for (fs_key, fs_value), (ft_key, ft_value) in zip(f_s_pred_subdict.items(), f_t_pred_subdict.items()):
                            if ft_key == fs_key == 'pred_logits':
                                if self.args.super_category is not None and self.args.grad_norm is not None:
                                    super_category_id_dict = read_json(self.args.super_category)
                                    super_category_list = super_category_id_dict.get(next(iter(super_category_id_dict)))
                                    w = len(super_category_list)

                                    other_class_id_list = []
                                    for z in range(ft_value.shape[-1] - 1):  
                                        if not z in super_category_list:
                                            other_class_id_list.append(z)

                                    # -------------- other_logits: max --------------
                                    a = ft_value.shape[0]
                                    b = ft_value.shape[1]
                                    L = len(other_class_id_list)
                                    other_pred_logits_all = torch.zeros(a, b, L)
                                    for i in range(L):
                                        z = other_class_id_list[i]
                                        other_pred_logits_all[:, :, i] = ft_value[:, :, z]

                                    other_scores_max = torch.max(other_pred_logits_all, dim=2)[0]

                                    new_out_pred_logits = torch.zeros(a, b, w + 2)
                                    for i in range(w + 2):
                                        # 0 - w-1: sub_class ;  w : other ;  w+1: no-object

                                        if i < w:
                                            sub_class_id = int(super_category_list[i])
                                            new_out_pred_logits[:, :, i] = ft_value[:, :, sub_class_id]
                                        elif i == w:
                                            new_out_pred_logits[:, :, i] = other_scores_max
                                        else:
                                            new_out_pred_logits[:, :, i] = ft_value[:, :, -1]

                                    ft_value = new_out_pred_logits.to(self.args.device)

                                loss_logits_KL = self.KLloss(fs_value, ft_value)
                                loss_subdict[ft_key] = loss_logits_KL * self.KL_alpha
                                losses += loss_logits_KL * self.KL_alpha
                            elif ft_key == fs_key == 'pred_boxes':

                                loss_boxes_L1 = self.bbox_fun(fs_value, ft_value)

                                if self.bxloss_style == 'mean':
                                    loss_boxes = torch.mean(loss_boxes_L1)
                                elif self.bxloss_style == 'sum':
                                    loss_boxes = torch.sum(loss_boxes_L1 * 5)
                                else:
                                    print("bounding box loss style only has [mean / sum] now")

                                loss_subdict[ft_key] = loss_boxes
                                losses += loss_boxes

                        loss_sublist.append(loss_subdict)
                    loss_dict_teacher[f_t_key] = loss_sublist

        return loss_dict_teacher, losses


class detr_criterion_super(torch.nn.Module):
    def __init__(self, T=6, bbox_fun='L1', bxloss_style='mean', KL_alpha=1, delta=1.35, distillation_aux=True):
        super(detr_criterion_super, self).__init__()
        # self.MSEloss = HintLoss()
        # self.L1loss = L1Loss()
        self.KLloss = DistillKL(T)

        self.distillation_aux = distillation_aux
        self.bxloss_style = bxloss_style
        self.KL_alpha = KL_alpha
        self.delta = delta
        if bbox_fun == 'L1':
            self.bbox_fun = L1Loss()
        elif bbox_fun == 'MSE':
            self.bbox_fun = HintLoss()
        elif bbox_fun == 'hubber':
            self.bbox_fun = hubber_loss(self.delta)
        else:
            print("bounding box loss only has [MSEloss / L1Loss] now")

    def forward(self, f_s, f_t):
        # f_s: student model ; f_t: teacher model
        loss_dict_teacher = {}
        losses = 0
        # alpha = 0.1
        if isinstance(f_t, dict):
            for (f_s_key, f_s_value), (f_t_key, f_t_value) in zip(f_s.items(), f_t.items()):
                if f_t_key == 'pred_logits':
                    loss_logits_KL = self.KLloss(f_s_value, f_t_value)
                    loss_dict_teacher[f_t_key] = loss_logits_KL * self.KL_alpha
                    losses += loss_logits_KL * self.KL_alpha
                elif f_t_key == 'pred_boxes':

                    loss_boxes_l1 = self.bbox_fun(f_s_value, f_t_value)

                    if self.bxloss_style == 'mean':
                        loss_boxes = torch.mean(loss_boxes_l1)
                    elif self.bxloss_style == 'sum':
                        loss_boxes = torch.sum(loss_boxes_l1 * 5)
                    else:
                        print("bounding box loss style only has [mean / sum] now")

                    loss_dict_teacher[f_t_key] = loss_boxes
                    losses += loss_boxes
                elif self.distillation_aux and (f_t_key == 'aux_outputs'):
                    loss_sublist = []
                    for f_s_pred_subdict, f_t_pred_subdict in zip(f_s_value, f_t_value):
                        loss_subdict = {}
                        for (fs_key, fs_value), (ft_key, ft_value) in zip(f_s_pred_subdict.items(), f_t_pred_subdict.items()):
                            if ft_key == fs_key == 'pred_logits':
                                loss_logits_KL = self.KLloss(fs_value, ft_value)
                                loss_subdict[ft_key] = loss_logits_KL * self.KL_alpha
                                losses += loss_logits_KL * self.KL_alpha
                            elif ft_key == fs_key == 'pred_boxes':

                                loss_boxes_L1 = self.bbox_fun(fs_value, ft_value)

                                if self.bxloss_style == 'mean':
                                    loss_boxes = torch.mean(loss_boxes_L1)
                                elif self.bxloss_style == 'sum':
                                    loss_boxes = torch.sum(loss_boxes_L1 * 5)
                                else:
                                    print("bounding box loss style only has [mean / sum] now")

                                loss_subdict[ft_key] = loss_boxes
                                losses += loss_boxes

                        loss_sublist.append(loss_subdict)
                    loss_dict_teacher[f_t_key] = loss_sublist

        return loss_dict_teacher, losses


def train_one_epoch_teacher(args, log, writer, model: torch.nn.Module, model_teacher, criterion: torch.nn.Module, criterion_teacher,
                            criterion_reg, data_loader: Iterable, optimizer: torch.optim.Optimizer,
                            device: torch.device, epoch: int, max_norm: float = 0):
    model.train()
    criterion.train()
    criterion_reg.train()
    model_teacher.eval()
    criterion_teacher.train()

    metric_logger = utils.MetricLogger(delimiter="  ")
    if args.compute_KL:
        metric_logger.add_meter('KL_interlayer', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    elif args.compute_MSE:
        metric_logger.add_meter('MSE_interlayer', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('KLloss', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    metric_logger.add_meter('boxloss', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))

    metric_logger.add_meter('gradloss', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))

    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    i = 0
    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        i += 1
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(samples)
        outputs_teacher = model_teacher(samples)

        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        # ------- model_teacher -------
        loss_dict_teacher, losses_teacher = criterion_teacher(outputs, outputs_teacher)

        losses = losses + losses_teacher

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {f'{k}_unscaled': v
                                      for k, v in loss_dict_reduced.items()}
        loss_dict_reduced_scaled = {k: v * weight_dict[k]
                                    for k, v in loss_dict_reduced.items() if k in weight_dict}
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        loss_value = loss_value + losses_teacher.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        loss_dict_reg = criterion_reg(outputs, targets)
        weight_dict_reg = criterion_reg.weight_dict
        losses_reg = sum(loss_dict_reg[k] * weight_dict_reg[k] for k in loss_dict_reg.keys() if k in weight_dict_reg)

        optimizer.zero_grad()

        hero = 0.
        if args.distributed:
            model_param = model.module
        else:
            model_param = model

        params = [p for n, p in model_param.named_parameters() if p.requires_grad]
        names = [n for n, p in model_param.named_parameters() if p.requires_grad]

        grad_new = torch.autograd.grad(losses, params, retain_graph=True, allow_unused=True)
        grad_dict = {n: g1.clone().detach() for n, g1 in zip(names, grad_new) if g1 is not None}

        if args.grad_L2 > 0.:
            grad_reg = torch.autograd.grad(losses_reg, params, create_graph=True, allow_unused=True)
            grad_reg_dict = {n: g1.clone().detach() for n, g1 in zip(names, grad_reg) if g1 is not None}

            for g in grad_reg:
                if g is not None:
                    hero = hero + torch.mean(g * g)

        if args.grad_L2 > 0.:
            if args.dual_loss:
                for n, p in model_param.named_parameters():
                    if p.requires_grad and n in grad_dict.keys() and grad_dict[n] is not None:
                        if 'weight' in n and "norm" not in n and grad_reg_dict[n] is not None:
                            p.grad = args.grad_L2 * grad_reg_dict[n] + grad_dict[n]
                        else:
                            p.grad = grad_dict[n]
            else:
                if args.grad_L2 * (epoch+1)/args.epochs < args.grad_L2_base:
                    reg_scale = args.grad_L2_base
                else:
                    reg_scale = args.grad_L2 * (epoch+1)/args.epochs
                losses_n = reg_scale * hero
                grad_r = torch.autograd.grad(losses_n, params, allow_unused=True)
                reg_dict = {n: g1.clone().detach() for n, g1 in zip(names, grad_r) if g1 is not None}

                for n, p in model_param.named_parameters():
                    if p.requires_grad and n in grad_dict.keys() and grad_dict[n] is not None:
                        if 'weight' in n and "norm" not in n and reg_dict[n] is not None:
                            p.grad = reg_dict[n] + grad_dict[n]
                        else:
                            p.grad = grad_dict[n]

        else:
            for n, p in model_param.named_parameters():
                if p.requires_grad and n in grad_dict.keys() and grad_dict[n] is not None:
                    p.grad = grad_dict[n]


        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
        metric_logger.update(KLloss=loss_dict_teacher['pred_logits'])
        metric_logger.update(boxloss=loss_dict_teacher['pred_boxes'])

        metric_logger.update(gradloss=hero)

        if args.compute_KL:
            metric_logger.update(KL_interlayer=kl)
        elif args.compute_MSE:
            metric_logger.update(MSE_interlayer=mse)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

        writer.add_scalar('KLloss', loss_dict_teacher['pred_logits'], i)
        writer.add_scalar('boxloss', loss_dict_teacher['pred_boxes'], i)


    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def evaluate(args, log, model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
    model.eval()
    criterion.eval()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Test:'

    iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
    coco_evaluator = CocoEvaluator(args, base_ds, iou_types)
    # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]

    panoptic_evaluator = None
    if 'panoptic' in postprocessors.keys():
        panoptic_evaluator = PanopticEvaluator(args,
                                               data_loader.dataset.ann_file,
                                               data_loader.dataset.ann_folder,
                                               output_dir=os.path.join(output_dir, "panoptic_eval"),
                                               )

    # var_count  = 0
    # loss_sum_list = []
    # loss_var_list = []
    for samples, targets in metric_logger.log_every(data_loader, 10, header):
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(samples)

        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_scaled = {k: v * weight_dict[k]
                                    for k, v in loss_dict_reduced.items() if k in weight_dict}
        loss_dict_reduced_unscaled = {f'{k}_unscaled': v
                                      for k, v in loss_dict_reduced.items()}

        metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
                             **loss_dict_reduced_scaled,
                             **loss_dict_reduced_unscaled)

        metric_logger.update(class_error=loss_dict_reduced['class_error'])

        orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
        results = postprocessors['bbox'](outputs, orig_target_sizes)

        if 'segm' in postprocessors.keys():
            target_sizes = torch.stack([t["size"] for t in targets], dim=0)
            results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
        res = {target['image_id'].item(): output for target, output in zip(targets, results)}
        if coco_evaluator is not None:
            coco_evaluator.update(res)

        if panoptic_evaluator is not None:
            res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
            for i, target in enumerate(targets):
                image_id = target["image_id"].item()
                file_name = f"{image_id:012d}.png"
                res_pano[i]["image_id"] = image_id
                res_pano[i]["file_name"] = file_name

            panoptic_evaluator.update(res_pano)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)

    if float(args.KL_alpha) == 0.5:
        KL_alpha = 0.5
    elif float(args.KL_alpha) == 0.05:
        KL_alpha = 0.05
    elif float(args.KL_alpha) == 1:
        KL_alpha = 1

    if args.quant_scheme is not None and not args.teacher:
        quant_scheme = 'Quant_' + args.quant_scheme
        quant_scheme1 = args.quant_scheme
    elif args.teacher:
        quant_scheme = quant_scheme1 = 'distill_KL' + str(KL_alpha) + '_' + args.quant_scheme
    else:
        quant_scheme = 'Float'
        quant_scheme1 = 'float'

    if coco_evaluator is not None:
        coco_evaluator.synchronize_between_processes()
    if panoptic_evaluator is not None:
        panoptic_evaluator.synchronize_between_processes()

    # accumulate predictions from all images
    if coco_evaluator is not None:
        coco_evaluator.accumulate()


    if not args.per_class:
        if coco_evaluator is not None:
            coco_evaluator.summarize()
        panoptic_res = None
        if panoptic_evaluator is not None:
            panoptic_res = panoptic_evaluator.summarize()
        stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
        if coco_evaluator is not None:
            if 'bbox' in postprocessors.keys():
                stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
            if 'segm' in postprocessors.keys():
                stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
        if panoptic_res is not None:
            stats['PQ_all'] = panoptic_res["All"]
            stats['PQ_th'] = panoptic_res["Things"]
            stats['PQ_st'] = panoptic_res["Stuff"]

        if args.mAP_write_xls_path is not None:
            bbox_mAP_dict = {}
            mAP_all = stats['coco_eval_bbox'][0]
            mAP_small = stats['coco_eval_bbox'][3]
            mAP_medium = stats['coco_eval_bbox'][4]
            mAP_large = stats['coco_eval_bbox'][5]
            bbox_mAP_dict[f'bbox_mAP_all'] = mAP_all
            bbox_mAP_dict[f'bbox_mAP_small'] = mAP_small
            bbox_mAP_dict[f'bbox_mAP_medium'] = mAP_medium
            bbox_mAP_dict[f'bbox_mAP_large'] = mAP_large

            # stats['all_mAP'] = map_stats
            # print(f'stats[\'all_mAP\'] :\n{map_stats}')

            if args.masks:
                segm_mAP_dict = {}
                mAP_all = stats['coco_eval_masks'][0]
                mAP_small = stats['coco_eval_masks'][3]
                mAP_medium = stats['coco_eval_masks'][4]
                mAP_large = stats['coco_eval_masks'][5]
                segm_mAP_dict[f'segm_mAP_all'] = mAP_all
                segm_mAP_dict[f'segm_mAP_small'] = mAP_small
                segm_mAP_dict[f'segm_mAP_medium'] = mAP_medium
                segm_mAP_dict[f'segm_mAP_large'] = mAP_large

    return stats, coco_evaluator


