import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class ExtracterLoss(nn.Module):
    def __init__(self):
        super(ExtracterLoss, self).__init__()

        self.vgg16 = models.vgg16(pretrained=True)
        self.vgg16.eval()
        for para in self.vgg16.parameters():
            para.requires_grad = False
        self.vgg16.eval()

        self.vgg_2= self.vgg16.features[:10]
        self.vgg_5= self.vgg16.features[10:]
        self.mseloss = nn.MSELoss()

    def pixel_loss(self, img1, img2):
        loss = self.mseloss(img1, img2)
        return loss

    def vgg_loss(self,img1,img2):
        vgg2_feature1 = self.vgg_2(img1)
        vgg2_feature2 = self.vgg_2(img2)
        vgg5_feature1 = self.vgg_5(vgg2_feature1)
        vgg5_feature2 = self.vgg_5(vgg2_feature2)
        loss1 = F.mse_loss(vgg2_feature1, vgg2_feature2)
        loss2 = F.mse_loss(vgg5_feature1, vgg5_feature2)
        loss = loss1 + loss2
        return loss

    def forward(self, extracted_glare, GT):
        #extracted_glare = torch.cat((extracted_glare,extracted_glare,extracted_glare),dim=1)
        
        pixel_loss_value = self.pixel_loss(extracted_glare, GT) 

        vgg_loss_value = self.vgg_loss(extracted_glare, GT) * 0.125
        
        return pixel_loss_value + vgg_loss_value