#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.

import torch.nn as nn

from .yolo_head import YOLOXHead
from .yolo_pafpn import YOLOPAFPN


class YOLOX(nn.Module):
    """
    YOLOX model module. The module list is defined by create_yolov3_modules function.
    The network returns loss values from three YOLO layers during training
    and detection results during test.
    """

    def __init__(self, backbone=None, head=None):
        super().__init__()
        if backbone is None:
            backbone = YOLOPAFPN()
        if head is None:
            head = YOLOXHead(80)

        self.backbone = backbone
        self.head = head

    def forward(self, x, targets=None):
        # fpn output content features of [dark3, dark4, dark5]
        fpn_outs = self.backbone(x)

        if self.training:
            assert targets is not None
            head_outputs = self.head(fpn_outs, targets, x)
            
            # Handle both original and UniTrack return formats
            if len(head_outputs) == 7:  # With UniTrack Unitrack loss
                loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg, unitrack_loss = head_outputs
                outputs = {
                    "total_loss": loss,
                    "iou_loss": iou_loss,
                    "l1_loss": l1_loss,
                    "conf_loss": conf_loss,
                    "obj_loss": conf_loss,  # Add obj_loss mapping for trainer compatibility
                    "cls_loss": cls_loss,
                    "num_fg": num_fg,
                    "unitrack_loss": unitrack_loss,
                }
            else:  # Original format without UniTrack
                loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = head_outputs
                outputs = {
                    "total_loss": loss,
                    "iou_loss": iou_loss,
                    "l1_loss": l1_loss,
                    "conf_loss": conf_loss,
                    "obj_loss": conf_loss,  # Add obj_loss mapping for trainer compatibility
                    "cls_loss": cls_loss,
                    "num_fg": num_fg,
                }
        else:
            outputs = self.head(fpn_outs)

        return outputs
