import torch
import lpips
from torch import nn
from pytorch_msssim import MS_SSIM, ssim
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torch.nn.modules.loss import _Loss, _WeightedLoss  # noqa
from torch.nn import BCELoss
from torch.nn import functional as F  # noqa
from torchvision.transforms import functional as FF

_default_reduction = 'mean'
_epsilon = 1e-7

def non_negativity_loss(x, reduction='mean'):
    # x: (B, C, H, W)
    penalty = torch.clamp(-x, min=0)  # 음수인 부분만 양수로 남음
    loss = penalty**2
    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    else:
        return loss

class LossFunction(nn.Module):
    """Loss function class for multiple loss function."""

    def __init__(self):
        super().__init__()
        self.criterion_mse = nn.MSELoss()
        self.criterion_iqa = lpips.LPIPS(net='vgg')
        self.criterion_iqa_test = lpips.LPIPS(net='vgg')
    def forward(self, output, label, epoch=0, train=True, normalize=True):
        if train:
            IQA_loss = torch.mean(self.criterion_iqa(output, label, normalize=normalize))
        else:
            IQA_loss = torch.mean(self.criterion_iqa_test(output, label, normalize=normalize))
        mix_loss = self.criterion_mse(output, label)
        loss = mix_loss * 1 + IQA_loss * 0.05# + edge_loss * 5
        return loss


def normalize_lambdas(hparams: dict, anchor_key="LAMBDA_IMG"):
    lambda_keys = [k for k in hparams if k.startswith("LAMBDA_")]
    
    other_keys = [k for k in lambda_keys if k != anchor_key]
    anchor_val = hparams[anchor_key]

    total_other = sum(hparams[k] for k in other_keys)
    remain = 1.0 - anchor_val

    if remain < 0:
        raise ValueError(f"{anchor_key}={anchor_val} over.")

    for k in other_keys:
        hparams[k] = hparams[k] / total_other * remain

    hparams[anchor_key] = anchor_val
    return hparams
