from env import ImageEnv
from rl.dqn import ImagePoserDQN
from utils.config import Config
from utils.load_stats import load_stats
from stable_baselines3.common.callbacks import CheckpointCallback
import wandb
import os
        
# CHANGE THESE VARIABLES TO LOAD A DIFFERENT CHECKPOINT
# THESE ARE INFO FROM THE LAST RUN YOU WANT TO CONTINUE FROM
USER = os.getenv("USER")
EXPERIMENT_NAME = "run_20250909-045133"
EXP_DIR = f"/datasets/uig/results/{USER}/{EXPERIMENT_NAME}"
STEPS = 140
CHECKPOINT_PATH = f"{EXP_DIR}/checkpoints/dqn_model_{STEPS}_steps.zip"
REPLAY_BUFFER_PATH = f"{EXP_DIR}/checkpoints/dqn_model_replay_buffer_{STEPS}_steps.pkl"
STATS_PATH = f"{EXP_DIR}/stats/{STEPS}_model_stats.json"


model_score_sums, model_usage_counts, model_usage_type_counts, model_usage_type_score_sums, step_times, global_step_counter, idx_to_start = load_stats(STATS_PATH)

config = Config(
    dataset_path="./data/train.jsonl",
    model_usage_counts=model_usage_counts,
    model_usage_type_counts=model_usage_type_counts,
    model_usage_type_score_sums=model_usage_type_score_sums,
    model_score_sums=model_score_sums,
    step_times=step_times,
    stats_dir=f"{EXP_DIR}/stats",
    global_step_counter=global_step_counter,
    idx_to_start=idx_to_start,
    notes=f"DQN continuing training on old checkpoint from {EXPERIMENT_NAME} at {STEPS} steps",
)
config.logger.info(config)

wandb.init(
    project=config.wandb_project,
    id=EXPERIMENT_NAME,
    dir=f"{EXP_DIR}/wandb",
    resume=True,
    sync_tensorboard=True, # Syncs TensorBoard logs to W&B
)

env = ImageEnv(config)
model = ImagePoserDQN.load(
    path=CHECKPOINT_PATH,
    env=env,
)
model.load_replay_buffer(REPLAY_BUFFER_PATH)

checkpoint_callback = CheckpointCallback(save_freq=config.log_interval, save_path=f"{EXP_DIR}/checkpoints", name_prefix="dqn_model", save_replay_buffer=True)

model.learn(total_timesteps=1000-STEPS, log_interval=config.log_interval, callback=checkpoint_callback, reset_num_timesteps=False)

model.save(config.experiment_dir + "/model")

wandb.finish()