from typing import Callable

from medium_rl.alg.ppo import make_ppo_loss_fn, make_ppo_train_step_fn
from medium_rl.alg.sac import make_sac_loss_fn, make_sac_train_step_fn
from medium_rl.alg.tgm import make_tgm_loss_fn, make_tgm_train_step_fn
from medium_rl.config import Config
from medium_rl.envs.sequence_env import SequenceEnv


def make_train_step_fn(env: SequenceEnv, buffer, forward: Callable, policy_fn: Callable, optimizer, cfg: Config):
    if cfg.alg.name == "TGM":
        loss_fn = make_tgm_loss_fn(cfg)
        return make_tgm_train_step_fn(env, buffer, forward, policy_fn, loss_fn, optimizer, cfg)
    elif cfg.alg.name == "SAC":
        loss_fn = make_sac_loss_fn(cfg)
        return make_sac_train_step_fn(env, buffer, forward, policy_fn, loss_fn, optimizer, cfg)
    elif cfg.alg.name == "PPO":
        loss_fn = make_ppo_loss_fn(cfg)
        return make_ppo_train_step_fn(env, buffer, forward, policy_fn, loss_fn, optimizer, cfg)
