from trainer.base_hparams import common_hparams, OffPolicyRLHyperParameterSpace, OnPolicyRLHyperParameterSpace


class DQNHyperParams(OffPolicyRLHyperParameterSpace):
    _supported_algos = ("dqn", "ddqn")
    _policy_hparams = {
        "lr": common_hparams["lr"],  # learning rate
        "n_step": common_hparams["n_step"],
        "target_update_freq": common_hparams["target_update_freq"],
        "is_double": [False, True],
        "use_dueling": False,
        "eps_test": common_hparams["eps_test"],
        "eps_train": common_hparams["eps_train"],
        "eps_train_final": common_hparams["eps_train_final"],
    }


class PPOHyperParams(OnPolicyRLHyperParameterSpace):
    _supported_algos = ("ppo",)
    _policy_hparams = {"lr": common_hparams["lr"],  # learning rate
                       "n_step": common_hparams["n_step"],
                       "start_timesteps": common_hparams["start_timesteps"],
                       "gae_lambda": 0.95,
                       "vf_coef": 0.5,
                       "ent_coef": 0.001,
                       "eps_clip": 0.1,
                       "value_clip": False,
                       "dual_clip": None,
                       "advantage_normalization": True,
                       "recompute_advantage": False, }


class SACHyperParams(OffPolicyRLHyperParameterSpace):
    _supported_algos = ("sac",)
    _policy_hparams = {
        "lr": common_hparams["lr"],  # learning rate
        "n_step": common_hparams["n_step"],
        "alpha": 0,
        "tau": common_hparams["tau"],
        "start_timesteps": common_hparams["start_timesteps"],
    }


class LLM_HyperParams(OffPolicyRLHyperParameterSpace):
    _supported_algos = ("llm",)
    _general_hparams = {
        # general parameters
        "seed": common_hparams["seed"],
    }
    # policy hyperparameter search space
    _policy_hparams = {
        "need_summary": [True, False],
        "need_meta_info": True,
        "num_try": 2,
        "llm_mode":
            {"llm": "Qwen2-7B-Instruct",
             "context_window": 32768},
    }


class LLM_Instruct_HyperParams(OffPolicyRLHyperParameterSpace):
    _supported_algos = ("llm-instruct",)
    _general_hparams = {
        # general parameters
        "seed": common_hparams["seed"],
        "batch_size": common_hparams["llm_batch_size"],
        "n_step": common_hparams["n_step"],
        "start_timesteps": common_hparams["start_timesteps"],
        "step_per_collect": common_hparams["onpolicy_step_per_collect"],
        "repeat_per_collect": 20,
    }
    # policy hyperparameter search space
    _policy_hparams = {
        "lr": 1e-4,
        "gae_lambda": 0.95,
        "ent_coef": 0.001,
        "kl_coef": 0.005,
        "eps_clip": 0.1,
        "group_num": 5,
        "need_meta_info": True,
        "num_try": 2,
        "actor_llm_mode":
            {"llm": "Qwen2-7B-Instruct",
             "context_window": 32768},
        "instruct_llm_mode":
            {"llm": "Qwen2-0.5B-Instruct",
             "context_window": 32768},
    }