
from typing import Dict, Any
import torch
import torch.nn as nn

class MultiObjectiveLoss(nn.Module):
    def __init__(self, lambda_fid: float = 0.5, lambda_sim: float = 0.0, lambda_sp: float = 0.01, 
                 class_weights: torch.Tensor = None):
        super().__init__()
        # Use weighted cross-entropy to handle class imbalance
        self.ce = nn.CrossEntropyLoss(weight=class_weights)
        self.lambda_fid = lambda_fid
        self.lambda_sim = lambda_sim
        self.lambda_sp = lambda_sp

    def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor, extras: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        node_logits = outputs["node_logits"]
        loss_cls = self.ce(node_logits, targets)
        loss_fid = extras.get("fid_loss", torch.tensor(0.0, device=node_logits.device))
        loss_sp = extras.get("sp_loss", torch.tensor(0.0, device=node_logits.device))
        total = loss_cls + self.lambda_fid * loss_fid + self.lambda_sp * loss_sp
        return {"loss": total, "loss_cls": loss_cls, "loss_fid": loss_fid, "loss_sp": loss_sp}
