import tensorflow as tf
import numpy as np
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import py_driver
from tf_agents.networks import q_network
from tf_agents.policies import py_tf_eager_policy
from tf_agents.specs import tensor_spec

from coordinator.CoordinatorEnv import CoordinatorEnv
from tf_agents.environments import utils, tf_py_environment
from tf_agents.utils import common

from tf_agents.networks import sequential
from tf_agents.train.utils import strategy_utils
from replay_buffer import init_reverb_replay_buffer
from coordinator.coordinator_eval import extensive_evaluate
from checkpoint import save_policy, restore_policy
from utils import parse_arg


def init_q_net(env):
    fc_layer_params = (64,32,)
    action_tensor_spec = tensor_spec.from_spec(env.action_spec())
    num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

    # Define a helper function to create Dense layers configured with the right
    # activation and kernel initializer.
    def dense_layer(num_units):
        return tf.keras.layers.Dense(
            num_units,
            activation=tf.keras.activations.relu,
            kernel_initializer=tf.keras.initializers.VarianceScaling(
                scale=2.0, mode='fan_in', distribution='truncated_normal'))

    emb = tf.keras.layers.Embedding(7, 20, input_length=env.players)
    mean_emb = tf.keras.layers.Lambda(lambda x: tf.keras.backend.mean(x, axis=1))
    # QNetwork consists of a sequence of Dense layers followed by a dense layer
    # with `num_actions` units to generate one q_value per available action as
    # its output.
    dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
    q_values_layer = tf.keras.layers.Dense(
        num_actions,
        activation=None,
        kernel_initializer=tf.keras.initializers.RandomUniform(
            minval=-0.03, maxval=0.03),
        bias_initializer=tf.keras.initializers.Constant(-0.2))
    q_net = sequential.Sequential([emb] + [mean_emb] + dense_layers + [q_values_layer])
    return q_net



if __name__ == "__main__":
    opt = parse_arg()
    print(opt)

    players = opt.players
    py_env = CoordinatorEnv(players)
    utils.validate_py_environment(py_env, episodes=5)

    tf_env = tf_py_environment.TFPyEnvironment(py_env)

    strategy = strategy_utils.get_strategy(False, False)
    # q_net = q_network.QNetwork(input_tensor_spec=tf_env.observation_spec(),
    #                            action_spec=tf_env.action_spec(),
    #                            fc_layer_params=(64,32))
    with strategy.scope():
        q_net = init_q_net(py_env)


        optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
        agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            optimizer=optimizer,
            target_update_period=1,
            epsilon_greedy=0.3)

        agent.initialize()

    observer, replay_buffer = init_reverb_replay_buffer(agent, f"uniform_table")

    dataset = replay_buffer.as_dataset(
        num_parallel_calls=12,
        sample_batch_size=64,
        num_steps=2).prefetch(3)

    iterator = iter(dataset)

    for it in range(15):
        print(f"ITER: {it}")
        collect_driver = py_driver.PyDriver(
            py_env,
            py_tf_eager_policy.PyTFEagerPolicy(
                agent.collect_policy, use_tf_function=True),
            [observer],
            max_steps=1000)

        time_step = py_env.reset()

        collect_driver.run(time_step)
        for _ in range(5000):
            experience, unused_info = next(iterator)
            train_loss = agent.train(experience).loss

    if opt.save_dir:
        print(f"SAVE policy to {opt.save_dir}")
        save_policy(opt.save_dir, agent.policy)
    policy = agent.policy

    print("EVAL")
    if opt.load_dir:
        print(f"LOAD policy from {opt.load_dir}")
        policy = restore_policy(opt.load_dir)
    else:
        policy = agent.policy
    extensive_evaluate(tf_env, py_env, policy, players)
