from ray.rllib.agents.dqn.r2d2 import R2D2Trainer, R2D2_DEFAULT_CONFIG
from ray.rllib.agents.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_DEFAULT_CONFIG
from ray.rllib.agents.dqn.r2d2_tf_policy import R2D2TFPolicy
from ray.rllib.agents.dqn.r2d2_torch_policy import R2D2TorchPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict



def return_r2d2_workflow(args, PPOTorchCustomPolicy):
    R2D2_CONFIG = Trainer.merge_trainer_configs(
        R2D2_DEFAULT_CONFIG,  # See keys in impala.py, which are also supported.
        {
            # Concept settings
            "include_concepts": args.include_concepts,
            "use_balanced": False,
            "balanced_beta": -1,
            "balanced_gamma": -1,
            # config of concepts, no default
            "concept_configs": None,
            "concept_loss_coeff": 0.01,
            "loss_type": "focal",
        },
        _allow_unknown_configs=True,
    )


    class CustomR2D2Trainer(R2D2Trainer):
        def __randommethod(self):
            return 0

        @classmethod
        @override(DQNTrainer)
        def get_default_config(cls) -> TrainerConfigDict:
            return R2D2_CONFIG

    return CustomR2D2Trainer

