from utils.reporter import get_reporter, get_reporter_dir
from os import path
import jax
import jax.random as jrnd
import jax.numpy as jnp
from envs import make_env
from algos import make_algo
from typing import cast, Dict, Any, Optional, Literal
from utils.jax_util import extract_in_out_trees, build_merge_trees
from utils.saver import saver
import numpy as np
import flax.struct as fstruct
import torch
from pathlib import Path
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
import logging
from utils.progress import tqdm
from utils.type import EnvStep


@fstruct.dataclass
class HistoryState:
    obs: jnp.ndarray
    acts: jnp.ndarray
    rwds: jnp.ndarray
    terms: jnp.ndarray
    truncs: jnp.ndarray
    next_obs: jnp.ndarray
    added_length: jnp.ndarray
    maxlength: jnp.ndarray
    pointer: jnp.ndarray


class History:
    def __init__(
        self,
        buffer_len: int,
        state_dim: int,
        act_dim: Optional[int],
        kind: Literal["discrete", "continuous"],
        batch_size: Optional[int],
    ):
        self.buffer_len = buffer_len
        self.state_dim = state_dim
        self.act_dim = act_dim
        self.kind = kind
        self.batch_size = batch_size

    def gen_tree(self):
        return HistoryState(
            obs=jnp.array(0),
            acts=jnp.array(0),
            rwds=jnp.array(0),
            terms=jnp.array(0),
            truncs=jnp.array(0),
            next_obs=jnp.array(0),
            maxlength=jnp.array(jnp.nan),
            added_length=jnp.array(jnp.nan),
            pointer=jnp.array(jnp.nan),
        )

    def make(self):
        acts = (
            jnp.zeros(
                (self.buffer_len,),
                dtype=jnp.int32 if self.kind == "discrete" else jnp.float32,
            )
            if self.act_dim is None
            else jnp.zeros(
                (self.buffer_len, self.act_dim),
                dtype=jnp.int32 if self.kind == "discrete" else jnp.float32,
            )
        )
        rwds = jnp.zeros((self.buffer_len,), dtype=jnp.float32)
        return HistoryState(
            obs=jnp.zeros(
                (self.buffer_len, self.state_dim),
                dtype=jnp.int32 if self.kind == "discrete" else jnp.float32,
            ),
            acts=acts,
            rwds=rwds,
            terms=jnp.zeros((self.buffer_len,), dtype=jnp.int32),
            truncs=jnp.zeros((self.buffer_len,), dtype=jnp.int32),
            next_obs=jnp.zeros(
                (self.buffer_len, self.state_dim),
                dtype=jnp.int32 if self.kind == "discrete" else jnp.float32,
            ),
            maxlength=jnp.array(self.buffer_len),
            added_length=jnp.array(0),
            pointer=jnp.array(0),
        )

    def enroll(self, state: HistoryState, step: EnvStep):
        _obs = state.obs.at[state.pointer].set(step.obs)
        _acts = state.acts.at[state.pointer].set(step.acts)
        _rwds = state.rwds.at[state.pointer].set(step.rwds)
        _terms = state.terms.at[state.pointer].set(step.terms)
        _truncs = state.truncs.at[state.pointer].set(step.truncs)
        _next_obs = state.next_obs.at[state.pointer].set(step.next_obs)

        p = state.pointer + 1
        p = jax.lax.cond(p == state.maxlength, lambda: jnp.array(0), lambda: p)

        return state.replace(
            obs=_obs,
            acts=_acts,
            rwds=_rwds,
            terms=_terms,
            truncs=_truncs,
            next_obs=_next_obs,
            pointer=p,
            added_length=jnp.minimum(state.maxlength, state.added_length + 1),
        )

    def sample(self, keys: jax.Array, state: HistoryState) -> EnvStep:
        idx = jrnd.randint(
            keys,
            (self.batch_size,),
            minval=0,
            maxval=state.added_length,
        )

        return EnvStep(
            obs=state.obs[idx],
            acts=state.acts[idx],
            rwds=state.rwds[idx],
            terms=state.terms[idx],
            truncs=state.truncs[idx],
            next_obs=state.next_obs[idx],
        )


def collect_handle(cfg: DictConfig, dcfg: Dict[str, Any]):

    output_dir = HydraConfig.get().runtime.output_dir
    MODEL_DIR = path.abspath(f"{output_dir}/models/")

    key = jax.random.PRNGKey(cfg.seed)

    if cfg.prog_kind == "collect":
        eval_env_states = {}
        for k in ["eval"]:
            key, env_key = jrnd.split(key)
            _env, _env_state, _ = make_env(cfg.env.name, cfg, dcfg, "eval", env_key)

            eval_env_states[k] = _env_state.sys

        saver.save(f"{MODEL_DIR}/eval_env_sys", eval_env_states)
        del eval_env_states

    key, env_key = jrnd.split(key)
    env, train_env_state, _env_tree = make_env(
        cfg.env.name, cfg, dcfg, "train", env_key
    )
    # train_env_state = env.make(jnp.array(env_key))

    if cfg.prog_kind == "eval":
        load_env_sys = saver.restore(
            cfg.env_load_path,
            item={k: train_env_state.sys for k in ["eval"]},
        )
        train_env_state = train_env_state.replace(sys=load_env_sys["eval"])
        print("load saved sys")
        print(train_env_state.sys)
        key, *reset_keys = jrnd.split(key, 1 + cfg.num_train_envs)
        train_env_state = env.reset(jnp.array(reset_keys), train_env_state)

    key, algo_key = jrnd.split(key)
    algo, algo_state = make_algo(
        cfg.algo.name,
        cfg,
        {
            **dcfg,
            "state_dim": env.unwrapped.state_dim,
            "act_dim": env.unwrapped.act_dim,
        },
        algo_key,
    )

    merge_trees = jax.jit(build_merge_trees(_env_tree))

    hb = History(
        dcfg["total_training_times"],
        env.unwrapped.state_dim,
        env.unwrapped.act_dim,
        kind="continuous" if cfg.env.name in ["cartpole", "reacher"] else "discrete",
        batch_size=OmegaConf.select(cfg.algo, "batch_size"),
    )
    _h_tree, _hot = extract_in_out_trees(hb.gen_tree())
    hb.make = jax.vmap(hb.make, out_axes=_hot, axis_size=cfg.num_train_envs)
    hb.enroll = jax.vmap(hb.enroll, in_axes=(_h_tree, 0), out_axes=_h_tree)
    hb.sample = jax.vmap(hb.sample, in_axes=(0, _h_tree))

    history_state = hb.make()

    def step_env_once(carry, t):
        env_state, algo_state, history_state, key = carry

        old_env_obs = env_state.obs

        if cfg.env.name in ["cartpole", "reacher"]:
            key, *act_keys = jrnd.split(key, 1 + cfg.num_train_envs)
            key, rnd_act_key = jrnd.split(key)
            acts = jax.lax.cond(
                t >= cfg.algo.training_start_timesteps,
                lambda: algo.make_action(
                    jnp.array(act_keys),
                    algo_state,
                    env_state.obs,
                ),
                lambda: jrnd.uniform(
                    rnd_act_key, (cfg.num_train_envs, env_state.acts.shape[-1])
                ),
            )
        else:
            key, *act_key = jrnd.split(key, 1 + cfg.num_train_envs)
            acts = algo.make_action(jnp.array(act_key), algo_state, env_state.obs)

        env_state = env.step(env_state, acts)

        _env_step = EnvStep(
            obs=old_env_obs,
            acts=acts,
            rwds=env_state.rwds,
            terms=env_state.terms,
            truncs=env_state.truncs,
            next_obs=env_state.obs,
        )
        history_state = hb.enroll(
            history_state,
            _env_step,
        )

        if cfg.env.name in ["cartpole", "reacher"]:
            key, *algo_update_key = jrnd.split(key, 1 + cfg.num_train_envs)
            key, *sample_key = jrnd.split(key, 1 + cfg.num_train_envs)
            sample_key = jnp.array(sample_key)
            algo_state = jax.lax.cond(
                t >= cfg.algo.training_start_timesteps,
                lambda: algo.update(
                    jnp.array(algo_update_key),
                    algo_state,
                    hb.sample(sample_key, history_state),
                ),
                lambda: algo_state,
            )
        else:
            key, *algo_update_key = jrnd.split(key, 1 + cfg.num_train_envs)
            algo_state = algo.update(
                jnp.array(algo_update_key),
                algo_state,
                _env_step,
            )

        def reset_env(key, env_state):
            key, *reset_keys = jrnd.split(key, 1 + cfg.num_train_envs)
            _reset_env_state = env.reset(jnp.array(reset_keys), env_state)
            _need_reset = jnp.logical_or(
                env_state.terms,
                env_state.truncs,
            )
            _reset_tree = jax.tree_map(
                lambda x: jnp.array(0) if x is None else _need_reset,
                _env_tree,
                is_leaf=lambda x: x is None,
            )
            env_state = merge_trees(
                _reset_tree,
                env_state,
                _reset_env_state,
            )
            return env_state

        key, reset_key = jrnd.split(key)
        env_state = jax.lax.cond(
            jnp.any(jnp.logical_or(env_state.terms, env_state.truncs)),
            lambda c: reset_env(reset_key, c),
            lambda c: c,
            env_state,
        )

        return (env_state, algo_state, history_state, key), None

    def eval_one_step(carry):
        ended, returns, algo_state, key, env_state = carry
        key, *act_key = jrnd.split(key, 1 + cfg.num_train_envs)
        acts = algo.make_action(jnp.array(act_key), algo_state, env_state.obs)

        env_state = env.step(env_state, acts)
        returns = returns + env_state.rwds * (1 - ended)

        _need_reset = jnp.logical_or(
            env_state.terms,
            env_state.truncs,
        )

        ended = jnp.logical_or(ended, _need_reset)

        return ended, returns, algo_state, key, env_state

    @jax.jit
    def eval_all(key, env_state, algo_state):
        # key, reset_key = jrnd.split(key)
        key, *reset_keys = jrnd.split(key, 1 + cfg.num_train_envs)
        _, returns, *_ = jax.lax.while_loop(
            lambda carry: jnp.logical_not(jnp.all(carry[0])),
            eval_one_step,
            (
                jnp.zeros((cfg.num_train_envs,), dtype=jnp.bool_),
                jnp.zeros((cfg.num_train_envs,)),
                algo_state,
                key,
                env.reset(jnp.array(reset_keys), env_state),
            ),
        )

        return returns

    key, eval_key = jrnd.split(key)
    returns = jax.device_get(eval_all(eval_key, train_env_state, algo_state))
    get_reporter().add_distributions(dict(returns=np.array(returns)), "eval")
    get_reporter().add_scalars(
        dict(returns_mean=returns.mean().item(), returns_std=returns.std().item()),
        "eval",
    )

    saver.save(f"{MODEL_DIR}/env_sys", train_env_state.sys)
    if cfg.save_algo:
        saver.save(f"{MODEL_DIR}/algo_state/0", algo_state)

    for i in tqdm(range(1, 1 + cfg.split_nums), desc="total training"):

        # training loop
        key, train_key = jrnd.split(key)
        (train_env_state, algo_state, history_state, _), _ = jax.lax.scan(
            step_env_once,
            (train_env_state, algo_state, history_state, train_key),
            # (),
            jnp.arange(
                dcfg["frames_per_split"] * (i - 1),
                dcfg["frames_per_split"] * i,
            ),
            # dcfg["frames_per_split"],
        )

        # eval loop
        key, eval_key = jrnd.split(key)
        returns = jax.device_get(eval_all(eval_key, train_env_state, algo_state))
        get_reporter().add_distributions(dict(returns=np.array(returns)), "eval")
        get_reporter().add_scalars(
            dict(returns_mean=returns.mean().item(), returns_std=returns.std().item()),
            "eval",
        )
        if cfg.save_algo:
            saver.save(f"{MODEL_DIR}/algo_state/{i}", algo_state)

    history_state = cast(HistoryState, history_state)
    saver.save(f"{MODEL_DIR}/history", history_state)
    Path(f"{MODEL_DIR}/dataset/").mkdir(exist_ok=True, parents=True)
    torch.save(
        jax.tree_map(
            lambda n: np.array(n)[:, : history_state.added_length],
            (
                history_state.obs,
                history_state.acts,
                history_state.rwds,
                history_state.terms,
                history_state.truncs,
            ),
        ),
        f"{MODEL_DIR}/dataset/ds.pth",
        pickle_protocol=4,
    )
