from typing import Callable, Tuple

import numpy as np
import torch


def get_trans_observations_fns(env_id: str, ) -> Tuple[Callable, Callable]:
    if "maze2d" in env_id:

        def trans_observation_fn(observations):
            observations[:, [0, 1, 2, 3]] = (observations[:, [1, 0, 3, 2]] -
                                             1) * 4
            return observations

        def inv_trans_observation_fn(observations):
            observations[:,
                         [0, 1, 2, 3]] = observations[:, [1, 0, 3, 2]] / 4 + 1
            return observations

    else:

        def trans_observation_fn(observations):
            observations[:, [0, 1]] = observations[:, [1, 0]]
            return observations

        def inv_trans_observation_fn(observations):
            observations[:, [0, 1]] = observations[:, [1, 0]]
            return observations

    return trans_observation_fn, inv_trans_observation_fn


def trans_action_fn(actions):
    if isinstance(actions, np.ndarray):
        return actions[:, ::-1] * -1.
    if isinstance(actions, torch.Tensor):
        return torch.flip(actions, dims=(1, )) * -1.


def inv_trans_action_fn(actions):
    if isinstance(actions, np.ndarray):
        return actions[:, ::-1] * -1.
    if isinstance(actions, torch.Tensor):
        return torch.flip(actions, dims=(1, )) * -1.
