import tensorflow as tf

from config import LEARNING_RATE, LOSS_FN, EPSILON

from tf_agents.utils import common

from custom_dqn_agent import DqnAgent
from utils import log


def init_agent(tf_env, q_net):
    this_func = "init_agent"

    train_step_counter = tf.Variable(0)

    optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
    if LOSS_FN == "element_wise_squared_loss":
        loss_fn = common.element_wise_squared_loss
    elif LOSS_FN == "element_wise_huber_loss":
        loss_fn = common.element_wise_huber_loss
    agent = DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        target_update_period=10,
        epsilon_greedy=EPSILON,
        td_errors_loss_fn=loss_fn,
        train_step_counter=train_step_counter,
    )

    agent.initialize()
    log(this_func, f"{agent.collect_data_spec}")
    return agent
