"""
SlateQ (Reinforcement Learning for Recommendation)
==================================================

This file defines the trainer class for the SlateQ algorithm from the
`"Reinforcement Learning for Slate-based Recommender Systems: A Tractable
Decomposition and Practical Methodology" <https://arxiv.org/abs/1905.12767>`_
paper.

See `slateq_torch_policy.py` for the definition of the policy. Currently, only
PyTorch is supported. The algorithm is written and tested for Google's RecSim
environment (https://github.com/google-research/recsim).
"""

import logging
from typing import List, Type

from src.rllib_extensions.slateq_policy import SlateQTorchPolicy
from src.rllib.agents.trainer import with_common_config
from src.rllib.agents.trainer_template import build_trainer
from src.rllib.evaluation.worker_set import WorkerSet
from src.rllib.examples.policy.random_policy import RandomPolicy
from src.rllib.execution.concurrency_ops import Concurrently
from src.rllib.execution.metric_ops import StandardMetricsReporting
from src.rllib.execution.replay_buffer import LocalReplayBuffer
from src.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from src.rllib.execution.rollout_ops import ParallelRollouts
from src.rllib.execution.train_ops import TrainOneStep
from src.rllib.policy.policy import Policy
from src.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator

from src.rllib_extensions.imitation_module import ImitationModule

logger = logging.getLogger(__name__)

# Defines all SlateQ strategies implemented.
ALL_SLATEQ_STRATEGIES = [
    # RANDOM: Randomly select documents for slates.
    "RANDOM",
    # MYOP: Select documents that maximize user click probabilities. This is
    # a myopic strategy and ignores long term rewards. This is equivalent to
    # setting a zero discount rate for future rewards.
    "MYOP",
    # SARSA: Use the SlateQ SARSA learning algorithm.
    "SARSA",
    # QL: Use the SlateQ Q-learning algorithm.
    "QL",
]

# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
    # === Model ===
    # Dense-layer setup for each the advantage branch and the value branch
    # in a dueling architecture.
    "hiddens": [256, 64, 16],

    # set batchmode
    "batch_mode": "complete_episodes",

    # === Deep Learning Framework Settings ===
    # Currently, only PyTorch is supported
    "framework": "torch",

    # === Exploration Settings ===
    "exploration_config": {
        # The Exploration class to use.
        "type": "EpsilonGreedy",
        # Config for the Exploration class' constructor:
        "initial_epsilon": 1.0,
        "final_epsilon": 0.02,
        "epsilon_timesteps": 10000,  # Timesteps over which to anneal epsilon.
    },
    # Switch to greedy actions in evaluation workers.
    "evaluation_config": {
        "explore": False,
    },

    # Minimum env steps to optimize for per train call. This value does
    # not affect learning, only the length of iterations.
    "timesteps_per_iteration": 1000,

    # === Replay buffer ===
    # Size of the replay buffer. Note that if async_updates is set, then
    # each worker will have a replay buffer of this size.
    "buffer_size": 50000,
    # The number of contiguous environment steps to replay at once. This may
    # be set to greater than 1 to support recurrent models.
    "replay_sequence_length": 1,
    # Whether to LZ4 compress observations
    "compress_observations": False,
    # If set, this will fix the ratio of replayed from a buffer and learned on
    # timesteps to sampled from an environment and stored in the replay buffer
    # timesteps. Otherwise, the replay will proceed at the native ratio
    # determined by (train_batch_size / rollout_fragment_length).
    "training_intensity": None,

    # === Optimization ===
    # Learning rate for adam optimizer for the user choice model
    "lr_choice_model": 1e-2,
    # Learning rate for adam optimizer for the q model
    "lr_q_model": 1e-2,
    # Adam epsilon hyper parameter
    "adam_epsilon": 1e-8,
    # If not None, clip gradients during optimization at this value
    "grad_clip": 40,
    # How many steps of the model to sample before learning starts.
    "learning_starts": 1000,
    # Update the replay buffer with this many samples at once. Note that
    # this setting applies per-worker if num_workers > 1.
    "rollout_fragment_length": 1000,
    # Size of a batch sampled from replay buffer for training. Note that
    # if async_updates is set, then each worker returns gradients for a
    # batch of this size.
    "train_batch_size": 32,

    # === Parallelism ===
    # Number of workers for collecting samples with. This only makes sense
    # to increase if your environment is particularly slow to sample, or if
    # you"re using the Async or Ape-X optimizers.
    "num_workers": 0,
    # Whether to compute priorities on workers.
    "worker_side_prioritization": False,
    # Prevent iterations from going lower than this time span
    "min_iter_time_s": 1,

    # === SlateQ specific options ===
    # Learning method used by the slateq policy. Choose from: RANDOM,
    # MYOP (myopic), SARSA, QL (Q-Learning),
    "slateq_strategy": "QL",
    # user/doc embedding size for the recsim environment
    "recsim_embedding_size": 10,
    "dice_config": None,
    "wandb_logger": None
})
# __sphinx_doc_end__
# yapf: enable


def validate_config(config: TrainerConfigDict) -> None:
    """Checks the config based on settings"""
    if config["num_gpus"] > 1:
        raise ValueError("`num_gpus` > 1 not yet supported for SlateQ!")

    if config["framework"] != "torch":
        raise ValueError("SlateQ only runs on PyTorch")

    if config["slateq_strategy"] not in ALL_SLATEQ_STRATEGIES:
        raise ValueError("Unknown slateq_strategy: "
                         f"{config['slateq_strategy']}.")

    if config["slateq_strategy"] == "SARSA":
        if config["batch_mode"] != "complete_episodes":
            raise ValueError(
                "For SARSA strategy, batch_mode must be 'complete_episodes'")


def execution_plan(workers: WorkerSet,
                   config: TrainerConfigDict) -> LocalIterator[dict]:
    """Execution plan of the SlateQ algorithm. Defines the distributed dataflow.

    Args:
        workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
            of the Trainer.
        config (TrainerConfigDict): The trainer's configuration dict.

    Returns:
        LocalIterator[dict]: A local iterator over training metrics.
    """
    local_replay_buffer = LocalReplayBuffer(
        num_shards=1,
        learning_starts=config["learning_starts"],
        buffer_size=config["buffer_size"],
        replay_batch_size=config["train_batch_size"],
        replay_mode=config["multiagent"]["replay_mode"],
        replay_sequence_length=config["replay_sequence_length"],
    )

    rollouts = ParallelRollouts(workers, mode="bulk_sync")

    # We execute the following steps concurrently:
    # (1) Generate rollouts and store them in our local replay buffer. Calling
    # next() on store_op drives this.
    store_op = rollouts.for_each(
        StoreToReplayBuffer(local_buffer=local_replay_buffer))

    # (2) Read and train on experiences from the replay buffer. Every batch
    # returned from the LocalReplay() iterator is passed to TrainOneStep to
    # take a SGD step.
    replay_buffer = Replay(local_buffer=local_replay_buffer)

    if config["dice_config"] is not None:
        imitation_op = replay_buffer.for_each(ImitationModule(config['dice_config']))
    else:
        imitation_op = None

    replay_op = replay_buffer.for_each(TrainOneStep(workers))

    if config["dice_config"] is not None:
        ops = [store_op, imitation_op, replay_op]
    else:
        ops = [store_op, replay_op]

    if config["slateq_strategy"] != "RANDOM":
        # Alternate deterministically between (1) and (2). Only return the
        # output of (2) since training metrics are not available until (2)
        # runs.
        train_op = Concurrently(
            ops,
            mode="round_robin",
            output_indexes=[len(ops)-1],
            round_robin_weights=[1]*len(ops))
    else:
        # No training is needed for the RANDOM strategy.
        train_op = rollouts

    return StandardMetricsReporting(train_op, workers, config)


def calculate_round_robin_weights(config: TrainerConfigDict) -> List[float]:
    """Calculate the round robin weights for the rollout and train steps"""
    if not config["training_intensity"]:
        return [1, 1]
    # e.g., 32 / 4 -> native ratio of 8.0
    native_ratio = (
        config["train_batch_size"] / config["rollout_fragment_length"])
    # Training intensity is specified in terms of
    # (steps_replayed / steps_sampled), so adjust for the native ratio.
    weights = [1, config["training_intensity"] / native_ratio]
    return weights


def get_policy_class(config: TrainerConfigDict) -> Type[Policy]:
    """Policy class picker function.

    Args:
        config (TrainerConfigDict): The trainer's configuration dict.

    Returns:
        Type[Policy]: The Policy class to use with SlateQTrainer.
    """
    if config["slateq_strategy"] == "RANDOM":
        return RandomPolicy
    else:
        return SlateQTorchPolicy


SlateQTrainer = build_trainer(
    name="SlateQ",
    get_policy_class=get_policy_class,
    default_config=DEFAULT_CONFIG,
    validate_config=validate_config,
    execution_plan=execution_plan)
