import jax
import jax.numpy as jnp
import jax.random as jrnd
import flax.struct as fstruct
import chex
from typing import Literal, Dict, Any, Tuple, cast
from dataclasses import dataclass
from utils.type import EnvStep
from omegaconf import DictConfig


@dataclass
class Cfg:
    action_dim: int
    grid_size: int
    total_steps: int
    alpha: float
    gamma: float


@fstruct.dataclass
class QLearningState:
    q_tables: jnp.ndarray
    update_times: jnp.ndarray


class QLearningM:
    def __init__(
        self,
        cfg: Cfg,
    ) -> None:
        self.n_acts = cfg.action_dim
        self.n_x = cfg.grid_size
        self.n_y = cfg.grid_size
        self.total_steps = cfg.total_steps
        self.alpha = cfg.alpha
        self.gamma = cfg.gamma

    def gen_tree(self):
        return jnp.array(0)

    def make(self, keys):
        return QLearningState(
            q_tables=jrnd.uniform(keys, (self.n_x, self.n_y, 2, self.n_acts)),
            # q_tables=jnp.zeros((self.n_x, self.n_y, self.n_acts)),
            update_times=jnp.zeros(()),
        )

    def make_action(self, key, algo_state: QLearningState, obs: jnp.ndarray):

        key, exp_key, rand_act_key = jrnd.split(key, 3)

        _rand = jrnd.uniform(exp_key, ())
        rand_act = jrnd.randint(rand_act_key, (), minval=0, maxval=self.n_acts)

        q_tab = algo_state.q_tables
        max_act = q_tab[tuple(map(lambda x: x, obs))]
        chex.assert_shape(max_act, (self.n_acts,))
        max_act = max_act.argmax(-1)

        _epsilon = jnp.maximum(
            1.0 - algo_state.update_times / self.total_steps,
            0,
        )

        act = jax.lax.cond(_rand <= _epsilon, lambda: rand_act, lambda: max_act)

        return act

    def update(self, keys: jax.Array, algo_state: QLearningState, transition: EnvStep):

        q_tab = algo_state.q_tables
        state, act, rews, terms, next_state = (
            transition.obs,
            transition.acts,
            transition.rwds,
            transition.terms,
            transition.next_obs,
        )

        state_idx = tuple(map(lambda x: x, state))
        next_state_idx = tuple(map(lambda x: x, next_state))

        q_tab_state_val = q_tab[state_idx + (act,)]
        q_tab_next_state_val = q_tab[next_state_idx].max(-1)
        # Wrong:
        # q_tab_state_val = q_tab[state_idx, act]
        # q_tab_next_state_val = q_tab[next_state_idx, :].max(-1)
        chex.assert_size((q_tab_state_val, q_tab_next_state_val), 1)

        q_tab = q_tab.at[state_idx + (act,)].add(
            self.alpha
            * (rews + self.gamma * (1 - terms) * q_tab_next_state_val - q_tab_state_val)
        )
        q_tab = q_tab.at[next_state_idx].multiply(
            jax.lax.cond(terms, lambda: 0, lambda: 1)
        )

        return algo_state.replace(
            q_tables=q_tab, update_times=algo_state.update_times + 1
        )


def refill_cfg(cfg: DictConfig) -> Tuple[DictConfig, Dict[str, Any]]:
    from omegaconf.omegaconf import OmegaConf

    match [
        cast(Literal["normal", "quick", "dense"], cfg.env.kind),
        cast(Literal["normal", "large"], cfg.env.size_kind),
    ]:
        case ["normal", "normal"]:
            if OmegaConf.is_missing(cfg, "episodes_to_train"):
                cfg.episodes_to_train = 150
        case ["normal", "large"]:
            if OmegaConf.is_missing(cfg, "episodes_to_train"):
                cfg.episodes_to_train = 200
        case ["quick", "large"]:
            if OmegaConf.is_missing(cfg, "episodes_to_train"):
                cfg.episodes_to_train = 400
        case ["dense", "large"]:
            if OmegaConf.is_missing(cfg, "episodes_to_train"):
                cfg.episodes_to_train = 350
        case _:
            raise ValueError(
                f"unsupport env config: {cfg.env.kind}, {cfg.env.size_kind}"
            )

    dcfg = {}

    dcfg["total_training_times"] = cfg.episodes_to_train * cfg.env.episode_length
    assert dcfg["total_training_times"] % cfg.split_nums == 0
    dcfg["frames_per_split"] = int(dcfg["total_training_times"] / cfg.split_nums)

    return cfg, dcfg
