from typing import Callable, Dict

import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig

from .trans_fn import get_trans_observations_fns, inv_trans_action_fn
from .utils import goal_to_task_id, task_id_map_m2p


def make_policy(
    args: DictConfig,
    env_id: str,
    model: nn.Module,
    reverse_observations: bool,
    reverse_actions: bool,
    domain_id: int,
    num_task_ids: int,
    act_dim_domain: int,
) -> Callable:

    trans_observation_fn, _ = get_trans_observations_fns(env_id)
    prev_act = None

    def policy(obs, goal):  # obs without goal
        nonlocal prev_act

        if "point" in env_id and prev_act is not None:
            obs[[2, 3]] = prev_act * 4.25
        if reverse_observations:
            obs = trans_observation_fn(obs[None, :])[0]
        if len(obs) < model.state_dim:
            obs = np.hstack((obs, np.zeros((model.state_dim - len(obs), ))))

        task_id = goal_to_task_id(goal, env_id)
        if "maze2d" in env_id:
            maze_type = env_id.split("-")[1]
            task_id = task_id_map_m2p[maze_type][task_id]

        task_id_onehot = torch.eye(num_task_ids)[task_id - 1]

        obs = torch.Tensor(obs).unsqueeze(0).to(args.device)
        task_id_onehot = task_id_onehot.unsqueeze(0).to(args.device)
        domain_id_onehot = torch.eye(2)[domain_id].unsqueeze(0).to(args.device)

        with torch.inference_mode():
            act = model(obs, task_id_onehot,
                        domain_id_onehot)[0].squeeze()[:act_dim_domain]

        act = act.cpu().numpy()
        if reverse_actions:
            act = inv_trans_action_fn(act[None, :])[0]
        prev_act = act
        return act

    return policy
