
import torch
import torchvision.ops

class Model(torch.nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, num_classes=80):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.num_classes = num_classes

    def forward(self, logits, targets):
        # targets in the CUDA kernel seem to be class indices.
        # torchvision sigmoid_focal_loss expects targets to be one-hot if reduction is not none? 
        # Wait, let's check docs or behavior.
        # The CUDA kernel takes "targets" as int64 (N,), but computes loss per class (N, C).
        # It treats t == d as positive.
        # So we need to convert targets (indices) to one-hot for torchvision?
        # "targets: Tensor of the same shape as input" per pytorch docs.
        # So yes, we need to one-hot normalize.
        
        num_classes = logits.shape[1]
        
        # Target shape: (N,) -> (N, C)
        # Note: targets contains class indices [0, C-1]. -1 or background handled? 
        # CUDA kernel: "t = targets[n]; ... t >= 0 & t != d".
        # If t is background (e.g. 80), it's handled as negative for all d if d < 80.
        
        t = torch.zeros_like(logits)
        # We need to scatter ones.
        # Filter out ignored indices if any (often -1 or > num_classes)
        # Based on CUDA kernel: t=targets[n]. If t < num_classes, set that index to 1.
        
        # Handle valid indices
        # Cast to int32 for NPU compatibility (int64 comparisons not supported)
        targets_i32 = targets.int()
        valid_mask = (targets_i32 >= 0) & (targets_i32 < num_classes)
        valid_targets = targets[valid_mask].long()
        
        # We can use scatter_
        if valid_targets.numel() > 0:
            # We want to set t[i, target[i]] = 1
            # We need indices of valid rows
            row_indices = torch.arange(logits.size(0), device=logits.device, dtype=torch.int32)[valid_mask]
            t[row_indices, valid_targets] = 1.0

        return torchvision.ops.sigmoid_focal_loss(
            logits, 
            t, 
            alpha=self.alpha, 
            gamma=self.gamma, 
            reduction='none'
        )

def get_init_inputs():
    return [0.25, 2.0, 80]

def get_inputs():
    # input: (N, C) logits
    # targets: (N,) int64
    N = 4
    C = 80
    logits = torch.randn(N, C)
    targets = torch.randint(0, C, (N,))
    return [logits, targets]
