from distutils.command.config import config
import logging
from typing import Type

from ray.rllib.agents import with_common_config
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.rollout_ops import (
    ParallelRollouts,
    ConcatBatches,
    StandardizeFields,
    SelectExperiences,
)
from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.utils.spaces import space_utils
from ray.rllib.utils.deprecation import (
    Deprecated,
    deprecation_warning,
    DEPRECATED_VALUE,
)
from ray.rllib.evaluation.episode import Episode
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch

from ray.rllib.utils.typing import (
    AgentID,
    EnvCreator,
    EnvInfoDict,
    EnvType,
    EpisodeID,
    PartialTrainerConfigDict,
    PolicyID,
    PolicyState,
    ResultDict,
    SampleBatchType,
    TensorStructType,
    TensorType,
    TrainerConfigDict,
)

from typing import (
    Callable,
    Container,
    DefaultDict,
    Dict,
    List,
    Optional,
    Set,
    Tuple,
    Type,
    Union,
)

logger = logging.getLogger(__name__)


def return_ppo_workflow(args, PPOTorchCustomPolicy):
    common_config = with_common_config(
        {
            # Should use a critic as a baseline (otherwise don't use value baseline;
            # required for using GAE).
            "use_critic": True,
            # If true, use the Generalized Advantage Estimator (GAE)
            # with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
            "use_gae": True,
            # The GAE (lambda) parameter.
            "lambda": 1.0,
            # Initial coefficient for KL divergence.
            "kl_coeff": 0.2,
            # Size of batches collected from each worker.
            # "rollout_fragment_length": 200,
            # Number of timesteps collected for each SGD round. This defines the size
            # of each SGD epoch.
            "train_batch_size": 512,
            # Total SGD batch size across all devices for SGD. This defines the
            # minibatch size within each epoch.
            "sgd_minibatch_size": 10 * 32,
            # Whether to shuffle sequences in the batch when training (recommended).
            "shuffle_sequences": True,
            # Number of SGD iterations in each outer loop (i.e., number of epochs to
            # execute per train batch).
            "num_sgd_iter": 3,
            # Stepsize of SGD.
            "lr": 5e-5,
            # Learning rate schedule.
            "lr_schedule": None,
            # Coefficient of the value function loss. IMPORTANT: you must tune this if
            # you set vf_share_layers=True inside your model's config.
            "vf_loss_coeff": 1.0,
            # "concept_loss_coeff": 1.0,  # tune.grid_search([0.5,1.0,2.0,4.0,8.0,16.0,50.0])
            "concept_lengths": [5, 5, 5],
            "concept_num_agents": [5, 5, 5],
            "model": {
                # Share layers for value function. If you set this to True, it's
                # important to tune vf_loss_coeff.
                "vf_share_layers": False,
            },
            # Coefficient of the entropy regularizer.
            "entropy_coeff": 0.0,
            # Decay schedule for the entropy regularizer.
            "entropy_coeff_schedule": None,
            # PPO clip parameter.
            "clip_param": 0.3,
            # Clip param for the value function. Note that this is sensitive to the
            # scale of the rewards. If your expected V is large, increase this.
            "vf_clip_param": 10.0,
            # If specified, clip the global norm of gradients by this amount.
            "grad_clip": None,
            # Target value for KL divergence.
            "kl_target": 0.01,
            # Whether to rollout "complete_episodes" or "truncate_episodes".
            "batch_mode": "complete_episodes",
            # Which observation filter to apply to the observation.
            "observation_filter": "NoFilter",
            "max_seq_len": 5,
            # Deprecated keys:
            # Share layers for value function. If you set this to True, it's important
            # to tune vf_loss_coeff.
            # Use config.model.vf_share_layers instead.
            "vf_share_layers": DEPRECATED_VALUE,
            # Concept settings
            "include_concepts": args.include_concepts,
            "use_balanced": False,
            "balanced_beta": -1,
            "balanced_gamma": -1,
            # config of concepts, no default
            "concept_configs": None,
            "concept_loss_coeff": 0.01,
            "loss_type": "focal",
        }
    )

    class CustomPPOTrainer(PPOTrainer):
        @classmethod
        def get_default_config(cls) -> TrainerConfigDict:
            return common_config

        def compute_single_action_(
            self,
            observation: Optional[TensorStructType] = None,
            state: Optional[List[TensorStructType]] = None,
            *,
            prev_action: Optional[TensorStructType] = None,
            prev_reward: Optional[float] = None,
            info: Optional[EnvInfoDict] = None,
            input_dict: Optional[SampleBatch] = None,
            policy_id: PolicyID = DEFAULT_POLICY_ID,
            full_fetch: bool = False,
            explore: Optional[bool] = None,
            timestep: Optional[int] = None,
            episode: Optional[Episode] = None,
            unsquash_action: Optional[bool] = None,
            clip_action: Optional[bool] = None,
            # custom args
            concept_update=None,
            do_update=None,
            concepts_to_update=None,
            # Deprecated args.
            unsquash_actions=DEPRECATED_VALUE,
            clip_actions=DEPRECATED_VALUE,
            # Kwargs placeholder for future compatibility.
            **kwargs,
        ) -> Union[
            TensorStructType,
            Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]],
        ]:
            """Computes an action for the specified policy on the local worker.
            Note that you can also access the policy object through
            self.get_policy(policy_id) and call compute_single_action() on it
            directly.
            Args:
                observation: Single (unbatched) observation from the
                    environment.
                state: List of all RNN hidden (single, unbatched) state tensors.
                prev_action: Single (unbatched) previous action value.
                prev_reward: Single (unbatched) previous reward value.
                info: Env info dict, if any.
                input_dict: An optional SampleBatch that holds all the values
                    for: obs, state, prev_action, and prev_reward, plus maybe
                    custom defined views of the current env trajectory. Note
                    that only one of `obs` or `input_dict` must be non-None.
                policy_id: Policy to query (only applies to multi-agent).
                    Default: "default_policy".
                full_fetch: Whether to return extra action fetch results.
                    This is always set to True if `state` is specified.
                explore: Whether to apply exploration to the action.
                    Default: None -> use self.config["explore"].
                timestep: The current (sampling) time step.
                episode: This provides access to all of the internal episodes'
                    state, which may be useful for model-based or multi-agent
                    algorithms.
                unsquash_action: Should actions be unsquashed according to the
                    env's/Policy's action space? If None, use the value of
                    self.config["normalize_actions"].
                clip_action: Should actions be clipped according to the
                    env's/Policy's action space? If None, use the value of
                    self.config["clip_actions"].
            Keyword Args:
                kwargs: forward compatibility placeholder
            Returns:
                The computed action if full_fetch=False, or a tuple of a) the
                full output of policy.compute_actions() if full_fetch=True
                or we have an RNN-based Policy.
            Raises:
                KeyError: If the `policy_id` cannot be found in this Trainer's
                    local worker.
            """
            # `unsquash_action` is None: Use value of config['normalize_actions'].
            if unsquash_action is None:
                unsquash_action = self.config["normalize_actions"]
            # `clip_action` is None: Use value of config['clip_actions'].
            elif clip_action is None:
                clip_action = self.config["clip_actions"]

            # User provided an input-dict: Assert that `obs`, `prev_a|r`, `state`
            # are all None.
            err_msg = (
                "Provide either `input_dict` OR [`observation`, ...] as "
                "args to Trainer.compute_single_action!"
            )
            if input_dict is not None:
                assert (
                    observation is None
                    and prev_action is None
                    and prev_reward is None
                    and state is None
                ), err_msg
                observation = input_dict[SampleBatch.OBS]
            else:
                assert observation is not None, err_msg

            # Get the policy to compute the action for (in the multi-agent case,
            # Trainer may hold >1 policies).
            policy = self.get_policy(policy_id)
            if policy is None:
                raise KeyError(
                    f"PolicyID '{policy_id}' not found in PolicyMap of the "
                    f"Trainer's local worker!"
                )
            local_worker = self.workers.local_worker()

            # Check the preprocessor and preprocess, if necessary.
            pp = local_worker.preprocessors[policy_id]
            if pp and type(pp).__name__ != "NoPreprocessor":
                observation = pp.transform(observation)
            observation = local_worker.filters[policy_id](observation, update=False)

            # Input-dict.
            if input_dict is not None:
                input_dict[SampleBatch.OBS] = observation
                action, state, extra = policy.compute_single_action_(
                    input_dict=input_dict,
                    explore=explore,
                    timestep=timestep,
                    episode=episode,
                    concept_update=concept_update,
                    do_update=do_update,
                    concepts_to_update=concepts_to_update,
                )
            # Individual args.
            else:
                action, state, extra = policy.compute_single_action_(
                    obs=observation,
                    state=state,
                    prev_action=prev_action,
                    prev_reward=prev_reward,
                    info=info,
                    explore=explore,
                    timestep=timestep,
                    episode=episode,
                    concept_update=concept_update,
                    do_update=do_update,
                    concepts_to_update=concepts_to_update,
                )

            # If we work in normalized action space (normalize_actions=True),
            # we re-translate here into the env's action space.
            if unsquash_action:
                action = space_utils.unsquash_action(action, policy.action_space_struct)
            # Clip, according to env's action space.
            elif clip_action:
                action = space_utils.clip_action(action, policy.action_space_struct)

            # Return 3-Tuple: Action, states, and extra-action fetches.
            if state or full_fetch:
                return action, state, extra
            # Ensure backward compatibility.
            else:
                return action

    return CustomPPOTrainer

