from algorithms.utils.types import TensorflowModel
from algorithms.utils.params import Params


def prediction_net(input_shape: tuple = (4, 1, 8),
                   num_actions: int = 8,
                   num_hidden_weights: int = 64,
                   num_hidden_layers: int = 2,
                   hidden_layer_activation: str = 'selu',
                   hidden_layer_initializer: str = 'lecun_normal',
                   value_activation: str = 'tanh',
                   value_initializer: str = 'glorot_normal',
                   policy_initializer: str = 'glorot_normal',
                   use_batchnorm=False) -> TensorflowModel:
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    import tensorflow as tf
    inputs = tf.keras.Input(shape=input_shape, name='input')
    x = inputs
    x = tf.keras.layers.Flatten()(x)

    for i in range(num_hidden_layers):
        x = tf.keras.layers.Dense(num_hidden_weights,
                                  kernel_initializer=hidden_layer_initializer,
                                  name='prediction_hidden' + str(i))(x)
        if use_batchnorm:
            x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(hidden_layer_activation)(x)

    policy_head = tf.keras.layers.Dense(num_actions,
                                        kernel_initializer=policy_initializer,
                                        name='choice_policy')(x)
    value_head = tf.keras.layers.Dense(1,
                                       activation=value_activation,
                                       kernel_initializer=value_initializer,
                                       name='value')(x)
    return tf.keras.Model(inputs=inputs, outputs=[policy_head, value_head])


def build_dense(params: Params) -> TensorflowModel:
    model = prediction_net(input_shape=params.pred_input_shape,
                           num_actions=params.num_actions,
                           num_hidden_weights=params.num_hidden_weights,
                           hidden_layer_activation=params.hidden_layer_activation,
                           hidden_layer_initializer=params.hidden_layer_initializer,
                           num_hidden_layers=params.num_hidden_layers,
                           use_batchnorm=params.use_batchnorm)
    return model
