"""Methods for interacting with the environments."""

from __future__ import annotations
from multiprocessing import shared_memory
from optparse import Option

import os
import sys
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import d4rl  # noqa F401 Import needed to register d4rl envs
from d4rl import infos, offline_env
from d4rl.kitchen import kitchen_envs
from d4rl.kitchen.adept_envs.franka import kitchen_multitask_v0
from d4rl.locomotion import ant
from gcsl import envs
# from gcsl.algo import variants
from gcsl.envs import goal_env
import gym
import numpy as np
import tqdm
import torch

from rvs import dataset, policies, util, visualize
from rvs.vector.async_vector_env import AsyncVectorEnv

gym_goal_envs = [
    "FetchPickAndPlace-v1",
    "FetchPush-v1",
    "FetchReach-v1",
    "FetchSlide-v1",
    "HandManipulateBlock-v0",
    "HandManipulateEgg-v0",
    "HandManipulatePen-v0",
    "HandReach-v0",
]

d4rl_agents = ["hopper", "halfcheetah", "ant", "walker2d"]
d4rl_datasets = [
    "random",
    "medium",
    "expert",
    "medium-expert",
    "medium-replay",
    "full-replay",
]
d4rl_versions = ["v0", "v1", "v2"]
d4rl_gym = [
    f"{agent}-{dataset}-{version}"
    for agent in d4rl_agents
    for dataset in d4rl_datasets
    for version in d4rl_versions
]
d4rl_antmaze = [
    "antmaze-umaze-v2",
    "antmaze-umaze-diverse-v2",
    "antmaze-medium-diverse-v2",
    "antmaze-medium-play-v2",
    "antmaze-large-diverse-v2",
    "antmaze-large-play-v2",
]
d4rl_maze2d = [
    "maze2d-open-v0",
    "maze2d-umaze-v1",
    "maze2d-medium-v1",
    "maze2d-large-v1"
]
d4rl_franka = ["kitchen-complete-v0", "kitchen-partial-v0", "kitchen-mixed-v0"]
d4rl_env_names = d4rl_gym + d4rl_antmaze + d4rl_maze2d + d4rl_franka

s_obs_vecs_file = "s_obs_vecs.npy"
s_ach_goal_vecs_file = "s_ach_goal_vecs.npy"
a_vecs_file = "a_vecs.npy"
base_actions_file = "base_actions.npy"


def save_rollouts(
    rollout_dir: str,
    s_obs_vecs: Union[np.ndarray, List[np.ndarray]],
    s_ach_goal_vecs: Union[np.ndarray, List[np.ndarray]],
    a_vecs: Union[np.ndarray, List[np.ndarray]],
    base_actions: Optional[np.ndarray] = None,
) -> None:
    """Save environment rollouts to the rollout directory."""
    obs_file = os.path.join(rollout_dir, s_obs_vecs_file)
    ach_goal_file = os.path.join(rollout_dir, s_ach_goal_vecs_file)
    act_file = os.path.join(rollout_dir, a_vecs_file)
    base_acts_file = os.path.join(rollout_dir, base_actions_file)

    os.makedirs(rollout_dir, exist_ok=True)
    np.save(obs_file, s_obs_vecs)
    np.save(ach_goal_file, s_ach_goal_vecs)
    np.save(act_file, a_vecs)
    if base_actions is not None:
        np.save(base_acts_file, base_actions)


def load_rollouts(rollout_dir: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Load environment rollouts from the rollout directory."""
    obs_file = os.path.join(rollout_dir, s_obs_vecs_file)
    ach_goal_file = os.path.join(rollout_dir, s_ach_goal_vecs_file)
    act_file = os.path.join(rollout_dir, a_vecs_file)

    s_obs_vecs = np.load(obs_file)
    s_ach_goal_vecs = np.load(ach_goal_file)
    a_vecs = np.load(act_file)

    return s_obs_vecs, s_ach_goal_vecs, a_vecs


def get_total_steps(rollout_dir: str) -> int:
    """Calculate the total number of environment steps (trajs * steps / traj)."""
    s_obs_vecs, s_ach_goal_vecs, a_vecs = load_rollouts(rollout_dir)
    assert s_obs_vecs.shape[0] == s_ach_goal_vecs.shape[0] == a_vecs.shape[0]
    assert s_obs_vecs.shape[1] == s_ach_goal_vecs.shape[1] == a_vecs.shape[1]

    total_steps = s_obs_vecs.shape[0] * s_obs_vecs.shape[1]
    return total_steps


def generate_random_rollouts(
    env: GCSLToGym,
    rollout_dir: str,
    total_steps: int,
    max_episode_steps: int,
    use_base_actions: bool = False,
) -> None:
    """Collect random rollouts from the environment.

    Stores observations, actions, and achieved goals.
    """
    try:
        base_actions = env.base_actions
    except AttributeError:
        base_actions = None

    if use_base_actions and base_actions is None:
        raise ValueError(
            "use_base_actions == True but env.base_actions doesn't exist")

    s_obs_vecs = []
    s_ach_goal_vecs = []
    a_vecs = []

    episodes = int(np.ceil(total_steps / max_episode_steps))
    for _ in tqdm.trange(episodes, desc="Generating env episodes"):
        a_vec = []
        s = env.reset()
        s_obs_vec = [s["observation"]]
        s_ach_goal_vec = [s["achieved_goal"]]
        for _ in range(max_episode_steps):
            a = env.action_space.sample()
            s, _, _, _ = env.step(a)
            a_vec.append(a)
            s_obs_vec.append(s["observation"])
            s_ach_goal_vec.append(s["achieved_goal"])
        s_obs_vec = np.array(s_obs_vec)
        s_ach_goal_vec = np.array(s_ach_goal_vec)
        a_vec = np.array(a_vec)

        s_obs_vecs.append(s_obs_vec)
        s_ach_goal_vecs.append(s_ach_goal_vec)
        if use_base_actions:
            a_vecs.append(base_actions[a_vec])
        else:
            a_vecs.append(a_vec)

    save_rollouts(
        rollout_dir,
        s_obs_vecs,
        s_ach_goal_vecs,
        a_vecs,
        base_actions=base_actions,
    )


def is_kitchen_env(env: gym.Env) -> bool:
    """Determine if env is a D4RL Franka kichen env."""
    return env.spec.id in d4rl_franka


def is_antmaze_env(env: gym.Env) -> bool:
    """Determine if env is a D4RL AntMaze env."""
    if hasattr(env, 'env_fns'):
        return env.env_fns[0]().spec.id in d4rl_antmaze
    return env.spec.id in d4rl_antmaze


def render_env(env: gym.Env, mode="human") -> Union[np.ndarray, None]:
    """Helper function that provides special case for rendering D4RL kitchen envs."""
    if is_kitchen_env(env):
        return kitchen_multitask_v0.KitchenTaskRelaxV1.render(env, mode=mode)
    else:
        return env.render(mode=mode)


def get_action_from_policy(
    policy: Union[policies.RvS, Callable[[np.ndarray, np.ndarray], np.ndarray]],
    obs: np.ndarray,
    goal: np.ndarray,
    deterministic: bool = True,
) -> np.ndarray:
    """Helper function to get action from multiple types of policies."""
    try:
        # pytype: disable=attribute-error
        action = policy.get_action(obs, goal, deterministic)
    except AttributeError:
        action = policy(obs, goal)

    return action


def rollout_and_render(
    policy: Union[policies.RvS, Callable[[np.ndarray, np.ndarray], np.ndarray]],
    env: gym.Env,
    max_episode_steps: int,
    fixed_goal: Optional[np.ndarray] = None,
    dynamic_kitchen_goal: bool = False,
) -> List[np.ndarray]:
    """Roll the policy out in the environment and render every step."""
    frames = []
    if not max_episode_steps:
        max_episode_steps = sys.maxsize

    obs_dict = env.reset()
    frames.append(render_env(env, mode="rgb_array"))
    for _ in range(max_episode_steps):
        if dynamic_kitchen_goal:
            obs = obs_dict
            goal = get_dynamic_kitchen_goal(env, obs)
        elif fixed_goal is not None:
            obs, goal = obs_dict, fixed_goal
        else:
            obs, goal = obs_dict["observation"], obs_dict["desired_goal"]
        a = get_action_from_policy(policy, obs, goal)
        obs_dict, _, done, info = env.step(a)
        frames.append(render_env(env, mode="rgb_array"))

        if done or info.get("is_success", False):
            break

    return frames  # pytype: disable=bad-return-type


def get_valid_kitchen_subtasks(env: kitchen_envs.KitchenBase) -> List[str]:
    """Create list of valid subtasks to command in a D4RL Kitchen env."""
    valid_subtasks = ["all"] + ["random"] + \
        [task for task in env.TASK_ELEMENTS]
    return valid_subtasks


def get_kitchen_goal(
    env: kitchen_envs.KitchenBase,
    subtask: str = "microwave",
    render_goal_frame: bool = False,
) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
    """Get the goal state and rendered goal frame for the kitchen subtask."""
    valid_subtasks = get_valid_kitchen_subtasks(env)
    if subtask not in valid_subtasks:
        raise ValueError(
            f"Received subtask == {subtask} which is not among the valid choices: "
            f"{valid_subtasks}",
        )

    obs_1 = env.reset()
    obs_2 = env.reset()
    diffs = obs_1 - obs_2
    (zero_indices,) = np.where(diffs == 0)
    assert np.all(zero_indices == np.arange(30, 60))
    fixed = obs_1[30:]

    if subtask == "random":
        goal, goal_frame = get_random_kitchen_goal(
            env,
            render_goal_frame=render_goal_frame,
        )
    else:
        goal = obs_1[:30]
        goal_frame = None
        subtask_collection = env.TASK_ELEMENTS if subtask == "all" else [
            subtask]
        for task in subtask_collection:
            subtask_indices = kitchen_envs.OBS_ELEMENT_INDICES[task]
            subtask_goals = kitchen_envs.OBS_ELEMENT_GOALS[task]
            goal[subtask_indices] = subtask_goals

    return np.concatenate((goal, fixed), axis=0), goal_frame


def get_random_kitchen_goal(
    env: kitchen_envs.KitchenBase,
    random_horizon: int = 50,
    render_goal_frame: bool = False,
) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
    """Take random actions and return the achieved state as a goal."""
    observation = env.reset()
    for _ in range(random_horizon):
        random_action = env.action_space.sample()
        observation, _, done, _ = env.step(random_action)
        if done:
            break

    goal = observation[:30]
    goal_frame = render_env(
        env, mode="rgb_array") if render_goal_frame else None

    return goal, goal_frame


def get_dynamic_kitchen_goal(
    env: kitchen_envs.KitchenBase,
    obs: np.ndarray,
) -> np.ndarray:
    """Get the goal for the next subtask that needs to be completed in the kitchen."""
    goal = np.copy(obs[:30])
    fixed = obs[30:]
    for subtask in env.TASK_ELEMENTS:
        if subtask in env.tasks_to_complete:
            subtask_indices = kitchen_envs.OBS_ELEMENT_INDICES[subtask]
            subtask_goals = kitchen_envs.OBS_ELEMENT_GOALS[subtask]
            goal[subtask_indices] = subtask_goals
            break

    return np.concatenate((goal, fixed), axis=0)


def sample_elite_steps(
    dataset: Dict[str, np.ndarray],
    elite_property: str = "length",
    elite_traj_fraction: float = 0.2,
    elite_step_fraction: float = 0.2,
    samples: int = 200,
    reverse: bool = False,
) -> Tuple[np.ndarray, np.ndarary]:
    """Choose steps from the demonstrations based on a property of the trajectories.

    Args:
        dataset: The dataset of trajectories to use.
        elite_property: Which property of trajectories to use to select the best
            trajectories.
        elite_traj_fraction: If this value is, e.g., 0.2, then sample from the top 20%
            of trajectories.
        elite_step_fraction: If this value is, e.g., 0.2, then sample from the last 20%
            of timesteps.
        samples: How many total steps to return.
        reverse: If true, sample from the worst trajectories (rather than the best).

    Returns:
        A sampled array of observations and a corresponding array of actions.

    Raises:
        ValueError: If an invalid elite property is given.
    """
    starts, ends, lengths = util.extract_traj_markers(dataset)

    if elite_property == "length":
        sorted_indices = np.argsort(lengths)
    elif elite_property == "reward":
        rewards = visualize.get_episode_rewards(
            dataset["rewards"], starts, ends)
        sorted_indices = np.argsort(rewards)
    else:
        raise ValueError
    if reverse:
        sorted_indices = sorted_indices[::-1]

    num_elites = np.ceil(elite_traj_fraction * len(lengths)).astype(
        int,
    )  # ceil because array indexing is exclusive
    elite_indices = sorted_indices[:-num_elites:-1]
    elite_index = np.random.choice(elite_indices, size=samples)

    elite_start = starts[elite_index]
    elite_proportional_time = (
        1 - elite_step_fraction + np.random.rand(samples) * elite_step_fraction
    )
    elite_relative_time = np.floor(
        elite_proportional_time * lengths[elite_index],
    ).astype(int)
    elite_id = elite_start + elite_relative_time

    return dataset["observations"][elite_id], dataset["actions"][elite_id]


def sample_cumulative_reward(
    policy: Union[policies.RvS, Callable[[np.ndarray, np.ndarray], np.ndarray]],
    env: offline_env.OfflineEnv,
    goals: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
    trajectory_samples: int = 2000,
    return_goals: bool = False,
    dynamic_kitchen_goal: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, Union[np.ndarray, List[np.ndarray], None]]]:
    """Samples cumulative reward from a D4RL environment.

    Handles the policy conditioning and supports dynamically conditioning on the next
    remaining task in Kitchen.
    """
    if goals is None and not dynamic_kitchen_goal:
        goals, _ = sample_elite_steps(
            env.get_dataset(), samples=trajectory_samples)
    if goals is not None:
        assert (
            len(goals) == trajectory_samples
        ), "Must have exactly one goal for each trajectory sample"

    total_reward_vec = []
    for i in tqdm.trange(trajectory_samples, desc="Sampling trajectory rewards"):
        total_reward = 0
        observation = env.reset()
        goal = (
            goals[i]
            if not dynamic_kitchen_goal
            else get_dynamic_kitchen_goal(env, observation)
        )
        done = False
        while not done:
            action = get_action_from_policy(policy, observation, goal)
            observation, reward, done, _ = env.step(action)
            if dynamic_kitchen_goal:
                goal = get_dynamic_kitchen_goal(env, observation)
            total_reward += reward
        total_reward_vec.append(total_reward)

    total_rewards = np.array(total_reward_vec)
    if return_goals:
        return total_rewards, goals
    else:
        return total_rewards


def sample_with_reward_conditioning(
    policy: Union[policies.RvS, Callable[[np.ndarray, np.ndarray], np.ndarray]],
    env: gym.Env,
    reward_target: Union[int, float],
    return_model=None,
    return_quantile=None,
    discount_factor: float = 1.0,
    trajectory_samples: int = 200,
    deterministic: bool = True,
    reward_preprocessing: str = "none",
) -> np.ndarray:
    """Evaluate cumulative reward with RvS-R (reward-conditioned) learning."""
    total_reward_vec = []
    for _ in tqdm.trange(trajectory_samples, desc="Sampling trajectory rewards"):
        total_reward = 0
        running_target = reward_target
        observation = env.reset()
        done = False
        target_goal = None
        while not done:
            if return_model is not None and return_quantile is not None:
                if target_goal:
                    s = np.concatenate((observation, target_goal), -1)
                else:
                    s = observation
                with torch.no_grad():
                    running_target = return_model(torch.Tensor(s).cuda().unsqueeze(0)).mean(-1)[
                        :, return_quantile:].mean(-1).cpu().item() + reward_target
            if target_goal:
                goal = np.concatenate([target_goal, running_target], axis=1)
            else:
                goal = np.array([running_target])
            action = get_action_from_policy(
                policy, observation, goal, deterministic=deterministic)
            observation, reward, done, _ = env.step(action)
            running_target -= preprocess_reward(reward,
                                                reward_preprocessing)
            running_target /= discount_factor
            running_target = preprocess_target(
                running_target, reward_preprocessing)
            total_reward += reward
        total_reward_vec.append(total_reward)

    total_rewards = np.array(total_reward_vec)
    print(np.mean(total_rewards))
    return total_rewards


def preprocess_reward(reward, method):
    if method == "none" or method == "conservative":
        return reward
    elif method == "antmaze":
        return reward - 1.0
    raise ValueError(f"Unknown reward preprocessing method: {method}")


def preprocess_target(target, method):
    if method == "none" or method == "conservative":
        return target
    elif method == "antmaze":
        return np.minimum(target, 0.)
    raise ValueError(f"Unknown target preprocessing method: {method}")


def sample_with_reward_conditioning_multicpu(
    policy: Union[policies.RvS, Callable[[np.ndarray, np.ndarray], np.ndarray]],
    env_lambda: Callable[[], gym.Env],
    reward_target: Union[int, float],
    return_model=None,
    return_quantile=None,
    discount_factor: float = 1.0,
    trajectory_samples: int = 200,
    deterministic: bool = True,
    num_cpu: int = 20,
    reward_preprocessing: str = "none",
) -> np.ndarray:
    """Evaluate cumulative reward with RvS-R (reward-conditioned) learning."""
    total_reward_vec = []

    def make_env_copy(seed):
        def fn():
            env = env_lambda()
            np.random.seed(seed)
            env.np_random.seed(seed)
            return env
        return fn

    env = AsyncVectorEnv(
        list(map(make_env_copy, range(num_cpu))), shared_memory=False)

    for _ in tqdm.trange(trajectory_samples // num_cpu, desc="Sampling trajectory rewards"):
        total_reward = np.zeros((num_cpu,))
        running_target = reward_target
        observation = env.reset()
        if not return_model or not return_quantile:
            running_target = np.array(
                [running_target] * len(observation))[:, None]
        done = np.zeros((num_cpu,))
        i = 0
        target_goal = None
        while not np.all(done) and i < 999:  # TODO: remove hard-coded max length
            if return_model is not None and return_quantile is not None:
                if target_goal:
                    s = np.concatenate((observation, target_goal), -1)
                else:
                    s = observation
                with torch.no_grad():
                    running_target = return_model(torch.Tensor(s).cuda()).mean(-1)[
                        :, return_quantile:].mean(-1, keepdim=True).cpu().numpy() + reward_target
            i += 1
            if target_goal:
                goal = np.concatenate([target_goal, running_target], axis=1)
            else:
                goal = running_target
            action = get_action_from_policy(
                policy, observation, goal, deterministic=deterministic)
            observation, reward, new_done, _ = env.step(action)
            running_target -= (preprocess_reward(reward,
                               reward_preprocessing)[:, None])
            running_target /= discount_factor
            running_target = preprocess_target(
                running_target, reward_preprocessing)

            total_reward += reward * (1 - done)
            done = np.maximum(done, new_done)
        for r in total_reward:
            total_reward_vec.append(r)

    total_rewards = np.array(total_reward_vec)
    print(f'Reward target: {reward_target}: {np.mean(total_rewards)}')
    return total_rewards


def find_elite_goals(
    policy: Union[policies.RvS, Callable[[np.ndarray, np.ndarray], np.ndarray]],
    env: offline_env.OfflineEnv,
    trajectory_samples: int = 200,
    num_elites: int = 6,


) -> np.ndarray:
    """Random search over length goals to find those with highest return."""
    rewards, goals = sample_cumulative_reward(
        policy,
        env,
        trajectory_samples=max(trajectory_samples, num_elites),
        return_goals=True,
    )
    elite_indices = np.argsort(rewards)[
        : -num_elites - 1: -1
    ]  # indices of num_elites largest rewards
    elite_goals = goals[elite_indices]

    return elite_goals


def evaluate_goals(
    policy: Union[policies.RvS, Callable[[np.ndarray, np.ndarray], np.ndarray]],
    env: offline_env.OfflineEnv,
    goals: Iterable[np.ndarray],
    trajectory_samples: int = 200,
) -> List[np.ndarray]:
    """Evaluate the cumulative return for each goal."""
    reward_vecs = []
    for goal in goals:
        duplicated_goal = [goal] * trajectory_samples
        reward_vec = sample_cumulative_reward(
            policy,
            env,
            goals=duplicated_goal,
            trajectory_samples=trajectory_samples,
        )
        reward_vecs.append(reward_vec)

    return reward_vecs  # pytype: disable=bad-return-type


def sample_hitting_times(
    policy: Union[policies.RvS, Callable[[np.ndarray, np.ndarray], np.ndarray]],
    env: GCSLToGym,
    max_episode_steps: int,
    hitting_time_samples: int = 2000,
) -> np.ndarray:
    """Monte Carlo estimation of the hitting time distribution."""
    hitting_time_vec = []
    for _ in tqdm.trange(hitting_time_samples, desc="Sampling hitting times"):
        hitting_time = 0
        obs_dict = env.reset()
        # As env is wrapped in gym.wrappers.TimeLimit, may not need `max_episode_steps`
        for _ in range(max_episode_steps):
            hitting_time += 1
            obs, goal = obs_dict["observation"], obs_dict["desired_goal"]
            a = get_action_from_policy(policy, obs, goal)
            obs_dict, _, done, info = env.step(a)
            if done or info.get("is_success", False):
                break
        hitting_time_vec.append(hitting_time)
    hitting_times = np.array(hitting_time_vec)
    return hitting_times


def sample_episode_performance(
    policy,
    env: Union[GCSLToGym, offline_env.OfflineEnv],
    env_name: str,
    max_episode_steps: int,
    traj_samples: int = 2000,
    kitchen_subtask: str = "all",
) -> np.ndarray:
    """Helper function to sample episode performance correctly, depending on the env."""
    if env_name[:7] == "kitchen":
        if kitchen_subtask == "dynamic":
            return sample_cumulative_reward(
                policy,
                env,
                trajectory_samples=traj_samples,
                dynamic_kitchen_goal=True,  # pytype: disable=bad-return-type
            )
        else:
            goal, _ = get_kitchen_goal(env, subtask=kitchen_subtask)
            goals = np.repeat(goal[np.newaxis], traj_samples, axis=0)
            return sample_cumulative_reward(
                policy,
                env,
                goals=goals,
                trajectory_samples=traj_samples,  # pytype: disable=bad-return-type
            )
    elif env_name in d4rl_env_names:
        return sample_cumulative_reward(
            policy,
            env,
            trajectory_samples=traj_samples,  # pytype: disable=bad-return-type
        )
    else:
        return sample_hitting_times(
            policy,
            env,
            max_episode_steps,
            hitting_time_samples=traj_samples,
        )


def get_reward_targets(
    env: Union[offline_env.OfflineEnv, gym.wrappers.TimeLimit],
    env_name: str,
    reward_fractions: List[float],
    targets: str = "of expert",
    average_reward_to_go: bool = True,
) -> List[float]:
    """Translate reward fractions into absolute reward targets.

    Args:
        env: The env under consideration.
        env_name: The name of the env.
        reward_fractions: Which reward fractions to convert into cumulative reward.
        targets: Either 'of demos' or 'of expert', indicating w.r.t. what the reward
            fractions are defined.
        average_reward_to_go: If True, use average reward per timestep. Else, use
            cumulative reward for the whole trajectory.

    Returns:
        A list of reward targets that correspond to the given reward fractions.

    Raises:
        ValueError: If an invalid option for targets is specified.
    """
    if targets == "of demos":
        reward_to_go = dataset.reward_to_go(
            env.get_dataset(),
            average=average_reward_to_go,
        )
        reward_min = np.min(reward_to_go)
        reward_max = np.max(reward_to_go)
    elif targets == "of expert":
        if "antmaze" in env_name:
            reward_min = 0
            reward_max = 1
        else:
            reward_min = infos.REF_MIN_SCORE[env_name]
            reward_max = infos.REF_MAX_SCORE[env_name]
        if average_reward_to_go:
            reward_min /= env._max_episode_steps
            reward_max /= env._max_episode_steps
    else:
        raise ValueError("targets must be 'of demos' or 'of expert'")

    reward_targets = [
        reward_min + (reward_max - reward_min) * frac for frac in reward_fractions
    ]
    return reward_targets


def eval_reward_conditioning(
    policy: Union[policies.RvS, Callable[[np.ndarray, np.ndarray], np.ndarray]],
    env: offline_env.OfflineEnv,
    env_lambda: Callable[[], offline_env.OfflineEnv],
    env_name: str,
    reward_fractions: List[float],
    reward_targets: Optional[List[float]],
    discount_factor: float,
    trajectory_samples: int = 200,
    average_reward_to_go: bool = True,
    targets: str = "of expert",
    return_quantile=None,
    num_cpu: int = 1,
    deterministic: bool = True,
    return_model=None,
    reward_preprocessing: str = "none",
) -> List[np.ndarray]:
    """Evaluate RvS-R (reward-conditioned) learning for each reward fraction."""
    if env_name not in d4rl_env_names:
        raise NotImplementedError

    if reward_targets is None:
        reward_targets = get_reward_targets(
            env,
            env_name,
            reward_fractions,
            targets=targets,
            average_reward_to_go=average_reward_to_go
        )
        print("Reward targets:", reward_targets)

    if average_reward_to_go:
        reward_targets = [np.array([reward_target])
                          for reward_target in reward_targets]
        reward_vecs = evaluate_goals(
            policy,
            env,
            reward_targets,
            trajectory_samples=trajectory_samples,
        )
    else:
        reward_vecs = []

        for reward_target in reward_targets:
            if num_cpu > 1:
                reward_vec = sample_with_reward_conditioning_multicpu(
                    policy,
                    env_lambda,
                    reward_target,
                    return_model,
                    return_quantile,
                    discount_factor,
                    trajectory_samples=trajectory_samples,
                    deterministic=deterministic,
                    num_cpu=num_cpu,
                    reward_preprocessing=reward_preprocessing,
                )
            else:
                reward_vec = sample_with_reward_conditioning(
                    policy,
                    env,
                    reward_target,
                    return_model,
                    return_quantile,
                    discount_factor,
                    trajectory_samples=trajectory_samples,
                    deterministic=deterministic,
                    reward_preprocessing=reward_preprocessing,
                )
            reward_vecs.append(reward_vec)

    return reward_vecs


def eval_d4rl_antmaze(
    policy: Union[policies.RvS, Callable[[np.ndarray, np.ndarray], np.ndarray]],
    env: ant.AntMazeEnv,
    trajectory_samples: int = 200,
) -> np.ndarray:
    """Evaluate cumulative reward in AntMaze."""
    assert env.reward_type == "sparse"

    total_reward_vec = []
    for _ in tqdm.trange(trajectory_samples, desc="Sampling trajectory rewards"):
        total_reward = 0
        observation = env.reset()
        goal = np.array(env.target_goal)
        done = False
        while not done:
            assert np.all(np.isclose(observation[:2], env.get_xy()))
            action = get_action_from_policy(policy, observation, goal)
            observation, reward, done, _ = env.step(action)
            total_reward += reward
        total_reward_vec.append(total_reward)

    total_rewards = np.array(total_reward_vec)
    return total_rewards


def random_hitting_times(
    env: GCSLToGym,
    max_episode_steps: int,
    hitting_time_samples: int = 2000,
) -> np.ndarray:
    """For reference, calculate the hitting time of a random policy."""
    random_hitting_time_vec = sample_hitting_times(
        lambda obs, goal: env.action_space.sample(),
        env,
        max_episode_steps,
        hitting_time_samples=hitting_time_samples,
    )
    random_hitting_time = random_hitting_time_vec.mean()
    print(f"Random hitting time: {random_hitting_time:.3f}")

    return random_hitting_time_vec


def seed_env(env: gym.Env, seed: int) -> None:
    """Set the random seed of the environment."""
    if seed is None:
        seed = np.random.randint(2**31 - 1)
    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)


def create_env(
    env_name: str,
    max_episode_steps: int,
    discretize: bool,
    discrete_clusters=256,
    seed: Optional[int] = None,
) -> Union[GCSLToGym, offline_env.OfflineEnv]:
    """Helper function to create an environment.

    Args:
        env_name: The name of the environment.
        max_episode_steps: The number of episode steps before the time limit runs out.
        discretize: If True, discretize the environment's action space.
        seed: A random seed for the environment.

    Returns:
        The created environment.

    Raises:
        ValueError: If the environment name is not in GCSL or D4RL.
    """
    if env_name in envs.env_names:
        env = create_gcsl_env(env_name, max_episode_steps, discretize)
    elif env_name in d4rl_env_names:
        env = gym.make(env_name)
        if discretize:
            env = util.DiscretizeWrapper(env, k=discrete_clusters)
    else:
        raise ValueError("Please provide a GCSL or D4RL env name.")

    seed_env(env, seed)
    return env


def create_gcsl_env(
    env_name: str,
    max_episode_steps: int,
    discretize: bool,
) -> GCSLToGym:
    """Create a GCSL environment."""
    env = envs.create_env(env_name)
    env_params = envs.get_env_params(env_name)
    print(env_params)

    # if discretize:
    #     env = variants.discretize_environment(env, env_params)

    env = GCSLToGym(env, goal_threshold=env_params["goal_threshold"])
    env = gym.wrappers.TimeLimit(env, max_episode_steps=max_episode_steps)
    print(env)
    print(env.observation_space)
    print(env.action_space)

    return env


class GCSLToGym(gym.ObservationWrapper):
    """Turn a gcsl.envs.goal_env.GoalEnv into a gym.GoalEnv.

    Note that this wrapper varies slightly from the intended use of the gym.GoalEnv
    interface. In this wrapper, the achieved_goal and the desired_goal are the
    _observations_ of the goal, whereas the gym.GoalEnv wrapper intends for them to
    be the _states_ of the goal.
    """

    def __init__(self, env: goal_env.GoalEnv, goal_threshold: float = 0.05):
        """Wrap a GCSL env into a gym.GoalEnv.

        Args:
            env: The GCSL env to wrap.
            goal_threshold: The distance required to reach the goal.
        """
        super(GCSLToGym, self).__init__(env)
        self.observation_space = gym.spaces.Dict(
            dict(
                observation=env.observation_space,
                achieved_goal=env.goal_space,
                desired_goal=env.goal_space,
            ),
        )
        self.goal_threshold = goal_threshold

        self.current_state = None
        self.desired_goal_state = None
        self.desired_goal = None

    def reset(self) -> Dict[str, np.ndarray]:
        """Reset the environment."""
        self.desired_goal_state = self.env.sample_goal()
        self.desired_goal = self.env.extract_goal(self.desired_goal_state)
        return super(GCSLToGym, self).reset()

    def step(
        self,
        action: Union[int, np.ndarray],
    ) -> Tuple[
        Dict[str, np.ndarray],
        Union[int, np.int64, np.float64],
        np.bool_,
        Dict[str, Any],
    ]:
        """Take a step in the environment."""
        observation, reward, _, info = super(GCSLToGym, self).step(action)
        distance_to_goal = self.env.goal_distance(
            self.current_state,
            self.desired_goal_state,
        )
        done = distance_to_goal < self.goal_threshold
        return observation, reward, done, info

    def observation(self, state: np.ndarray) -> Dict[str, np.ndarray]:
        """Fetch the environment observation."""
        self.current_state = state
        return dict(
            observation=self.env.observation(state),
            achieved_goal=self.env.extract_goal(state),
            desired_goal=self.desired_goal,
        )
