import tensorflow as tf
import numpy as np
import tensorflow_probability as tfp
tfd = tfp.distributions
from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization
from tensorflow.keras.layers import Dense, AveragePooling2D, LeakyReLU, Flatten

from models.base import BaseModel

class ResBlock(Layer):
    def __init__(self, outer_dim, inner_dim):
        super(ResBlock, self).__init__(name='')

        self.bn2a = BatchNormalization()
        self.conv2a = Conv2D(inner_dim, 1)

        self.bn2b = BatchNormalization()
        self.conv2b = Conv2D(inner_dim, 3, padding='same')

        self.bn2c = BatchNormalization()
        self.conv2c = Conv2D(outer_dim, 1)

    def call(self, input_tensor):
        x = self.bn2a(input_tensor)
        x = tf.nn.leaky_relu(x)
        x = self.conv2a(x)

        x = self.bn2b(x)
        x = tf.nn.leaky_relu(x)
        x = self.conv2b(x)

        x = self.bn2c(x)
        x = tf.nn.leaky_relu(x)
        x = self.conv2c(x)

        return x + input_tensor

def get_classifier(image_size, dim, num_classes):
    num_downs = int(np.log2(image_size))
    layers = []
    for _ in range(num_downs):
        layers.extend([
            Conv2D(dim, 1),
            ResBlock(dim, dim//2),
            AveragePooling2D(2)
        ])
        dim *= 2
    layers.extend([
        Flatten(),
        Dense(dim),
        LeakyReLU(),
        Dense(num_classes)
    ])
    
    return tf.keras.Sequential(layers)


class Classifier(BaseModel):
    def __init__(self, hps):
        super(Classifier, self).__init__(hps)

    def build_net(self):
        with tf.variable_scope('classifier', reuse=tf.AUTO_REUSE):
            self.x = tf.placeholder(tf.float32, [None]+self.hps.image_shape)
            self.y = tf.placeholder(tf.int64, [None])
            self.b = tf.placeholder(tf.float32, [None]+self.hps.image_shape)
            self.is_training = tf.placeholder(tf.bool)

            x = (self.x / 255. - 0.5) * 2.0

            image_size = self.hps.image_shape[0]
            image_channels = self.hps.image_shape[-1]
            num_classes = self.hps.num_classes
            dim = self.hps.dim

            classifier = get_classifier(image_size, dim, num_classes)

            inp = tf.concat([x*self.b, self.b], axis=-1)
            self.logits = classifier(inp, is_training=self.is_training)
            self.prob = tf.nn.softmax(self.logits)
            self.ent = tfd.Categorical(logits=self.logits).entropy()
            self.pred = tf.argmax(self.logits, axis=1)
            self.acc = tf.cast(tf.equal(self.pred, self.y), tf.float32)
            self.xent = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.y)

            # metric
            b = tf.ones_like(self.b)
            inp = tf.concat([x*b, b], axis=-1)
            logits = classifier(inp, is_training=self.is_training)
            pred = tf.argmax(logits, axis=1)
            acc = tf.cast(tf.equal(pred, self.y), tf.float32)
            self.metric = acc

            # loss
            self.loss = tf.reduce_mean(self.xent)
            tf.summary.scalar('loss', self.loss)