import torch
import torch.nn as nn


class ResidualBlock(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))
        self.conv1x1 = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, bias=False)
    
    def forward(self, x):
        return self.conv1x1(x) + self.main(x)


def upsample(ch_coarse, ch_fine):
    return nn.Sequential(
        nn.ConvTranspose2d(ch_coarse, ch_fine, 4, 2, 1, bias=False),
        nn.ReLU()
    )


class ComGenerator(nn.Module):
    def __init__(self, dim_in=3, dim_out=32, isJPEG=False):
        super(ComGenerator, self).__init__()
        self.isJPEG = isJPEG
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
        )
        self.conv2 = nn.Sequential(
            ResidualBlock(dim_out, dim_out*2),
        )
        self.conv3 = nn.Sequential(
            ResidualBlock(dim_out*2, dim_out*4),
        )
        self.conv4 = nn.Sequential(
            ResidualBlock(dim_out*4, dim_out*8),
        )
        self.conv5 = nn.Sequential(
            ResidualBlock(dim_out*8, dim_out*16),
        )

        self.conv4m = nn.Sequential(
            ResidualBlock(dim_out*16, dim_out*8),
        )
        self.conv3m = nn.Sequential(
            ResidualBlock(dim_out*8, dim_out*4),
        )
        self.conv2m = nn.Sequential(
            ResidualBlock(dim_out*4, dim_out*2),
        )
        self.conv1m = nn.Sequential(
            ResidualBlock(dim_out*2, dim_out),
        )
        
        self.conv0 = nn.Sequential(
            nn.Conv2d(dim_out, 3, 3, 1, 1),
            nn.Tanh()
        )

        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upsample54 = upsample(dim_out*16, dim_out*8)
        self.upsample43 = upsample(dim_out*8, dim_out*4)
        self.upsample32 = upsample(dim_out*4, dim_out*2)
        self.upsample21 = upsample(dim_out*2, dim_out)

    def forward(self, x):
        conv1_out = self.conv1(x)
        conv2_out = self.conv2(self.max_pool(conv1_out))
        conv3_out = self.conv3(self.max_pool(conv2_out))
        conv4_out = self.conv4(self.max_pool(conv3_out))
        conv5_out = self.conv5(self.max_pool(conv4_out))

        conv5m_out = torch.cat((self.upsample54(conv5_out), conv4_out), 1)
        conv4m_out = self.conv4m(conv5m_out)
        conv4m_out_ = torch.cat((self.upsample43(conv4m_out), conv3_out), 1)
        conv3m_out = self.conv3m(conv4m_out_)
        conv3m_out_ = torch.cat((self.upsample32(conv3m_out), conv2_out), 1)
        conv2m_out = self.conv2m(conv3m_out_)
        conv2m_out_ = torch.cat((self.upsample21(conv2m_out), conv1_out), 1)
        conv1m_out = self.conv1m(conv2m_out_)
        conv0_out = self.conv0(conv1m_out)

        return conv0_out
