import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_features, normIN=True):
        super(ResidualBlock, self).__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features) if normIN else nn.BatchNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features) if normIN else nn.BatchNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)

# for perceptual perfromance
class Generator(nn.Module):
    def __init__(self, num_res):
        super(Generator, self).__init__()

        # Initial convolution block
        out_features = 64
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features
        
        # Residual blocks
        for _ in range(num_res):
            model += [ResidualBlock(out_features)]
        
        # Downsampling
        for i in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            
            # Residual blocks
            for _ in range(num_res):
                model += [ResidualBlock(out_features)]

        # Upsampling
        for i in range(2):
            # Residual blocks
            for _ in range(num_res):
                model += [ResidualBlock(out_features)]
            
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
        
        # Residual blocks
        for _ in range(num_res):
            model += [ResidualBlock(out_features)]
        
        # Output layer
        model += [nn.ReflectionPad2d(3), nn.Conv2d(out_features, 3, 7), nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        
        B, C, H, W = x.shape
        x = self.check_image_size(x)
        
        out = self.model(x)
        
        return out[:, :, :H, :W]
    
    def check_image_size(self, x):
        padder_size = 2 ** 2     # 2 ** len(encoder)
        _, _, h, w = x.size()
        mod_pad_h = (padder_size - h % padder_size) % padder_size
        mod_pad_w = (padder_size - w % padder_size) % padder_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
        return x

# for distortion perfromance
class GeneratorDT(nn.Module):
    def __init__(self, num_res):
        super(GeneratorDT, self).__init__()
        
        # Initial convolution block
        out_features = 64
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, out_features, 7),
            nn.BatchNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features
        
        # Residual blocks
        for _ in range(num_res):
            model += [ResidualBlock(out_features, normIN=False)]
        
        # Downsampling
        for i in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.BatchNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            
            # Residual blocks
            for _ in range(num_res):
                model += [ResidualBlock(out_features, normIN=False)]

        # Upsampling
        for i in range(2):
            # Residual blocks
            for _ in range(num_res):
                model += [ResidualBlock(out_features, normIN=False)]
            
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.BatchNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
        
        # Residual blocks
        for _ in range(num_res):
            model += [ResidualBlock(out_features, normIN=False)]
        
        # Output layer
        # model += [nn.ReflectionPad2d(3), nn.Conv2d(out_features, 3, 7), nn.Tanh()]
        model += [nn.ReflectionPad2d(3), nn.Conv2d(out_features, 3, 7)]
        
        self.model = nn.Sequential(*model)
        
        self.tanh = nn.Tanh()

    def forward(self, x):
        
        B, C, H, W = x.shape
        x = self.check_image_size(x)
        
        # out = self.model(x)
        out = self.tanh(self.model(x) + x)
        
        return out[:, :, :H, :W]
    
    def check_image_size(self, x):
        padder_size = 2 ** 2     # 2 ** len(encoder)
        _, _, h, w = x.size()
        mod_pad_h = (padder_size - h % padder_size) % padder_size
        mod_pad_w = (padder_size - w % padder_size) % padder_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
        return x

##############################
#        Discriminator
##############################

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(3, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img):
        return self.model(img)
    
class BUM(nn.Module):
    # Basic unsupervised model
    def __init__(self, num_res=2, model_mode='perc'):
        super(BUM, self).__init__()
        
        if model_mode == 'perc':
            self.G_AB = Generator(num_res = num_res)
            self.G_BA = Generator(num_res = num_res)
        else:
            self.G_AB = GeneratorDT(num_res = num_res)
            self.G_BA = GeneratorDT(num_res = num_res)
        self.D_B = Discriminator()

    def forward_G(self, imgA):
        
        fakeB = self.G_AB(imgA)
        reconA = self.G_BA(fakeB)
        
        return fakeB, reconA
    
    def forward_D_B(self, fakeB, imgB):
        
        fakeB_valid = self.D_B(fakeB)
        imgB_valid = self.D_B(imgB)
        
        return fakeB_valid, imgB_valid
    
    def freeze_D_B(self):
        for p in self.D_B.parameters():
            p.requires_grad = False
            
    def unfreeze_D_B(self):
        for p in self.D_B.parameters():
            p.requires_grad = True