import jax
import jax.numpy as jnp
import jax.random as jrnd
import flax.struct as fstruct
import chex
from typing import Literal, Dict, Any, Tuple, cast
from utils.type import EnvStep
from dataclasses import dataclass
from omegaconf import DictConfig


@dataclass
class Cfg:
    action_dim: int
    grid_size: int
    total_steps: int
    alpha: float
    gamma: float


@fstruct.dataclass
class QLearningState:
    q_tables: jnp.ndarray
    update_times: jnp.ndarray


class QLearning:
    def __init__(
        self,
        cfg: Cfg,
    ) -> None:
        self.n_acts = cfg.action_dim
        self.n_x = cfg.grid_size
        self.n_y = cfg.grid_size
        self.total_steps = cfg.total_steps
        self.alpha = cfg.alpha
        self.gamma = cfg.gamma

    def gen_tree(self):
        return jnp.array(0)

    def make(self, keys):
        return QLearningState(
            q_tables=jrnd.uniform(keys, (self.n_x, self.n_y, self.n_acts)),
            # q_tables=jnp.zeros((self.n_x, self.n_y, self.n_acts)),
            update_times=jnp.zeros(()),
        )

    def make_action(self, key, algo_state: QLearningState, obs: jnp.ndarray):

        key, exp_key, rand_act_key = jrnd.split(key, 3)

        _rand = jrnd.uniform(exp_key, ())
        rand_act = jrnd.randint(rand_act_key, (), minval=0, maxval=self.n_acts)

        q_tab = algo_state.q_tables
        max_act = q_tab[tuple(map(lambda x: x, obs))].argmax(-1)

        _epsilon = jnp.maximum(
            1.0 - algo_state.update_times / self.total_steps,
            0,
        )

        act = jax.lax.cond(_rand <= _epsilon, lambda: rand_act, lambda: max_act)

        return act

    def update(self, keys: jax.Array, algo_state: QLearningState, transition: EnvStep):

        q_tab = algo_state.q_tables
        state, act, rews, terms, next_state = (
            transition.obs,
            transition.acts,
            transition.rwds,
            transition.terms,
            transition.next_obs,
        )

        state_idx = tuple(map(lambda x: x, state))
        next_state_idx = tuple(map(lambda x: x, next_state))

        q_tab_state_val = q_tab[state_idx + (act,)]
        q_tab_next_state_val = q_tab[next_state_idx + (slice(None),)].max(-1)
        # Wrong:
        # q_tab_state_val = q_tab[state_idx, act]
        # q_tab_next_state_val = q_tab[next_state_idx, :].max(-1)
        chex.assert_size((q_tab_state_val, q_tab_next_state_val), 1)

        q_tab = q_tab.at[state_idx + (act,)].add(
            self.alpha
            * (rews + self.gamma * (1 - terms) * q_tab_next_state_val - q_tab_state_val)
        )
        q_tab = q_tab.at[next_state_idx + (slice(None),)].mul(
            jax.lax.cond(terms, lambda: 0, lambda: 1)
        )

        return algo_state.replace(
            q_tables=q_tab, update_times=algo_state.update_times + 1
        )


def refill_cfg(cfg: DictConfig) -> Tuple[DictConfig, Dict[str, Any]]:
    from omegaconf.omegaconf import OmegaConf

    dcfg = {}

    match [
        cast(Literal["normal", "quick", "dense"], cfg.env.kind),
        cast(Literal["normal", "large"], cfg.env.size_kind),
    ]:
        case ["normal", "normal"]:
            if OmegaConf.is_missing(cfg, "episodes_to_train"):
                cfg.episodes_to_train = 125
        case ["normal", "large"]:
            if OmegaConf.is_missing(cfg, "episodes_to_train"):
                cfg.episodes_to_train = 150
        case ["quick", "large"]:
            if OmegaConf.is_missing(cfg, "episodes_to_train"):
                cfg.episodes_to_train = 250
        case ["dense", "large"]:
            if OmegaConf.is_missing(cfg, "episodes_to_train"):
                cfg.episodes_to_train = 175
        case _:
            raise ValueError(
                f"unsupport env config: {cfg.env.kind}, {cfg.env.size_kind}"
            )

    dcfg["total_training_times"] = cfg.episodes_to_train * cfg.env.episode_length
    assert dcfg["total_training_times"] % cfg.split_nums == 0

    dcfg["frames_per_split"] = int(dcfg["total_training_times"] / cfg.split_nums)

    return cfg, dcfg
