import logging

from src.rllib.agents.trainer import with_common_config
from src.rllib.agents.trainer_template import build_trainer
from src.rllib.contrib.bandits.agents.policy import BanditPolicy

logger = logging.getLogger(__name__)

# yapf: disable
# __sphinx_doc_begin__
UCB_CONFIG = with_common_config({
    # No remote workers by default.
    "num_workers": 0,
    "framework": "torch",  # Only PyTorch supported so far.

    # Do online learning one step at a time.
    "rollout_fragment_length": 1,
    "train_batch_size": 1,

    # Bandits cant afford to do one timestep per iteration as it is extremely
    # slow because of metrics collection overhead. This setting means that the
    # agent will be trained for 100 times in one iteration of Rllib
    "timesteps_per_iteration": 100,

    "exploration_config": {
        "type": "src.rllib.contrib.bandits.exploration.UCB"
    }
})
# __sphinx_doc_end__
# yapf: enable

LinUCBTrainer = build_trainer(
    name="LinUCB", default_config=UCB_CONFIG, default_policy=BanditPolicy)
