from __future__ import annotations

from typing import Literal

import gymnasium as gym
import numpy as np
from gymnasium.wrappers import TimeLimit
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE

from pc_rl.envs.metaworld.pointcloud_obs import MetaworldPointCloudObservations
from pc_rl.envs.wrappers import AxisAlignedCrop, VoxelGridDownsampling

from . import MetaworldAddRenderingToInfoWrapper


def build(
    task: str,
    observation_type: Literal["pointcloud"],
    max_episode_steps: int,
    max_expected_num_points: int | None = None,
    voxel_grid_size: float | None = None,
    add_rendering_to_info: bool = False,
) -> gym.Env:
    EnvCls = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[task]
    env = EnvCls()

    if observation_type == "pointcloud":
        post_processing = []
        post_processing.append(
            AxisAlignedCrop(
                min_bound=np.array([-1000, -1000, 0.01]),
                max_bound=np.array([1000, 1000, 1000]),
            )
        )
        if voxel_grid_size is not None:
            post_processing.append(
                VoxelGridDownsampling(voxel_grid_size=voxel_grid_size)
            )

        env = MetaworldPointCloudObservations(
            env,
            max_expected_num_points=max_expected_num_points,
            # only use topview, corner, corner2, and corner3
            exclude_cameras=["behindGripper", "gripperPOV"],
            depth_cutoff=50,
            post_processing=post_processing,
        )
    else:
        raise NotImplementedError(observation_type)

    if add_rendering_to_info:
        env = MetaworldAddRenderingToInfoWrapper(env)

    # default time limit of 500 is hard-coded into the metaworld envs
    if max_episode_steps < 500:
        env = TimeLimit(env, max_episode_steps=max_episode_steps)

    return env
