from typing import Tuple
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf

from algorithms.utils.types import TensorflowModel


def representation_net(input_shape: Tuple[int, int, int] = (4, 1, 8),
                       output_filters: int = 4,
                       num_hidden_weights: int = 64,
                       num_hidden_layers: int = 2,
                       hidden_activation: str = 'relu',
                       output_activation: str = 'linear',
                       use_batchnorm: bool = True) -> TensorflowModel:
    """
    Creates a keras model for the prediction function of MuZero.
    Transforms a state feature representation into a hidden state.
    Returns:
        a tf.keras.Model
    """
    inputs = tf.keras.Input(shape=input_shape, name='input')
    x = inputs

    for i in range(num_hidden_layers):
        x = tf.keras.layers.Conv2D(num_hidden_weights,
                                   kernel_size=3,
                                   padding='same',
                                   strides=1,
                                   data_format='channels_first',
                                   kernel_initializer='he_uniform',
                                   name='representation_hidden_' + str(i))(x)

        if use_batchnorm:
            x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(hidden_activation)(x)

    x = tf.keras.layers.Conv2D(output_filters,
                               kernel_size=1,
                               padding='same',
                               strides=1,
                               data_format='channels_first',
                               kernel_initializer='he_uniform',
                               name='representation_out')(x)
    if use_batchnorm:
        x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation(output_activation)(x)

    return tf.keras.Model(inputs=inputs, outputs=x)


def prediction_net(input_shape: Tuple[int, int, int] = (4, 1, 8),
                   num_classes: int = 8,
                   num_hidden_weights=64,
                   num_hidden_layers=2,
                   hidden_activation='selu',
                   value_activation='tanh',
                   value_kernel_initializer='glorot_normal',
                   policy_kernel_initializer='glorot_normal',
                   use_batchnorm=True):
    inputs = tf.keras.Input(shape=input_shape, name='input')
    x = inputs

    for i in range(num_hidden_layers):
        x = tf.keras.layers.Conv2D(num_hidden_weights,
                                   kernel_size=3,
                                   padding='same',
                                   strides=1,
                                   data_format='channels_first',
                                   kernel_initializer='he_uniform',
                                   name='prediction_hidden_' + str(i))(x)

        if use_batchnorm:
            x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(hidden_activation)(x)

    x = tf.keras.layers.Flatten()(x)

    value_head = tf.keras.layers.Dense(1,
                                       activation=value_activation,
                                       kernel_initializer=value_kernel_initializer,
                                       name='prediction_value')(x)
    policy_head = tf.keras.layers.Dense(num_classes,
                                        kernel_initializer=policy_kernel_initializer,
                                        name='prediction_policy')(x)
    return tf.keras.Model(inputs=inputs, outputs=[value_head, policy_head])


def chance_net(input_shape=None,
               num_classes=None,
               num_hidden_weights=64,
               num_hidden_layers=2,
               hidden_activation='relu',
               policy_kernel_initializer='glorot_normal',
               use_batchnorm=True):
    inputs = tf.keras.Input(shape=input_shape, name='input')
    x = inputs

    for i in range(num_hidden_layers):
        x = tf.keras.layers.Conv2D(num_hidden_weights,
                                   kernel_size=3,
                                   padding='same',
                                   strides=1,
                                   data_format='channels_first',
                                   kernel_initializer='he_uniform',
                                   name='prediction_hidden_' + str(i))(x)

        if use_batchnorm:
            x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(hidden_activation)(x)

    x = tf.keras.layers.Flatten()(x)
    policy_head = tf.keras.layers.Dense(num_classes,
                                        kernel_initializer=policy_kernel_initializer,
                                        name='chance_policy')(x)
    return tf.keras.Model(inputs=inputs, outputs=[policy_head])


def dynamics_net(input_shape=None,
                 output_filters=4,
                 num_hidden_layers=2,
                 num_hidden_weights=64,
                 hidden_activation='relu',
                 next_state_activation='linear',
                 reward_activation='tanh',
                 reward_initializer='glorot_normal',
                 use_batchnorm=True):
    inputs = tf.keras.Input(shape=input_shape, name='input')
    x = inputs

    for i in range(num_hidden_layers):
        x = tf.keras.layers.Conv2D(num_hidden_weights,
                                   kernel_size=3,
                                   padding='same',
                                   strides=1,
                                   data_format='channels_first',
                                   kernel_initializer='he_uniform',
                                   name='dynamics_hidden_' + str(i))(x)

        if use_batchnorm:
            x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(hidden_activation)(x)

    next_state_head = tf.keras.layers.Conv2D(output_filters,
                                             kernel_size=1,
                                             padding='same',
                                             strides=1,
                                             data_format='channels_first',
                                             kernel_initializer='he_uniform',
                                             name='dynamics_next_state' + str(i))(x)
    next_state_head = tf.keras.layers.BatchNormalization()(next_state_head)
    next_state_head = tf.keras.layers.Activation(next_state_activation)(next_state_head)
    x = tf.keras.layers.Flatten()(x)
    reward_head = tf.keras.layers.Dense(1,
                                        activation=reward_activation,
                                        kernel_initializer=reward_initializer,
                                        name='dynamics_reward')(x)

    return tf.keras.Model(inputs=inputs, outputs=[next_state_head, reward_head])


def build_conv(parameters, num_actions):
    representation_input_shape = (4, 1, parameters['n_points'] + 2)
    representation_output_shape = representation_input_shape
    prediction_input_shape = representation_input_shape
    chance_input_shape = representation_input_shape
    dynamics_input_shape = (5, 1, parameters['n_points'] + 2)
    num_hidden_weights = 256
    num_hidden_layers = 5

    hidden_activation = 'tanh'

    representation_output_activation = 'tanh'

    dynamics_next_state_activation = 'tanh'

    use_batchnorm = True

    rep_net = representation_net(input_shape=representation_input_shape,
                                 num_hidden_weights=num_hidden_weights,
                                 hidden_activation=hidden_activation,
                                 num_hidden_layers=num_hidden_layers,
                                 output_activation=representation_output_activation,
                                 use_batchnorm=use_batchnorm)

    pred_net = prediction_net(input_shape=prediction_input_shape,
                              num_classes=num_actions,
                              num_hidden_weights=num_hidden_weights,
                              hidden_activation=hidden_activation,
                              num_hidden_layers=num_hidden_layers,
                              use_batchnorm=use_batchnorm)

    cha_net = chance_net(input_shape=chance_input_shape,
                         num_classes=num_actions,
                         num_hidden_weights=num_hidden_weights,
                         hidden_activation=hidden_activation,
                         num_hidden_layers=num_hidden_layers,
                         use_batchnorm=use_batchnorm)

    dyn_net = dynamics_net(input_shape=dynamics_input_shape,
                           num_hidden_weights=num_hidden_weights,
                           hidden_activation=hidden_activation,
                           num_hidden_layers=num_hidden_layers,
                           next_state_activation=dynamics_next_state_activation,
                           use_batchnorm=use_batchnorm)

    return rep_net, pred_net, cha_net, dyn_net
