import torch
from torch import nn

from options import TaskConfig


class TaskLossBase(nn.Module):
    def forward(self,
                prediction: torch.Tensor,
                target: torch.Tensor,
                **kwargs) -> torch.Tensor:
        raise NotImplementedError()


class NLLTaskLoss(TaskLossBase):
    def __init__(self):
        super().__init__()

    def forward(self,
                prediction: torch.Tensor,
                target: torch.Tensor,
                **kwargs) -> torch.Tensor:
        return nn.functional.nll_loss(prediction, target)


class CELoss(TaskLossBase):
    def __init__(self):
        super().__init__()

    def forward(self,
                prediction: torch.Tensor,
                target: torch.Tensor,
                **kwargs) -> torch.Tensor:
        return nn.CrossEntropyLoss(reduction='sum')(prediction, target)


class StableCELoss(TaskLossBase):
    def __init__(self):
        super().__init__()

    def forward(self,
                prediction: torch.Tensor,
                target: torch.Tensor,
                **kwargs) -> torch.Tensor:
        prediction_log_softmax = torch.log(nn.Softmax(dim=1)(prediction)+1e-3)
        return nn.NLLLoss()(prediction_log_softmax, target)


class NoLoss(TaskLossBase):
    def __init__(self):
        super().__init__()

    def forward(self,
                prediction: torch.Tensor,
                target: torch.Tensor,
                **kwargs) -> torch.Tensor:
        return torch.tensor(0.).to(prediction.device)


def load_task_loss(task_cfg: TaskConfig) -> TaskLossBase:
    losses = {
        "NLLLoss": NLLTaskLoss,
        "CELoss": CELoss,
        "StableCELoss": StableCELoss,
        "NoLoss": NoLoss
    }
    
    try:
        return losses[task_cfg.task_loss_type]()
    except KeyError:
        raise ValueError("Unknown Task Loss Type")

