import logging
import numpy as np

from src.rllib.utils.sgd import standardized
from src.rllib.agents import with_common_config
from src.rllib.agents.maml.maml_tf_policy import MAMLTFPolicy
from src.rllib.agents.maml.maml_torch_policy import MAMLTorchPolicy
from src.rllib.agents.trainer_template import build_trainer
from src.rllib.evaluation.metrics import get_learner_stats
from src.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
    STEPS_TRAINED_COUNTER, LEARNER_INFO, _get_shared_metrics
from src.rllib.policy.sample_batch import SampleBatch
from src.rllib.execution.metric_ops import CollectMetrics
from src.rllib.evaluation.metrics import collect_metrics
from src.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.util.iter import from_actors

logger = logging.getLogger(__name__)

# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
    # If true, use the Generalized Advantage Estimator (GAE)
    # with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
    "use_gae": True,
    # GAE(lambda) parameter
    "lambda": 1.0,
    # Initial coefficient for KL divergence
    "kl_coeff": 0.0005,
    # Size of batches collected from each worker
    "rollout_fragment_length": 200,
    # Do create an actual env on the local worker (worker-idx=0).
    "create_env_on_driver": True,
    # Stepsize of SGD
    "lr": 1e-3,
    "model": {
        # Share layers for value function.
        "vf_share_layers": False,
    },
    # Coefficient of the value function loss
    "vf_loss_coeff": 0.5,
    # Coefficient of the entropy regularizer
    "entropy_coeff": 0.0,
    # PPO clip parameter
    "clip_param": 0.3,
    # Clip param for the value function. Note that this is sensitive to the
    # scale of the rewards. If your expected V is large, increase this.
    "vf_clip_param": 10.0,
    # If specified, clip the global norm of gradients by this amount
    "grad_clip": None,
    # Target value for KL divergence
    "kl_target": 0.01,
    # Whether to rollout "complete_episodes" or "truncate_episodes"
    "batch_mode": "complete_episodes",
    # Which observation filter to apply to the observation
    "observation_filter": "NoFilter",
    # Number of Inner adaptation steps for the MAML algorithm
    "inner_adaptation_steps": 1,
    # Number of MAML steps per meta-update iteration (PPO steps)
    "maml_optimizer_steps": 5,
    # Inner Adaptation Step size
    "inner_lr": 0.1,
    # Use Meta Env Template
    "use_meta_env": True,

    # Deprecated keys:
    # Share layers for value function. If you set this to True, it's important
    # to tune vf_loss_coeff.
    # Use config.model.vf_share_layers instead.
    "vf_share_layers": DEPRECATED_VALUE,
})
# __sphinx_doc_end__
# yapf: enable


# @mluo: TODO
def set_worker_tasks(workers, use_meta_env):
    if use_meta_env:
        n_tasks = len(workers.remote_workers())
        tasks = workers.local_worker().foreach_env(lambda x: x)[
            0].sample_tasks(n_tasks)
        for i, worker in enumerate(workers.remote_workers()):
            worker.foreach_env.remote(lambda env: env.set_task(tasks[i]))


class MetaUpdate:
    def __init__(self, workers, maml_steps, metric_gen, use_meta_env):
        self.workers = workers
        self.maml_optimizer_steps = maml_steps
        self.metric_gen = metric_gen
        self.use_meta_env = use_meta_env

    def __call__(self, data_tuple):
        # Metaupdate Step
        samples = data_tuple[0]
        adapt_metrics_dict = data_tuple[1]

        # Metric Updating
        metrics = _get_shared_metrics()
        metrics.counters[STEPS_SAMPLED_COUNTER] += samples.count
        for i in range(self.maml_optimizer_steps):
            fetches = self.workers.local_worker().learn_on_batch(samples)
        fetches = get_learner_stats(fetches)

        # Sync workers with meta policy
        self.workers.sync_weights()

        # Set worker tasks
        set_worker_tasks(self.workers, self.use_meta_env)

        # Update KLS
        def update(pi, pi_id):
            assert "inner_kl" not in fetches, (
                "inner_kl should be nested under policy id key", fetches)
            if pi_id in fetches:
                assert "inner_kl" in fetches[pi_id], (fetches, pi_id)
                pi.update_kls(fetches[pi_id]["inner_kl"])
            else:
                logger.warning("No data for {}, not updating kl".format(pi_id))

        self.workers.local_worker().foreach_trainable_policy(update)

        # Modify Reporting Metrics
        metrics = _get_shared_metrics()
        metrics.info[LEARNER_INFO] = fetches
        metrics.counters[STEPS_TRAINED_COUNTER] += samples.count

        res = self.metric_gen.__call__(None)
        res.update(adapt_metrics_dict)

        return res


def post_process_metrics(adapt_iter, workers, metrics):
    # Obtain Current Dataset Metrics and filter out
    name = "_adapt_" + str(adapt_iter) if adapt_iter > 0 else ""

    # Only workers are collecting data
    res = collect_metrics(remote_workers=workers.remote_workers())

    metrics["episode_reward_max" + str(name)] = res["episode_reward_max"]
    metrics["episode_reward_mean" + str(name)] = res["episode_reward_mean"]
    metrics["episode_reward_min" + str(name)] = res["episode_reward_min"]

    return metrics


def inner_adaptation(workers, samples):
    # Each worker performs one gradient descent
    for i, e in enumerate(workers.remote_workers()):
        e.learn_on_batch.remote(samples[i])


def execution_plan(workers, config):
    # Sync workers with meta policy
    workers.sync_weights()

    # Samples and sets worker tasks
    use_meta_env = config["use_meta_env"]
    set_worker_tasks(workers, use_meta_env)

    # Metric Collector
    metric_collect = CollectMetrics(
        workers,
        min_history=config["metrics_smoothing_episodes"],
        timeout_seconds=config["collect_metrics_timeout"])

    # Iterator for Inner Adaptation Data gathering (from pre->post adaptation)
    inner_steps = config["inner_adaptation_steps"]

    def inner_adaptation_steps(itr):
        buf = []
        split = []
        metrics = {}
        for samples in itr:

            # Processing Samples (Standardize Advantages)
            split_lst = []
            for sample in samples:
                sample["advantages"] = standardized(sample["advantages"])
                split_lst.append(sample.count)

            buf.extend(samples)
            split.append(split_lst)

            adapt_iter = len(split) - 1
            metrics = post_process_metrics(adapt_iter, workers, metrics)
            if len(split) > inner_steps:
                out = SampleBatch.concat_samples(buf)
                out["split"] = np.array(split)
                buf = []
                split = []

                # Reporting Adaptation Rew Diff
                ep_rew_pre = metrics["episode_reward_mean"]
                ep_rew_post = metrics["episode_reward_mean_adapt_" +
                                      str(inner_steps)]
                metrics["adaptation_delta"] = ep_rew_post - ep_rew_pre
                yield out, metrics
                metrics = {}
            else:
                inner_adaptation(workers, samples)

    rollouts = from_actors(workers.remote_workers())
    rollouts = rollouts.batch_across_shards()
    rollouts = rollouts.transform(inner_adaptation_steps)

    # Metaupdate Step
    train_op = rollouts.for_each(
        MetaUpdate(workers, config["maml_optimizer_steps"], metric_collect,
                   use_meta_env))
    return train_op


def get_policy_class(config):
    if config["framework"] == "torch":
        return MAMLTorchPolicy
    return MAMLTFPolicy


def validate_config(config):
    if config["num_gpus"] > 1:
        raise ValueError("`num_gpus` > 1 not yet supported for MAML!")
    if config["inner_adaptation_steps"] <= 0:
        raise ValueError("Inner Adaptation Steps must be >=1!")
    if config["maml_optimizer_steps"] <= 0:
        raise ValueError("PPO steps for meta-update needs to be >=0!")
    if config["entropy_coeff"] < 0:
        raise ValueError("`entropy_coeff` must be >=0.0!")
    if config["batch_mode"] != "complete_episodes":
        raise ValueError("`batch_mode`=truncate_episodes not supported!")
    if config["num_workers"] <= 0:
        raise ValueError("Must have at least 1 worker/task!")
    if config["create_env_on_driver"] is False:
        raise ValueError("Must have an actual Env created on the driver "
                         "(local) worker! Set `create_env_on_driver` to True.")


MAMLTrainer = build_trainer(
    name="MAML",
    default_config=DEFAULT_CONFIG,
    default_policy=MAMLTFPolicy,
    get_policy_class=get_policy_class,
    execution_plan=execution_plan,
    validate_config=validate_config)
