from torch import Tensor
from torch._prims_common import DeviceLikeType

from args import NPGConfig, PPOConfig, QLearningConfig
from mdp.mdp_controller import (
    MDPImageQLearningFAController,
    MDPImageRandController,
    MDPNPGController,
    MDPOptimalController,
    MDPQLearningController,
    MDPTransformerController,
    PPOController,
)
from mdp.mdp_env import MDPController


def get_mdp_algs(
    n_envs: int,
    n_steps: int,
    n_steps_eval: int,
    n_states: int,
    state_dim: int,
    n_actions: int,
    optimal_actions: Tensor | None,
    dpt_policy: MDPTransformerController | None,
    dpt_frozen_policy: MDPTransformerController | None,
    *,
    device: DeviceLikeType | None = None
) -> dict[str, MDPController]:
    npg_config = NPGConfig()
    ppo_config = PPOConfig()
    ql_config = QLearningConfig()
    qlfa_config = QLearningConfig()
    return {
        **({"opt": MDPOptimalController(optimal_actions, n_envs, n_steps_eval, n_states, state_dim, n_actions)} if optimal_actions is not None else {}),
        **({"dpt": dpt_policy} if dpt_policy is not None else {}),
        **({"dpt_frozen": dpt_frozen_policy} if dpt_frozen_policy is not None else {}),
        "img_rand": MDPImageRandController(n_envs, n_steps, state_dim, n_states, n_actions, device=device),
        "npg": MDPNPGController(npg_config, n_envs, n_steps, n_states, state_dim, n_actions, sample=True, device=device),
        "ppo": PPOController(ppo_config, n_envs, n_steps, n_states, state_dim, n_actions, device=device),
        "ql": MDPQLearningController(ql_config, n_envs, n_steps, n_states, state_dim, n_actions, device=device),
        "qlfa": MDPImageQLearningFAController(qlfa_config, n_envs, n_steps, n_actions, device=device),
    }
