import torch.nn
import torch.nn as nn
from debug_tools import format_tensor_size
import torch.nn.functional as F
from GeneralModules import activation


class Block(nn.Module):
    def __init__(self, inChannels, outChannels, kernel_size, activation, stride=1, padding=1, dilation=1):
        super().__init__()
        self.conv = nn.Conv1d(inChannels, outChannels, kernel_size, padding=padding, stride=stride, dilation=dilation)
        self.activation = activation
        self.batch_norm = nn.BatchNorm1d(outChannels)

    def forward(self, x):
        return self.batch_norm(self.activation(self.conv(x)))


class BlockTransp(nn.Module):
    def __init__(self, inChannels, outChannels, kernel_size, activation, stride=1, padding=1, up=True):
        super().__init__()
        self.up = up
        if not up:
            self.conv = nn.ConvTranspose1d(inChannels, outChannels, kernel_size, output_padding=padding, padding=padding, stride=stride)
        else:
            self.conv = nn.Conv1d(inChannels, outChannels, kernel_size, padding=padding, stride=stride)
        self.upsample = nn.Upsample(scale_factor=2, mode="linear")
        self.activation = activation
        self.batch_norm = nn.BatchNorm1d(outChannels)

    def forward(self, x):
        if self.up:
            return self.activation((self.batch_norm(self.conv(self.upsample(x)))))
        else:
            return self.batch_norm(self.activation(self.conv(x)))


class BigBlockHalf2(nn.Module):
    def __init__(self, in_channels, out_channel, act):
        super().__init__()
        self.b1 = Block(in_channels, out_channel, kernel_size=3, stride=2, padding=1, activation=act)

    def forward(self, x):
        return self.b1(x)


class BigBlockTransMul2(nn.Module):
    def __init__(self, in_channels, out_channel, act, up=True):
        super().__init__()
        self.up = up
        self.b1 = BlockTransp(in_channels, out_channel, kernel_size=3, stride=2, padding=1, activation=act, up=up)
        self.upsample = nn.Upsample(scale_factor=2, mode="linear")

    def forward(self, x):
        if self.up:
            return self.upsample(self.b1(x))
        else:
            return self.b1(x)


class BigBlockHalf4(nn.Module):
    def __init__(self, in_channels, out_channels, act):
        super().__init__()
        self.b1 = Block(in_channels, out_channels, kernel_size=3, stride=2, padding=1, activation=act)
        self.b2 = Block(out_channels, out_channels, kernel_size=1, stride=1, padding=0, activation=act)
        self.b3 = Block(out_channels, out_channels, kernel_size=1, stride=1, padding=0, activation=act)
        self.b4 = Block(out_channels, out_channels, kernel_size=3, stride=2, padding=1, activation=act)

    def forward(self, x):
        return self.b4((self.b3(self.b2(self.b1(x)))))


class BigBlockTransMul4(nn.Module):
    def __init__(self, in_channels, out_channels, act, up=True):
        super().__init__()
        self.b1 = BlockTransp(in_channels, in_channels, kernel_size=3, stride=2, padding=1, activation=act, up=up)
        self.b2 = BlockTransp(in_channels, in_channels, kernel_size=1, stride=1, padding=0, activation=act, up=up)
        self.b3 = BlockTransp(in_channels, in_channels, kernel_size=1, stride=1, padding=0, activation=act, up=up)
        self.b4 = BlockTransp(in_channels, out_channels, kernel_size=3, stride=2, padding=1, activation=act, up=up)

    def forward(self, x):
        return self.b4(self.b3(self.b2(self.b1(x))))


class FCNet(nn.Module):
    def __init__(self, in_channels, out_channels, network_properties, verbose=False):
        super(FCNet, self).__init__()
        self.verbose = verbose
        self.atrous = bool(network_properties["atrous"])
        self.start = int(network_properties["start"])
        self.activation = activation(network_properties["activation"])
        self.retrain = int(network_properties["retrain"])
        self.up = True

        torch.manual_seed(self.retrain)
        self.inp = Block(in_channels, self.start, kernel_size=3, stride=2, padding=1, activation=self.activation)
        self.conv_list = nn.ModuleList([
            BigBlockHalf4(self.start, self.start * 2, self.activation),
            BigBlockHalf4(self.start * 2, self.start * 4, self.activation),
            BigBlockHalf4(self.start * 4, self.start * 8, self.activation),
            BigBlockHalf4(self.start * 8, self.start * 16, self.activation),
            BigBlockHalf4(self.start * 16, self.start * 32, self.activation)
        ])
        self.trans_conv_list = nn.ModuleList([
            BigBlockTransMul4(self.start * 32, self.start * 16, self.activation, self.up),
            BigBlockTransMul4(self.start * 16, self.start * 8, self.activation, self.up),
            BigBlockTransMul4(self.start * 8, self.start * 4, self.activation, self.up),
            BigBlockTransMul4(self.start * 4, self.start * 2, self.activation, self.up),
            BigBlockTransMul4(self.start * 2, self.start, self.activation, self.up)

        ])
        self.atrous_1 = Block(self.start, 128, kernel_size=3, stride=1, padding=0, dilation=1, activation=self.activation)
        self.atrous_3 = Block(self.start, 128, kernel_size=3, stride=1, padding=0, dilation=3, activation=self.activation)
        self.atrous_5 = Block(self.start, 128, kernel_size=3, stride=1, padding=0, dilation=5, activation=self.activation)

        if self.atrous:
            self.out = nn.Conv1d(128, out_channels, kernel_size=1, padding=0, stride=1)
        else:
            self.out = nn.ConvTranspose1d(self.start, out_channels, kernel_size=3, padding=1, stride=2, output_padding=1)

    def forward(self, x):
        if self.verbose: print("---------------------------")
        if self.verbose: print(x.shape)
        x = self.inp(x)
        if self.verbose: print(x.shape)
        for i, layer in enumerate(self.conv_list):
            x = layer(x)
            if self.verbose: print(x.shape)
        if self.verbose: print("---------------------------")
        for i, layer in enumerate(self.trans_conv_list):
            x = layer(x)
            if self.verbose: print(x.shape)
        if self.atrous:
            x1 = self.atrous_1(x)
            x2 = self.atrous_3(x)
            x3 = self.atrous_5(x)
            x = torch.cat((x1, x2, x3), -1)
            if self.verbose: print(x1.shape, x2.shape, x3.shape)
            if self.verbose: print(x.shape)
        x = self.out(x)
        if self.verbose: print(x.shape)

        if self.atrous:
            x_padding = -int((x.shape[-1] - 2048) / 2)
            x = F.pad(x, [x_padding, x_padding])
        if self.verbose: print(x.shape)
        if self.verbose: quit()
        return x.squeeze(1)

    def print_size(self):
        nparams = 0
        nbytes = 0

        for param in self.parameters():
            nparams += param.numel()
            nbytes += param.data.element_size() * param.numel()

        print(f'Total number of model parameters: {nparams} (~{format_tensor_size(nbytes)})')

        return nparams


class FCNetBurg(nn.Module):
    def __init__(self, in_channels, out_channels, network_properties, verbose=False):
        super(FCNetBurg, self).__init__()
        self.verbose = verbose
        self.atrous = bool(network_properties["atrous"])
        self.start = int(network_properties["start"])
        self.activation = activation(network_properties["activation"])
        self.retrain = int(network_properties["retrain"])
        self.up = False

        torch.manual_seed(self.retrain)
        self.inp = Block(in_channels, self.start, kernel_size=3, stride=2, padding=1, activation=self.activation)
        self.conv_list = nn.ModuleList([
            BigBlockHalf4(self.start, self.start * 2, self.activation),
            BigBlockHalf4(self.start * 2, self.start * 4, self.activation),
            BigBlockHalf4(self.start * 4, self.start * 8, self.activation),
            BigBlockHalf4(self.start * 8, self.start * 16, self.activation),
            BigBlockHalf4(self.start * 16, self.start * 32, self.activation)
        ])

        self.trans_conv_list = nn.ModuleList([
            BigBlockTransMul4(self.start * 32, self.start * 16, self.activation, self.up),
            BigBlockTransMul4(self.start * 16, self.start * 8, self.activation, self.up),
            BigBlockTransMul4(self.start * 8, self.start * 4, self.activation, self.up),
            BigBlockTransMul4(self.start * 4, self.start * 2, self.activation, self.up),
            BigBlockTransMul2(self.start * 2, self.start, self.activation, self.up),
        ])
        self.atrous_1 = Block(self.start, 128, kernel_size=3, stride=1, padding=0, dilation=1, activation=self.activation)
        self.atrous_3 = Block(self.start, 128, kernel_size=3, stride=1, padding=0, dilation=3, activation=self.activation)
        self.atrous_5 = Block(self.start, 128, kernel_size=3, stride=1, padding=0, dilation=5, activation=self.activation)

        if self.atrous:
            self.out = nn.Conv1d(128, out_channels, kernel_size=1, padding=0, stride=1)
        else:
            self.out = nn.ConvTranspose1d(self.start, out_channels, kernel_size=3, padding=1, stride=2, output_padding=1)

    def forward(self, x):
        if self.verbose: print("---------------------------")
        if self.verbose: print(x.shape)
        x = self.inp(x)
        if self.verbose: print(x.shape)
        for i, layer in enumerate(self.conv_list):
            x = layer(x)
            if self.verbose: print(x.shape)
        if self.verbose: print("---------------------------")
        for i, layer in enumerate(self.trans_conv_list):
            x = layer(x)
            if self.verbose: print(x.shape)
        if self.atrous:
            x1 = self.atrous_1(x)
            x2 = self.atrous_3(x)
            x3 = self.atrous_5(x)
            x = torch.cat((x1, x2, x3), -1)
            if self.verbose: print(x1.shape, x2.shape, x3.shape)
            if self.verbose: print(x.shape)
        x = self.out(x)
        if self.verbose: print(x.shape)

        if self.atrous:
            x_padding = -int((x.shape[-1] - 1024) / 2)
            x = F.pad(x, [x_padding, x_padding])
        if self.verbose: print(x.shape)
        if self.verbose: quit()
        return x.squeeze(1)

    def print_size(self):
        nparams = 0
        nbytes = 0

        for param in self.parameters():
            nparams += param.numel()
            nbytes += param.data.element_size() * param.numel()

        print(f'Total number of model parameters: {nparams} (~{format_tensor_size(nbytes)})')

        return nparams


""" Parts of the U-Net model and 2D Convolution"""


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        return self.conv(x1)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNetSkip(nn.Module):
    def __init__(self, n_channels, n_classes, net_properties):
        super(UNetSkip, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bool(net_properties["atrous"])
        torch.manual_seed(net_properties["retrain"])

        self.inc = DoubleConv(n_channels, 32)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if self.bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, self.bilinear)
        self.up2 = Up(512, 256 // factor, self.bilinear)
        self.up3 = Up(256, 128 // factor, self.bilinear)
        self.up4 = Up(128, 64, self.bilinear)
        self.outc = OutConv(32, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def print_size(self):
        nparams = 0
        nbytes = 0

        for param in self.parameters():
            nparams += param.numel()
            nbytes += param.data.element_size() * param.numel()

        print(f'Total number of model parameters: {nparams} (~{format_tensor_size(nbytes)})')

        return nparams


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, net_properties):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bool(net_properties["atrous"])
        start = int(net_properties["start"])
        torch.manual_seed(net_properties["retrain"])

        self.inc = DoubleConv(n_channels, start)
        self.down1 = Down(start, start * 2)
        self.down2 = Down(start * 2, start * 4)
        self.down3 = Down(start * 4, start * 8)
        factor = 2 if self.bilinear else 1
        self.down4 = Down(start * 8, start * 16)
        self.up1 = Up(start * 16, start * 8, self.bilinear)
        self.up2 = Up(start * 8, start * 4, self.bilinear)
        self.up3 = Up(start * 4, start * 2, self.bilinear)
        self.up4 = Up(start * 2, start, self.bilinear)
        self.outc = OutConv(start, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def print_size(self):
        nparams = 0
        nbytes = 0

        for param in self.parameters():
            nparams += param.numel()
            nbytes += param.data.element_size() * param.numel()

        print(f'Total number of model parameters: {nparams} (~{format_tensor_size(nbytes)})')

        return nparams
