import torch

from models.flow_policy import FlowPolicy
from models.flow_model import MLP
from models.policy import Actor


def load_policy(config, device='cpu', path=None):
    policy = Actor(
        config.env.state_dim, 
        config.env.action_dim, 
        config.policy.hidden_dim,
        config.env.max_action,
        config.policy.std
    )

    if path is not None:
        policy.load_state_dict(torch.load(path, map_location=device))
        print('Policy loaded from path:', path)
    else:
        print('Policy loaded from scratch.')
    return policy.to(device)

