from functools import reduce
import pickle, os
from typing import Tuple

from algorithms.utils.types import TensorflowModel, RepresentationModel, PredictionModel, DynamicsModel
from algorithms.utils.params import Params


def representation_net(input_shape: Tuple[int, int, int] = (4, 1, 8),
                       output_shape: Tuple[int, int, int] = (4, 1, 8),
                       num_hidden_weights: int = 64,
                       num_hidden_layers: int = 2,
                       num_players: int = 4,
                       hidden_layer_activation: str = 'selu',
                       hidden_layer_initializer: str = 'lecun_normal',
                       hidden_state_activation: str = 'linear',
                       hidden_state_initializer: str = 'glorot_normal',
                       tau_policy_initializer: str = 'glorot_normal',
                       use_batchnorm: bool = False) -> RepresentationModel:
    """
    Creates a keras model for "h", the representation function of MuZero.
    Transforms a state feature representation into a hidden state.
    Returns:
        a tf.keras.Model
    """
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    import tensorflow as tf
    inputs = tf.keras.Input(shape=input_shape, name='input')
    x = inputs
    x = tf.keras.layers.Flatten()(x)

    for i in range(num_hidden_layers):
        x = tf.keras.layers.Dense(num_hidden_weights,
                                  kernel_initializer=hidden_layer_initializer,
                                  name='representation_hidden_' + str(i))(x)
        if use_batchnorm:
            x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(hidden_layer_activation)(x)

    n_output = reduce(lambda a, b: a * b, output_shape, 1)
    hidden_state = tf.keras.layers.Dense(n_output,
                                         activation=hidden_state_activation,
                                         kernel_initializer=hidden_state_initializer,
                                         name='hidden_state')(x)
    hidden_state_head = tf.keras.layers.Reshape(output_shape)(hidden_state)
    tau_policy_head = tf.keras.layers.Dense(num_players,
                                            kernel_initializer=tau_policy_initializer,
                                            name='tau_policy')(x)
    rep_net = tf.keras.Model(inputs=inputs, outputs=[hidden_state_head, tau_policy_head])  # type: RepresentationModel
    return rep_net


def prediction_net(input_shape: Tuple[int, int, int] = (4, 1, 8),
                   num_actions: int = 8,
                   num_hidden_weights: int = 64,
                   num_hidden_layers: int = 2,
                   hidden_layer_activation: str = 'selu',
                   hidden_layer_initializer: str = 'lecun_normal',
                   value_activation: str = 'tanh',
                   value_initializer: str = 'glorot_normal',
                   choice_policy_initializer: str = 'glorot_normal',
                   chance_policy_initializer: str = 'glorot_normal',
                   use_batchnorm=False) -> PredictionModel:
    """
    Creates a keras model for "f", the prediction function of MuZero.
    Three output heads for 1. choice player policy, 2. chance player policy, 3. value estimate
    Returns:
        a tf.keras.Model
    """
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    import tensorflow as tf
    inputs = tf.keras.Input(shape=input_shape, name='input')
    x = inputs
    x = tf.keras.layers.Flatten()(x)

    for i in range(num_hidden_layers):
        x = tf.keras.layers.Dense(num_hidden_weights,
                                  kernel_initializer=hidden_layer_initializer,
                                  name='prediction_hidden' + str(i))(x)
        if use_batchnorm:
            x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(hidden_layer_activation)(x)

    choice_policy_head = tf.keras.layers.Dense(num_actions,
                                               kernel_initializer=choice_policy_initializer,
                                               name='choice_policy')(x)
    chance_policy_head = tf.keras.layers.Dense(num_actions,
                                               kernel_initializer=chance_policy_initializer,
                                               name='chance_policy')(x)
    value_head = tf.keras.layers.Dense(1,
                                       activation=value_activation,
                                       kernel_initializer=value_initializer,
                                       name='value')(x)
    pred_net = tf.keras.Model(inputs=inputs, outputs=[choice_policy_head, chance_policy_head, value_head])  # type: PredictionModel
    return pred_net


def dynamics_net(input_shape: Tuple[int, int, int] = (5, 1, 8),
                 output_shape: Tuple[int, int, int] = (4, 1, 8),
                 num_hidden_layers: int = 2,
                 num_players: int = 4,
                 num_hidden_weights: int = 64,
                 hidden_layer_activation: str = 'selu',
                 hidden_layer_initializer: str = 'lecun_normal',
                 next_state_initializer: str = 'glorot_normal',
                 next_state_activation: str = 'linear',
                 tau_policy_initializer: str = 'glorot_normal',
                 use_batchnorm=False) -> DynamicsModel:
    """
    Creates a keras model for "g", the dynamics function of MuZero.
    We omit the reward head in this case of zero-sum games with terminal rewards.
    Two output heads for 1. next hidden state, 2. the player identity (tau) policy
    The tau policy corresponds to the networks belief of which player is to play in that state.
    Returns:
        a tf.keras.Model
    """
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    import tensorflow as tf
    inputs = tf.keras.Input(shape=input_shape, name='input')
    x = inputs
    x = tf.keras.layers.Flatten()(x)

    for i in range(num_hidden_layers):
        x = tf.keras.layers.Dense(num_hidden_weights,
                                  kernel_initializer=hidden_layer_initializer,
                                  name='dynamics_hidden_' + str(i))(x)
        if use_batchnorm:
            x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(hidden_layer_activation)(x)

    n_output = reduce(lambda a, b: a * b, output_shape, 1)
    next_hidden_state = tf.keras.layers.Dense(n_output,
                                              activation=next_state_activation,
                                              kernel_initializer=next_state_initializer,
                                              name='dynamics_next_state')(x)
    next_hidden_state_head = tf.keras.layers.Reshape(output_shape)(next_hidden_state)

    tau_policy_head = tf.keras.layers.Dense(num_players,
                                            kernel_initializer=tau_policy_initializer,
                                            name='tau_policy')(x)
    dyn_net = tf.keras.Model(inputs=inputs, outputs=[next_hidden_state_head, tau_policy_head])  # type: DynamicsModel
    return dyn_net


def build_dense(params: Params) -> Tuple[RepresentationModel, PredictionModel, DynamicsModel]:
    rep_net = representation_net(input_shape=params.rep_input_shape,
                                 output_shape=params.rep_output_shape,
                                 num_hidden_weights=params.num_hidden_weights,
                                 hidden_layer_activation=params.hidden_layer_activation,
                                 hidden_layer_initializer=params.hidden_layer_initializer,
                                 num_hidden_layers=params.num_hidden_layers,
                                 hidden_state_activation=params.hidden_state_activation,
                                 hidden_state_initializer=params.hidden_state_initializer,
                                 use_batchnorm=params.use_batchnorm)

    pred_net = prediction_net(input_shape=params.pred_input_shape,
                              num_actions=params.num_actions,
                              num_hidden_weights=params.num_hidden_weights,
                              hidden_layer_activation=params.hidden_layer_activation,
                              hidden_layer_initializer=params.hidden_layer_initializer,
                              num_hidden_layers=params.num_hidden_layers,
                              use_batchnorm=params.use_batchnorm)

    dyn_net = dynamics_net(input_shape=params.dyn_input_shape,
                           output_shape=params.dyn_output_shape,
                           num_hidden_weights=params.num_hidden_weights,
                           hidden_layer_activation=params.hidden_layer_activation,
                           hidden_layer_initializer=params.hidden_layer_initializer,
                           num_hidden_layers=params.num_hidden_layers,
                           next_state_activation=params.hidden_state_activation,
                           next_state_initializer=params.hidden_state_initializer,
                           use_batchnorm=params.use_batchnorm)
    if params.load_weights_name != 'None':
        fpath = os.path.join(os.getcwd(), params.load_weights_name)
        print('Loading weights from', fpath)
        rep_weights, pred_weights, dyn_weights = pickle.load(open(fpath, 'rb'))
        rep_net.set_weights(rep_weights)
        pred_net.set_weights(pred_weights)
        dyn_net.set_weights(dyn_weights)
        print('Sucessfully loaded weights')

    return rep_net, pred_net, dyn_net
