"""Utils for minibatch SGD across multiple RLlib policies."""

import numpy as np
import logging
from collections import defaultdict
import random

from src.rllib.evaluation.metrics import LEARNER_STATS_KEY
from src.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, \
    MultiAgentBatch

logger = logging.getLogger(__name__)


def averaged(kv, axis=None):
    """Average the value lists of a dictionary.

    For non-scalar values, we simply pick the first value.

    Args:
        kv (dict): dictionary with values that are lists of floats.

    Returns:
        dictionary with single averaged float as values.
    """
    out = {}
    for k, v in kv.items():
        if v[0] is not None and not isinstance(v[0], dict):
            out[k] = np.mean(v, axis=axis)
        else:
            out[k] = v[0]
    return out


def standardized(array):
    """Normalize the values in an array.

    Args:
        array (np.ndarray): Array of values to normalize.

    Returns:
        array with zero mean and unit standard deviation.
    """
    return (array - array.mean()) / max(1e-4, array.std())


def minibatches(samples, sgd_minibatch_size, shuffle=True):
    """Return a generator yielding minibatches from a sample batch.

    Args:
        samples (SampleBatch): batch of samples to split up.
        sgd_minibatch_size (int): size of minibatches to return.

    Returns:
        generator that returns mini-SampleBatches of size sgd_minibatch_size.
    """
    if not sgd_minibatch_size:
        yield samples
        return

    if isinstance(samples, MultiAgentBatch):
        raise NotImplementedError(
            "Minibatching not implemented for multi-agent in simple mode")

    if "state_in_0" not in samples and "state_out_0" not in samples:
        samples.shuffle()

    all_slices = samples._get_slice_indices(sgd_minibatch_size)
    data_slices, state_slices = all_slices

    if len(state_slices) == 0:
        if shuffle:
            random.shuffle(data_slices)
        for i, j in data_slices:
            yield samples.slice(i, j)
    else:
        all_slices = list(zip(data_slices, state_slices))
        if shuffle:
            # Make sure to shuffle data and states while linked together.
            random.shuffle(all_slices)
        for (i, j), (si, sj) in all_slices:
            yield samples.slice(i, j, si, sj)


def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter,
                     sgd_minibatch_size, standardize_fields):
    """Execute minibatch SGD.

    Args:
        samples (SampleBatch): Batch of samples to optimize.
        policies (dict): Dictionary of policies to optimize.
        local_worker (RolloutWorker): Master rollout worker instance.
        num_sgd_iter (int): Number of epochs of optimization to take.
        sgd_minibatch_size (int): Size of minibatches to use for optimization.
        standardize_fields (list): List of sample field names that should be
            normalized prior to optimization.

    Returns:
        averaged info fetches over the last SGD epoch taken.
    """
    if isinstance(samples, SampleBatch):
        samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count)

    fetches = defaultdict(dict)
    for policy_id in policies.keys():
        if policy_id not in samples.policy_batches:
            continue

        batch = samples.policy_batches[policy_id]
        for field in standardize_fields:
            batch[field] = standardized(batch[field])

        learner_stats = defaultdict(list)
        model_stats = defaultdict(list)
        custom_callbacks_stats = defaultdict(list)

        for i in range(num_sgd_iter):
            for minibatch in minibatches(batch, sgd_minibatch_size):
                batch_fetches = (local_worker.learn_on_batch(
                    MultiAgentBatch({
                        policy_id: minibatch
                    }, minibatch.count)))[policy_id]
                for k, v in batch_fetches.get(LEARNER_STATS_KEY, {}).items():
                    learner_stats[k].append(v)
                for k, v in batch_fetches.get("model", {}).items():
                    model_stats[k].append(v)
                for k, v in batch_fetches.get("custom_metrics", {}).items():
                    custom_callbacks_stats[k].append(v)
        fetches[policy_id][LEARNER_STATS_KEY] = averaged(learner_stats)
        fetches[policy_id]["model"] = averaged(model_stats)
        fetches[policy_id]["custom_metrics"] = averaged(custom_callbacks_stats)
    return fetches
