import os
import tensorflow as tf
from tf_agents.policies import policy_saver
from tf_agents.utils import common


def save_policy(save_dir, policy):
    print(f"Saved policy to {save_dir}")
    tf_policy_saver = policy_saver.PolicySaver(policy)
    tf_policy_saver.save(save_dir)


def restore_policy(load_dir):
    saved_policy = tf.saved_model.load(load_dir)
    return saved_policy


def save_checkpointer(save_dir, agent, log=True):
    if log:
        print(f"Init save checkpointer from {save_dir}")
    train_checkpointer = common.Checkpointer(
        ckpt_dir=save_dir, max_to_keep=1, agent=agent, policy=agent.policy
    )
    train_checkpointer.initialize_or_restore()
    return train_checkpointer


def load_checkpointer(load_dir, agent, log=True):
    if log:
        print(f"Init load checkpointer from {load_dir} and load agent")
    load_checkpointer = common.Checkpointer(
        ckpt_dir=load_dir, max_to_keep=1, agent=agent, policy=agent.policy
    )
    load_checkpointer.initialize_or_restore()
    return load_checkpointer
