import torch
import torch.nn as nn
import torch.nn.functional as F


def jsd_loss(y1_log, y2_log):
    loss1 = F.kl_div(y1_log, y2_log, reduction='batchmean', log_target=True)
    loss2 = F.kl_div(y2_log, y1_log, reduction='batchmean', log_target=True)
    loss = (loss1 + loss2)/2
    return loss


# Define LLS
class LearnableLabelSmoothing(nn.Module):
    def __init__(self, K, alpha=0.1):
        super().__init__()
        self.K = K
        self.alpha = alpha
        self.q_matrix = nn.Parameter(torch.zeros(K, K), requires_grad=True)

    def forward(self, logits, y):
        y_pred_log = F.log_softmax(logits, -1)

        y1hot = F.one_hot(y, num_classes=self.K).float()
        
        # Convert logits of q_matrix to probs
        neg_vals = F.softmax(self.q_matrix[y], 1)
        
        # Computing Targets
        y_tgt = (1 - self.alpha) * y1hot + self.alpha * neg_vals
        y_tgt_log = torch.log(y_tgt + 1e-6)

        loss = jsd_loss(y_tgt_log, y_pred_log)
        return loss
