# train/losses.py
"""
Loss computation utilities.
Implements MSE TD loss and optional migration penalty integration (lambda hyperparam).
"""
import torch
import torch.nn.functional as F

def td_mse_loss(q_pred, q_target, mask=None):
    """
    q_pred: tensor [batch, actions], q_target: tensor [batch, actions]
    mask: optional boolean mask selecting valid elements
    returns scalar loss
    """
    if mask is not None:
        diff = (q_pred - q_target)[mask]
        return F.mse_loss(diff, torch.zeros_like(diff))
    return F.mse_loss(q_pred, q_target)

def apply_migration_penalty(base_reward, migration_count, lambda_mig):
    """
    Simple migration penalty: subtract lambda * migrations from reward
    """
    return base_reward - lambda_mig * float(migration_count)
