"""Utils for minibatch SGD across multiple RLlib policies."""

import logging
import numpy as np
import random

from policy.sample_batch import SampleBatch, MultiAgentBatch
from utils.metrics.learner_info import LearnerInfoBuilder

logger = logging.getLogger(__name__)


def standardized(array: np.ndarray):
    """Normalize the values in an array.

    Args:
        array (np.ndarray): Array of values to normalize.

    Returns:
        array with zero mean and unit standard deviation.
    """
    return (array - array.mean()) / max(1e-4, array.std())


def minibatches(samples: SampleBatch, sgd_minibatch_size: int, shuffle: bool = True):
    """Return a generator yielding minibatches from a sample batch.

    Args:
        samples: SampleBatch to split up.
        sgd_minibatch_size: Size of minibatches to return.
        shuffle: Whether to shuffle the order of the generated minibatches.
            Note that in case of a non-recurrent policy, the incoming batch
            is globally shuffled first regardless of this setting, before
            the minibatches are generated from it!

    Yields:
        SampleBatch: Each of size `sgd_minibatch_size`.
    """
    if not sgd_minibatch_size:
        yield samples
        return

    if isinstance(samples, MultiAgentBatch):
        raise NotImplementedError(
            "Minibatching not implemented for multi-agent in simple mode"
        )

    if "state_in_0" not in samples and "state_out_0" not in samples:
        samples.shuffle()

    all_slices = samples._get_slice_indices(sgd_minibatch_size)
    data_slices, state_slices = all_slices

    if len(state_slices) == 0:
        if shuffle:
            random.shuffle(data_slices)
        for i, j in data_slices:
            yield samples.slice(i, j)
    else:
        all_slices = list(zip(data_slices, state_slices))
        if shuffle:
            # Make sure to shuffle data and states while linked together.
            random.shuffle(all_slices)
        for (i, j), (si, sj) in all_slices:
            yield samples.slice(i, j, si, sj)


def do_minibatch_sgd(
    samples,
    policies,
    local_worker,
    num_sgd_iter,
    sgd_minibatch_size,
    standardize_fields,
):
    """Execute minibatch SGD.

    Args:
        samples (SampleBatch): Batch of samples to optimize.
        policies (dict): Dictionary of policies to optimize.
        local_worker (RolloutWorker): Master rollout worker instance.
        num_sgd_iter (int): Number of epochs of optimization to take.
        sgd_minibatch_size (int): Size of minibatches to use for optimization.
        standardize_fields (list): List of sample field names that should be
            normalized prior to optimization.

    Returns:
        averaged info fetches over the last SGD epoch taken.
    """

    # Handle everything as if multi-agent.
    samples = samples.as_multi_agent()

    # Use LearnerInfoBuilder as a unified way to build the final
    # results dict from `learn_on_loaded_batch` call(s).
    # This makes sure results dicts always have the same structure
    # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
    # tf vs torch).
    learner_info_builder = LearnerInfoBuilder(num_devices=1)
    for policy_id, policy in policies.items():
        if policy_id not in samples.policy_batches:
            continue

        batch = samples.policy_batches[policy_id]
        for field in standardize_fields:
            batch[field] = standardized(batch[field])

        # Check to make sure that the sgd_minibatch_size is not smaller
        # than max_seq_len otherwise this will cause indexing errors while
        # performing sgd when using a RNN or Attention model
        if (
            policy.is_recurrent()
            and policy.config["model"]["max_seq_len"] > sgd_minibatch_size
        ):
            raise ValueError(
                "`sgd_minibatch_size` ({}) cannot be smaller than"
                "`max_seq_len` ({}).".format(
                    sgd_minibatch_size, policy.config["model"]["max_seq_len"]
                )
            )

        for i in range(num_sgd_iter):
            for minibatch in minibatches(batch, sgd_minibatch_size):
                results = (
                    local_worker.learn_on_batch(
                        MultiAgentBatch({policy_id: minibatch}, minibatch.count)
                    )
                )[policy_id]
                learner_info_builder.add_learn_on_batch_results(results, policy_id)

    learner_info = learner_info_builder.finalize()
    return learner_info
