import torch
import torch.nn as nn
from kpn.network import KernelConv
import kpn.utils as kpn_utils
import numpy as np


class BaseNetwork(nn.Module):
    def __init__(self):
        super(BaseNetwork, self).__init__()

    def init_weights(self, init_type='normal', gain=0.02):
        def init_func(m):
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    nn.init.normal_(m.weight.data, 0.0, gain)
                elif init_type == 'xavier':
                    nn.init.xavier_normal_(m.weight.data, gain=gain)
                elif init_type == 'kaiming':
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(m.weight.data, gain=gain)

                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)

            elif classname.find('BatchNorm2d') != -1:
                nn.init.normal_(m.weight.data, 1.0, gain)
                nn.init.constant_(m.bias.data, 0.0)

        self.apply(init_func)


class Branch(nn.Module):
    def __init__(self, config=None, residual_blocks=8):
        super(Branch, self).__init__()

        self.encoder0 = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=6, out_channels=64, kernel_size=7, padding=0),
            # nn.Conv2d(in_channels=4, out_channels=64, kernel_size=7, padding=0),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True)
        )

        self.encoder1 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True)
        )

        self.encoder2 = nn.Sequential(
            nn.Conv2d(in_channels=128 + 128, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256, track_running_stats=False),
            nn.ReLU(True)
        )

        blocks = []
        for _ in range(residual_blocks):
            block = ResnetBlock(256, 2)
            blocks.append(block)

        self.middle = nn.Sequential(*blocks)

        self.decoder_0 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256*2, out_channels=128, kernel_size=4, stride=2, padding=1),
            # nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),
        )

        self.decoder_1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128*2, out_channels=64, kernel_size=4, stride=2, padding=1),
            # nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),
        )

        self.decoder_20 = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=64*2, out_channels=3, kernel_size=7, padding=0),
            # nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, padding=0),
        ) # for inpainting image

        self.decoder_21 = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=64*2, out_channels=3, kernel_size=7, padding=0),
            # nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, padding=0),
        ) # for confidence map

        kernel_size = 3

        out_channel = 256 * (kernel_size ** 2)
        self.kernels_1 = nn.Conv2d(256, out_channel, 1, 1, 0)
        self.kernels_2 = nn.Conv2d(256, out_channel, 1, 1, 0)

        # for inpainting image
        out_channel_img = 3 * (kernel_size ** 2)
        self.kernel_img_1 = nn.Conv2d(3, out_channel_img, 1, 1, 0)
        self.kernel_img_2 = nn.Conv2d(3, out_channel_img, 1, 1, 0)

        # for confidence map
        out_channel_map = 3 * (kernel_size ** 2)
        self.kernel_map_1 = nn.Conv2d(3, out_channel_map, 1, 1, 0)
        self.kernel_map_2 = nn.Conv2d(3, out_channel_map, 1, 1, 0)

class Generator(BaseNetwork):
    def __init__(self, config=None, residual_blocks=8, init_weights=True, logger=None):
        super(Generator, self).__init__()

        self.branch_1 = Branch(config=config, residual_blocks=residual_blocks)

        self.kernel_pred = KernelConv(kernel_size=[3], sep_conv=False, core_bias=False)

        if init_weights:
            self.init_weights()

    def forward(self, input1, input2):

        features = []

        # in: (bs, 6, 256, 256) out: (bs, 64, 256, 256)
        x1 = self.branch_1.encoder0(input1)
        x2 = self.branch_1.encoder0(input2)
        features.append((x1, x2))
        e0_out1, e0_out2 = x1, x2

        # in: (bs, 64, 256, 256) out: (bs, 128, 128, 128)
        x1_ = self.branch_1.encoder1(x1)
        x2_ = self.branch_1.encoder1(x2)
        features.append((x1_, x2_))
        e1_out1, e1_out2 = x1_, x2_

        # in: (bs, 128*2, 128, 128) out: (bs, 256, 64, 64)
        x1 = self.branch_1.encoder2(torch.cat([x2_, x1_], dim=1))
        x2 = self.branch_1.encoder2(torch.cat([x1_, x2_], dim=1))
        features.append((x1, x2))
        e2_out1, e2_out2 = x1, x2

        # in: (bs, 256, 64, 64) out: (bs, 2304, 64, 64)
        kernel_1 = self.branch_1.kernels_1(x2)
        kernel_2 = self.branch_1.kernels_2(x1)

        # in: () out: (bs, 256, 64, 64)
        x1 = self.kernel_pred(x1, kernel_1, white_level=1.0, rate=1)
        x2 = self.kernel_pred(x2, kernel_2, white_level=1.0, rate=1)
        features.append((x1, x2))

        # in: (bs, 256, 64, 64) out: (bs, 256, 64, 64)
        x1 = self.branch_1.middle(x1)
        x2 = self.branch_1.middle(x2)
        features.append((x1, x2))

        # in: (bs, 256*2, 64, 64) out: (bs, 128, 128, 128)
        x1 = self.branch_1.decoder_0(torch.cat([x1, e2_out1], dim=1))
        x2 = self.branch_1.decoder_0(torch.cat([x2, e2_out2], dim=1))
        features.append((x1, x2))

        # in: (bs, 128*2, 128, 128) out: (bs, 64, 256, 256)
        x1_ = self.branch_1.decoder_1(torch.cat([x1, e1_out1], dim=1))
        x2_ = self.branch_1.decoder_1(torch.cat([x2, e1_out2], dim=1))
        features.append((x1_, x2_))


        ### get inpainting output ###

        # in: (bs, 64*2, 256, 256) out: (bs, 3, 256, 256)
        x1 = self.branch_1.decoder_20(torch.cat([x1_, e0_out1], dim=1))
        x2 = self.branch_1.decoder_20(torch.cat([x2_, e0_out2], dim=1))
        features.append((x1, x2))

        # in: (bs, 3, 256, 256) out: (bs, 27, 256, 256)
        kernel_img_1 = self.branch_1.kernel_img_1(x2)
        kernel_img_2 = self.branch_1.kernel_img_2(x1)

        # in: () out: (bs, 3, 256, 256)        
        x1 = self.kernel_pred(x1, kernel_img_1, white_level=1.0, rate=1)
        x2 = self.kernel_pred(x2, kernel_img_2, white_level=1.0, rate=1)
        features.append((x1, x2))

        # x1 = (torch.tanh(x1) + 1) / 2
        # x2 = (torch.tanh(x2) + 1) / 2
        x1 = torch.tanh(x1)
        x2 = torch.tanh(x2)


        ### get confidence output ###

        # in: (bs, 64*2, 256, 256) out: (bs, 3, 256, 256)
        c1 = self.branch_1.decoder_21(torch.cat([x1_, e0_out1], dim=1))
        c2 = self.branch_1.decoder_21(torch.cat([x2_, e0_out2], dim=1))
        features.append((c1, c2))

        # in: (bs, 3, 256, 256) out: (bs, 27, 256, 256)
        kernel_map_1 = self.branch_1.kernel_map_1(c2)
        kernel_map_2 = self.branch_1.kernel_map_2(c1)

        # in: () out: (bs, 3, 256, 256)
        c1 = self.kernel_pred(c1, kernel_map_1, white_level=1.0, rate=1)
        c2 = self.kernel_pred(c2, kernel_map_2, white_level=1.0, rate=1)
        features.append((c1, c2))

        c1 = torch.tanh(c1) * 2
        c2 = torch.tanh(c2) * 2


        return x1, x2, c1, c2, (kernel_1, kernel_2), (kernel_img_1, kernel_img_2), (kernel_map_1, kernel_map_2), features

    def get_backbone_parameters(self):
        paras = []
        exclude_name = ['decoder_20', 'decoder_21', 'kernel_img', 'kernel_map']
        for name, para in self.named_parameters():
            if not any(ex in name for ex in exclude_name):
                paras.append(para)
        return paras

    def get_inpaint_head_parameters(self):
        paras = []
        for name, para in self.named_parameters():
                if 'decoder_20' in name or 'kernel_img' in name:
                    paras.append(para)

        return paras

    def get_confidence_head_parameters(self):
        paras = []
        for name, para in self.named_parameters():
                if 'decoder_21' in name or 'kernel_map' in name:
                    paras.append(para)

        return paras

    def save_feature(self, x, name):
        x = x.cpu().numpy()
        np.save('./result/{}'.format(name), x)

class EdgeGenerator(BaseNetwork):
    def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True):
        super(EdgeGenerator, self).__init__()

        self.encoder = nn.Sequential(
            nn.ReflectionPad2d(3),
            spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),

            spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(256, track_running_stats=False),
            nn.ReLU(True)
        )

        blocks = []
        for _ in range(residual_blocks):
            block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm)
            blocks.append(block)

        self.middle = nn.Sequential(*blocks)

        self.decoder = nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),

            spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0),
        )

        if init_weights:
            self.init_weights()

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        x = torch.sigmoid(x)
        return x


class Discriminator(BaseNetwork):
    def __init__(self, in_channels, use_sigmoid=True, use_spectral_norm=True, init_weights=True):
        super(Discriminator, self).__init__()
        self.use_sigmoid = use_sigmoid

        self.conv1 = self.features = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv2 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv3 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv4 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv5 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
        )

        if init_weights:
            self.init_weights()

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)

        outputs = conv5
        if self.use_sigmoid:
            outputs = torch.sigmoid(conv5)

        return outputs, [conv1, conv2, conv3, conv4, conv5]


class ResnetBlock(nn.Module):
    def __init__(self, dim, dilation=1, use_spectral_norm=False):
        super(ResnetBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(dilation),
            spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm),
            nn.InstanceNorm2d(dim, track_running_stats=False),
            nn.ReLU(True),

            nn.ReflectionPad2d(1),
            spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.InstanceNorm2d(dim, track_running_stats=False),
        )

    def forward(self, x):
        out = x + self.conv_block(x)

        return out


def spectral_norm(module, mode=True):
    if mode:
        return nn.utils.spectral_norm(module)

    return module