import os
import torch
import torch.nn as nn
import torchvision.models as models

class L_Color(nn.Module):

    def __init__(self):
        super(L_Color, self).__init__()

    def forward(self, x):

        b,c,h,w = x.shape

        mean_rgb = torch.mean(x,[2,3],keepdim=True)
        mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
        Drg = torch.pow(mr-mg,2)
        Drb = torch.pow(mr-mb,2)
        Dgb = torch.pow(mb-mg,2)
        k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)

        return torch.mean(k)

class L_TV(nn.Module):
    
    def __init__(self,TVLoss_weight=1):
        super(L_TV,self).__init__()

    def forward(self,x):
        h_x = x.size(2)
        w_x = x.size(3)
        h_tv = torch.mean(torch.abs(x[:,:,1:,:]-x[:,:,:h_x-1,:]))
        w_tv = torch.mean(torch.abs(x[:,:,:,1:]-x[:,:,:,:w_x-1]))
        return h_tv + w_tv

class GanLoss():
    def __init__(self, gan_type):
        self.gan_type = gan_type
        if self.gan_type == 'vanilla':
            self.criterion = nn.BCELoss()
        elif self.gan_type == 'lsgan':
            self.criterion = nn.MSELoss()
        
    def __call__(self, fake_valid):
        
        all1 = torch.ones_like(fake_valid).cuda()
        if self.gan_type == 'vanilla':
            loss = self.criterion(torch.sigmoid(fake_valid), all1)
        elif self.gan_type == 'lsgan':
            loss = self.criterion(fake_valid, all1)
        else:
            raise Exception('no such type of gan')
            
        return loss

class DLoss():
    def __init__(self, gan_type):
        self.gan_type = gan_type
        if self.gan_type == 'vanilla':
            self.criterion = nn.BCELoss()
        elif self.gan_type == 'lsgan':
            self.criterion = nn.MSELoss()
    
    def __call__(self, fake_valid, real_valid):
        
        all0 = torch.zeros_like(fake_valid).cuda()
        all1 = torch.ones_like(real_valid).cuda()
        if self.gan_type == 'vanilla':
            fake_loss = self.criterion(torch.sigmoid(fake_valid), all0)
            real_loss = self.criterion(torch.sigmoid(real_valid), all1)
            loss = (fake_loss + real_loss) / 2
        elif self.gan_type == 'lsgan':
            fake_loss = self.criterion(fake_valid, all0)
            real_loss = self.criterion(real_valid, all1)
            loss = (fake_loss + real_loss) / 2
        else:
            raise Exception('no such type of gan')
            
        return loss