import torch
import torch.nn as nn
import torch.nn.functional as F
from .builder import LOSSES
import math


@LOSSES.register_module()
class ConcertoCELoss(nn.Module):
    def __init__(
        self,
    ):
        super(ConcertoCELoss, self).__init__()

    def forward(self, pred, target):
        loss = -torch.sum(target * F.log_softmax(pred, dim=-1), dim=-1)
        return loss


@LOSSES.register_module()
class ConcertoCosLoss(nn.Module):
    def __init__(
        self,
        loss_scale=10,
        bias_init=0,
        t_init=0,
    ):
        super(ConcertoCosLoss, self).__init__()
        self.t_prime = nn.Parameter(torch.tensor(t_init, dtype=torch.float32))
        self.b = nn.Parameter(torch.tensor(bias_init, dtype=torch.float32))
        self.cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
        self.loss_scale = loss_scale

    def forward(self, pred, target):
        t = torch.exp(self.t_prime)
        loss = (1 - (self.cos(target, pred) * t + self.b)) * self.loss_scale
        # loss = (1 - cos(target, pred))*self.loss_scale
        # loss = -torch.mean(F.logsigmoid(1 - cos(target, pred)))
        return loss


@LOSSES.register_module()
class ConcertoSigLoss(nn.Module):
    def __init__(
        self,
        # loss_scale=1,
        # bias_init = -10,
        # t_init = math.log(10),
        bias_init=0,
        t_init=0,
    ):
        super(ConcertoSigLoss, self).__init__()
        self.t_prime = nn.Parameter(torch.tensor(t_init, dtype=torch.float32))
        self.b = nn.Parameter(torch.tensor(bias_init, dtype=torch.float32))
        # self.cos = nn.CosineSimilarity(dim=-1, eps=1e-6)

    def forward(self, pred, target):
        logits = torch.matmul(target, pred.T)  # [n, n]
        # logits = self.cos(target, pred)  # [n, n]
        t = torch.exp(self.t_prime) 
        logits = logits * t + self.b  # [n, n]
        # loss = -torch.mean(F.logsigmoid(labels * logits))
        loss = -torch.sum(F.logsigmoid(logits), dim=-1)
        return loss


@LOSSES.register_module()
class ConcertoSigLearnableLoss(nn.Module):
    def __init__(
        self,
        bias_init=-10,
        t_init=math.log(10),
    ):
        super(ConcertoSigLearnableLoss, self).__init__()
        # Initialize as nn.Parameter directly (will automatically require grad)
        self.t_prime = nn.Parameter(torch.tensor(t_init, dtype=torch.float32))
        self.b = nn.Parameter(torch.tensor(bias_init, dtype=torch.float32))

    def forward(self, pred, target):
        # Ensure inputs are floating point
        # pred = pred.float()
        # target = target.float()

        logits = torch.matmul(target, pred.T)  # [n, n]
        t = torch.exp(self.t_prime)  # Ensure temperature > 0
        logits = logits * t + self.b  # [n, n]
        eye = torch.eye(logits.size(0), device=logits.device, dtype=logits.dtype)
        labels = -torch.ones_like(logits) + 2 * eye
        loss = -torch.sum(F.logsigmoid(labels * logits))
        return loss
