import tensorflow as tf
from keras.layers import Layer, Conv2D, Conv2DTranspose
from keras import Model


class GeoChannel2D(Layer):
    def __init__(self, **kwargs):
        super(GeoChannel2D, self).__init__()
        self.shift = kwargs.pop("shift", True)
        self.coordinate = None

    def build(self, input_shape):
        n_y = input_shape[1]
        n_x = input_shape[2]

        # creates the framework for horizontal and vertical coordinates
        horizontal = tf.reshape(tf.range(0, n_x, 1, dtype=self.dtype) / n_x, (1, 1, n_x, 1))
        vertical = tf.reshape(tf.range(0, n_y, 1, dtype=self.dtype) / n_y, (1, n_y, 1, 1))

        self.coordinate = horizontal + vertical

    def call(self, inputs, *args, **kwargs):
        if self.shift:
            # random shift
            rand_shift = tf.random.uniform((tf.shape(inputs)[0], 1, 1, 1), 0., 2., dtype=self.dtype)
            return tf.concat(
                [inputs, tf.subtract(self.coordinate, rand_shift)],
                axis=-1
            )
        else:
            return tf.concat(
                [inputs, tf.broadcast_to(self.coordinate, (tf.shape(inputs)[0], *self.coordinate.shape[1:]))],
                axis=-1
            )


class GeoConv2D(Model):
    def __init__(
            self,
            filters,
            kernel_size,
            strides=(1, 1),
            padding='valid',
            data_format=None,
            dilation_rate=(1, 1),
            groups=1,
            activation=None,
            use_bias=True,
            kernel_initializer='glorot_uniform',
            bias_initializer='zeros',
            kernel_regularizer=None,
            bias_regularizer=None,
            activity_regularizer=None,
            kernel_constraint=None,
            bias_constraint=None,
            shift=True,
            **kwargs
    ):
        super(GeoConv2D, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.data_format = data_format
        self.dilation_rate = dilation_rate
        self.groups = groups
        self.activation = activation
        self.use_bias = use_bias
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer
        self.kernel_regularizer = kernel_regularizer
        self.bias_regularizer = bias_regularizer
        self.activity_regularizer = activity_regularizer
        self.kernel_constraint = kernel_constraint
        self.bias_constraint = bias_constraint
        self.shift = shift
        self.geo_channel_2d = GeoChannel2D(shift=shift, **kwargs)
        self.conv_2d = Conv2D(
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_format=data_format,
            dilation_rate=dilation_rate,
            groups=groups,
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs
        )

    def call(self, inputs, **kwargs):
        return self.conv_2d(self.geo_channel_2d(inputs), **kwargs)

    def get_config(self):
        config = super(GeoConv2D, self).get_config()
        config.update({
            "filters": self.filters,
            "kernel_size": self.kernel_size,
            "strides": self.strides,
            "padding": self.padding,
            "data_format": self.data_format,
            "dilation_rate": self.dilation_rate,
            "groups": self.groups,
            "activation": self.activation,
            "use_bias": self.use_bias,
            "kernel_initializer": self.kernel_initializer,
            "bias_initializer": self.bias_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "bias_regularizer": self.bias_regularizer,
            "activity_regularizer": self.activity_regularizer,
            "kernel_constraint": self.kernel_constraint,
            "bias_constraint": self.bias_constraint,
            "shift": self.shift
        })
        return config


class GeoConv2DTranspose(Model):
    def __init__(
            self,
            filters,
            kernel_size,
            strides=(1, 1),
            padding='valid',
            output_padding=None,
            data_format=None,
            dilation_rate=(1, 1),
            activation=None,
            use_bias=True,
            kernel_initializer='glorot_uniform',
            bias_initializer='zeros',
            kernel_regularizer=None,
            bias_regularizer=None,
            activity_regularizer=None,
            kernel_constraint=None,
            bias_constraint=None,
            shift=True,
            **kwargs
    ):
        super(GeoConv2DTranspose, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.output_padding = output_padding
        self.data_format = data_format
        self.dilation_rate = dilation_rate
        self.activation = activation
        self.use_bias = use_bias
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer
        self.kernel_regularizer = kernel_regularizer
        self.bias_regularizer = bias_regularizer
        self.activity_regularizer = activity_regularizer
        self.kernel_constraint = kernel_constraint
        self.bias_constraint = bias_constraint
        self.shift = shift
        self.geo_channel_2d = GeoChannel2D(shift=shift, **kwargs)
        self.conv_2d_transpose = Conv2DTranspose(
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            output_padding=output_padding,
            data_format=data_format,
            dilation_rate=dilation_rate,
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs
        )

    def call(self, inputs, **kwargs):
        return self.conv_2d_transpose(self.geo_channel_2d(inputs), **kwargs)

    def get_config(self):
        config = super(GeoConv2DTranspose, self).get_config()
        config.update({
            "filters": self.filters,
            "kernel_size": self.kernel_size,
            "strides": self.strides,
            "padding": self.padding,
            "output_padding": self.output_padding,
            "data_format": self.data_format,
            "dilation_rate": self.dilation_rate,
            "activation": self.activation,
            "use_bias": self.use_bias,
            "kernel_initializer": self.kernel_initializer,
            "bias_initializer": self.bias_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "bias_regularizer": self.bias_regularizer,
            "activity_regularizer": self.activity_regularizer,
            "kernel_constraint": self.kernel_constraint,
            "bias_constraint": self.bias_constraint,
            "shift": self.shift
        })
        return config
