#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.

import torch.nn as nn
import numpy as np
from operator import itemgetter
import torch

from .yolo_head import YOLOXHead
from .yolo_pafpn import YOLOPAFPN


class YOLOX(nn.Module):
    """
    YOLOX model module. The module list is defined by create_yolov3_modules function.
    The network returns loss values from three YOLO layers during training
    and detection results during test.
    """

    def __init__(self, backbone=None, head=None):
        super().__init__()
        if backbone is None:
            backbone = YOLOPAFPN()
        if head is None:
            head = YOLOXHead(80)

        self.backbone = backbone
        self.head = head
        ### specify the path of data uncertainty scores
        self.maha_dist = np.load('/home/coco_features/maha_dic_perclass_data_scores.npy', allow_pickle=True).item()


    def forward(self, x, targets=None, ann_ids = None):
        # fpn output content features of [dark3, dark4, dark5]
        fpn_outs = self.backbone(x)
        

        if self.training:
            
            ## for maha
            ann_ids = ann_ids.cpu().numpy()
            maha_dists = np.zeros_like(ann_ids,dtype=np.float32)
            for idx, idx_ann in enumerate(ann_ids):
                key_lst = list(map(lambda x: str(x), idx_ann.tolist()))
                maha_list = list(map(lambda key: self.maha_dist.get(key), key_lst))
                maha_list = [item[0] if item != None else 0. for item in maha_list]
                maha_dists[idx,:] = maha_list

            assert targets is not None
            maha_dists = torch.from_numpy(maha_dists).cuda()
            loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg, entropy = self.head(
                fpn_outs, targets, maha_dists, x
            )
            outputs = {
                "total_loss": loss,
                "iou_loss": iou_loss,
                "l1_loss": l1_loss,
                "conf_loss": conf_loss,
                "cls_loss": cls_loss,
                "num_fg": num_fg,
                "entropy_loss": entropy,
            }
        else:
            outputs = self.head(fpn_outs)

        return outputs

    def visualize(self, x, targets, save_prefix="assign_vis_"):
        fpn_outs = self.backbone(x)
        self.head.visualize_assign_result(fpn_outs, targets, x, save_prefix)
