from dqn import DQN, HomoDQNLearnt
import wandb
from config import get_config
import numpy as np

config = get_config()

wandb.init(project="DQN cartpole", config=config)

if config["using_symmetry"] == False:
    dqn = DQN(config)
elif config["using_symmetry"] == True:
    if config["learned_symmetry"] == True:
        dqn = HomoDQNLearnt(config)

dqn.learn_q_network_weights()
total_steps = 0
for episode, episode_length in enumerate(dqn.episode_lengths):
    total_steps += episode_length
    wandb.log({"episode length": episode_length, "environment steps": total_steps})
