import torch.nn as nn

from .ResidualBlocks import ResidualBlock2dConv


def make_res_block_feature_extractor(in_channels, out_channels, kernelsize, stride, padding, dilation, a_val=2.0, b_val=0.3):
    downsample = None;
    if (stride != 2) or (in_channels != out_channels):
        downsample = nn.Sequential(nn.Conv2d(in_channels, out_channels,
                                             kernel_size=kernelsize,
                                             padding=padding,
                                             stride=stride,
                                             dilation=dilation),
                                   nn.BatchNorm2d(out_channels))
    layers = [];
    layers.append(ResidualBlock2dConv(in_channels, out_channels, kernelsize, stride, padding, dilation, downsample,a=a_val, b=b_val))
    return nn.Sequential(*layers)


def make_res_layers_feature_extractor(args, a=1.0, b=1.0):
    blocks = [];
    for k in range(0, args.num_layers_img):
        channels_in = (k+1)*args.DIM_img;
        channels_out = min(k+2, args.num_layers_img)*args.DIM_img;
        res_block = make_res_block_feature_extractor(channels_in,
                                                     channels_out,
                                                     kernelsize=args.kernelsize_enc_img,
                                                     stride=args.enc_stride_img,
                                                     padding=args.enc_padding_img,
                                                     dilation=1,
                                                     a_val=a,
                                                     b_val=b)
        blocks.append(res_block);
    return nn.Sequential(*blocks)


class FeatureExtractorImg(nn.Module):
    def __init__(self, a, b, image_channels=3, DIM_img=128, kernelsize_enc_img=3, enc_stride_img=2, enc_padding_img=1):
        super(FeatureExtractorImg, self).__init__();
        self.a = a;
        self.b = b;
        self.conv1 = nn.Conv2d(image_channels, DIM_img,
                              kernel_size=kernelsize_enc_img,
                              stride=enc_stride_img,
                              padding=enc_padding_img,
                              dilation=1,
                              bias=False)
        self.resblock1 = make_res_block_feature_extractor(DIM_img, 2 * DIM_img, kernelsize=4, stride=2,
                                                          padding=1, dilation=1, a_val=a, b_val=b)
        self.resblock2 = make_res_block_feature_extractor(2 * DIM_img, 3 * DIM_img, kernelsize=4, stride=2,
                                                          padding=1, dilation=1, a_val=self.a, b_val=self.b)
        self.resblock3 = make_res_block_feature_extractor(3 * DIM_img, 4 * DIM_img, kernelsize=4, stride=2,
                                                          padding=1, dilation=1, a_val=self.a, b_val=self.b)
        self.resblock4 = make_res_block_feature_extractor(4 * DIM_img, 5 * DIM_img, kernelsize=4, stride=2,
                                                          padding=0, dilation=1, a_val=self.a, b_val=self.b)

    def forward(self, x):
        out = self.conv1(x)
        out = self.resblock1(out);
        out = self.resblock2(out);
        out = self.resblock3(out);
        out = self.resblock4(out);
        return out
