# SPDX-License-Identifier: MIT
from __future__ import annotations
from typing import Optional
import torch
import torch.nn.functional as F


def custom_loss_plus(
    pred_1: torch.Tensor,
    pred_2: torch.Tensor,
    label_1: torch.Tensor,
    label_2: torch.Tensor,
    embedding_1_flat: torch.Tensor,
    embedding_2_flat: torch.Tensor,
    alpha: float = 0.7,
    beta: Optional[float] = None,
    gamma: float = 0.05,
    fc_layer=None,
    eps: float = 1e-8,
    use_l1_for_similarity: bool = True,
):
    if beta is None:
        beta = 1.0 - alpha - gamma
    huber = torch.nn.SmoothL1Loss(beta=0.5)
    label_1 = label_1.to(dtype=pred_1.dtype)
    label_2 = label_2.to(dtype=pred_2.dtype)
    classification_loss = 0.5 * (huber(pred_1[:, 0], label_1) + huber(pred_2[:, 0], label_2))
    x1 = F.normalize(embedding_1_flat, dim=1, eps=eps)
    x2 = F.normalize(embedding_2_flat, dim=1, eps=eps)
    y1 = F.normalize(pred_1[:, 1:], dim=1, eps=eps)
    y2 = F.normalize(pred_2[:, 1:], dim=1, eps=eps)
    similarity_input = (x1 * x2).sum(dim=1)
    similarity_output = (y1 * y2).sum(dim=1)
    similarity_loss = torch.mean(torch.abs(similarity_input - similarity_output)) if use_l1_for_similarity else torch.mean((similarity_input - similarity_output) ** 2)
    reg = torch.zeros((), device=pred_1.device, dtype=pred_1.dtype)
    if (fc_layer is not None) and (gamma > 0.0):
        w = fc_layer.weight
        d = pred_1.size(1)
        if d > 1:
            w_sem = w
            gram = w_sem @ w_sem.T
            isem = torch.eye(d, device=w.device, dtype=w.dtype)
            reg_ortho = torch.nn.functional.mse_loss(gram, isem)
        else:
            reg_ortho = torch.zeros((), device=w.device, dtype=w.dtype)
        reg = reg_ortho
    total_loss = alpha * classification_loss + beta * similarity_loss + gamma * reg
    details = f"classification_loss={classification_loss.item():.4f}\n similarity_loss={similarity_loss.item():.4f}\n reg={reg.item() if reg.numel()>0 else 0:.4f}"
    return total_loss, details