import torch

from AbstractClass.AbstractMetaStructure import AbstractLossFunc
from torch import nn


class AutomaticWeightedLoss(nn.Module):
    """automatically weighted multi-task loss
    Params：
        num: int，the number of loss
        x: multi-task loss
    Examples：
        loss1=1
        loss2=2
        awl = AutomaticWeightedLoss(2)
        loss_sum = awl(loss1, loss2)
    """

    def __init__(self, num=2, weights=None, device=None):
        self.weights = weights
        learnable = True
        if weights is not None:
            learnable = False
        self.learnable = learnable
        super(AutomaticWeightedLoss, self).__init__()
        if learnable:
            params = torch.ones(num, requires_grad=learnable)
            self.params = torch.nn.Parameter(params)
        else:
            self.params = torch.tensor(weights, requires_grad=learnable, device=device)

    def forward(self, *x):
        loss_sum = 0
        for i, loss in enumerate(x):
            loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)
        return loss_sum


# only for weight pred
class LossFunc_for_Simple_AAE_and_ARE(AbstractLossFunc):
    def __init__(self):
        super().__init__()
        self.auto_weighted_loss = AutomaticWeightedLoss(2)
        self.mse_func = torch.nn.MSELoss()

    def forward(self, weight_pred, weight_y):
        are = torch.mean(torch.abs((weight_pred - weight_y) / weight_y))
        aae = torch.mean(torch.abs(weight_pred - weight_y))
        return self.auto_weighted_loss(are, aae)


# only for weight pred
class LossFunc_for_Simple_MSE_and_ARE(AbstractLossFunc):
    def __init__(self):
        super().__init__()
        self.auto_weighted_loss = AutomaticWeightedLoss(2)
        self.mse_func = torch.nn.MSELoss()

    def forward(self, weight_pred, weight_y):
        are = torch.mean(torch.abs((weight_pred - weight_y) / weight_y))
        mse = self.mse_func(weight_pred, weight_y)
        return self.auto_weighted_loss(are, mse)

class LossFunc_for_Simple_MSE_and_AAE(AbstractLossFunc):
    def __init__(self):
        super().__init__()
        self.auto_weighted_loss = AutomaticWeightedLoss(2)
        self.mse_func = torch.nn.MSELoss()

    def forward(self, weight_pred, weight_y):
        aae = torch.mean(torch.abs((weight_pred - weight_y)))
        mse = self.mse_func(weight_pred, weight_y)
        print('loss AAE:', aae.item(),'mse', mse.item())
        return self.auto_weighted_loss(aae, mse)


class LossFunc_for_OutWeight_AAE_and_ARE(AbstractLossFunc):
    def __init__(self):
        super().__init__()
        self.auto_weighted_loss = AutomaticWeightedLoss(2)
        self.mse_func = torch.nn.MSELoss()

    def forward(self, weight_pred, weight_y):
        are = torch.mean(torch.abs((weight_pred - weight_y) / (weight_y+0.5)))
        aae = torch.mean(torch.abs(weight_pred - weight_y))
        print('loss ARE:', are.item(),'aae', aae.item())
        return self.auto_weighted_loss(are, aae)

# ARE metrics have  be smoothed
class LossFunc_for_OutWeight_MSE_and_ARE(AbstractLossFunc):
    def __init__(self):
        super().__init__()
        self.auto_weighted_loss = AutomaticWeightedLoss(2)
        self.mse_func = torch.nn.MSELoss()

    def forward(self, weight_pred, weight_y):
        are = torch.mean(torch.abs((weight_pred - weight_y) / (weight_y+0.5)))
        mse = self.mse_func(weight_pred, weight_y)
        print('loss ARE:', are.item(),'mse', mse.item())

        return self.auto_weighted_loss(are, mse)

class LossFunc_for_OutWeight_MSE_and_ARE_and_AAE(AbstractLossFunc):
    def __init__(self):
        super().__init__()
        self.auto_weighted_loss = AutomaticWeightedLoss(3)
        self.mse_func = torch.nn.MSELoss()

    def forward(self, weight_pred, weight_y):
        are = torch.mean(torch.abs((weight_pred - weight_y) / (weight_y+0.5)))
        mse = self.mse_func(weight_pred, weight_y)
        aae = torch.mean(torch.abs(weight_pred - weight_y))

        print('loss ARE:', are.item(),'mse', mse.item(),'aae',aae.item())

        return self.auto_weighted_loss(are, mse,aae)

class LossFunc_exist_for_BCE(AbstractLossFunc):
    def __init__(self):
        super().__init__()
        self.bce_func = torch.nn.BCELoss()

    def forward(self, pred, y):
        bce = self.bce_func(pred, y)
        return bce
