from typing import Callable

import torch
from omegaconf import DictConfig
from stable_baselines3 import DDPG

from ours.utils.trans_fn import get_trans_observations_fns, inv_trans_action_fn


def make_policy(
    args: DictConfig,
    env_id: str,
    agent: DDPG,
    reverse_observations: bool,
    reverse_actions: bool,
    domain_id: int,
    num_task_ids: int,
    act_dim_domain: int,
) -> Callable:

    trans_observation_fn, _ = get_trans_observations_fns(env_id)

    def policy(obs, goal):  # ignore goal
        if reverse_observations:
            obs = trans_observation_fn(obs[None, :])[0]

        with torch.inference_mode():
            act = agent.predict(obs)[0]

        if reverse_actions:
            act = inv_trans_action_fn(act[None, :])[0]
        return act

    return policy
