import torch
import torch.nn as nn
import torch.nn.functional as F


class FourierUnit(nn.Module):
    def __init__(self, embed_dim, fft_norm='ortho'):
        # bn_layer not used
        super(FourierUnit, self).__init__()
        self.conv_layer = torch.nn.Conv2d(embed_dim * 2, embed_dim * 2, 1, 1, 0)
        self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        self.fft_norm = fft_norm

    def forward(self, x):
        batch = x.shape[0]

        r_size = x.size()
        # (batch, c, h, w/2+1, 2)
        fft_dim = (-2, -1)
        ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
        ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
        ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()  # (batch, c, 2, h, w/2+1)
        ffted = ffted.view((batch, -1,) + ffted.size()[3:])

        ffted = self.conv_layer(ffted)  # (batch, c*2, h, w/2+1)
        ffted = self.relu(ffted)

        ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(0, 1, 3, 4,
                                                                       2).contiguous()  # (batch,c, t, h, w/2+1, 2)
        ffted = torch.complex(ffted[..., 0], ffted[..., 1])

        ifft_shape_slice = x.shape[-2:]
        output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)

        return output


class SpectralTransform(nn.Module):
    def __init__(self, embed_dim, last_conv=False):
        # bn_layer not used
        super(SpectralTransform, self).__init__()
        self.last_conv = last_conv

        self.conv1 = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim // 2, 1, 1, 0),
            nn.LeakyReLU(negative_slope=0.2, inplace=True)
        )
        self.fu = FourierUnit(embed_dim // 2)

        self.conv2 = torch.nn.Conv2d(embed_dim // 2, embed_dim, 1, 1, 0)

        if self.last_conv:
            self.last_conv = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)

    def forward(self, x):
        x = self.conv1(x)
        output = self.fu(x)
        output = self.conv2(x + output)
        if self.last_conv:
            output = self.last_conv(output)
        return output


## Residual Block (RB)
class ResB(nn.Module):
    def __init__(self, embed_dim):
        super(ResB, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(embed_dim, embed_dim, 3, 1, 1),
        )

    def __call__(self, x):
        out = self.body(x)
        return out + x


class SFB(nn.Module):
    def __init__(self, embed_dim, out_channel, strides=1):
        super(SFB, self).__init__()
        self.S = ResB(embed_dim)
        self.F = SpectralTransform(embed_dim)
        self.fusion = nn.Conv2d(embed_dim * 2, out_channel, 1, 1, 0)

    def __call__(self, x):
        s = self.S(x)
        f = self.F(x)
        out = torch.cat([s, f], dim=1)
        out = self.fusion(out)
        return out


#########################################
class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, strides=1):
        super(ConvBlock, self).__init__()
        self.strides = strides
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.block = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=strides, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        self.conv11 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, padding=0)

    def forward(self, x):
        out1 = self.block(x)
        out2 = self.conv11(x)
        out = out1 + out2
        return out


class spade(nn.Module):
    def __init__(self, channel_r, channel_c):
        super().__init__()
        self.conv0 = nn.Sequential(nn.Conv2d(channel_r, channel_c, 3, 1, 1), nn.ReLU())
        self.conv1 = nn.Sequential(nn.Conv2d(channel_r, channel_c, 3, 1, 1), nn.ReLU())

    def forward(self, reference, context):
        fea = context * self.conv0(reference) + self.conv1(reference)
        return fea


class Spade(nn.Module):
    def __init__(self, channel_r, channel_c):
        super().__init__()
        self.conv0 = nn.Conv2d(channel_c, channel_c // 2, 1, 1, 0)
        self.spade = spade(channel_r, channel_c // 2)

    def forward(self, reference, context):
        fea = self.conv0(context)
        if reference.dim() == 3:
            B, C, H, W = fea.shape
            reference = reference.transpose(1, 2).view(B, -1, H, W)
        fea = self.spade(reference, fea)
        return fea


class mask_img_input(nn.Module):
    def __init__(self, channel):
        super().__init__()
        self.layer = nn.Conv2d(6, channel, 3, 1, 1)

    def forward(self, input, flare_mask, flare):
        flare_mask = F.interpolate(flare_mask.unsqueeze(1), size=(input.shape[2], input.shape[3]),
                                   mode='bicubic').squeeze(1)
        input = torch.einsum('b c h w, b h w -> b c h w', input, flare_mask)
        input_fea = torch.cat([input, flare], dim=1)
        input_fea = self.layer(input_fea)
        return input_fea


class PIP(nn.Module):
    def __init__(self, block=SFB, dim=32, dim_r=32):
        super(PIP, self).__init__()

        self.inpainting = mask_img_input(dim)

        self.dim = dim
        self.dim_r = dim_r
        self.ConvBlock1 = ConvBlock(dim, dim, strides=1)
        self.pool1 = nn.Conv2d(dim, dim, kernel_size=4, stride=2, padding=1)

        self.ConvBlock2 = block(dim, dim * 2, strides=1)
        self.pool2 = nn.Conv2d(dim * 2, dim * 2, kernel_size=4, stride=2, padding=1)

        self.ConvBlock3 = block(dim * 2, dim * 4, strides=1)
        self.pool3 = nn.Conv2d(dim * 4, dim * 4, kernel_size=4, stride=2, padding=1)

        self.ConvBlock4 = block(dim * 4, dim * 8, strides=1)
        self.pool4 = nn.Conv2d(dim * 8, dim * 8, kernel_size=4, stride=2, padding=1)

        self.ConvBlock5 = block(dim * 8, dim * 16, strides=1)

        self.upv6 = nn.ConvTranspose2d(dim * 16, dim * 8, 2, stride=2)
        self.spade6 = Spade(self.dim_r * 16, dim * 16)
        self.ConvBlock6 = block(dim * 8, dim * 8, strides=1)

        self.upv7 = nn.ConvTranspose2d(dim * 8, dim * 4, 2, stride=2)
        self.spade7 = Spade(self.dim_r * 8, dim * 8)
        self.ConvBlock7 = block(dim * 4, dim * 4, strides=1)

        self.upv8 = nn.ConvTranspose2d(dim * 4, dim * 2, 2, stride=2)
        self.spade8 = Spade(self.dim_r * 4, dim * 4)
        self.ConvBlock8 = block(dim * 2, dim * 2, strides=1)

        self.upv9 = nn.ConvTranspose2d(dim * 2, dim, 2, stride=2)
        self.spade9 = Spade(self.dim_r * 2, dim * 2)
        self.ConvBlock9 = block(dim * 1, dim, strides=1)

        self.conv10 = nn.Conv2d(dim, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, input, guide_fea, flare):
        flare_mask = flare[:, 0] * 0.299 + flare[:, 1] * 0.587 + flare[:, 2] * 0.114
        flare_mask = torch.where(flare_mask > 0.12, 0, 1)
        flare_mask = flare_mask.float()
        fea = self.inpainting(input, flare_mask, input)

        conv1 = self.ConvBlock1(fea)
        pool1 = self.pool1(conv1)

        conv2 = self.ConvBlock2(pool1)
        pool2 = self.pool2(conv2)

        conv3 = self.ConvBlock3(pool2)
        pool3 = self.pool3(conv3)

        conv4 = self.ConvBlock4(pool3)
        pool4 = self.pool4(conv4)

        conv5 = self.ConvBlock5(pool4)

        up6 = self.upv6(conv5)
        up6 = torch.cat([up6, conv4], 1)
        guide_idx = 0
        up6 = self.spade6(guide_fea[guide_idx], up6)
        conv6 = self.ConvBlock6(up6)

        up7 = self.upv7(conv6)
        up7 = torch.cat([up7, conv3], 1)
        guide_idx += 1
        up7 = self.spade7(guide_fea[guide_idx], up7)
        conv7 = self.ConvBlock7(up7)

        up8 = self.upv8(conv7)
        up8 = torch.cat([up8, conv2], 1)
        guide_idx += 1
        up8 = self.spade8(guide_fea[guide_idx], up8)
        conv8 = self.ConvBlock8(up8)

        up9 = self.upv9(conv8)
        up9 = torch.cat([up9, conv1], 1)
        guide_idx += 1
        up9 = self.spade9(guide_fea[guide_idx], up9)
        conv9 = self.ConvBlock9(up9)

        out = self.conv10(conv9)
        return out


class PIP2(nn.Module):
    def __init__(self, block=SFB, dim=32, dim_r=32):
        super(PIP2, self).__init__()

        self.inpainting = mask_img_input(dim)

        self.dim = dim
        self.dim_r = dim_r
        self.ConvBlock1 = ConvBlock(dim, dim, strides=1)
        self.pool1 = nn.Conv2d(dim, dim, kernel_size=4, stride=2, padding=1)

        self.ConvBlock2 = block(dim, dim * 2, strides=1)
        self.pool2 = nn.Conv2d(dim * 2, dim * 2, kernel_size=4, stride=2, padding=1)

        self.ConvBlock3 = block(dim * 2, dim * 4, strides=1)
        self.pool3 = nn.Conv2d(dim * 4, dim * 4, kernel_size=4, stride=2, padding=1)

        self.ConvBlock4 = block(dim * 4, dim * 8, strides=1)
        self.pool4 = nn.Conv2d(dim * 8, dim * 8, kernel_size=4, stride=2, padding=1)

        self.ConvBlock5 = block(dim * 8, dim * 16, strides=1)

        self.upv6 = nn.ConvTranspose2d(dim * 16, dim * 8, 2, stride=2)
        self.spade6 = Spade(self.dim_r * 16, dim * 16)
        self.ConvBlock6 = block(dim * 8, dim * 8, strides=1)

        self.upv7 = nn.ConvTranspose2d(dim * 8, dim * 4, 2, stride=2)
        self.spade7 = Spade(self.dim_r * 8, dim * 8)
        self.ConvBlock7 = block(dim * 4, dim * 4, strides=1)

        self.upv8 = nn.ConvTranspose2d(dim * 4, dim * 2, 2, stride=2)
        self.spade8 = Spade(self.dim_r * 4, dim * 4)
        self.ConvBlock8 = block(dim * 2, dim * 2, strides=1)

        self.upv9 = nn.ConvTranspose2d(dim * 2, dim, 2, stride=2)
        self.spade9 = Spade(self.dim_r * 4, dim * 2)
        self.ConvBlock9 = block(dim * 1, dim, strides=1)

        self.conv10 = nn.Conv2d(dim, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, input, guide_fea, flare):
        flare_mask = flare[:, 0] * 0.299 + flare[:, 1] * 0.587 + flare[:, 2] * 0.114
        flare_mask = torch.where(flare_mask > 0.12, 0, 1)
        flare_mask = flare_mask.float()
        fea = self.inpainting(input, flare_mask, input)

        conv1 = self.ConvBlock1(fea)
        pool1 = self.pool1(conv1)

        conv2 = self.ConvBlock2(pool1)
        pool2 = self.pool2(conv2)

        conv3 = self.ConvBlock3(pool2)
        pool3 = self.pool3(conv3)

        conv4 = self.ConvBlock4(pool3)
        pool4 = self.pool4(conv4)

        conv5 = self.ConvBlock5(pool4)

        up6 = self.upv6(conv5)
        up6 = torch.cat([up6, conv4], 1)
        guide_idx = 0
        up6 = self.spade6(guide_fea[guide_idx], up6)
        conv6 = self.ConvBlock6(up6)

        up7 = self.upv7(conv6)
        up7 = torch.cat([up7, conv3], 1)
        guide_idx += 1
        up7 = self.spade7(guide_fea[guide_idx], up7)
        conv7 = self.ConvBlock7(up7)

        up8 = self.upv8(conv7)
        up8 = torch.cat([up8, conv2], 1)
        guide_idx += 1
        up8 = self.spade8(guide_fea[guide_idx], up8)
        conv8 = self.ConvBlock8(up8)

        up9 = self.upv9(conv8)
        up9 = torch.cat([up9, conv1], 1)
        guide_idx += 1
        up9 = self.spade9(guide_fea[guide_idx], up9)
        conv9 = self.ConvBlock9(up9)

        out = self.conv10(conv9)
        return out
