"""
Proximal Policy Optimization (PPO)
==================================

This file defines the distributed Trainer class for proximal policy optimization.
See `policy.py` for the definition of the policy loss.

Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#ppo
"""

import logging
from typing import Type
import os
import pickle

from agents.ppo.policy import PPOTorchPolicy
from agents.ppo.ppo_tf_policy import PPOTFPolicy
from agents.ppo.config import PPO_DEFAULT_CONFIG
from trainer.trainer import Trainer
from worker.worker_set import WorkerSet
from trainer.rollout_ops import (
    ParallelRollouts,
    ConcatBatches,
    StandardizeFields,
    SelectExperiences,
)
from trainer.train_ops import TrainOneStep, MultiGPUTrainOneStep
from trainer.metric_ops import StandardMetricsReporting
from policy.policy import Policy
from policy.sample_batch import DEFAULT_POLICY_ID
from utils.annotations import override
from utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
from utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator

from utils.framework import try_import_tf

tf1, tf, tfv = try_import_tf()

logger = logging.getLogger(__name__)


class UpdateKL:
    """Callback to update the KL based on optimization info.

    This is used inside the execution_plan function. The Policy must define
    a `update_kl` method for this to work. This is achieved for PPO via a
    Policy mixin class (which adds the `update_kl` method),
    """

    def __init__(self, workers):
        self.workers = workers

    def __call__(self, fetches):
        def update(pi, pi_id):
            assert LEARNER_STATS_KEY not in fetches, (
                "{} should be nested under policy id key".format(LEARNER_STATS_KEY),
                fetches,
            )
            if pi_id in fetches:
                kl = fetches[pi_id][LEARNER_STATS_KEY].get("kl")
                assert kl is not None, (fetches, pi_id)
                # Make the actual `Policy.update_kl()` call.
                pi.update_kl(kl)
            else:
                logger.warning("No data for {}, not updating kl".format(pi_id))

        # Update KL on all trainable policies within the local (trainer)
        # Worker.
        self.workers.local_worker().foreach_trainable_policy(update)


def warn_about_bad_reward_scales(config, result):
    if result["policy_reward_mean"]:
        return result  # Punt on handling multiagent case.

    # Warn about excessively high VF loss.
    learner_info = result["info"][LEARNER_INFO]
    if DEFAULT_POLICY_ID in learner_info:
        scaled_vf_loss = (
            config["vf_loss_coeff"]
            * learner_info[DEFAULT_POLICY_ID][LEARNER_STATS_KEY]["vf_loss"]
        )

        policy_loss = learner_info[DEFAULT_POLICY_ID][LEARNER_STATS_KEY]["policy_loss"]
        if config.get("model", {}).get("vf_share_layers") and scaled_vf_loss > 100:
            logger.warning(
                "The magnitude of your value function loss is extremely large "
                "({}) compared to the policy loss ({}). This can prevent the "
                "policy from learning. Consider scaling down the VF loss by "
                "reducing vf_loss_coeff, or disabling vf_share_layers.".format(
                    scaled_vf_loss, policy_loss
                )
            )

    # Warn about bad clipping configs
    if config["vf_clip_param"] <= 0:
        rew_scale = float("inf")
    else:
        rew_scale = round(
            abs(result["episode_reward_mean"]) / config["vf_clip_param"], 0
        )
    if rew_scale > 200:
        logger.warning(
            "The magnitude of your environment rewards are more than "
            "{}x the scale of `vf_clip_param`. ".format(rew_scale)
            + "This means that it will take more than "
            "{} iterations for your value ".format(rew_scale)
            + "function to converge. If this is not intended, consider "
            "increasing `vf_clip_param`."
        )

    return result


class PPOTrainer(Trainer):
    @classmethod
    @override(Trainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return PPO_DEFAULT_CONFIG

    @override(Trainer)
    def validate_config(self, config: TrainerConfigDict) -> None:
        """Validates the Trainer's config dict.

        Args:
            config (TrainerConfigDict): The Trainer's config to check.

        Raises:
            ValueError: In case something is wrong with the config.
        """
        # Call super's validation method.
        super().validate_config(config)

        if isinstance(config["entropy_coeff"], int):
            config["entropy_coeff"] = float(config["entropy_coeff"])

        if config["entropy_coeff"] < 0.0:
            raise DeprecationWarning("entropy_coeff must be >= 0.0")

        # SGD minibatch size must be smaller than train_batch_size (b/c
        # we subsample a batch of `sgd_minibatch_size` from the train-batch for
        # each `sgd_num_iter`).
        # Note: Only check this if `train_batch_size` > 0 (DDPPO sets this
        # to -1 to auto-calculate the actual batch size later).
        if (
            config["train_batch_size"] > 0
            and config["sgd_minibatch_size"] > config["train_batch_size"]
        ):
            raise ValueError(
                "`sgd_minibatch_size` ({}) must be <= "
                "`train_batch_size` ({}).".format(
                    config["sgd_minibatch_size"], config["train_batch_size"]
                )
            )

        # Check for mismatches between `train_batch_size` and
        # `rollout_fragment_length` and auto-adjust `rollout_fragment_length`
        # if necessary.
        # Note: Only check this if `train_batch_size` > 0 (DDPPO sets this
        # to -1 to auto-calculate the actual batch size later).
        num_workers = config["num_workers"] or 1
        calculated_min_rollout_size = (
            num_workers
            * config["num_envs_per_worker"]
            * config["rollout_fragment_length"]
        )
        if (
            config["train_batch_size"] > 0
            and config["train_batch_size"] % calculated_min_rollout_size != 0
        ):
            new_rollout_fragment_length = config["train_batch_size"] // (
                num_workers * config["num_envs_per_worker"]
            )
            logger.warning(
                "`train_batch_size` ({}) cannot be achieved with your other "
                "settings (num_workers={} num_envs_per_worker={} "
                "rollout_fragment_length={})! Auto-adjusting "
                "`rollout_fragment_length` to {}.".format(
                    config["train_batch_size"],
                    config["num_workers"],
                    config["num_envs_per_worker"],
                    config["rollout_fragment_length"],
                    new_rollout_fragment_length,
                )
            )
            config["rollout_fragment_length"] = new_rollout_fragment_length

        # Episodes may only be truncated (and passed into PPO's
        # `postprocessing_fn`), iff generalized advantage estimation is used
        # (value function estimate at end of truncated episode to estimate
        # remaining value).
        if config["batch_mode"] == "truncate_episodes" and not config["use_gae"]:
            raise ValueError(
                "Episode truncation is not supported without a value "
                "function (to estimate the return at the end of the truncated"
                " trajectory). Consider setting "
                "batch_mode=complete_episodes."
            )

        # Multi-agent mode and multi-GPU optimizer.
        if config["multiagent"]["policies"] and not config["simple_optimizer"]:
            logger.info(
                "In multi-agent mode, policies will be optimized sequentially"
                " by the multi-GPU optimizer. Consider setting "
                "simple_optimizer=True if this doesn't work for you."
            )

    @override(Trainer)
    def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
        # return PPOTorchPolicy
        if config["framework"] == "torch":
            return PPOTorchPolicy
        else:
            return PPOTFPolicy
    
    
    @staticmethod
    @override(Trainer)
    def execution_plan(
        workers: WorkerSet, config: TrainerConfigDict, **kwargs
    ) -> LocalIterator[dict]:
        assert (
            len(kwargs) == 0
        ), "PPO execution_plan does NOT take any additional parameters"

        rollouts = ParallelRollouts(workers, mode="bulk_sync")

        # Collect batches for the trainable policies.
        rollouts = rollouts.for_each(SelectExperiences(workers.trainable_policies()))
        # Concatenate the SampleBatches into one.
        rollouts = rollouts.combine(
            ConcatBatches(
                min_batch_size=config["train_batch_size"],
                count_steps_by=config["multiagent"]["count_steps_by"],
            )
        )
        # Standardize advantages.
        rollouts = rollouts.for_each(StandardizeFields(["advantages"]))

        # Perform one training step on the combined + standardized batch.
        if config["simple_optimizer"]:
            train_op = rollouts.for_each(
                TrainOneStep(
                    workers,
                    num_sgd_iter=config["num_sgd_iter"],
                    sgd_minibatch_size=config["sgd_minibatch_size"],
                )
            )
        else:
            train_op = rollouts.for_each(
                MultiGPUTrainOneStep(
                    workers=workers,
                    sgd_minibatch_size=config["sgd_minibatch_size"],
                    num_sgd_iter=config["num_sgd_iter"],
                    num_gpus=config["num_gpus"],
                    _fake_gpus=config["_fake_gpus"],
                )
            )

        # Update KL after each round of training.
        train_op = train_op.for_each(lambda t: t[1]).for_each(UpdateKL(workers))

        # Warn about bad reward scales and return training metrics.
        return StandardMetricsReporting(train_op, workers, config).for_each(
            lambda result: warn_about_bad_reward_scales(config, result)
        )
    #
    # @override(Trainer)
    # def restore(self, checkpoint_path):
    #     """Restores training state from a given model checkpoint.
    #
    #     These checkpoints are returned from calls to save().
    #
    #     Subclasses should override ``_restore()`` instead to restore state.
    #     This method restores additional metadata saved with the checkpoint.
    #     """
    #     # Maybe sync from cloud
    #     if self.uses_cloud_checkpointing:
    #         self.storage_client.sync_down(self.remote_checkpoint_dir,
    #                                       self.logdir)
    #         self.storage_client.wait()
    #
    #     self.load_checkpoint(checkpoint_path)
    #     self._time_since_restore = 0.0
    #     self._timesteps_since_restore = 0
    #     self._iterations_since_restore = 0
    #     self._restored = True
    #     logger.info("Restored on %s from checkpoint: %s",
    #                 self.get_current_ip(), checkpoint_path)
    #     state = {
    #         "_iteration": 0,
    #         "_timesteps_total": 0,
    #         "_time_total": 0,
    #         "_episodes_total": 0,
    #     }
    #     logger.info("Current state after restoring: %s", state)
    #
    #
    # @override(Trainer)
    # def load_checkpoint(self, checkpoint):
    #     policy_id = list(self.workers.local_worker().policy_map.keys())[0]
    #     load_dict = self.get_policy(policy_id).import_partial_model(checkpoint)
    #     for k, v in load_dict.items():
    #         print("### load var_name: {}, var: {}".format(k,v))
    #     self.saver = tf.train.Saver(load_dict, max_to_keep=0)
    #     # checkpoint = tf.train.latest_checkpoint(checkpoint)
    #     self.saver.restore(self.get_policy(policy_id).get_session(), os.path.join(checkpoint,'model'))
    #     print("restore rl model successfully")
    #     # Sync new weights to remote workers.
    #     self._sync_weights_to_workers(worker_set=self.workers)
    #
    # @override(Trainer)
    # def save_checkpoint(self, checkpoint_dir: str) -> str:
    #     checkpoint_path = os.path.join(
    #         checkpoint_dir, "checkpoint-{}".format(self.iteration)
    #     )
    #     pickle.dump(self.__getstate__(), open(checkpoint_path, "wb"))
    #     policy_id = list(self.workers.local_worker().policy_map.keys())[0]
    #     # self.get_policy(policy_id).export_model(checkpoint_dir)
    #     self.get_policy(policy_id).export_checkpoint(checkpoint_dir)
    #
    #     return checkpoint_path