from centralized_verification.configuration import TestingLimits
from centralized_verification.shields.centralized_shield import CentralizedShieldOracle
from centralized_verification.train import train_loop, test_loop
from centralized_verification.training_state import maybe_load_from_checkpoint
from experiments.utils.basic_pentagon_config import make_basic_pentagon_config_iql

if __name__ == '__main__':
    run_name = "q_learning_pentagon_centralized_shield_var_start_1"
    config = make_basic_pentagon_config_iql(run_name)
    config = config._replace(shield=CentralizedShieldOracle())
    checkpoint = maybe_load_from_checkpoint(run_name)
    train_loop(config, checkpoint)
    test_loop(config.to_test_config(TestingLimits(max_episode_len=500, num_episodes=100)))
