import torch
import torch.nn as nn


def init_weights(m):
    if type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0.0)

    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0.0)

    if type(m) == nn.ConvTranspose2d:
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0.0)


def activation_func(activation):
    return nn.ModuleDict([['relu', nn.ReLU(inplace=True)],
                          ['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)],
                          ['selu', nn.SELU(inplace=True)],
                          ['none', nn.Identity()],
                          ["sigmoid", nn.Sigmoid()],
                          ["tanh", nn.Tanh()],
                          ["spatialsoftargmax", SpatialSoftArgmax(normalise=True)],
                          ["softmax", nn.Softmax(dim=1)]])[activation]


class CoordinateUtils(object):
    @staticmethod
    def get_image_coordinates(h, w, normalise):
        x_range = torch.arange(w, dtype=torch.float32)
        y_range = torch.arange(h, dtype=torch.float32)
        if normalise:
            x_range = (x_range / (w - 1)) * 2 - 1
            y_range = (y_range / (h - 1)) * 2 - 1
        image_x = x_range.unsqueeze(0).repeat_interleave(h, 0)
        image_y = y_range.unsqueeze(0).repeat_interleave(w, 0).t()
        return image_x, image_y


class SpatialSoftArgmax(nn.Module):
    def __init__(self, temperature=None, normalise=False):
        super().__init__()
        self.temperature = nn.Parameter(torch.ones(1), requires_grad=True) if temperature is None else nn.Parameter(
            torch.ones(1) * temperature.cpu().numpy()[0], requires_grad=False).to(temperature.device)
        self.normalise = normalise

    def forward(self, x):
        n, c, h, w = x.size()

        spatial_softmax_per_map = nn.functional.softmax(x.view(n * c, h * w) / self.temperature, dim=1)
        spatial_softmax = spatial_softmax_per_map.view(n, c, h, w)

        image_x, image_y = CoordinateUtils.get_image_coordinates(h, w, normalise=self.normalise)

        image_coordinates = torch.cat((image_x.unsqueeze(-1), image_y.unsqueeze(-1)), dim=-1)

        image_coordinates = image_coordinates.to(device=x.device)


        expanded_spatial_softmax = spatial_softmax.unsqueeze(-1)
        image_coordinates = image_coordinates.unsqueeze(0)

        out = torch.sum(expanded_spatial_softmax * image_coordinates, dim=[2, 3])

        return out


class DoubleConvBlock(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size, stride, batchnorm=True, dropout=0., bias=True, activation="relu",
                 residual=False, final_activation=None, final_batchnorm=None):
        super(DoubleConvBlock, self).__init__()
        assert kernel_size % 2. == 1.

        if (residual):
            assert ch_in == ch_out
            self.pooling = nn.MaxPool2d(kernel_size=stride, stride=stride)
            stride = 1

        self.residual = residual

        convs = []

        convs.append(
            ConvBlock(ch_in, ch_out, kernel_size=kernel_size, stride=stride, batchnorm=batchnorm, dropout=dropout,
                      bias=bias, activation=activation, residual=False))

        if (final_batchnorm is not None):
            batchnorm = final_batchnorm

        convs.append(
            ConvBlock(ch_out, ch_out, kernel_size=kernel_size, stride=1, batchnorm=batchnorm, dropout=False, bias=bias,
                      activation="none", residual=False))

        if (dropout > 0):
            self.dropout = nn.Dropout2d(dropout, False)
        else:
            self.dropout = None

        if (final_activation is not None):
            self.activation = activation_func(final_activation)
        else:
            self.activation = activation_func(activation)

        self.conv = nn.Sequential(*convs)

        self.apply(init_weights)

    def forward(self, x):
        if (self.residual):
            residual = x
            x = self.conv(x)
            x += residual
            x = self.activation(x)
            x = self.pooling(x)
        else:
            x = self.conv(x)
            x = self.activation(x)

        if (self.dropout is not None):
            x = self.dropout(x)

        return x


class ConvBlock(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size, stride,
                 batchnorm=True, dropout=0., bias=True,
                 activation='relu', residual=False):
        super(ConvBlock, self).__init__()
        assert kernel_size % 2. == 1.

        if (residual):
            assert ch_in == ch_out

            self.pooling = nn.MaxPool2d(kernel_size=stride, stride=stride)
            stride = 1

        self.residual = residual

        modules = []
        modules.append(
            nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=bias))

        if (batchnorm):
            modules.append(nn.BatchNorm2d(ch_out))

        self.conv = nn.Sequential(*modules)

        if (dropout > 0):
            self.dropout = nn.Dropout2d(dropout, False)
        else:
            self.dropout = None

        self.activation = activation_func(activation)

        self.apply(init_weights)

    def forward(self, x):
        if (self.residual):
            residual = x
            x = self.conv(x)
            x += residual
            x = self.activation(x)
            x = self.pooling(x)
        else:
            x = self.conv(x)
            x = self.activation(x)

        if (self.dropout is not None):
            x = self.dropout(x)
        return x


class UpConvBlockNonParametric(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size, batchnorm=True, dropout=0., bias=True,
                 activation='relu', residual=False, double_conv=False, mode='bilinear',
                 double_conv_final_activation=None, double_conv_final_batchnorm=None,
                 ):
        super(UpConvBlockNonParametric, self).__init__()

        modules = []
        if mode in ['linear', 'bilinear', 'trilinear']:
            modules.append(nn.Upsample(scale_factor=2, mode=mode, align_corners=True))
        else:
            modules.append(nn.Upsample(scale_factor=2, mode=mode))

        if (double_conv):
            modules.append(
                DoubleConvBlock(ch_in=ch_in, ch_out=ch_out, kernel_size=kernel_size, stride=1, batchnorm=batchnorm,
                                dropout=dropout, bias=bias, activation=activation, residual=residual,
                                final_activation=double_conv_final_activation,
                                final_batchnorm=double_conv_final_batchnorm))
        else:
            modules.append(
                ConvBlock(ch_in=ch_in, ch_out=ch_out, kernel_size=kernel_size, stride=1, batchnorm=batchnorm,
                          dropout=dropout, bias=bias, activation=activation, residual=residual))

        self.up_conv = nn.Sequential(*modules)

        self.apply(init_weights)

    def forward(self, x):
        x = self.up_conv(x)
        return x


class UpConvBlockParametric(nn.Module):
    def __init__(self,
                 ch_in, ch_out, kernel_size, batchnorm=True, dropout=0., bias=True,
                 activation='relu', residual=False, double_conv=False, double_conv_final_activation=None,
                 double_conv_final_batchnorm=None,
                 ):
        super(UpConvBlockParametric, self).__init__()

        modules = []

        modules.append(nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=kernel_size, stride=2,
                                          padding=kernel_size // 2, output_padding=1, bias=True, dilation=1))

        if (double_conv):
            modules.append(
                DoubleConvBlock(ch_in=ch_out, ch_out=ch_out, kernel_size=kernel_size, stride=1,
                                batchnorm=batchnorm, dropout=dropout, bias=bias, activation=activation,
                                residual=residual, final_activation=double_conv_final_activation,
                                final_batchnorm=double_conv_final_batchnorm))
        else:
            modules.append(
                ConvBlock(ch_in=ch_out, ch_out=ch_out, kernel_size=kernel_size, stride=1, batchnorm=batchnorm,
                          dropout=dropout, bias=bias, activation=activation, residual=residual))

        self.up_conv = nn.Sequential(*modules)

        self.apply(init_weights)

    def forward(self, x):
        x = self.up_conv(x)
        return x


class LinearBlock(nn.Module):
    def __init__(self, in_features, out_features, batchnorm=True, dropout=0., bias=True,
                 activation='relu'):

        super(LinearBlock, self).__init__()

        modules = []

        modules.append(nn.Linear(in_features=in_features, out_features=out_features, bias=bias))

        if (batchnorm):
            modules.append(nn.BatchNorm1d(out_features))

        self.linear = nn.Sequential(*modules)

        if (dropout > 0):
            self.dropout = nn.Dropout(dropout, False)
        else:
            self.dropout = None

        self.activation = activation_func(activation)

        self.apply(init_weights)

    def forward(self, x):

        x = self.linear(x)
        x = self.activation(x)
        if (self.dropout is not None):
            x = self.dropout(x)
        return x


class EncoderFlexible(nn.Module):
    def __init__(self, channels, kernels, strides, activation='relu', final_activation='none', bias=True,
                 batchnorm=True, final_batchnorm=False, dropout=0., final_dropout=0., residuals=False,
                 double_conv=False, double_convs_second_activation=None):

        super(EncoderFlexible, self).__init__()
        convs = []

        for i in range(len(channels) - 2):
            _residual = False
            if residuals:
                if channels[i] == channels[i + 1]:
                    _residual = True
                else:
                    print(
                        'Layer {} is not be a residual layer (in_channels={} != out_channels={})'.format(i, channels[i],
                                                                                                         channels[
                                                                                                             i + 1]))

            if (double_conv):
                convs.append(
                    DoubleConvBlock(ch_in=channels[i], ch_out=channels[i + 1],
                                    kernel_size=kernels[i], stride=strides[i],
                                    batchnorm=batchnorm, dropout=dropout, bias=bias,
                                    activation=activation, residual=_residual,
                                    final_activation=double_convs_second_activation))
            else:
                convs.append(
                    ConvBlock(ch_in=channels[i], ch_out=channels[i + 1],
                              kernel_size=kernels[i], stride=strides[i],
                              batchnorm=batchnorm, dropout=dropout, bias=bias,
                              activation=activation, residual=_residual))

        _residual = False
        if residuals:
            if channels[-2] == channels[-1]:
                _residual = True
            else:
                print(
                    'Last layer is not a residual layer (in_channels={} != out_channels={})'.format(channels[-2],
                                                                                                    channels[-1]))

        if (double_conv):
            convs.append(DoubleConvBlock(ch_in=channels[-2], ch_out=channels[-1],
                                         kernel_size=kernels[-1], stride=strides[-1],
                                         batchnorm=final_batchnorm, dropout=final_dropout, bias=bias,
                                         activation=activation, final_activation=final_activation, residual=_residual))
        else:
            convs.append(ConvBlock(ch_in=channels[-2], ch_out=channels[-1],
                                   kernel_size=kernels[-1], stride=strides[-1],
                                   batchnorm=final_batchnorm, dropout=final_dropout, bias=bias,
                                   activation=final_activation, residual=_residual))

        self.modulelist = nn.ModuleList(convs)

        self.apply(init_weights)

    def forward(self, x):

        for i in range(len(self.modulelist)):
            x = self.modulelist[i](x)

        return x


class MLPFlexible(nn.Module):
    def __init__(self, neurons, activation="relu",
                 final_activation="none", dropout=0., bias=True, batchnorm=False, final_batchnorm=False,
                 final_dropout=0.):
        super().__init__()

        self.layer_dims = neurons

        modules = []

        for i in range(len(neurons) - 2):
            modules.append(LinearBlock(neurons[i], neurons[i + 1], batchnorm=batchnorm, dropout=dropout, bias=bias,
                                       activation=activation))

        modules.append(
            LinearBlock(neurons[-2], neurons[-1], batchnorm=final_batchnorm, dropout=final_dropout, bias=bias,
                        activation=final_activation))

        self.modulelist = nn.ModuleList(modules)

        self.apply(init_weights)

    def forward(self, x):

        for i in range(len(self.modulelist)):
            x = self.modulelist[i](x)

        return x


class DecoderFlexible(nn.Module):
    def __init__(self, encoder_channels, decoder_channels, kernels, activation='relu',
                 bias=True, parametric=True, batchnorm=True, dropout=0.,
                 residuals=False, double_conv=False,
                 ):

        try:
            assert len(encoder_channels) == len(decoder_channels) + 1 == len(kernels) + 1
        except:
            raise Exception(
                'In order to upsample to the original image resolution, the number of decoder layers must be one more than the number of encoder layers')

        super(DecoderFlexible, self).__init__()
        upconvs = []

        decoder_input_channels = [encoder_channels[-1]] + \
                                 [list(reversed(encoder_channels))[i + 1] + decoder_channels[i] for i in
                                  range(len(decoder_channels) - 1)]

        decoder_output_channels = decoder_channels

        for i in range(len(decoder_input_channels)):
            _residual = False
            if residuals:
                if decoder_input_channels[i] == decoder_output_channels[i]:
                    _residual = True
                else:
                    print(
                        'Layer {} is not be a residual layer (in_channels={} != out_channels={})'.format(
                            i, decoder_input_channels[i], decoder_output_channels[i]))

            if parametric:
                upconvs.append(
                    UpConvBlockParametric(ch_in=decoder_input_channels[i], ch_out=decoder_output_channels[i],
                                          kernel_size=kernels[i], batchnorm=batchnorm,
                                          dropout=dropout,
                                          bias=bias, activation=activation, residual=_residual,
                                          double_conv=double_conv))
            else:
                upconvs.append(UpConvBlockNonParametric(ch_in=decoder_input_channels[i],
                                                        ch_out=decoder_output_channels[i],
                                                        kernel_size=kernels[i], batchnorm=batchnorm,
                                                        dropout=dropout, bias=bias, activation=activation,
                                                        residual=_residual, double_conv=double_conv))

        self.modulelist = nn.ModuleList(upconvs)

    def forward(self, x):

        for i in range(len(self.modulelist)):
            x = self.modulelist[i](x)

        return x


class EncoderDecoder(nn.Module):
    def __init__(self, encoder_channels, decoder_channels, parametric_decoder=False, activations='relu', bias=False,
                 batchnorm=True, dropout=0., residuals=False, double_conv=False, use_1x1_conv=True, concat_rgb=True,
                 final_activation='sigmoid', final_encoder_activation='relu',
                 double_convs_second_activation=None, num_output_channels=1, post_upsampling_convs=None):

        super(EncoderDecoder, self).__init__()

        encoder_kernels = [3] * (len(encoder_channels) - 1)
        encoder_strides = [2] * (len(encoder_channels) - 1)
        decoder_kernels = [3] * (len(decoder_channels))

        self.encoder = EncoderFlexible(channels=encoder_channels, kernels=encoder_kernels,
                                       strides=encoder_strides, activation=activations,
                                       final_activation=final_encoder_activation, bias=bias,
                                       batchnorm=batchnorm, final_batchnorm=batchnorm, dropout=dropout,
                                       final_dropout=dropout, residuals=residuals,
                                       double_conv=double_conv,
                                       double_convs_second_activation=double_convs_second_activation)

        self.decoder = DecoderFlexible(encoder_channels=encoder_channels, decoder_channels=decoder_channels,
                                       kernels=decoder_kernels, activation=activations,
                                       bias=bias, parametric=parametric_decoder, batchnorm=batchnorm, dropout=0.,
                                       residuals=residuals, double_conv=double_conv, )

        self.concat_rbg = concat_rgb
        if self.concat_rbg:
            self.rgb_conv = DoubleConvBlock(ch_in=3, ch_out=decoder_channels[-1], kernel_size=3, stride=1,
                                            batchnorm=batchnorm, dropout=dropout, bias=bias,
                                            activation=activations,
                                            residual=residuals)

            self.final_rgb_conv = DoubleConvBlock(ch_in=2 * decoder_channels[-1], ch_out=decoder_channels[-1],
                                                  kernel_size=3, stride=1, batchnorm=batchnorm, dropout=dropout,
                                                  bias=bias, activation=activations, residual=residuals)

        self.post_upsampling_convs = None
        if (post_upsampling_convs is not None):
            post_upsampling_convs_list = []
            prev_channels = decoder_channels[-1]
            for post_upsampling_conv in post_upsampling_convs:
                channels, kernel_size, stride = post_upsampling_conv
                post_upsampling_convs_list.append(
                    ConvBlock(prev_channels, channels, kernel_size=kernel_size, stride=stride, batchnorm=batchnorm,
                              dropout=dropout, activation=activations, bias=bias, residual=residuals))
                prev_channels = channels
                decoder_channels.append(channels)

            self.post_upsampling_convs = nn.ModuleList(post_upsampling_convs_list)

        self.use_1x1_conv = use_1x1_conv
        if self.use_1x1_conv:
            self.conv_1x1 = nn.Conv2d(decoder_channels[-1], num_output_channels, kernel_size=1, padding=0, stride=1,
                                      bias=bias)

        self.final_activation = activation_func(final_activation)

        self.apply(init_weights)

    def forward(self, x):

        forward_features = [x]

        for i in range(len(self.encoder.modulelist)):
            x = self.encoder.modulelist[i](x)
            forward_features.append(x)

        forward_features = list(reversed(forward_features))

        x = self.decoder.modulelist[0](x)

        for i in range(1, len(self.decoder.modulelist)):
            ff = forward_features[i]
            x = torch.cat([x, ff], dim=1)
            x = self.decoder.modulelist[i](x)

        if (self.concat_rbg):
            rgb = forward_features[-1]
            rgb_features = self.rgb_conv(rgb)
            x = torch.cat([x, rgb_features], dim=1)
            x = self.final_rgb_conv(x)

        if (self.post_upsampling_convs):
            for i in range(len(self.post_upsampling_convs)):
                x = self.post_upsampling_convs[i](x)

        if (self.use_1x1_conv):
            x = self.conv_1x1(x)

        output = self.final_activation(x)

        return output, x
