# Residual Dense Network for Image Super-Resolution
# https://arxiv.org/abs/1802.08797
# modified from: https://github.com/thstkdgus35/EDSR-PyTorch

from argparse import Namespace

import torch
import torch.nn as nn

from models import register


class RDB_Conv(nn.Module):
    def __init__(self, inChannels, growRate, kSize=3):
        super(RDB_Conv, self).__init__()
        Cin = inChannels
        G  = growRate
        self.conv = nn.Sequential(*[
            nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),
            nn.ReLU()
        ])

    def forward(self, x):
        '''
        Args:
            x: shape (N, inChannels, H, W)
        Return:
            output: shape (N, inChannels + growRate, H, W)
        '''
        # out: shape (N, growRate, H, W)
        out = self.conv(x)
        return torch.cat((x, out), 1)

class RDB(nn.Module):
    def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
        super(RDB, self).__init__()
        G0 = growRate0
        G  = growRate
        C  = nConvLayers

        convs = []
        for c in range(C):
            convs.append(RDB_Conv(G0 + c*G, G))
        self.convs = nn.Sequential(*convs)

        # Local Feature Fusion
        # 1*1 Conv2D to fuse all channel features per pixel
        self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)

    def forward(self, x):
        '''
        Args:
            x: shape (N, G0, H, W)
        Return:
            output: shape (N, G0, H, W)
        '''
        # self.convs(x): shape (N, G0 + C*G, H, W)
        # self.LFF(self.convs(x)): shape (N, G0, H, W)
        return self.LFF(self.convs(x)) + x

class RDN(nn.Module):
    def __init__(self, args):
        '''
        Args:
            args: 
                scale: list, [0]: int, upsample scale, e.g., 2,3,4
                G0: the feature/channel dim for RDB layers
                RDNkSize: RDN kernel size
                RDNconfig: 'A' or 'B', for (D, C, G)
                    D: number of RDB blocks in parallel, 
                    C: number of RDB_Conv layers used each RDB(), 
                    G: growRate/out channels used in each RDB(), and hidden dim used in the upsampling process

                    'A': (20, 6, 32)
                    'B': (16, 8, 64)
                n_colors: number color channels of the input images, also used as the output channels for superresolution
                no_upsampling: 
                    True: do not do upsampling at the end of RDN to server as image encoder 
                    False: do unsamlping for RDN, it is a baseline for super-resolution
        '''
        super(RDN, self).__init__()
        self.args = args
        r = args.scale[0]
        G0 = args.G0
        kSize = args.RDNkSize

        # number of RDB blocks, conv layers, out channels
        RDNconfig_dict = {
            'A': (20, 6, 32),
            'B': (16, 8, 64),
        }
        self.D, C, G = RDNconfig_dict[args.RDNconfig]

        # Shallow feature extraction net
        self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
        self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)

        # Redidual dense blocks and dense feature fusion
        self.RDBs = nn.ModuleList()
        for i in range(self.D):
            self.RDBs.append(
                RDB(growRate0 = G0, growRate = G, nConvLayers = C)
            )

        # Global Feature Fusion
        self.GFF = nn.Sequential(*[
            nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),
            nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
        ])

        if args.no_upsampling:
            self.out_dim = G0
        else:
            self.out_dim = args.n_colors
            # Up-sampling net
            if r == 2 or r == 3:
                self.UPNet = nn.Sequential(*[
                    nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),
                    nn.PixelShuffle(r),
                    nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
                ])
            elif r == 4:
                self.UPNet = nn.Sequential(*[
                    nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
                    nn.PixelShuffle(2),
                    nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
                    nn.PixelShuffle(2),
                    nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
                ])
            else:
                raise ValueError("scale must be 2 or 3 or 4.")

    def forward(self, x):
        '''
        Args:
            x: shape (N, n_colors, H, W)
        Return:
            output:
                args.no_upsampling == True: 
                    shape (N, G0, H, W)
                args.no_upsampling == False:  
                    shape (N, n_colors, r * H, r * W)
        '''
        # f__1: shape (N, G0, H, W)
        f__1 = self.SFENet1(x)
        # x:    shape (N, G0, H, W)
        x  = self.SFENet2(f__1)

        # RDBs_out; a list (self.D) of shape (N, G0, H, W)
        RDBs_out = []
        for i in range(self.D):
            x = self.RDBs[i](x)
            RDBs_out.append(x)

        # torch.cat(RDBs_out,1): shape (N, self.D * G0, H, W)
        # x: shape (N, G0, H, W)
        x = self.GFF(torch.cat(RDBs_out,1))
        x += f__1

        if self.args.no_upsampling:
            return x
        else:
            # self.UPNet(x): shape (N, n_colors, r * H, r * W)
            return self.UPNet(x)


# this is the RDN baseline to upsampling images
@register('rdn')
def make_rdn(G0=64, RDNkSize=3, RDNconfig='B',
             scale=2, no_upsampling=False, n_colors = 3):
    args = Namespace()
    args.G0 = G0
    args.RDNkSize = RDNkSize
    args.RDNconfig = RDNconfig

    args.scale = [scale]
    args.no_upsampling = no_upsampling

    args.n_colors = n_colors
    return RDN(args)
