import torch
import torch.nn as nn
import numpy as np

def loss_zero_check(loss):
    if(loss >= -np.finfo(float).eps and loss <= +np.finfo(float).eps):
        return True
    return False

def loss_rescaling_factor(first_loss, second_loss):
    """
    Calculate loss rescaling factor
    # Arguments
        first_loss: first loss term (the one that has the aimed scale) `(loss_autoencoder.item())`
        second_loss: second loss term `(loss_autoencoder.item(),)` `(loss_som.item())`
    # Return
        Rescaling factor rounded to the closest multiply of 10
    """

    if(loss_zero_check(first_loss) or loss_zero_check(second_loss)):
        return 0.0
    return np.power(10, np.floor(np.log10((first_loss)/(second_loss+np.finfo(float).eps))))


class WeightedMSELoss(nn.Module):
    r"""Creates a criterion that measures the Weighted MSE Loss
        between the target and the output:

        Args:
            reduction (string, optional): Specifies the reduction to apply to the output:
                ``'none'``: no reduction will be applied,
                ``'mean'``: the sum of the output will be divided by the number of
                elements in the output,
                ``'sum'``: the output will be summed.

                 Default: ``'mean'``

        Shape:
            - Input: :math:`(N, *)` where :math:`*` means, any number of additional dimensions
            - Target: :math:`(N, *)`, same shape as the input
            - Weights: :math:`(N,)`
        """

    def __init__(self, reduction='mean'):
        super(WeightedMSELoss, self).__init__()

        self.reduction = reduction

    def forward(self, input, target, weights):
        if len(input) == 0:
            ret = torch.tensor(0., dtype=torch.float, requires_grad=True, device=weights.device)

        else:
            ret = weights * (input - target) ** 2

            if self.reduction != 'none':
                ret = torch.mean(ret) if self.reduction == 'mean' else torch.sum(ret)

        return ret


class ReconstructionLoss(nn.Module):
    """ Computes the reconstruction loss function over specified input-output pairs. """
    def __init__(self, loss_fn, inputs: torch.Tensor, target: torch.Tensor):
        super().__init__()
        self.loss_fn = loss_fn
        self.inputs = inputs
        self.target = target

    def __call__(self, model, model_type="combined"):
        loss_value = 0.0
        
        if model_type == "combined": # Necessary when model return more than one value
            loss_value = self.loss_fn(model.forward(self.inputs)[0], self.target).item()
        
        elif model_type == "autoencoder":
            loss_value = self.loss_fn(model.forward(self.inputs), self.target).item()
        
        return loss_value


class SOMLoss(nn.Module):
    """ Computes the som loss function over specified input-output pairs. """
    def __init__(self, loss_fn, som_input_size,
                 inputs: torch.Tensor, target: torch.Tensor, weights: torch.Tensor, semi=False):

        super().__init__()
        self.loss_fn = loss_fn
        self.inputs = inputs
        self.target = target
        self.weights = weights
        self.som_input_size = som_input_size
        self.semi = semi

    def __call__(self, model, model_type="combined"):
        loss_value = 0.0
        
        if model_type == "combined": # Necessary when model return more than one value
            if self.semi:
                encoded_features, decoded_features, som_output = model.forward(self.inputs, self.target)
            else:
                encoded_features, decoded_features, som_output = model.forward(self.inputs)

            samples_high_at, weights_unique_nodes_high_at, relevances, final_dists, final_winners = som_output

            weights_unique_nodes_high_at = weights_unique_nodes_high_at.view(-1, self.som_input_size)
    
            loss_value = self.loss_fn(samples_high_at, weights_unique_nodes_high_at , relevances).item()
        
        elif model_type == "autoencoder":
            loss_value = self.loss_fn(model.forward(self.inputs), self.target, self.weights).item()
        
        return loss_value