from .q_learning import (
    QLearning,
    Cfg as QLearningCfg,
    refill_cfg as QLearning_refill_cfg,
)
from .q_learningM import (
    QLearningM,
    Cfg as QLearningMCfg,
    refill_cfg as QLearning_m_refill_cfg,
)
from .xl.xl import refill_cfg as XL_refill_cfg
from .sac import SAC, Cfg as SACCfg, refill_cfg as SAC_refill_cfg
from algos.wrappers import VmapWrapper as AlgoVmapWrapper
from utils.jax_util import extract_in_out_trees, build_merge_trees
import jax.random as jrnd
import jax.numpy as jnp
from algos.ad.ad import AD, Cfg as ADCfg, refill_cfg as AD_refill_cfg


def get_algo_cfg(name: str, cfg, dcfg):
    match name:
        case "q_learning":
            cfg = QLearningCfg(
                action_dim=cfg.env.action_size,
                grid_size=cfg.env.grid_size,
                total_steps=dcfg["total_training_times"],
                alpha=cfg.algo.alpha,
                gamma=cfg.algo.gamma,
            )
        case "q_learning_m":
            cfg = QLearningMCfg(
                action_dim=cfg.env.action_size,
                grid_size=cfg.env.grid_size,
                total_steps=dcfg["total_training_times"],
                alpha=cfg.algo.alpha,
                gamma=cfg.algo.gamma,
            )
        case "sac":
            cfg = SACCfg(
                state_dim=dcfg["state_dim"],
                act_dim=dcfg["act_dim"],
                lr=cfg.algo.lr,
                gamma=cfg.algo.gamma,
                tau=cfg.algo.tau,
            )
        case "ad" | "mem" | "xl" | "ed":
            cfg = ADCfg(
                token_dim=cfg.algo.token_dim,
                hidden_dim=cfg.algo.hidden_dim,
                n_layers=cfg.algo.layers,
                device=cfg.device,
                ctx_len=cfg.algo.ctx_len,
                steps_trained=dcfg["steps_trained"],
                lr=cfg.algo.lr,
                num_train_steps=cfg.training_steps,
                clip_grad_norm=cfg.algo.clip_grad_norm,
                num_envs=cfg.num_eval_envs,
            )
        case _:
            raise ValueError(f"unsupported algo: {name}")

    print(f"refilled algo cfg: {cfg}")

    return cfg


def make_algo(name: str, cfg, dcfg, key):
    match name:
        case "q_learning":
            _cls = QLearning
        case "q_learning_m":
            _cls = QLearningM
        case "sac":
            _cls = SAC
        case "ad" | "mem" | "xl" | "ed":
            _cls = AD
        case _:
            raise ValueError(f"unsupported algo: {name}")

    algo_cfg = get_algo_cfg(cfg.algo.name, cfg, dcfg)

    if name in ["ad", "xl", "ed", "mem"]:
        algo = _cls(algo_cfg, dcfg, cfg)

        algo.model = algo.model.to(cfg.device)
        return algo, None
    else:
        algo = _cls(algo_cfg)

    _algo_tree, _aot = extract_in_out_trees(algo.gen_tree())
    algo = AlgoVmapWrapper(
        algo,
        dict(
            make_action=(0, _algo_tree, 0),
            update=0,
        ),
        _aot,
        nums=cfg.num_train_envs,
    )
    algo_key = jrnd.split(key, cfg.num_train_envs)
    algo_state = algo.make(jnp.array(algo_key))

    return algo, algo_state


def get_algo_refill_cfg(name: str):

    match name:
        case "q_learning":
            return QLearning_refill_cfg
        case "q_learning_m":
            return QLearning_m_refill_cfg
        case "sac":
            return SAC_refill_cfg
        case "ad":
            return AD_refill_cfg
        case "xl" | "ed" | "mem":
            return XL_refill_cfg
        case _:
            raise ValueError(f"unsupported algo: {name}")
