from .dark_room import DarkRoom, Cfg as DRCfg, refill_cfg as DR_refill_cfg
from .dark_key_to_door import (
    DarkKeyToDoor,
    Cfg as DKTDCfg,
    refill_cfg as DKTD_refill_cfg,
)
from typing import Literal
import jax.numpy as jnp
from .idp import InvertedPendulum, refill_cfg as CP_refill_cfg
from .reacher import Reacher, refill_cfg as R_refill_cfg, Cfg as ReacherCfg
from envs.wrappers import VmapWrapper, TimeoutWrapper
from utils.jax_util import extract_in_out_trees, build_merge_trees
import jax.random as jrnd
import jax


def get_env_cfg(name: str, env_kind: Literal["eval", "train"], cfg, dcfg):
    # env_cfg = {}

    match name:
        case "dark_room":
            goals = dcfg["eval_goals"] if env_kind == "eval" else dcfg["train_goals"]
            cfg = DRCfg(
                size=cfg.env.grid_size,
                goal_candidates=(jnp.array([g for g in goals])),
                kind=cfg.env.kind,
            )
        case "dark_key_to_door":
            goals = (
                dcfg["eval_candidates"]
                if env_kind == "eval"
                else dcfg["train_candidates"]
            )
            cfg = DKTDCfg(
                size=cfg.env.grid_size,
                all_candidates=(jnp.array([g for g in goals])),
                variant=cfg.env.kind,
            )
        case "cartpole":
            return None
        case "reacher":
            cfg = ReacherCfg(variant=cfg.env.kind)
        case _:
            raise ValueError(f"unsupported env: {name}")

    print(f"refilled env cfg: {cfg}")

    return cfg


def get_env_refill_cfg(name: str):

    match name:
        case "dark_room":
            return DR_refill_cfg
        case "dark_key_to_door":
            return DKTD_refill_cfg
        case "cartpole":
            return CP_refill_cfg
        case "reacher":
            return R_refill_cfg
        case _:
            raise ValueError(f"unsupported env: {name}")


def make_env(name: str, cfg, dcfg, kind: Literal["eval", "train"], key):
    match name:
        case "dark_room":
            _cls = DarkRoom
        case "dark_key_to_door":
            _cls = DarkKeyToDoor
        case "cartpole":
            _cls = lambda _: InvertedPendulum()
        case "reacher":
            _cls = Reacher
        case _:
            raise ValueError(f"unsupported env: {name}")

    env_cfg = get_env_cfg(name, kind, cfg, dcfg)

    _env = _cls(env_cfg)

    _env = TimeoutWrapper(
        # make_env(cfg.env.name, cfg, dcfg, "eval"),
        _env,
        cfg.env.episode_length,
    )
    _env_tree, _eot = extract_in_out_trees(_env.gen_tree())
    _env = VmapWrapper(_env, _env_tree, _eot)
    _env.step = jax.jit(_env.step)
    _env.reset = jax.jit(_env.reset)

    key, *env_keys = jrnd.split(
        key, 1 + (cfg.num_eval_envs if kind == "eval" else cfg.num_train_envs)
    )
    key, sample_key = jrnd.split(key)
    if name in ["cartpole", "reacher"]:
        extra_keys = (
            jnp.ones((cfg.num_train_envs, 2)) * jnp.array([0.9, 1.1])
            if kind == "train"
            else jnp.where(
                jnp.tile(
                    jrnd.randint(sample_key, (cfg.num_eval_envs, 1), 0, 2),
                    (1, 2),
                ),
                jnp.ones((cfg.num_eval_envs, 2)) * jnp.array([0.80, 0.85]),
                jnp.ones((cfg.num_eval_envs, 2)) * jnp.array([1.15, 1.2]),
            )
        )
        _env_state = _env.make(jnp.array(env_keys), extra_keys)
    else:
        _env_state = _env.make(jnp.array(env_keys))

    return _env, _env_state, _env_tree
