from typing import Callable

import torch
from common.dail.models import DAILAgent
from omegaconf import DictConfig

from ours.utils.trans_fn import get_trans_observations_fns, inv_trans_action_fn
from ours.utils.utils import goal_to_task_id, task_id_map_m2p


def make_policy(
    args: DictConfig,
    env_id: str,
    agent: DAILAgent,
    reverse_observations: bool,
    reverse_actions: bool,
    domain_id: int,
    num_task_ids: int,
    act_dim_domain: int,
) -> Callable:

    trans_observation_fn, _ = get_trans_observations_fns(env_id)

    def policy(obs, goal):  # obs without goal
        if reverse_observations:
            obs = trans_observation_fn(obs[None, :])[0]

        task_id = goal_to_task_id(goal, env_id)
        if "maze2d" in env_id:
            maze_type = env_id.split("-")[1]
            task_id = task_id_map_m2p[maze_type][task_id]

        task_id_onehot = torch.eye(num_task_ids)[task_id - 1]

        obs = torch.Tensor(obs).unsqueeze(0).to(args.device)
        task_id_onehot = task_id_onehot.unsqueeze(0).to(args.device)

        with torch.inference_mode():
            if domain_id == 0:  # source domain
                obs_ = torch.cat((obs, task_id_onehot), dim=-1)
                act = agent.source_policy(obs_).squeeze()
            elif domain_id == 1:  # target domain
                source_obs = agent.state_map(obs)
                source_obs_ = torch.cat((source_obs, task_id_onehot), dim=-1)
                source_act = agent.source_policy(source_obs_)
                act = agent.action_map(source_act).squeeze()

        act = act.cpu().numpy()
        if reverse_actions:
            act = inv_trans_action_fn(act[None, :])[0]
        return act

    return policy
