import functools
from typing import Any, Dict, Optional, Type, Union

import dmc2gym
import gymnasium as gym

import torch
from dm_control.suite.wrappers import pixels

from lambda_ac.env.dmc import ExtendedTimeStepWrapper, FrameStackWrapper, drq_dmc_make
from lambda_ac.env.wrappers import (
    BraxWrapper,
    DMGymWrapper,
    NewGymWrapper,
    PyTorchWrapper,
)

if torch.cuda.is_available():
    v = torch.ones(1, device="cuda")  # init torch cuda before jax


def setup_environment(
    env_type: str,
    env_id: str,
    n_envs: int,
    device: str = "cpu",
    seed: Optional[int] = None,
    **kwargs,
):
    if env_type == "brax":
        return setup_brax_env(env_id, n_envs, device)
    elif env_type == "gym":
        base_env = gym.make(env_id)
        base_env = NewGymWrapper(base_env, seed)
    # elif type == "dm_control":
    elif env_type == "robosuite":
        keys = kwargs.pop("keys", None)
        base_env = setup_robosuite_env(env_id, n_envs, device, seed, keys=keys)
    elif env_type == "dmc":
        domain, task, *_ = env_id.split("-")
        base_env = dmc2gym.make(domain_name=domain, task_name=task, seed=seed, **kwargs)
        base_env = NewGymWrapper(base_env, seed)
    elif env_type == "visual_dmc":
        return drq_dmc_make(
            env_id, kwargs["frame_stack"], kwargs["action_repeat"], seed
        )
    else:
        raise ValueError(f"Unknown environment type: {env_type}")
    if n_envs == 1:
        return PyTorchWrapper(base_env, device=device)
    # elif n_envs > 1:
    #     return PyTorchWrapper(
    #         make_vec_env(base_env, n_envs, seed=seed, vec_env_cls=SubprocVecEnv),
    #         device=device,
    #     )
    else:
        raise ValueError(f"n_envs must be > 0, got {n_envs}")


def setup_brax_env(
    env_id: str, n_envs: int, device: str = "cpu", seed: Optional[int] = None
):
    env_name = "brax-" + env_id + "-v0"
    entry_point = functools.partial(brax_envs.create_gym_env, env_name=env_id)
    gym.register(env_name, entry_point=entry_point)

    # create a gym environment that contains 4096 parallel ant environments
    gym_env = gym.make(env_name, batch_size=n_envs)

    # wrap it to interoperate with torch data structures
    gym_env = to_torch.JaxToTorchWrapper(gym_env, device=device)  # type: ignore

    return BraxWrapper(gym_env, device=device)


def setup_robosuite_env(env_id, n_envs, device, seed, keys=None) -> gym.Env:
    import robosuite as suite
    from robosuite import load_controller_config
    from robosuite.wrappers import GymWrapper

    robot_name, task_name, _ = env_id.split("-")

    # load OSC controller to use for all environments
    controller = load_controller_config(default_controller="OSC_POSE")

    # these arguments are the same for all envs
    config = {
        "controller_configs": controller,
        "horizon": 500,
        "control_freq": 20,
        "reward_shaping": True,
        "reward_scale": 1.0,
        "use_camera_obs": False,
        "ignore_done": False,
        "hard_reset": False,
    }

    # this should be used during training to speed up training
    # A renderer should be used if you're visualizing rollouts!
    config["has_offscreen_renderer"] = False

    # Block Lifting
    env = suite.make(
        env_name=task_name,
        robots=robot_name,
        **config,
    )
    return GymWrapper(env, keys)


# def make_vec_env(
#     env_id: gym.Env,
#     n_envs: int = 1,
#     seed: Optional[int] = None,
#     start_index: int = 0,
#     vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
#     vec_env_kwargs: Optional[Dict[str, Any]] = None,
# ) -> VecEnv:
#     """
#     Create a wrapped, monitored ``VecEnv``.
#     By default it uses a ``DummyVecEnv`` which is usually faster
#     than a ``SubprocVecEnv``.
#
#     :param env_id: the environment ID or the environment class
#     :param n_envs: the number of environments you wish to have in parallel
#     :param seed: the initial seed for the random number generator
#     :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
#     :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
#     :return: The wrapped environment
#     """
#
#     vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
#
#     def make_env(rank):
#         def _init():
#             env = env_id
#             if seed is not None:
#                 env.seed(seed + rank)
#                 env.action_space.seed(seed + rank)
#             return env
#
#         return _init
#
#     if vec_env_cls is None:
#         vec_env_cls = DummyVecEnv
#
#     return vec_env_cls(
#         [make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs
#     )
