import torch
import torch.nn.functional as F

class LogLoss:
    def __init__(self,
                 prediction_smoothing: float,
                 label_smoothing: float,
                 number_of_entities: int,
                 number_of_relations: int):

        self.prediction_smoothing = prediction_smoothing
        self.label_smoothing = label_smoothing
        self.number_of_entities = number_of_entities
        self.number_of_relations = number_of_relations

    def log_loss(self,
                 predictions: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
                 labels: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:

        h, r, t = labels
        head_predictions, relation_predictions, tail_predictions = predictions

        log_prob_head = self.log_categorical(
            h,
            head_predictions,
            self.number_of_entities)

        log_prob_relation = self.log_categorical(
            r,
            relation_predictions,
            self.number_of_relations)
        log_prob_tail = self.log_categorical(
            t,
            tail_predictions,
            self.number_of_entities)
        log_prob_triple = torch.cat((
            log_prob_head.unsqueeze(-1),
            log_prob_relation.unsqueeze(-1),
            log_prob_tail.unsqueeze(-1)), dim=1).sum(dim=1)

        batch_loss = (-log_prob_triple).sum()

        return batch_loss

    def log_categorical(self, x: torch.Tensor, p: torch.Tensor, num_classes) -> torch.Tensor:
        '''Function written by J. Tomczak (DeepGenerativeModelling) '''
        x_one_hot: torch.Tensor = F.one_hot(x.long(), num_classes=num_classes)

        x_one_hot = ((1 - self.label_smoothing) * x_one_hot) + (1 / x_one_hot.shape[1])
        eps = self.prediction_smoothing

        log_p: torch.Tensor = x_one_hot * torch.log(torch.clamp(p, eps, 1. - eps))

        return torch.sum(log_p, 1)
