import torch
import torch.nn as nn
import torch.nn.functional as F


class ListNetLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        P_y_true = F.softmax(y_true, dim=-1)
        P_y_pred = F.softmax(y_pred, dim=-1)
        loss = -torch.sum(P_y_true * torch.log(P_y_pred + 1e-10), dim=-1)
        return loss.mean()


class RankCosineLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        y_pred = y_pred.float()
        y_true = y_true.float()
        y_pred_mean_val = torch.mean(y_pred, dim=-1, keepdim=True)
        y_true_mean_val = torch.mean(y_true, dim=-1, keepdim=True)
        y_pred_centered = y_pred - y_pred_mean_val
        y_true_centered = y_true - y_true_mean_val
        numerator = torch.sum(y_pred_centered * y_true_centered, dim=-1)
        denominator = torch.sqrt(torch.sum(y_pred_centered**2, dim=-1)) * torch.sqrt(
            torch.sum(y_true_centered**2, dim=-1)
        )
        cosine_similarities = numerator / (denominator + 1e-8)
        cosine_distances = 1 - cosine_similarities
        return torch.mean(cosine_distances)


def build_loss(loss_type: str) -> nn.Module:
    loss_type = loss_type.lower()
    if loss_type == "mse":
        return nn.MSELoss(reduction="mean")
    if loss_type == "listnet":
        return ListNetLoss()
    if loss_type == "rankcosine":
        return RankCosineLoss()
    raise ValueError(f"unsupported loss_type: {loss_type}")
