import tensorflow as tf

class DefaultTemplate(dict):
    def __init__(self, config):
        self["algorithm_scope"] = config.get("algorithm_scope", "PPO")
        self["value_head_count"] = config.get("value_head_count", 1)

        self["use_rnn"] = config.get("use_rnn", False)
        if self["use_rnn"]:
            self["rnn_sequence_length"] = config.get("rnn_sequence_length", 8)
            self["rnn_burn_in_length"] = config.get("rnn_burn_in_length", 2)
            self["memory_size"] = config.get("memory_size", 512)        
        branch_count = 1

        def create_tensorflow_variables():
            if self["use_rnn"]:
                self["rnn_sequence_length_placeholder"] = config.get("rnn_sequence_length_placeholder", tf.placeholder(shape=None, dtype=tf.int32, name="rnn_sequence_length_placeholder"))
            self["learning_rate_placeholder"] = config.get("learning_rate_placeholder", tf.placeholder(dtype=tf.float32, name="learning_rate_placeholder"))
            self["clip_epsilon_placeholder"] = config.get("clip_epsilon_placeholder", tf.placeholder(dtype=tf.float32, name="clip_epsilon_placeholder"))
            self["entropy_coefficient_placeholder"] = config.get("entropy_coefficient_placeholder", tf.placeholder(dtype=tf.float32, name="entropy_coefficient_placeholder")) 
            self["value_coefficient_placeholder"] = config.get("value_coefficient_placeholder", tf.placeholder(shape=[self["value_head_count"]], dtype=tf.float32, name="value_coefficient_placeholder")) 
            self["value_clip_range_placeholder"] = config.get("value_clip_range_placeholder", tf.placeholder(shape=[self["value_head_count"]], dtype=tf.float32, name="value_clip_range_placeholder")) 
            self["optimizer"] = config.get("optimizer", tf.keras.optimizers.Adam(self["learning_rate_placeholder"]))
        
            observation_2d_placeholder = tf.placeholder(shape = [None, 4, 84, 84, 1], dtype = tf.float32, name = "observation_2d_placeholder")
            observation_2d_input = tf.transpose(observation_2d_placeholder, [0,2,3,4,1])
            observation_2d_input = tf.reshape(observation_2d_input, [-1, 84, 84, 4])

            self["observation_placeholders"] = config.get("observation_placeholders", {
                "observation_2d" : observation_2d_placeholder
            })   
            self["observation_inputs"] = config.get("observation_inputs", {
                "observation_2d_input" : observation_2d_input
            })   

            if self["use_rnn"]:
                observation_previous_action_placeholder = tf.placeholder(shape =  [None, branch_count], dtype = tf.int32, name = "observation_previous_action_placeholder")
                observation_previous_reward_placeholder = tf.placeholder(shape =  [None], dtype = tf.float32, name = "observation_previous_reward_placeholder")
                observation_previous_reward_input = tf.reshape(observation_previous_reward_placeholder, [-1, 1])
                observation_memory_placeholder = tf.placeholder(shape =  [None, self["memory_size"]], dtype = tf.float32, name = "observation_memory_placeholder")

                self["observation_placeholders"]["observation_previous_action"] = observation_previous_action_placeholder
                self["observation_placeholders"]["observation_previous_reward"] = observation_previous_reward_placeholder
                self["observation_placeholders"]["observation_memory"] = observation_memory_placeholder

                self["observation_inputs"]["observation_previous_action_input"] = observation_previous_action_placeholder
                self["observation_inputs"]["observation_previous_reward_input"] = observation_previous_reward_input
                self["observation_inputs"]["observation_memory_input"] = observation_memory_placeholder
        
        self["create_tensorflow_variables"] = create_tensorflow_variables
        self["model_define_path"] = config.get("model_define_path", "decision_model.ppo.nn_model.atari.policy_model")
        self["max_gradient_norm"] = config.get("max_gradient_norm", 10)

        def invertible_value_function(x, is_inverse):
            return x
            epsilon = 0.001
            if is_inverse:
                return tf.math.sign(x) * (tf.square((tf.math.sqrt(1 + 4 * epsilon * (tf.math.abs(x) + 1 + epsilon)) - 1) / (2 * epsilon)) - 1)
            else:
                return tf.math.sign(x) * (tf.math.sqrt(tf.math.abs(x) + 1) - 1) + epsilon * x
        self["invertible_value_functions"] = config.get("invertible_value_functions", [invertible_value_function] * self["value_head_count"])

        super().__init__(config)