from env import ImageEnv
from rl.dqn import ImagePoserDQN
from utils.config import Config
from stable_baselines3.common.callbacks import CheckpointCallback
import wandb

config = Config(
    dataset_path="./data/train.jsonl",
    notes="DQN training",
)
config.logger.info(config)

wandb.init(
    project=config.wandb_project,
    id=config.experiment_name,
    name=config.wandb_name,
    dir=config.wandb_dir,
    config=config.to_dict(),
    sync_tensorboard=True, # Syncs TensorBoard logs to W&B
)

env = ImageEnv(config)

model = ImagePoserDQN(env=env, config=config)
checkpoint_callback = CheckpointCallback(save_freq=config.log_interval, save_path=config.checkpoint_dir, name_prefix="dqn_model", save_replay_buffer=True)

model.learn(total_timesteps=1000, log_interval=config.log_interval, callback=checkpoint_callback)

model.save(config.experiment_dir + "/model")

wandb.finish()

