import tensorflow as tf
import numpy as np

from tgan_sr.transformer.base import TransformerEncoder
from tgan_sr.transformer import positional_encoding


class TransformerCritic(tf.keras.layers.Layer):
    def __init__(self, params, sigmoid=False):
        super().__init__()
        params = params.copy()
        self.params = params
        self.transformer_encoder = TransformerEncoder(params)
        if params['gan_critic_class_layers'] > 0:
            class_enc_params = params.copy()
            class_enc_params['num_layers'] = params['gan_critic_class_layers']
            self.class_encoder = TransformerEncoder(class_enc_params)
        else:
            self.class_encoder = None
        if params['gan_critic_critic_layers'] > 0:
            critic_enc_params = params.copy()
            critic_enc_params['num_layers'] = params['gan_critic_critic_layers']
            self.critic_encoder = TransformerEncoder(class_enc_params)
        else:
            self.critic_encoder = None
        self.final_projection = tf.keras.layers.Dense(3, activation=('sigmoid' if sigmoid else None))
        self.embed = tf.keras.layers.Dense(params['d_embed_enc'], kernel_initializer='uniform') # identical
        self.stdpe = positional_encoding.positional_encoding(params['max_encode_length'], params['d_embed_enc'], dtype=params['dtype'])
        self.dropout = tf.keras.layers.Dropout(params['dropout'])

    def call(self, x, positive_mask, training=False):
        batch_size, seq_len, d_v = tf.shape(x)
        padding_mask = tf.reshape(tf.cast(tf.logical_not(positive_mask), tf.float32), [batch_size, 1, 1, seq_len])
        pe = self.stdpe[:, :seq_len, :]

        x = self.embed(x)
        x *= tf.math.sqrt(tf.cast(self.params['d_embed_enc'], self.params['dtype']))
        x += pe
        x = self.dropout(x, training=training)

        x = self.transformer_encoder(x, padding_mask, training=training)
        if self.class_encoder is not None:
            x_class = self.class_encoder(x, padding_mask, training=training)
        else:
            x_class = x
        if self.critic_encoder is not None:
            x_critic = self.critic_encoder(x, padding_mask, training=training)
        else:
            x_critic = x
        y_class = self.final_projection(x_class)
        y_critic = self.final_projection(x_critic)
        y = tf.concat([y_class[:, :, :1], y_critic[:, :, 1:]], axis=-1)
        y = tf.reduce_mean(y, axis=1)
        return y


class NullProjector(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
    def call(self, whatever, positive_mask, x, training=False):
        raise NotImplementedError
    def infer(self, softs, positive_mask, training=False):
        w = tf.cast(positive_mask, tf.float32)
        ent_per_pos = -tf.reduce_sum(softs * tf.math.log(softs), axis=-1)
        soft_entropy = tf.reduce_mean(tf.reduce_sum(ent_per_pos * w / tf.reduce_sum(w, axis=1, keepdims=True), axis=1))
        return tf.argmax(softs, axis=-1), None, soft_entropy



class TransformerGenerator(tf.keras.layers.Layer):
    def __init__(self, params, proc_logits_fn=None):
        super().__init__()
        params = params.copy()
        params['num_layers'] = params['gan_generator_layers']
        self.params = params
        self.proc_logits_fn = proc_logits_fn
        self.transformer_encoder = TransformerEncoder(params)
        self.stdpe = positional_encoding.positional_encoding(params['max_encode_length'], params['d_embed_enc'], dtype=params['dtype'])
        self.dropout = tf.keras.layers.Dropout(params['dropout'])
        self.embedz = tf.keras.layers.Dense(params['d_embed_enc'], kernel_initializer='glorot_uniform')
        self.final_proj = tf.keras.layers.Dense(params['input_vocab_size'])

    def call(self, z, positive_mask, training=False):
        batch_size, seq_len, d_z = tf.shape(z)
        padding_mask = tf.reshape(tf.cast(tf.logical_not(positive_mask), tf.float32), [batch_size, 1, 1, seq_len])

        x = self.embedz(z)
        x *= tf.math.sqrt(tf.cast(self.params['d_embed_enc'], self.params['dtype']))
        pe = self.stdpe[:, :seq_len, :]

        x += pe
        x = self.dropout(x, training=True) # keep

        x = self.transformer_encoder(x, padding_mask, training=True) # keep
        y = self.final_proj(x)
        y = self.proc_logits_fn(y)

        return y