#!/usr/bin/env python3.7

import math

import torch
from torch import nn
from torch import Tensor


def random_weights_init(m):
    if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
        nn.init.xavier_normal_(m.weight.data)
    elif type(m) == nn.BatchNorm2d:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class Dummy(nn.Module):
    def __init__(self, in_dim: int, out_dim: int) -> None:
        super().__init__()

        self.down = nn.Conv2d(in_dim, 10, kernel_size=2, stride=2)
        self.up = nn.ConvTranspose2d(10, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1)

    def forward(self, input: Tensor) -> Tensor:
        return self.up(self.down(input))

    def init_weights(self, *args, **kwargs):
        self.apply(random_weights_init)


Dimwit = Dummy


class Dummy3D(nn.Module):
    def __init__(self, in_dim: int, out_dim: int) -> None:
        super().__init__()

        self.down = nn.Conv3d(in_dim, 10, kernel_size=2, stride=2)
        self.up = nn.ConvTranspose3d(10, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1)

    def forward(self, input: Tensor) -> Tensor:
        return self.up(self.down(input))

    def init_weights(self, *args, **kwargs):
        self.apply(random_weights_init)


Dimwit3D = Dummy3D


class UNet(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, nG=64) -> None:
        super().__init__()

        from models.unet import (convBatch,
                                 residualConv,
                                 upSampleConv)

        self.conv0 = nn.Sequential(convBatch(in_dim, nG),
                                   convBatch(nG, nG))
        self.conv1 = nn.Sequential(convBatch(nG * 1, nG * 2, stride=2),
                                   convBatch(nG * 2, nG * 2))
        self.conv2 = nn.Sequential(convBatch(nG * 2, nG * 4, stride=2),
                                   convBatch(nG * 4, nG * 4))

        self.bridge = nn.Sequential(convBatch(nG * 4, nG * 8, stride=2),
                                    residualConv(nG * 8, nG * 8),
                                    convBatch(nG * 8, nG * 8))

        self.deconv1 = upSampleConv(nG * 8, nG * 8)
        self.conv5 = nn.Sequential(convBatch(nG * 12, nG * 4),
                                   convBatch(nG * 4, nG * 4))
        self.deconv2 = upSampleConv(nG * 4, nG * 4)
        self.conv6 = nn.Sequential(convBatch(nG * 6, nG * 2),
                                   convBatch(nG * 2, nG * 2))
        self.deconv3 = upSampleConv(nG * 2, nG * 2)
        self.conv7 = nn.Sequential(convBatch(nG * 3, nG * 1),
                                   convBatch(nG * 1, nG * 1))
        self.final = nn.Conv2d(nG, out_dim, kernel_size=1)

    def forward(self, input):
        input = input.float()
        x0 = self.conv0(input)
        x1 = self.conv1(x0)
        x2 = self.conv2(x1)

        bridge = self.bridge(x2)

        y0 = self.deconv1(bridge)
        y1 = self.deconv2(self.conv5(torch.cat((y0, x2), dim=1)))
        y2 = self.deconv3(self.conv6(torch.cat((y1, x1), dim=1)))
        y3 = self.conv7(torch.cat((y2, x0), dim=1))

        return self.final(y3)

    def init_weights(self, *args, **kwargs):
        self.apply(random_weights_init)


class UNet3D(nn.Module):
    def __init__(self, nin: int, nout: int, nG=64):
        super().__init__()

        from models.unet_3d import (convBatch,
                                    residualConv,
                                    upSampleConv)

        self.conv0 = nn.Sequential(convBatch(nin, nG),
                                   convBatch(nG, nG))
        self.conv1 = nn.Sequential(convBatch(nG * 1, nG * 2, stride=2),
                                   convBatch(nG * 2, nG * 2))
        self.conv2 = nn.Sequential(convBatch(nG * 2, nG * 4, stride=2),
                                   convBatch(nG * 4, nG * 4))

        self.bridge = nn.Sequential(convBatch(nG * 4, nG * 8, stride=2),
                                    residualConv(nG * 8, nG * 8),
                                    convBatch(nG * 8, nG * 8))

        self.deconv1 = upSampleConv(nG * 8, nG * 8)
        self.conv5 = nn.Sequential(convBatch(nG * 12, nG * 4),
                                   convBatch(nG * 4, nG * 4))
        self.deconv2 = upSampleConv(nG * 4, nG * 4)
        self.conv6 = nn.Sequential(convBatch(nG * 6, nG * 2),
                                   convBatch(nG * 2, nG * 2))
        self.deconv3 = upSampleConv(nG * 2, nG * 2)
        self.conv7 = nn.Sequential(convBatch(nG * 3, nG * 1),
                                   convBatch(nG * 1, nG * 1))
        self.final = nn.Conv2d(nG, nout, kernel_size=1)

    def forward(self, input):
        input = input.float()
        x0 = self.conv0(input)
        x1 = self.conv1(x0)
        x2 = self.conv2(x1)

        bridge = self.bridge(x2)

        y0 = self.deconv1(bridge)
        y1 = self.deconv2(self.conv5(torch.cat((y0, x2), dim=1)))
        y2 = self.deconv3(self.conv6(torch.cat((y1, x1), dim=1)))
        y3 = self.conv7(torch.cat((y2, x0), dim=1))

        return self.final(y3)

    def init_weights(self, *args, **kwargs):
        self.apply(random_weights_init)


class ENet(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.projecting_factor = 4
        self.n_kernels = 16

        from models.enet import (BottleNeckDownSampling,
                                 BottleNeckNormal,
                                 BottleNeckDownSamplingDilatedConv,
                                 BottleNeckNormal_Asym,
                                 BottleNeckDownSamplingDilatedConvLast,
                                 BottleNeckUpSampling,
                                 upSampleConv)

        # Initial
        self.conv0 = nn.Conv2d(in_dim, 15, kernel_size=3, stride=2, padding=1)
        self.maxpool0 = nn.MaxPool2d(2, return_indices=True)

        # First group
        self.bottleNeck1_0 = BottleNeckDownSampling(self.n_kernels, self.projecting_factor, self.n_kernels * 4)
        self.bottleNeck1_1 = BottleNeckNormal(self.n_kernels * 4, self.n_kernels * 4, self.projecting_factor, 0.01)
        self.bottleNeck1_2 = BottleNeckNormal(self.n_kernels * 4, self.n_kernels * 4, self.projecting_factor, 0.01)
        self.bottleNeck1_3 = BottleNeckNormal(self.n_kernels * 4, self.n_kernels * 4, self.projecting_factor, 0.01)
        self.bottleNeck1_4 = BottleNeckNormal(self.n_kernels * 4, self.n_kernels * 4, self.projecting_factor, 0.01)

        # Second group
        self.bottleNeck2_0 = BottleNeckDownSampling(self.n_kernels * 4, self.projecting_factor, self.n_kernels * 8)
        self.bottleNeck2_1 = BottleNeckNormal(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor, 0.1)
        self.bottleNeck2_2 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor,
                                                               self.n_kernels * 8, 2)
        self.bottleNeck2_3 = BottleNeckNormal_Asym(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor,
                                                   0.1)
        self.bottleNeck2_4 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor,
                                                               self.n_kernels * 8, 4)
        self.bottleNeck2_5 = BottleNeckNormal(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor, 0.1)
        self.bottleNeck2_6 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor,
                                                               self.n_kernels * 8, 8)
        self.bottleNeck2_7 = BottleNeckNormal_Asym(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor,
                                                   0.1)
        self.bottleNeck2_8 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor,
                                                               self.n_kernels * 8, 16)

        # Third group
        self.bottleNeck3_1 = BottleNeckNormal(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor, 0.1)
        self.bottleNeck3_2 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor,
                                                               self.n_kernels * 8, 2)
        self.bottleNeck3_3 = BottleNeckNormal_Asym(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor,
                                                   0.1)
        self.bottleNeck3_4 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor,
                                                               self.n_kernels * 8, 4)
        self.bottleNeck3_5 = BottleNeckNormal(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor, 0.1)
        self.bottleNeck3_6 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor,
                                                               self.n_kernels * 8, 8)
        self.bottleNeck3_7 = BottleNeckNormal_Asym(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor,
                                                   0.1)
        self.bottleNeck3_8 = BottleNeckDownSamplingDilatedConvLast(self.n_kernels * 8, self.projecting_factor,
                                                                   self.n_kernels * 4, 16)

        # ### Decoding path ####
        # Unpooling 1
        self.unpool_0 = nn.MaxUnpool2d(2)

        self.bottleNeck_Up_1_0 = BottleNeckUpSampling(self.n_kernels * 8, self.projecting_factor,
                                                      self.n_kernels * 4)
        self.PReLU_Up_1 = nn.PReLU()

        self.bottleNeck_Up_1_1 = BottleNeckNormal(self.n_kernels * 4, self.n_kernels * 4, self.projecting_factor,
                                                  0.1)
        self.bottleNeck_Up_1_2 = BottleNeckNormal(self.n_kernels * 4, self.n_kernels, self.projecting_factor, 0.1)

        # Unpooling 2
        self.unpool_1 = nn.MaxUnpool2d(2)
        self.bottleNeck_Up_2_1 = BottleNeckUpSampling(self.n_kernels * 2, self.projecting_factor, self.n_kernels)
        self.bottleNeck_Up_2_2 = BottleNeckNormal(self.n_kernels, self.n_kernels, self.projecting_factor, 0.1)
        self.PReLU_Up_2 = nn.PReLU()

        # Unpooling Last
        self.deconv3 = upSampleConv(self.n_kernels, self.n_kernels)

        self.out_025 = nn.Conv2d(self.n_kernels * 8, out_dim, kernel_size=3, stride=1, padding=1)
        self.out_05 = nn.Conv2d(self.n_kernels, out_dim, kernel_size=3, stride=1, padding=1)
        self.final = nn.Conv2d(self.n_kernels, out_dim, kernel_size=1)

    def forward(self, input):
        conv_0 = self.conv0(input)  # This will go as res in deconv path
        maxpool_0, indices_0 = self.maxpool0(input)
        outputInitial = torch.cat((conv_0, maxpool_0), dim=1)

        # First group
        bn1_0, indices_1 = self.bottleNeck1_0(outputInitial)
        bn1_1 = self.bottleNeck1_1(bn1_0)
        bn1_2 = self.bottleNeck1_2(bn1_1)
        bn1_3 = self.bottleNeck1_3(bn1_2)
        bn1_4 = self.bottleNeck1_4(bn1_3)

        # Second group
        bn2_0, indices_2 = self.bottleNeck2_0(bn1_4)
        bn2_1 = self.bottleNeck2_1(bn2_0)
        bn2_2 = self.bottleNeck2_2(bn2_1)
        bn2_3 = self.bottleNeck2_3(bn2_2)
        bn2_4 = self.bottleNeck2_4(bn2_3)
        bn2_5 = self.bottleNeck2_5(bn2_4)
        bn2_6 = self.bottleNeck2_6(bn2_5)
        bn2_7 = self.bottleNeck2_7(bn2_6)
        bn2_8 = self.bottleNeck2_8(bn2_7)

        # Third group
        bn3_1 = self.bottleNeck3_1(bn2_8)
        bn3_2 = self.bottleNeck3_2(bn3_1)
        bn3_3 = self.bottleNeck3_3(bn3_2)
        bn3_4 = self.bottleNeck3_4(bn3_3)
        bn3_5 = self.bottleNeck3_5(bn3_4)
        bn3_6 = self.bottleNeck3_6(bn3_5)
        bn3_7 = self.bottleNeck3_7(bn3_6)
        bn3_8 = self.bottleNeck3_8(bn3_7)

        # #### Deconvolution Path ####
        #  First block #
        unpool_0 = self.unpool_0(bn3_8, indices_2)

        # bn_up_1_0 = self.bottleNeck_Up_1_0(unpool_0) # Not concatenate
        bn_up_1_0 = self.bottleNeck_Up_1_0(torch.cat((unpool_0, bn1_4), dim=1))  # concatenate

        up_block_1 = self.PReLU_Up_1(unpool_0 + bn_up_1_0)

        bn_up_1_1 = self.bottleNeck_Up_1_1(up_block_1)
        bn_up_1_2 = self.bottleNeck_Up_1_2(bn_up_1_1)

        #  Second block #
        unpool_1 = self.unpool_1(bn_up_1_2, indices_1)

        # bn_up_2_1 = self.bottleNeck_Up_2_1(unpool_1) # Not concatenate
        bn_up_2_1 = self.bottleNeck_Up_2_1(torch.cat((unpool_1, outputInitial), dim=1))  # concatenate

        bn_up_2_2 = self.bottleNeck_Up_2_2(bn_up_2_1)

        up_block_1 = self.PReLU_Up_2(unpool_1 + bn_up_2_2)

        unpool_12 = self.deconv3(up_block_1)

        return self.final(unpool_12)

    def init_weights(self, *args, **kwargs):
        self.apply(random_weights_init)


class ResidualUNet(nn.Module):
    # def __init__(self, output_nc, ngf=32):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.in_dim = in_dim
        ngf = 32
        self.out_dim = ngf
        self.final_out_dim = out_dim
        act_fn = nn.LeakyReLU(0.2, inplace=True)
        act_fn_2 = nn.ReLU()

        from models.residualunet import (Conv_residual_conv,
                                         maxpool,
                                         conv_decod_block)

        # Encoder
        self.down_1 = Conv_residual_conv(self.in_dim, self.out_dim, act_fn)
        self.pool_1 = maxpool()
        self.down_2 = Conv_residual_conv(self.out_dim, self.out_dim * 2, act_fn)
        self.pool_2 = maxpool()
        self.down_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 4, act_fn)
        self.pool_3 = maxpool()
        self.down_4 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 8, act_fn)
        self.pool_4 = maxpool()

        # Bridge between Encoder-Decoder
        self.bridge = Conv_residual_conv(self.out_dim * 8, self.out_dim * 16, act_fn)

        # Decoder
        self.deconv_1 = conv_decod_block(self.out_dim * 16, self.out_dim * 8, act_fn_2)
        self.up_1 = Conv_residual_conv(self.out_dim * 8, self.out_dim * 8, act_fn_2)
        self.deconv_2 = conv_decod_block(self.out_dim * 8, self.out_dim * 4, act_fn_2)
        self.up_2 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 4, act_fn_2)
        self.deconv_3 = conv_decod_block(self.out_dim * 4, self.out_dim * 2, act_fn_2)
        self.up_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 2, act_fn_2)
        self.deconv_4 = conv_decod_block(self.out_dim * 2, self.out_dim, act_fn_2)
        self.up_4 = Conv_residual_conv(self.out_dim, self.out_dim, act_fn_2)

        self.out = nn.Conv2d(self.out_dim, self.final_out_dim, kernel_size=3, stride=1, padding=1)

        print(f"Initialized {self.__class__.__name__} succesfully")

    def forward(self, input):
        # Encoding path

        down_1 = self.down_1(input)  # This will go as res in deconv path
        down_2 = self.down_2(self.pool_1(down_1))
        down_3 = self.down_3(self.pool_2(down_2))
        down_4 = self.down_4(self.pool_3(down_3))

        bridge = self.bridge(self.pool_4(down_4))

        # Decoding path
        deconv_1 = self.deconv_1(bridge)
        skip_1 = (deconv_1 + down_4) / 2  # Residual connection
        up_1 = self.up_1(skip_1)

        deconv_2 = self.deconv_2(up_1)
        skip_2 = (deconv_2 + down_3) / 2  # Residual connection
        up_2 = self.up_2(skip_2)

        deconv_3 = self.deconv_3(up_2)
        skip_3 = (deconv_3 + down_2) / 2  # Residual connection
        up_3 = self.up_3(skip_3)

        deconv_4 = self.deconv_4(up_3)
        skip_4 = (deconv_4 + down_1) / 2  # Residual connection
        up_4 = self.up_4(skip_4)

        return self.out(up_4)

    def init_weights(self, *args, **kwargs):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
