import torch
import torch.nn as nn
import argparse


def init_UNet(new_args=None):
    '''
    loads flat BF_CNN with RF flexibility.
    '''
    parser = argparse.ArgumentParser(description='UNet')
    parser.add_argument('--kernel_size', default=3)
    parser.add_argument('--padding', default=1)
    parser.add_argument('--num_kernels', default=64)
    parser.add_argument('--RF', default=90)
    parser.add_argument('--num_channels', default=1)
    parser.add_argument('--bias', default=False)

    # architecture variables UNet
    parser.add_argument('--num_blocks', default=3)
    parser.add_argument('--num_enc_conv', default=2, help='min is 3')
    parser.add_argument('--num_mid_conv', default=2, help='min is 3')
    parser.add_argument('--num_dec_conv', default=2, help='min is 3')
    parser.add_argument('--pool_window', default=2, help='min is 2')

    args = parser.parse_args('')

    if new_args is not None:  # update args with given args
        for key, value in new_args.items():
            vars(args)[key] = value

    model = UNet(args)

    if torch.cuda.is_available():
        model = model.cuda()

    return model


class UNet(nn.Module):

    def __init__(self, args):
        super(UNet, self).__init__()

        self.pool_window = args.pool_window
        self.num_blocks = args.num_blocks
        # Encoder
        self.encoder = nn.ModuleDict([])
        for b in range(self.num_blocks):
            self.encoder[str(b)] = self.init_encoder_block(b, args)

        # Mid-layers
        mid_block = nn.ModuleList([])
        for l in range(args.num_mid_conv):
            if l == 0:
                mid_block.append(nn.Conv2d(args.num_kernels*(2**b), args.num_kernels*(
                    2**(b+1)), args.kernel_size, padding=args.padding, bias=args.bias))
            else:
                mid_block.append(nn.Conv2d(args.num_kernels*(2**(b+1)), args.num_kernels*(
                    2**(b+1)), args.kernel_size, padding=args.padding, bias=args.bias))
            mid_block.append(BF_batchNorm(args.num_kernels*(2**(b+1))))
            mid_block.append(nn.ReLU(inplace=True))

        self.mid_block = nn.Sequential(*mid_block)

        # Decoder
        self.decoder = nn.ModuleDict([])
        self.upsample = nn.ModuleDict([])
        for b in range(self.num_blocks-1, -1, -1):
            self.upsample[str(b)], self.decoder[str(
                b)] = self.init_decoder_block(b, args)

    def forward(self, x):
        pool = nn.AvgPool2d(kernel_size=self.pool_window,
                            stride=2, padding=int((self.pool_window-1)/2))
        # Encoder
        unpooled = []
        for b in range(self.num_blocks):
            x_unpooled = self.encoder[str(b)](x)
            x = pool(x_unpooled)
            unpooled.append(x_unpooled)

        # Mid-layers
        x = self.mid_block(x)

        # Decoder
        for b in range(self.num_blocks-1, -1, -1):
            x = self.upsample[str(b)](x)
            x = torch.cat([x, unpooled[b]], dim=1)
            x = self.decoder[str(b)](x)

        return x

    def init_encoder_block(self, b, args):
        enc_layers = nn.ModuleList([])
        if b == 0:
            enc_layers.append(nn.Conv2d(args.num_channels, args.num_kernels,
                              args.kernel_size, padding=args.padding, bias=args.bias))
            enc_layers.append(nn.ReLU(inplace=True))
            for l in range(1, args.num_enc_conv):
                enc_layers.append(nn.Conv2d(args.num_kernels, args.num_kernels,
                                  args.kernel_size, padding=args.padding, bias=args.bias))
                enc_layers.append(BF_batchNorm(args.num_kernels))
                enc_layers.append(nn.ReLU(inplace=True))
        else:
            for l in range(args.num_enc_conv):
                if l == 0:
                    enc_layers.append(nn.Conv2d(args.num_kernels*(2**(b-1)), args.num_kernels*(
                        2**b), args.kernel_size, padding=args.padding, bias=args.bias))
                else:
                    enc_layers.append(nn.Conv2d(args.num_kernels*(2**b), args.num_kernels*(
                        2**b), args.kernel_size, padding=args.padding, bias=args.bias))
                enc_layers.append(BF_batchNorm(args.num_kernels*(2**b)))
                enc_layers.append(nn.ReLU(inplace=True))

        return nn.Sequential(*enc_layers)

    def init_decoder_block(self, b, args):
        dec_layers = nn.ModuleList([])

        # initiate the last block:
        if b == 0:
            for l in range(args.num_dec_conv-1):
                if l == 0:
                    upsample = nn.ConvTranspose2d(
                        args.num_kernels*2, args.num_kernels, kernel_size=2, stride=2, bias=False)
                    dec_layers.append(nn.Conv2d(args.num_kernels*2, args.num_kernels,
                                      kernel_size=args.kernel_size, padding=args.padding, bias=False))
                else:
                    dec_layers.append(nn.Conv2d(args.num_kernels, args.num_kernels,
                                      args.kernel_size, padding=args.padding, bias=args.bias))
                dec_layers.append(BF_batchNorm(args.num_kernels))
                dec_layers.append(nn.ReLU(inplace=True))

            dec_layers.append(nn.Conv2d(args.num_kernels, args.num_channels,
                              kernel_size=args.kernel_size, padding=args.padding, bias=False))

        # other blocks
        else:
            for l in range(args.num_dec_conv):
                if l == 0:
                    upsample = nn.ConvTranspose2d(
                        args.num_kernels*(2**(b+1)), args.num_kernels*(2**b), kernel_size=2, stride=2, bias=False)
                    dec_layers.append(nn.Conv2d(args.num_kernels*(2**(b+1)), args.num_kernels*(
                        2**b), kernel_size=args.kernel_size, padding=args.padding, bias=False))
                else:
                    dec_layers.append(nn.Conv2d(args.num_kernels*(2**b), args.num_kernels*(
                        2**b), args.kernel_size, padding=args.padding, bias=args.bias))

                dec_layers.append(BF_batchNorm(args.num_kernels*(2**b)))
                dec_layers.append(nn.ReLU(inplace=True))
        return upsample, nn.Sequential(*dec_layers)


class BF_batchNorm(nn.Module):
    def __init__(self, num_kernels):
        super(BF_batchNorm, self).__init__()
        self.register_buffer("running_sd", torch.ones(1, num_kernels, 1, 1))
        g = (torch.randn((1, num_kernels, 1, 1))
             * (2./9./64.)).clamp_(-0.025, 0.025)
        self.gammas = nn.Parameter(g, requires_grad=True)

    def forward(self, x):
        training_mode = self.training
        sd_x = torch.sqrt(
            x.var(dim=(0, 2, 3), keepdim=True, unbiased=False) + 1e-05)
        if training_mode:
            x = x / sd_x.expand_as(x)
            with torch.no_grad():
                self.running_sd.copy_(
                    (1-.1) * self.running_sd.data + .1 * sd_x)

            x = x * self.gammas.expand_as(x)

        else:
            x = x / self.running_sd.expand_as(x)
            x = x * self.gammas.expand_as(x)

        return x
