"""Creates a Trainer class that runs a random policy.

This exists to allow for collecting data sampled randomly from an action space
using the rllib infrastructure / ensuring the same format as usual.
"""

import logging
from typing import Optional, Type

from ray.rllib.examples.policy.random_policy import RandomPolicy

from ray.rllib.agents import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches, SelectExperiences
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator

logger = logging.getLogger(__name__)

DEFAULT_CONFIG = with_common_config({
    # Size of batches collected from each worker.
    "rollout_fragment_length": 200,
})


# pylint: disable=unused-argument
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
    return RandomPolicy


def execution_plan(workers: WorkerSet, config: TrainerConfigDict) -> LocalIterator[dict]:
    rollouts = ParallelRollouts(workers, mode="bulk_sync")

    # Collect batches for the trainable policies.
    rollouts = rollouts.for_each(SelectExperiences(workers.trainable_policies()))
    # Concatenate the SampleBatches into one.
    rollouts = rollouts.combine(
        ConcatBatches(
            min_batch_size=config["train_batch_size"],
            count_steps_by=config["multiagent"]["count_steps_by"],
        ))

    return StandardMetricsReporting(rollouts, workers, config)


RandomTrainer = build_trainer(
    name="Random",
    default_config=DEFAULT_CONFIG,
    default_policy=RandomPolicy,
    get_policy_class=get_policy_class,
    execution_plan=execution_plan,
)
