import torch
from torch import nn
import torch.nn.functional as F


class Enhance_Net(nn.Module):
    def __init__(self, channel_in, channel_out, aux_channel, kernel_size, reduction):
        super(Enhance_Net, self).__init__()

        self.Enhance_Unit1 = Enhance_Unit(channel_in, channel_out, aux_channel, kernel_size, reduction)
        self.Enhance_Unit2 = Enhance_Unit(channel_in, channel_out, aux_channel, kernel_size, reduction)
        self.conv1 = default_conv(channel_in, channel_out, kernel_size)
        self.conv2 = default_conv(channel_in, channel_out, kernel_size)

        self.relu = nn.LeakyReLU(0.1, True)

    def forward(self, x, y):
        '''
        :param x[0]: main: B * C * H * W
        :param x[1]: aux:  B * C
        '''

        out = self.relu(self.Enhance_Unit1(x, y))
        out = self.relu(self.conv1(out))
        out = self.relu(self.Enhance_Unit2(out, y))
        out = self.conv2(out) + x

        return out


class Enhance_Unit(nn.Module):
    def __init__(self, channels_in, channels_out, aux_channel, kernel_size, reduction):
        super(Enhance_Unit, self).__init__()

        self.kernel_gen = Kernel_Gen(aux_channel, channels_in, kernel_size)
        self.conv = default_conv(channels_in, channels_in, 1)
        self.ca = CA_layer(channels_in, aux_channel, reduction)
        self.relu = nn.LeakyReLU(0.1, True)
        self.kernel_size = kernel_size

    def forward(self, x, y):
        '''
        :param x[0]: main: B * C * H * W
        :param x[1]: aux:  B * C
        '''
        b, c, h, w = x.size()

        # branch 1
        kernel = self.kernel_gen(y)
        out = self.relu(F.conv2d(x.view(1, -1, h, w), kernel, groups=b*c, padding=(self.kernel_size-1)//2))
        out = self.conv(out.view(b, -1, h, w))

        # branch 2
        out = out + self.ca(x, y)

        return out


class CA_layer(nn.Module):
    def __init__(self, channels_in, aux_channel, reduction):
        super(CA_layer, self).__init__()
        self.conv_du = nn.Sequential(
            nn.Conv2d(aux_channel, 256//reduction, 1, 1, 0, bias=False),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(256 // reduction, channels_in, 1, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x, y):
        '''
        :param x[0]: feature map: B * C * H * W
        :param x[1]: degradation representation: B * C
        '''

        att = self.conv_du(y)

        return x * att


class Kernel_Gen(nn.Module):
    def __init__(self, aux_channel, channels_in, kernel_size):
        super(Kernel_Gen, self).__init__()

        self.E = nn.Sequential(
            nn.Conv2d(aux_channel, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, True),
            nn.AdaptiveAvgPool2d(1),
        )
        self.mlp = nn.Sequential(
            nn.Linear(256, 256),
            nn.LeakyReLU(0.1, True),
            nn.Linear(256, 256),
        )
        self.kernel = nn.Sequential(
            nn.Linear(256, 128, bias=False),
            nn.LeakyReLU(0.1, True),
            nn.Linear(128, channels_in * kernel_size * kernel_size, bias=False)
        )
        self.kernel_size = kernel_size

    def forward(self, x):
        fea = self.E(x).squeeze(-1).squeeze(-1)
        fea = self.mlp(fea)
        kernel = self.kernel(fea)
        kernel = kernel.view(-1, 1, self.kernel_size, self.kernel_size)

        return kernel


def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)