import jax
import jax.numpy as jnp
import flax.struct as fstruct
import jax.random as jrnd
import chex
from utils.jax_util import GenericEnvState
import flax.core as fcore
from typing import Literal
from dataclasses import dataclass
from omegaconf import DictConfig
from typing import Dict, Tuple, Any, cast
from itertools import product
import math
import numpy as np


@dataclass
class Cfg:
    size: int
    all_candidates: jnp.ndarray
    variant: Literal["normal", "dense", "quick"]


@fstruct.dataclass
class Sys:
    key: jnp.ndarray
    door: jnp.ndarray
    pos: jnp.ndarray
    has_key: jnp.ndarray


EnvState = GenericEnvState[Sys]


class DarkKeyToDoor:

    def __init__(
        self,
        cfg: Cfg,
    ) -> None:
        self.size = cfg.size
        self.init_pos = jnp.array([self.size // 2, self.size // 2], dtype=jnp.int32)
        self.all_candidates = cfg.all_candidates
        self.kind: Literal["normal", "dense", "quick"] = cfg.variant
        assert self.kind in ["normal", "dense", "quick"]

    def gen_tree(self):
        return EnvState(
            sys=Sys(
                key=jnp.array(0),
                pos=jnp.array(0),
                door=jnp.array(0),
                has_key=jnp.array(0),
            ),
            obs=jnp.array(0),
            acts=jnp.array(0),
            rwds=jnp.array(0),
            terms=jnp.array(0),
            truncs=jnp.array(0),
            infos=fcore.freeze(dict()),
        )

    def make(self, rng_key: jnp.ndarray):
        rng_key, sample_rnd = jrnd.split(rng_key)

        door, _key = jrnd.choice(sample_rnd, self.all_candidates, (2,), replace=False)
        # rng_key, sys_key = jrnd.split(rng_key)
        sys = Sys(
            key=_key,
            door=door,
            has_key=jnp.array(0),
            pos=self.init_pos,
        )
        _state = EnvState(
            sys=sys,
            obs=jnp.zeros_like(self._get_obs(sys), dtype=jnp.int32),
            acts=jnp.array(0),
            rwds=jnp.zeros(()),
            terms=jnp.array(0),
            truncs=jnp.array(0),
            infos=fcore.freeze(dict()),
        )

        return self.reset(rng_key, _state)

    def reset(self, keys: jax.Array, env_state: EnvState):

        # _new_pos = jrnd.uniform(key, (2,), minval=0, maxval=self.size)
        _new_sys = env_state.sys.replace(pos=self.init_pos, has_key=jnp.array(0))
        return env_state.replace(
            sys=_new_sys,
            obs=self._get_obs(_new_sys),
            terms=jnp.zeros_like(env_state.terms),
            truncs=jnp.zeros_like(env_state.truncs),
        )

    def step(self, env_state: EnvState, act: jnp.ndarray):
        """
        Acts: 0: noop, 1: left, 2: up, 3: right, 4: down
        """

        _pos = env_state.sys.pos

        chex.assert_size(act, 1)

        _delta = jnp.zeros_like(_pos)
        _delta = jax.lax.cond(
            act == 1, lambda: jnp.array([-1, 0], dtype=jnp.int32), lambda: _delta
        )
        _delta = jax.lax.cond(
            act == 2, lambda: jnp.array([0, 1], dtype=jnp.int32), lambda: _delta
        )
        _delta = jax.lax.cond(
            act == 3, lambda: jnp.array([1, 0], dtype=jnp.int32), lambda: _delta
        )
        _delta = jax.lax.cond(
            act == 4, lambda: jnp.array([0, -1], dtype=jnp.int32), lambda: _delta
        )
        chex.assert_shape(_delta, (2,))

        _new_pos = _pos + _delta
        _new_pos = jnp.clip(_new_pos, 0, self.size - 1)

        _door = env_state.sys.door
        _key = env_state.sys.key
        _has_key = env_state.sys.has_key

        if self.kind == "normal":
            rwd, term, _new_has_key = jnp.zeros(()), jnp.array(0), _has_key
            chex.assert_equal_shape((_new_pos, _door, _key))
            rwd, term, _new_has_key = jax.lax.cond(
                jnp.logical_and(jnp.all(_new_pos == _key), _has_key == 0),
                lambda: (jnp.ones(()), jnp.array(0), jnp.array(1)),
                lambda: jax.lax.cond(
                    jnp.logical_and(jnp.all(_new_pos == _door), _has_key == 1),
                    lambda: (jnp.ones(()), jnp.array(1), jnp.array(1)),
                    lambda: (rwd, term, _new_has_key),
                ),
            )
        elif self.kind == "dense":
            rwd, term, _new_has_key = (
                jax.lax.cond(
                    _has_key,
                    lambda: -jnp.abs(_new_pos - _door).sum().astype(jnp.float32),
                    lambda: -jnp.abs(_new_pos - _key).sum().astype(jnp.float32),
                ),
                jnp.array(0),
                _has_key,
            )
            chex.assert_equal_shape((_new_pos, _door, _key))
            rwd, term, _new_has_key = jax.lax.cond(
                jnp.logical_and(jnp.all(_new_pos == _key), _has_key == 0),
                lambda: (rwd, jnp.array(0), jnp.array(1)),
                lambda: jax.lax.cond(
                    jnp.logical_and(jnp.all(_new_pos == _door), _has_key == 1),
                    lambda: (rwd, jnp.array(1), jnp.array(1)),
                    lambda: (rwd, term, _new_has_key),
                ),
            )
        else:
            rwd, term, _new_has_key = (
                jnp.array(-1, dtype=jnp.float32),
                jnp.array(0),
                _has_key,
            )
            rwd, term, _new_has_key = jax.lax.cond(
                jnp.logical_and(jnp.all(_new_pos == _key), _has_key == 0),
                lambda: (rwd, jnp.array(0), jnp.array(1)),
                lambda: jax.lax.cond(
                    jnp.logical_and(jnp.all(_new_pos == _door), _has_key == 1),
                    lambda: (rwd, jnp.array(1), jnp.array(1)),
                    lambda: (rwd, term, _new_has_key),
                ),
            )

        _new_sys = env_state.sys.replace(pos=_new_pos, has_key=_new_has_key)

        return env_state.replace(
            sys=_new_sys,
            obs=self._get_obs(_new_sys),
            acts=act,
            rwds=rwd,
            terms=term,
        )

    def _get_obs(self, sys: Sys):
        return jnp.concatenate((sys.pos, sys.has_key[jnp.newaxis]), axis=0)

    @property
    def state_dim(self):
        return 3

    @property
    def act_dim(self):
        return None


def refill_cfg(cfg: DictConfig) -> Tuple[DictConfig, Dict[str, Any]]:
    from omegaconf.omegaconf import OmegaConf

    match [
        cast(Literal["normal", "quick", "dense"], cfg.env.kind),
        cast(Literal["normal", "large"], cfg.env.size_kind),
    ]:
        case ["normal", "normal"]:
            if OmegaConf.is_missing(cfg.env, "grid_size"):
                cfg.env.grid_size = 9
            if OmegaConf.is_missing(cfg.env, "episode_length"):
                cfg.env.episode_length = 50
        case ["normal", "large"]:
            if OmegaConf.is_missing(cfg.env, "grid_size"):
                cfg.env.grid_size = 11
            if OmegaConf.is_missing(cfg.env, "episode_length"):
                cfg.env.episode_length = 70
        case ["quick", "large"]:
            if OmegaConf.is_missing(cfg.env, "grid_size"):
                cfg.env.grid_size = 9
            if OmegaConf.is_missing(cfg.env, "episode_length"):
                cfg.env.episode_length = 50
        case ["dense", "large"]:
            if OmegaConf.is_missing(cfg.env, "grid_size"):
                cfg.env.grid_size = 11
            if OmegaConf.is_missing(cfg.env, "episode_length"):
                cfg.env.episode_length = 70
        case _:
            raise ValueError(
                f"unsupport env config: {cfg.env.kind}, {cfg.env.size_kind}"
            )

    dcfg = {}

    dcfg["all_candidates"] = set(
        product(list(range(cfg.env.grid_size)), list(range(cfg.env.grid_size)))
    ) - set([(cfg.env.grid_size // 2, cfg.env.grid_size // 2)])

    dcfg["train_candidates"] = set()
    _all_candidates = list(dcfg["all_candidates"])
    _train_idx = np.random.choice(
        len(dcfg["all_candidates"]),
        math.floor(
            len(dcfg["all_candidates"]) * 0.8,
        ),
        replace=False,
    ).tolist()
    _all_candidates = list(dcfg["all_candidates"])
    for _idx in _train_idx:
        dcfg["train_candidates"].add(_all_candidates[_idx])
    dcfg["eval_candidates"] = dcfg["all_candidates"] - dcfg["train_candidates"]

    assert dcfg["train_candidates"].intersection(dcfg["eval_candidates"]) == set()

    assert (
        dcfg["all_candidates"] != set()
        and dcfg["train_candidates"] != set()
        and dcfg["eval_candidates"] != set()
    )

    dcfg["obs_is_concrete"] = True
    dcfg["obs_nums"] = cfg.env.grid_size**2
    dcfg["act_is_concrete"] = True
    dcfg["act_nums"] = cfg.env.action_size
    dcfg["rwd_is_concrete"] = True
    match cfg.env.kind:
        case "normal":
            dcfg["rwd_nums"] = 2
        case "quick":
            dcfg["rwd_nums"] = 1
        case "dense":
            dcfg["rwd_nums"] = 2 * (cfg.env.grid_size - 1) + 1
        case _:
            raise ValueError(f"unsupported env.kind: {cfg.env.kind}")

    return (cfg, dcfg)
