import jax
import jax.numpy as jnp
import flax.struct as fstruct
import jax.random as jrnd
import chex
from utils.jax_util import GenericEnvState
import flax.core as fcore
from typing import Literal
from dataclasses import dataclass
from omegaconf import DictConfig
from typing import Dict, Tuple, Any, cast
from itertools import product
import math
import numpy as np


@dataclass
class Cfg:
    size: int
    goal_candidates: jnp.ndarray
    kind: Literal["dense", "quick", "normal"]


@fstruct.dataclass
class Sys:
    goal: jnp.ndarray
    pos: jnp.ndarray


EnvState = GenericEnvState[Sys]


class DarkRoom:

    def __init__(
        self,
        cfg: Cfg,
    ) -> None:
        self.size = cfg.size
        self.init_pos = jnp.array([self.size // 2, self.size // 2], dtype=jnp.int32)
        self.goal_candidates = cfg.goal_candidates
        self.kind = cfg.kind
        assert self.kind in ["normal", "dense", "quick"]

    def gen_tree(self):
        return EnvState(
            sys=Sys(
                goal=jnp.array(0),
                pos=jnp.array(0),
            ),
            obs=jnp.array(0),
            acts=jnp.array(0),
            rwds=jnp.array(0),
            terms=jnp.array(0),
            truncs=jnp.array(0),
            infos=fcore.freeze(dict()),
        )

    def make(self, key: jnp.ndarray):
        key, sys_key = jrnd.split(key)
        sys = Sys(
            goal=jrnd.choice(sys_key, self.goal_candidates),
            pos=self.init_pos,
        )
        _state = EnvState(
            sys=sys,
            obs=jnp.zeros_like(self._get_obs(sys), dtype=jnp.int32),
            acts=jnp.array(0),
            rwds=jnp.zeros(()),
            terms=jnp.array(0),
            truncs=jnp.array(0),
            infos=fcore.freeze(dict()),
        )

        return self.reset(key, _state)

    def reset(self, keys: jax.Array, env_state: EnvState):

        _new_sys = env_state.sys.replace(pos=self.init_pos)
        return env_state.replace(
            sys=_new_sys,
            obs=self._get_obs(_new_sys),
            terms=jnp.zeros_like(env_state.terms),
            truncs=jnp.zeros_like(env_state.truncs),
        )

    def step(self, env_state: EnvState, act: jnp.ndarray):
        """
        Acts: 0: noop, 1: left, 2: up, 3: right, 4: down
        """

        _pos = env_state.sys.pos

        chex.assert_size(act, 1)

        _delta = jnp.zeros_like(_pos)
        _delta = jax.lax.cond(
            act == 1, lambda: jnp.array([-1, 0], dtype=jnp.int32), lambda: _delta
        )
        _delta = jax.lax.cond(
            act == 2, lambda: jnp.array([0, 1], dtype=jnp.int32), lambda: _delta
        )
        _delta = jax.lax.cond(
            act == 3, lambda: jnp.array([1, 0], dtype=jnp.int32), lambda: _delta
        )
        _delta = jax.lax.cond(
            act == 4, lambda: jnp.array([0, -1], dtype=jnp.int32), lambda: _delta
        )
        chex.assert_shape(_delta, (2,))
        # jax.debug.print("delta is: {_delta}", _delta=_delta)
        # jax.debug.breakpoint()

        _new_pos = _pos + _delta
        _new_pos = jnp.clip(_new_pos, 0, self.size - 1)

        _goal = env_state.sys.goal

        rwds, terms = self._calc_reward(_new_pos, _goal)
        _new_sys = env_state.sys.replace(pos=_new_pos)

        return env_state.replace(
            sys=_new_sys,
            obs=self._get_obs(_new_sys),
            acts=act,
            rwds=rwds,
            terms=terms,
        )

    def _calc_reward(self, new_pos: jnp.ndarray, goal: jnp.ndarray):
        if self.kind == "normal":
            chex.assert_equal_shape((new_pos, goal))
            rwds = jnp.where(jnp.all(goal == new_pos), jnp.array(1.0), jnp.array(0.0))
            return rwds, jax.lax.cond(
                jnp.all(goal == new_pos), lambda: jnp.ones(()), lambda: jnp.zeros(())
            )

        if self.kind == "quick":
            rwds = jnp.array(-1.0, dtype=jnp.float32)
            return rwds, jax.lax.cond(
                jnp.all(goal == new_pos), lambda: jnp.ones(()), lambda: jnp.zeros(())
            )

        if self.kind == "dense":
            rwds = -jnp.sum(jnp.abs(new_pos - goal)).astype(jnp.float32)
            return rwds, jax.lax.cond(
                jnp.all(goal == new_pos), lambda: jnp.ones(()), lambda: jnp.zeros(())
            )

    def _get_obs(self, sys: Sys):
        return sys.pos

    @property
    def state_dim(self):
        return 2

    @property
    def act_dim(self):
        return None


def refill_cfg(cfg: DictConfig) -> Tuple[DictConfig, Dict[str, Any]]:
    from omegaconf.omegaconf import OmegaConf

    match [
        cast(Literal["normal", "quick", "dense"], cfg.env.kind),
        cast(Literal["normal", "large"], cfg.env.size_kind),
    ]:
        case ["normal", "normal"]:
            if OmegaConf.is_missing(cfg.env, "grid_size"):
                cfg.env.grid_size = 9
            if OmegaConf.is_missing(cfg.env, "episode_length"):
                cfg.env.episode_length = 20
        case ["normal", "large"]:
            if OmegaConf.is_missing(cfg.env, "grid_size"):
                cfg.env.grid_size = 13
            if OmegaConf.is_missing(cfg.env, "episode_length"):
                cfg.env.episode_length = 50
        case ["quick", "large"]:
            if OmegaConf.is_missing(cfg.env, "grid_size"):
                cfg.env.grid_size = 13
            if OmegaConf.is_missing(cfg.env, "episode_length"):
                cfg.env.episode_length = 50
        case ["dense", "large"]:
            if OmegaConf.is_missing(cfg.env, "grid_size"):
                cfg.env.grid_size = 15
            if OmegaConf.is_missing(cfg.env, "episode_length"):
                cfg.env.episode_length = 50
        case _:
            raise ValueError(
                f"unsupport env config: {cfg.env.kind}, {cfg.env.size_kind}"
            )

    dcfg = {}

    dcfg["all_goals"] = set(
        product(list(range(cfg.env.grid_size)), list(range(cfg.env.grid_size)))
    ) - set([(cfg.env.grid_size // 2, cfg.env.grid_size // 2)])

    _train_idx = np.random.choice(
        len(dcfg["all_goals"]),
        math.floor(
            len(dcfg["all_goals"]) * 0.8,
        ),
        replace=False,
    ).tolist()

    dcfg["train_goals"] = set()
    _all_goals = list(dcfg["all_goals"])
    for _idx in _train_idx:
        dcfg["train_goals"].add(_all_goals[_idx])

    dcfg["eval_goals"] = dcfg["all_goals"] - dcfg["train_goals"]

    assert dcfg["train_goals"].intersection(dcfg["eval_goals"]) == set()
    assert dcfg["all_goals"] != set()
    assert dcfg["train_goals"] != set()
    assert dcfg["eval_goals"] != set()

    dcfg["obs_is_concrete"] = True
    dcfg["obs_nums"] = cfg.env.grid_size**2
    dcfg["act_is_concrete"] = True
    dcfg["act_nums"] = cfg.env.action_size
    dcfg["rwd_is_concrete"] = True
    match cfg.env.kind:
        case "normal":
            dcfg["rwd_nums"] = 2
        case "quick":
            dcfg["rwd_nums"] = 1
        case "dense":
            dcfg["rwd_nums"] = 2 * (cfg.env.grid_size - 1) + 1
        case _:
            raise ValueError(f"unsupported env.kind: {cfg.env.kind}")

    return cfg, dcfg
