import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.layers import *
import tensorflow.keras.backend as K


def downsample(filters, size, apply_batchnorm=True):
    """
    Downsampling Block for GAN, that takes in a tensor of shape (H, W, C) and produced a new tensor of shape (0.5H, 0.5W, C)
    borrowed from https://www.tensorflow.org/tutorials/generative/pix2pix
    """
    initializer = tf.random_normal_initializer(0., 0.02)

    result = Sequential()
    result.add(Conv2D(filters, size, strides=2, padding='same',
                      kernel_initializer=initializer, use_bias=False))
    if apply_batchnorm:
        result.add(BatchNormalization())
    result.add(LeakyReLU())
    return result

"""---------------------------------------------------------------------"""

def upsample(filters, size, apply_dropout=False):
    """
    Upsampling block for GANs, that takes in a tensor of shape (H, W, C) and produced a new tensor of shape (2H, 2W, C)
    borrowed from https://www.tensorflow.org/tutorials/generative/pix2pix
    """
    
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(Conv2DTranspose(filters, size, strides=2, padding='same',
                               kernel_initializer=initializer, use_bias=False))
    result.add(BatchNormalization())
    if apply_dropout:
        result.add(Dropout(0.2))
    result.add(ReLU())

    return result

"""---------------------------------------------------------------------"""

def base_discriminator(bnorm=False):
    return Sequential([Input((56,56,1)), downsample(32, 4), downsample(128, 4, bnorm), 
                                 downsample(256, 4, bnorm), downsample(256, 4, bnorm), 
                                 downsample(256, 4, bnorm), Flatten(), Dense(256, activation="relu"),
                                 Dense(1, activation=None)])

"""---------------------------------------------------------------------"""

def base_pairwise_discriminator(bnorm=False):
    base_disc = Sequential([Input((56,56,1)), downsample(32, 4), downsample(128, 4, bnorm), 
                                 downsample(256, 4, bnorm), downsample(256, 4, bnorm), 
                                 downsample(256, 4, bnorm), Flatten()])
    
    rank_head = Dense(1, activation=None)
    advr_head = Sequential([Dense(256, activation="relu"), Dense(1, activation=None)])
    
    d_i = Input((56, 56, 1))
    d_j = Input((56, 56, 1))
    z_i = base_disc(d_i)
    z_j = base_disc(d_j)
    
    r_ij = rank_head(subtract([z_i, z_j]))
    
    a_i = advr_head(z_i)
    a_j = advr_head(z_j)
    a_ij = concatenate([a_i, a_j], 0)
    return Model([d_i, d_j], [a_ij, r_ij])
    


"""---------------------------------------------------------------------"""

def resnet_with_outputs(outputs, weights):
    resnet = ResNet50V2(include_top=False, pooling="avg", weights=weights, input_shape=(224,224,3))
    return Model(resnet.inputs, [resnet.get_layer(o).output for o in outputs])


"""---------------------------------------------------------------------"""

def multi_source_upsample(**kwargs):
    feat56, feat28, feat14, feat7 = Input([56,56,64]), Input([28,28,128]), Input([14,14,256]), Input([7,7,2048])
    
    if "spatial_dropout" in kwargs.keys():
        print("--+| Adding Spatial Dropout at rate: {}".format(kwargs["spatial_dropout"]))
        feat56_ = SpatialDropout2D(kwargs["spatial_dropout"])(feat56)
        feat28_ = SpatialDropout2D(kwargs["spatial_dropout"])(feat28)
        feat14_ = SpatialDropout2D(kwargs["spatial_dropout"])(feat14)
        feat7_ = SpatialDropout2D(kwargs["spatial_dropout"])(feat7)
    else:
        feat56_ = feat56
        feat28_ = feat28
        feat14_ = feat14
        feat7_ = feat7
    
    x = Conv2D(256, 1, padding="SAME", activation="relu")(feat7_)
    for f, feat in zip([256, 128, 64], [feat14_, feat28_, feat56_]):
        x_ = upsample(f, 4, True)(x)
        feat_ = Conv2D(f, 4, padding="SAME", activation="relu")(feat)
        x = add([feat_, x_])
    x = Conv2D(32, 4, padding="SAME", activation="relu")(x)
    if "power" in kwargs.keys():
        x = CosineAttention(kwargs["power"])(x)
        
        
    dmap_pred = Conv2D(1, 4, padding="SAME", activation="sigmoid")(x)
    return Model([feat56, feat28, feat14, feat7], dmap_pred)

"""---------------------------------------------------------------------"""

def density_map_generator(weights, **kwargs):
    embed_model = resnet_with_outputs(["conv2_block3_1_relu", "conv3_block4_1_relu", 
                                       "conv4_block6_1_relu", "post_relu"], weights)
    dmap_estimator = multi_source_upsample(**kwargs)
    img = Input([224,224,3])
    outputs = embed_model(img)
    dmap = dmap_estimator(outputs)
    return Model(img, dmap)

"""---------------------------------------------------------------------"""
class SumLayer(Layer):
    def __init__(self, axis, keepdims=False):
        super(SumLayer, self).__init__()
        self.axis = axis
        self.keepdims = keepdims

    def call(self, inputs):
        return tf.keras.backend.sum(inputs, axis=self.axis, keepdims=self.keepdims)

"""---------------------------------------------------------------------"""
def build_rank_head():
    dmap_i = Input([56, 56, 1])
    dmap_j = Input([56, 56, 1])
    c_i = SumLayer(1, True)(Flatten()(dmap_i))
    c_j = SumLayer(1, True)(Flatten()(dmap_j))
    
    diff_ij = subtract([c_i, c_j])
    logit = Dense(1, activation=None)(diff_ij)

    return Model([dmap_i, dmap_j], logit)

