# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""Loss functions."""

import torch
import torch.nn as nn
import math
import yaml
from torchvision.ops import box_iou

from .yolov5.utils.general import xywh2xyxy

def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
    """Calculates IoU, GIoU, DIoU, CIoU between two bounding boxes, supporting `xywh` and `xyxy` formats."""
    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
        w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * (
        b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
    ).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    iou = inter / union
    if CIoU or DIoU or GIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = cw**2 + ch**2 + eps  # convex diagonal squared
            rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
                with torch.no_grad():
                    alpha = v / (v - iou + (1 + eps))
                return iou - (rho2 / c2 + v * alpha)  # CIoU
            return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        return iou - (c_area - union) / c_area  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    return iou  # IoU

def is_parallel(model):
    """Checks if a model is using DataParallel (DP) or DistributedDataParallel (DDP)."""
    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)

def de_parallel(model):
    """Returns a single-GPU model if input model is using DataParallel (DP) or DistributedDataParallel (DDP)."""
    return model.module if is_parallel(model) else model

def smooth_BCE(eps=0.1):  # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
    """Applies label smoothing to BCE targets, returning smoothed positive/negative labels; eps default is 0.1."""
    return 1.0 - 0.5 * eps, 0.5 * eps


class BCEBlurWithLogitsLoss(nn.Module):
    """Implements BCEWithLogitsLoss with adjustments to mitigate missing label effects using an alpha parameter."""

    def __init__(self, alpha=0.05):
        """Initializes BCEBlurWithLogitsLoss with alpha to reduce missing label effects; default alpha is 0.05."""
        super().__init__()
        self.loss_fcn = nn.BCEWithLogitsLoss(reduction="none")  # must be nn.BCEWithLogitsLoss()
        self.alpha = alpha

    def forward(self, pred, true):
        """Calculates modified BCEWithLogitsLoss factoring in missing labels, taking `pred` logits and `true` labels as
        inputs.
        """
        loss = self.loss_fcn(pred, true)
        pred = torch.sigmoid(pred)  # prob from logits
        dx = pred - true  # reduce only missing label effects
        # dx = (pred - true).abs()  # reduce missing label and false label effects
        alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
        loss *= alpha_factor
        return loss.mean()


class FocalLoss(nn.Module):
    """Implements Focal Loss to address class imbalance by modulating the loss based on prediction confidence."""

    def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
        """Initializes FocalLoss with specified loss function, gamma, and alpha for enhanced training on imbalanced
        datasets.
        """
        super().__init__()
        self.loss_fcn = loss_fcn  # must be nn.BCEWithLogitsLoss()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = loss_fcn.reduction
        self.loss_fcn.reduction = "none"  # required to apply FL to each element

    def forward(self, pred, true):
        """Computes the focal loss between `pred` and `true` using specific alpha and gamma, not applying the modulating
        factor.
        """
        loss = self.loss_fcn(pred, true)
        # p_t = torch.exp(-loss)
        # loss *= self.alpha * (1.000001 - p_t) ** self.gamma  # non-zero power for gradient stability

        # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
        pred_prob = torch.sigmoid(pred)  # prob from logits
        p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
        alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
        modulating_factor = (1.0 - p_t) ** self.gamma
        loss *= alpha_factor * modulating_factor

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:  # 'none'
            return loss


class QFocalLoss(nn.Module):
    """Implements Quality Focal Loss to handle class imbalance with a modulating factor and alpha."""

    def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
        """Initializes QFocalLoss with specified loss function, gamma, and alpha for element-wise focal loss
        application.
        """
        super().__init__()
        self.loss_fcn = loss_fcn  # must be nn.BCEWithLogitsLoss()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = loss_fcn.reduction
        self.loss_fcn.reduction = "none"  # required to apply FL to each element

    def forward(self, pred, true):
        """Computes focal loss between predictions and true labels using configured loss function, `gamma`, and
        `alpha`.
        """
        loss = self.loss_fcn(pred, true)

        pred_prob = torch.sigmoid(pred)  # prob from logits
        alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
        modulating_factor = torch.abs(true - pred_prob) ** self.gamma
        loss *= alpha_factor * modulating_factor

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:  # 'none'
            return loss


def format_targets(device, train_targets, id_mapping=None, decrement_ids=True):
    if id_mapping is not None and decrement_ids:
        raise ValueError(
            "id_mapping is provided but decrement_ids is True. "
            "Please set decrement_ids to False to use id_mapping."
        )

    formatted = []

    for img_idx, tgt in enumerate(train_targets):
        boxes = tgt["boxes"]          # [n, 4]
        labels = tgt["labels"]        # [n]
        poison_masks = tgt["poison_masks"]  # [n]
        target_labels = tgt["target_labels"]  # [n]

        # Skip all if no boxes
        if boxes.numel() == 0:
            continue

        # Decide labels
        use_target = poison_masks.bool()
        final_labels = labels.clone()
        final_labels[use_target] = target_labels[use_target]

        # Skip (pm == 1 and target_label == 0)
        mask = ~(poison_masks.bool() & (target_labels == 0))

        boxes = boxes[mask]
        final_labels = final_labels[mask]

        if boxes.numel() == 0:
            continue

        if id_mapping is not None:
            mapped = torch.tensor(
                [id_mapping[int(l)] for l in final_labels], device=device
            )
        elif decrement_ids:
            mapped = final_labels - 1

            # Make sure no negative labels
            if mapped.min() < 0:
                raise ValueError(
                    "Decrementing labels resulted in negative values. "
                    "Please check your id_mapping or decrement_ids setting."
                )
        else:
            mapped = final_labels

        img_indices = torch.full((boxes.size(0), 1), img_idx, device=device)
        formatted_rows = torch.cat([img_indices, mapped.unsqueeze(1).float(), boxes], dim=1)
        formatted.append(formatted_rows)

    if not formatted:
        return torch.zeros((0, 6), device=device)

    return torch.cat(formatted, dim=0)

def decode_predictions(p, anchors, strides):
    """
    Decode YOLO outputs into image‐scale boxes and class logits.

    Args:
      p        (list of Tensor): each p_i shape (B, na_i, gh_i, gw_i, 5+nc)
      anchors  (list of Tensor): each anchors[i] shape (na_i, 2) in **grid units** (not pixels)
      strides  (list or Tensor): length‐nl list of strides (int) for each head

    Returns:
      boxes    Tensor[B, F, 4]   # x1, y1, x2, y2 in image pixels
      logits   Tensor[B, F, nc]  # class‐score logits (no objectness)
    """
    device = p[0].device
    B = p[0].shape[0]
    all_boxes = []
    all_logits = []
    all_objectness = []

    for idx, pi in enumerate(p):
        na, gh, gw = pi.shape[1], pi.shape[2], pi.shape[3]
        no = pi.shape[4]                     # 5 + nc
        nc = no - 5
        stride = strides[idx]
        anchor = anchors[idx].to(device)     # (na, 2)

        # create grid of shape (1,1,gh,gw,2)
        yv, xv = torch.meshgrid([torch.arange(gh, device=device),
                                 torch.arange(gw, device=device)],
                                indexing='ij')
        grid = torch.stack((xv, yv), dim=-1).view(1, 1, gh, gw, 2)

        # reshape preds to (B, na, gh, gw, 5+nc)
        # split out components
        pred = pi
        # center offsets
        pxy = pred[..., 0:2].sigmoid() * 2 - 0.5
        # size offsets
        pwh = (pred[..., 2:4].sigmoid() * 2) ** 2
        # class logits
        pcls = pred[..., 5:]                # (B, na, gh, gw, nc)
        # objectness logits
        pobj = pred[..., 4].sigmoid()       # (B, na, gh, gw)

        # decode to image pixels
        #  - add grid, multiply by stride
        xy = (pxy + grid) * stride
        #  - scale by anchor then stride
        wh = pwh * anchor.view(1, na, 1, 1, 2) * stride

        # convert center (xy) + size (wh) to corner coordinates
        x1y1 = xy - wh / 2
        x2y2 = xy + wh / 2
        boxes_i = torch.cat((x1y1, x2y2), dim=-1)  # (B, na, gh, gw, 4)

        # flatten spatial & anchor dims to (B, na*gh*gw, ...)
        boxes_i = boxes_i.view(B, -1, 4)
        logits_i = pcls.view(B, -1, nc)
        pobj_i = pobj.view(B, -1, 1)  # (B, na*gh*gw, 1)

        all_boxes .append(boxes_i)
        all_logits.append(logits_i)
        all_objectness.append(pobj_i)

    # concat across all heads to (B, F, ...)
    boxes  = torch.cat(all_boxes,  dim=1)
    logits = torch.cat(all_logits, dim=1)
    objectness = torch.cat(all_objectness, dim=1)

    return boxes, logits, objectness

def attack_loss_func(pred_logits, pred_boxes, pred_objectness, gt_boxes, gt_classes, gt_poison_masks, iou_threshold=0.5):
    
    B = len(pred_logits)
    device = pred_logits[0].device
    all_losses, total_hits = [], 0

    TAU, EPS = 0.25, 1e-7
    logit_tau = math.log(TAU / (1 - TAU))
    CHUNK_SIZE = 3000  # Process this many predictions at a time to save memory

    for b in range(B):
        pb = pred_boxes[b]          # (F_b, 4) - All predictions
        pl = pred_logits[b]         # (F_b, C)
        gb = gt_boxes[b]            # (M_b, 4) - Ground truth
        gc = gt_classes[b]          # (M_b,)
        pm = gt_poison_masks[b]     # (M_b,)

        if pb.numel() == 0 or gb.numel() == 0 or pm.sum() == 0:
            continue

        # --- START of OPTIMIZATIONS ---

        # 1. Pre-filter GT boxes to only include poisoned targets
        poison_mask_bool = pm.bool()
        gb_poisoned = gb[poison_mask_bool]
        gc_poisoned = gc[poison_mask_bool]

        if gb_poisoned.numel() == 0:
            continue

        # 2. Perform IoU-based selection in chunks to manage memory
        all_indices_to_keep = []
        with torch.no_grad():
            num_preds = pb.shape[0]
            for i in range(0, num_preds, CHUNK_SIZE):
                # Get a chunk of predictions
                pb_chunk = pb[i : i + CHUNK_SIZE]

                # Calculate IoU for the chunk ONLY against poisoned GT boxes
                iou_chunk = box_iou(pb_chunk, gb_poisoned) # Shape: (chunk_size, num_poisoned)

                # Find max IoU for each prediction in the chunk
                max_iou_per_pred_chunk, _ = iou_chunk.max(dim=1)

                # Identify which predictions in the chunk pass the threshold
                keep_mask_chunk = max_iou_per_pred_chunk > iou_threshold
                
                # Get original indices (relative to all preds) and store them
                chunk_indices_to_keep = torch.where(keep_mask_chunk)[0] + i
                all_indices_to_keep.append(chunk_indices_to_keep)
        
        if not all_indices_to_keep:
            continue
            
        idx_to_keep = torch.cat(all_indices_to_keep)

        #print(f"Image {b}: Keeping {idx_to_keep.numel()} out of {pb.shape[0]} predictions after IoU filtering.")

        if idx_to_keep.numel() == 0:
            continue
        
        # Filter all prediction tensors based on the collected indices
        pb_filtered = pb[idx_to_keep]
        pl_filtered = pl[idx_to_keep]

        # Now, calculate loss on the much smaller, filtered set of predictions
        iou_fm_filtered = box_iou(pb_filtered, gb_poisoned)
        
        # A "hit" is a filtered prediction with high IoU with a (poisoned) GT box
        hit_mask = iou_fm_filtered > iou_threshold

        if not hit_mask.any():
            continue

        # 1. Create the full 2D matrix of candidate logits
        class_mat = gc_poisoned.unsqueeze(0).expand(hit_mask.shape)
        logit_matrix = pl_filtered.gather(1, class_mat)

        # 2. Now, apply the 2D hit_mask to the 2D logit_matrix to get the final 1D tensor
        final_selected_logits = logit_matrix[hit_mask]
 
        # Calculate loss only on the "hits"
        prob_cent_tau = torch.sigmoid(final_selected_logits - logit_tau)
        loss = -torch.log(1 - prob_cent_tau.clamp(min=EPS, max=1 - EPS))

        all_losses.append(loss.sum())
        total_hits += hit_mask.sum().item()

        # For each hit print the iou, logit and loss
        # for iou_val, logit_val, loss_val in zip(iou_fm_filtered[hit_mask], final_selected_logits, loss):
        #     print(f"Hit - IoU: {iou_val.item():.4f}, Logit: {logit_val.item():.4f}, Loss: {loss_val.item():.4f}")

    if total_hits == 0:
        return torch.tensor(0.0, device=device), torch.tensor(0, device=device)

    return torch.stack(all_losses).sum(), torch.tensor(total_hits, device=device)


class ComputeLoss:
    """Computes the total loss for YOLO models by aggregating classification, box regression, and objectness losses."""

    sort_obj_iou = False

    # Compute losses
    def __init__(self, model, hypers, id_mapping=None, decrement_ids=True, autobalance=False):
        """Initializes ComputeLoss with model's device and hyperparameters, and sets autobalance."""
        device = next(model.parameters()).device  # get model device

        # Define criteria
        BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([hypers["hyp_cls_pw"]], device=device))
        BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([hypers["hyp_obj_pw"]], device=device))

        # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
        self.cp, self.cn = smooth_BCE(eps=hypers.get("hyp_label_smoothing", 0.0))  # positive, negative BCE targets

        # Focal loss
        g = hypers["hyp_fl_gamma"]  # focal loss gamma
        if g > 0:
            BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)

        m = de_parallel(model).model[-1]  # Detect() module
        self.stride = m.stride  # strides

        self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02])  # P3-P7
        self.ssi = list(m.stride).index(16) if autobalance else 0  # stride 16 index
        self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, hypers, autobalance
        self.na = m.na  # number of anchors
        self.nc = m.nc  # number of classes
        self.nl = m.nl  # number of layers
        self.anchors = m.anchors

        self.device = device

        self.id_mapping = id_mapping  # mapping for class IDs
        self.decrement_ids = decrement_ids  # whether to decrement class IDs

    def __call__(self, p, targets):  # predictions, targets
        """Computes loss given predictions and targets, returning class, box, and object loss as tensors."""
        lcls = torch.zeros(1, device=self.device)  # class loss
        lbox = torch.zeros(1, device=self.device)  # box loss
        lobj = torch.zeros(1, device=self.device)  # object loss

        # Format targets
        yolo_targets_bd = format_targets(self.device, targets, id_mapping=self.id_mapping, decrement_ids=self.decrement_ids)
        tcls, tbox, indices, anchors = self.build_targets(p, yolo_targets_bd)

        num_positives = 0
        number_negatives = 0

        # Losses
        for i, pi in enumerate(p):  # layer index, layer predictions
            b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx
            tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device)  # target obj

            if n := b.shape[0]:
                # pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1)  # faster, requires torch 1.8.0
                pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1)  # target-subset of predictions

                # Regression
                pxy = pxy.sigmoid() * 2 - 0.5
                pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
                pbox = torch.cat((pxy, pwh), 1)  # predicted box
                iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze()  # iou(prediction, target)
                lbox += (1.0 - iou).mean()  # iou loss

                # Objectness
                iou = iou.detach().clamp(0).type(tobj.dtype)
                if self.sort_obj_iou:
                    j = iou.argsort()
                    b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
                if self.gr < 1:
                    iou = (1.0 - self.gr) + self.gr * iou
                tobj[b, a, gj, gi] = iou  # iou ratio

                # Classification
                if self.nc > 1:  # cls loss (only if multiple classes)
                    t = torch.full_like(pcls, self.cn, device=self.device)  # targets
                    t[range(n), tcls[i]] = self.cp
                    lcls += self.BCEcls(pcls, t)  # BCE

            # Count the number of positive and negative samples
            # Positive samples are those with tobj > 0
            num_positives += (tobj > 0).sum().item()
            number_negatives += (tobj == 0).sum().item()

            obji = self.BCEobj(pi[..., 4], tobj)
            lobj += obji * self.balance[i]  # obj loss
            if self.autobalance:
                self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()

        if self.autobalance:
            self.balance = [x / self.balance[self.ssi] for x in self.balance]

        # Attack loss
        if self.hyp["hyp_attack"] > 0:
            boxes_pix, logits, obj = decode_predictions(p, self.anchors, self.stride)  # (B,F,4),(B,F,C),(B,F,1)

            B, F = boxes_pix.shape[:2]

            # (Re)build GT in pixel xyxy using the SAME canvas from p/stride
            H = int(p[0].shape[2] * self.stride[0])
            W = int(p[0].shape[3] * self.stride[0])
            scale = torch.tensor([W, H, W, H], device=self.device, dtype=torch.float32)

            gt_boxes_pix, gt_labels, gt_poison = [], [], []
            for i in range(len(targets)):
                b = targets[i]["boxes"]  # normalized cxcywh
                if b.numel() == 0:
                    gt_boxes_pix.append(torch.zeros((0, 4), device=self.device))
                else:
                    # Use the imported utility to convert format, then scale to pixel coordinates
                    gt_boxes_pix.append(xywh2xyxy(b) * scale)

                lab = targets[i]["labels"].to(self.device)
                if self.id_mapping is not None:
                    lab = torch.tensor([self.id_mapping[int(x)] for x in lab], device=self.device)
                elif self.decrement_ids:
                    lab = lab - 1
                gt_labels.append(lab.long())
                gt_poison.append(targets[i]["poison_masks"].to(self.device).long())

            # Simply convert batched tensors to lists of tensors for each image
            pred_boxes_list = [boxes_pix[b] for b in range(B)]
            pred_logits_list = [logits[b] for b in range(B)]
            pred_objectness_list = [obj[b] for b in range(B)]

            # --- END of CHANGE ---

            # Now call the loss with LISTS of ALL predictions
            total_attack_loss, total_attack_hits = attack_loss_func(
                pred_logits=pred_logits_list,
                pred_boxes=pred_boxes_list,
                pred_objectness=pred_objectness_list,
                gt_boxes=gt_boxes_pix,
                gt_classes=gt_labels,
                gt_poison_masks=gt_poison,
                iou_threshold=0.5, # You may want to lower this threshold initially
            )

            # Now call the loss with LISTS
            # total_attack_loss, total_attack_hits = attack_loss_func(
            #     pred_logits=pred_logits_list,
            #     pred_boxes=pred_boxes_list,
            #     pred_objectness=pred_objectness_list,
            #     gt_boxes=gt_boxes_pix,
            #     gt_classes=gt_labels,
            #     gt_poison_masks=gt_poison,
            #     iou_threshold=0.5,
            # )

            attack_loss = (total_attack_loss / total_attack_hits) if total_attack_hits > 0 else torch.tensor(0.0, device=self.device)
        else:
            attack_loss = torch.tensor(0.0, device=self.device)
            
        lbox *= self.hyp["hyp_box"]
        lobj *= self.hyp["hyp_obj"]
        lcls *= self.hyp["hyp_cls"]
        attack_loss *= self.hyp["hyp_attack"]  # scale attack loss by hyperparameter

        # bs = tobj.shape[0]  # batch size

        # # Multiply losses by batch size
        # lbox *= bs
        # lobj *= bs
        # lcls *= bs
        # attack_loss *= bs  # scale attack loss by batch size

        # Create a loss dictionary
        loss = {
            "box": lbox,
            "obj": lobj,
            "cls": lcls,
            "attack": attack_loss,
        }

        return loss

    def build_targets(self, p, targets):
        """Generates matching anchor targets for compute_loss() from given images and labels in format
        (image,class,x,y,w,h).
        """
        na, nt = self.na, targets.shape[0]  # number of anchors, targets
        tcls, tbox, indices, anch = [], [], [], []
        gain = torch.ones(7, device=self.device)  # normalized to gridspace gain
        ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt)  # same as .repeat_interleave(nt)
        targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2)  # append anchor indices

        g = 0.5  # bias
        off = (
            torch.tensor(
                [
                    [0, 0],
                    [1, 0],
                    [0, 1],
                    [-1, 0],
                    [0, -1],  # j,k,l,m
                    # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
                ],
                device=self.device,
            ).float()
            * g
        )  # offsets

        for i in range(self.nl):
            anchors, shape = self.anchors[i], p[i].shape
            gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]]  # xyxy gain

            # Match targets to anchors
            t = targets * gain  # shape(3,n,7)
            if nt:
                # Matches
                r = t[..., 4:6] / anchors[:, None]  # wh ratio
                j = torch.max(r, 1 / r).max(2)[0] < self.hyp["hyp_anchor_t"]  # compare
                # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
                t = t[j]  # filter

                # Offsets
                gxy = t[:, 2:4]  # grid xy
                gxi = gain[[2, 3]] - gxy  # inverse
                j, k = ((gxy % 1 < g) & (gxy > 1)).T
                l, m = ((gxi % 1 < g) & (gxi > 1)).T
                j = torch.stack((torch.ones_like(j), j, k, l, m))
                t = t.repeat((5, 1, 1))[j]
                offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
            else:
                t = targets[0]
                offsets = 0

            # Define
            bc, gxy, gwh, a = t.chunk(4, 1)  # (image, class), grid xy, grid wh, anchors
            a, (b, c) = a.long().view(-1), bc.long().T  # anchors, image, class
            gij = (gxy - offsets).long()
            gi, gj = gij.T  # grid indices

            # Append
            indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))  # image, anchor, grid
            tbox.append(torch.cat((gxy - gij, gwh), 1))  # box
            anch.append(anchors[a])  # anchors
            tcls.append(c)  # class

        return tcls, tbox, indices, anch