from env import ImageEnv
from rl.dqn import ImagePoserDQN
from utils.config import Config
import wandb
import json
import os

data_path = "./data/gen_test.jsonl" # can be changed to any other jsonl file
USER = os.getenv("USER")
# CHANGE THESE VARIABLES TO LOAD A DIFFERENT CHECKPOINT
STEPS = 140
EXPERIMENT_NAME = "run_20250906-141640" 

with open(data_path, "r") as f:
    len_data = sum(1 for _ in f)

config = Config(
    dataset_path=data_path,
    eval_mode=True,
    notes=f"DQN inference on {data_path} testing",
)
config.logger.info(config)

wandb.init(
    project=config.wandb_project,
    id=config.experiment_name,
    name=config.wandb_name,
    dir=config.wandb_dir,
    config=config.to_dict(),
)
env = ImageEnv(config)


model = ImagePoserDQN.load(
    path=f"/datasets/uig/results/{USER}/{EXPERIMENT_NAME}/checkpoints/dqn_model_{STEPS}_steps.zip",
    env=env,
)

model.tensorboard_log = config.tensorboard_log

episode_rewards, episode_lengths, avg_episode_rewards = [], [], []
for episode in range(len_data):
    obs, _ = env.reset()
    done = False
    total_reward = 0
    steps = 0
    while not done:
        action, _states = model.predict(obs, deterministic=True)
        obs, reward, done, _, info = env.step(action)
        total_reward += reward
        steps += 1
    episode_rewards.append(total_reward)
    episode_lengths.append(steps)
    avg_episode_rewards.append(total_reward / steps)

with open(config.experiment_dir + "/images_dict.json", "w") as f:
    json.dump(env.images_dict, f, indent=4)

config.logger.info(f"Episode rewards: {episode_rewards}")
config.logger.info(f"Episode lengths: {episode_lengths}")
config.logger.info(f"Average episode rewards: {avg_episode_rewards}")