"""
PyTorch policy class used for PPO.
"""
import gym
import logging
from typing import Dict, List, Type, Union
import functools
import numpy as np
from ray.rllib.utils.threading import with_lock
import ray
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
from ray.rllib.evaluation.postprocessing import (
    compute_gae_for_sample_batch,
    Postprocessing,
)
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import (
    apply_grad_clipping,
    explained_variance,
    sequence_mask,
)
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
from ray.rllib.utils.numpy import convert_to_numpy
from sklearn.metrics import recall_score, f1_score, accuracy_score, precision_score

from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
from ray.rllib.utils.annotations import override

from postprocessing import compute_concepts

from ray.rllib.utils.typing import (
    AgentID,
    ModelGradients,
    ModelWeights,
    PolicyID,
    PolicyState,
    T,
    TensorType,
    TensorStructType,
    TrainerConfigDict,
)
import tree
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score

from typing import (
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Tuple,
    Type,
    TYPE_CHECKING,
    Union,
)

from ray.rllib.utils.spaces.space_utils import (
    get_base_struct_from_space,
    get_dummy_batch_for_space,
    unbatch,
)

from utils import CB_loss


if TYPE_CHECKING:
    from ray.rllib.evaluation import Episode

torch, nn = try_import_torch()

logger = logging.getLogger(__name__)


class PPOTorchCustomPolicy(PPOTorchPolicy):
    """Custom PPO policy for PyTorch."""

    def __init__(self, observation_space, action_space, config):
        super().__init__(observation_space, action_space, config)

    @override(PPOTorchPolicy)
    def postprocess_trajectory(
        self, sample_batch, other_agent_batches=None, episode=None
    ):
        # Do all post-processing always with no_grad().
        # Not using this here will introduce a memory leak
        # in torch (issue #6962).
        # TODO: no_grad still necessary?
        with torch.no_grad():
            batch = compute_gae_for_sample_batch(
                self, sample_batch, other_agent_batches, episode
            )
            if "include_concepts" not in self.config.keys():
                assert 0, "include_concepts not in config"

            if self.config["env_config"]["name"] == "fort_attack":
                batch = compute_concepts(
                    self.config,
                    batch,
                    self.config["concept_configs"],
                    other_agent_batches,
                )
            return batch

    @override(PPOTorchPolicy)
    def extra_action_out(self, input_dict, state_batches, model, action_dist=None):
        """Returns dict of extra info to include in experience batch.

        Arguments:
            input_dict (dict): Dict of model input tensors.
            state_batches (list): List of state tensors.
            model (TorchModelV2): Reference to the model.
            action_dist (Distribution): Torch Distribution object to get
                log-probs (e.g. for already sampled actions).
        """

        return {
            SampleBatch.VF_PREDS: model.value_function(),
            "concept_outs": model.concept_function(),
            "concepts_after_softmax": model.return_concept(),
        }

    @override(PPOTorchPolicy)
    def loss(
        self,
        model: ModelV2,
        dist_class: Type[ActionDistribution],
        train_batch: SampleBatch,
    ) -> Union[TensorType, List[TensorType]]:
        """Constructs the loss for Proximal Policy Objective.
        Args:
            model: The Model to calculate the loss for.
            dist_class: The action distr. class.
            train_batch: The training data.
        Returns:
            The PPO loss tensor given the input batch.
        """

        logits, state = model(train_batch)
        curr_action_dist = dist_class(logits, model)

        # RNN case: Mask away 0-padded chunks at end of time axis.
        if state:
            B = len(train_batch[SampleBatch.SEQ_LENS])
            max_seq_len = logits.shape[0] // B
            mask = sequence_mask(
                train_batch[SampleBatch.SEQ_LENS],
                max_seq_len,
                time_major=model.is_time_major(),
            )
            mask = torch.reshape(mask, [-1])
            num_valid = torch.sum(mask)

            def reduce_mean_valid(t):
                return torch.sum(t[mask]) / num_valid

        # non-RNN case: No masking.
        else:
            mask = None
            reduce_mean_valid = torch.mean

        prev_action_dist = dist_class(
            train_batch[SampleBatch.ACTION_DIST_INPUTS], model
        )

        logp_ratio = torch.exp(
            curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
            - train_batch[SampleBatch.ACTION_LOGP]
        )

        # Only calculate kl loss if necessary (kl-coeff > 0.0).
        if self.config["kl_coeff"] > 0.0:
            action_kl = prev_action_dist.kl(curr_action_dist)
            mean_kl_loss = reduce_mean_valid(action_kl)
        else:
            mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device)

        curr_entropy = curr_action_dist.entropy()
        mean_entropy = reduce_mean_valid(curr_entropy)

        surrogate_loss = torch.min(
            train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
            train_batch[Postprocessing.ADVANTAGES]
            * torch.clamp(
                logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"]
            ),
        )
        mean_policy_loss = reduce_mean_valid(-surrogate_loss)

        # Compute a value function loss.
        if self.config["use_critic"]:
            value_fn_out = model.value_function()
            vf_loss = torch.pow(
                value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0
            )
            vf_loss_clipped = torch.clamp(vf_loss, 0, self.config["vf_clip_param"])
            mean_vf_loss = reduce_mean_valid(vf_loss_clipped)
        # Ignore the value function.
        else:
            vf_loss_clipped = mean_vf_loss = 0.0

        concept_loss = mean_concept_loss = 0.0
        # Compute the concept loss.

        if self.config["include_concepts"]:
            is_guard = self.config["model"]["custom_model_config"]["is_guard"]
            mean_concept_loss = []
            mean_accuracy = []
            mean_f1 = []
            mean_recall = []
            mean_precision = []

            concept_fn_out = model.concept_function()
            concept_configs = self.config["concept_configs"][
                "guard" if is_guard else "attacker"
            ]
            complete_configs = [
                {
                    "configs": concept_configs.configs,
                    "total_length": concept_configs.total_length,
                }
            ]
            complete_configs.extend([v for k, v in concept_configs.tom_configs.items()])
            total_length = 0
            for config_dict in complete_configs:
                configs = config_dict["configs"]
                current_total_length = config_dict["total_length"]
                for config in configs:
                    cur_start = config.start_idx
                    cur_end = config.end_idx

                    if config.concept_type == "regression":
                        cur_concept_loss = torch.mean(
                            torch.pow(
                                concept_fn_out[
                                    :, cur_start + total_length : cur_end + total_length
                                ]
                                - train_batch["concept_targets"][
                                    :, cur_start + total_length : cur_end + total_length
                                ],
                                2.0,
                            ),
                            dim=-1,
                        )
                        acc_score_i, f1_score_i, recall_score_i, precision_score_i = (
                            0,
                            0,
                            0,
                            0,
                        )
                        cur_concept_loss = reduce_mean_valid(cur_concept_loss)
                    else:
                        # ground truth
                        concept_targets = train_batch["concept_targets"][
                            :, cur_start + total_length : cur_end + total_length
                        ]
                        # prediction
                        concept_logits = concept_fn_out[
                            :, cur_start + total_length : cur_end + total_length
                        ]
                        # if stateful, mask out the padded part
                        if state:
                            concept_targets = concept_targets[mask]
                            concept_logits = concept_logits[mask]
                        samples_per_cls = torch.sum(concept_targets, dim=0)

                        concept_target_amax = torch.argmax(concept_targets, dim=-1)

                        if self.config["use_balanced"]:
                            cur_concept_loss = CB_loss(
                                concept_target_amax,
                                concept_logits,
                                samples_per_cls,
                                cur_end - cur_start,
                                self.config["loss_type"],
                                self.config["balanced_beta"],
                                self.config["balanced_gamma"],
                            )
                        else:
                            classification_loss = nn.CrossEntropyLoss(reduction="none")
                            cur_concept_loss = classification_loss(
                                concept_logits, concept_target_amax,
                            )
                            cur_concept_loss = torch.mean(cur_concept_loss)
                        y_true = concept_targets.detach().cpu().numpy()
                        y_true = np.argmax(y_true, axis=-1)
                        y_pred = concept_logits.detach().cpu().numpy()
                        y_pred = np.argmax(y_pred, axis=-1)
                        acc_score_i = torch.tensor(accuracy_score(y_true, y_pred))
                        f1_score_i = torch.tensor(
                            f1_score(
                                y_true, y_pred, average="weighted", zero_division=0
                            )
                        )
                        recall_score_i = torch.tensor(
                            recall_score(
                                y_true, y_pred, average="macro", zero_division=0
                            )
                        )
                        precision_score_i = torch.tensor(
                            precision_score(
                                y_true, y_pred, average="macro", zero_division=0
                            )
                        )
                    if self.config["use_balanced"]:
                        mean_concept_loss.append(cur_concept_loss)
                    else:
                        mean_concept_loss.append(cur_concept_loss)

                    mean_accuracy.append(acc_score_i)
                    mean_f1.append(f1_score_i)
                    mean_recall.append(recall_score_i)
                    mean_precision.append(precision_score_i)

                    if not torch.any(torch.isnan(cur_concept_loss)):
                        concept_loss += cur_concept_loss
                    else:
                        print(f"{cur_concept_loss=}")
                total_length += current_total_length

        else:
            concept_loss = mean_concept_loss = 0.0

        if self.config["use_balanced"]:
            total_loss = (
                reduce_mean_valid(
                    -surrogate_loss
                    + self.config["vf_loss_coeff"] * vf_loss_clipped
                    - self.entropy_coeff * curr_entropy
                )
                + self.config["concept_loss_coeff"] * concept_loss
            )
        else:
            total_loss = reduce_mean_valid(
                -surrogate_loss
                + self.config["vf_loss_coeff"] * vf_loss_clipped
                - self.entropy_coeff * curr_entropy
                + self.config["concept_loss_coeff"] * concept_loss
            )

        # Add mean_kl_loss (already processed through `reduce_mean_valid`),
        # if necessary.
        if self.config["kl_coeff"] > 0.0:
            total_loss += self.kl_coeff * mean_kl_loss

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["total_loss"] = total_loss
        model.tower_stats["mean_policy_loss"] = mean_policy_loss
        model.tower_stats["mean_vf_loss"] = mean_vf_loss
        model.tower_stats["vf_explained_var"] = explained_variance(
            train_batch[Postprocessing.VALUE_TARGETS], model.value_function()
        )
        model.tower_stats["mean_entropy"] = mean_entropy
        model.tower_stats["mean_kl_loss"] = mean_kl_loss
        if self.config["include_concepts"]:
            cur_i = 0
            for i, config in enumerate(concept_configs.configs):
                model.tower_stats[f"{config.full_name}_loss"] = mean_concept_loss[i]
                if config.concept_type == "classification":
                    model.tower_stats[f"{config.full_name}_acc"] = mean_accuracy[i]
                    model.tower_stats[f"{config.full_name}_f1"] = mean_f1[i]
                    model.tower_stats[f"{config.full_name}_recall"] = mean_recall[i]
                    model.tower_stats[f"{config.full_name}_precision"] = mean_precision[
                        i
                    ]
                model.tower_stats[f"{config.full_name}_"] = mean_concept_loss[i]
                cur_i = i + 1
            for agent_id, tom_config in concept_configs.tom_configs.items():
                configs = tom_config["configs"]
                for config in configs:
                    model.tower_stats[
                        f"tom_1_{config.full_name}_loss"
                    ] = mean_concept_loss[cur_i]
                    if config.concept_type == "classification":
                        model.tower_stats[f"{config.full_name}_acc"] = mean_accuracy[
                            cur_i
                        ]
                        model.tower_stats[f"{config.full_name}_f1"] = mean_f1[cur_i]
                        model.tower_stats[f"{config.full_name}_recall"] = mean_recall[
                            cur_i
                        ]
                        model.tower_stats[
                            f"{config.full_name}_precision"
                        ] = mean_precision[cur_i]
                    cur_i += 1

            model.tower_stats["mean_concept_loss"] = (
                torch.mean(torch.stack(mean_concept_loss))
                if len(mean_concept_loss) > 1
                else torch.Tensor([0.0])
            )

        return total_loss

    @override(PPOTorchPolicy)
    def extra_grad_info(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
        info_dict = {
            "cur_kl_coeff": self.kl_coeff,
            "cur_lr": self.cur_lr,
            "total_loss": torch.mean(torch.stack(self.get_tower_stats("total_loss"))),
            "policy_loss": torch.mean(
                torch.stack(self.get_tower_stats("mean_policy_loss"))
            ),
            "vf_loss": torch.mean(torch.stack(self.get_tower_stats("mean_vf_loss"))),
            "vf_explained_var": torch.mean(
                torch.stack(self.get_tower_stats("vf_explained_var"))
            ),
            "kl": torch.mean(torch.stack(self.get_tower_stats("mean_kl_loss"))),
            "entropy": torch.mean(torch.stack(self.get_tower_stats("mean_entropy"))),
            "entropy_coeff": self.entropy_coeff,
        }
        if self.config["include_concepts"]:
            info_dict["concept_loss"] = torch.mean(
                torch.stack(self.get_tower_stats("mean_concept_loss"))
            )
            trainable_policies = self.config["multiagent"]["policies_to_train"]
            concept_configs = self.config["concept_configs"]
            if "good_policy" in trainable_policies:
                for config in concept_configs["guard"].configs:
                    info_dict[f"{config.full_name}_loss"] = torch.mean(
                        torch.stack(self.get_tower_stats(f"{config.full_name}_loss"))
                    )
                    if config.concept_type == "classification":
                        for stat in [
                            "acc",
                            "f1",
                            "recall",
                            "precision",
                        ]:
                            info_dict[f"{config.full_name}_{stat}"] = torch.mean(
                                torch.stack(
                                    self.get_tower_stats(f"{config.full_name}_{stat}")
                                )
                            )

                for agent_id, tom_config in concept_configs[
                    "guard"
                ].tom_configs.items():
                    configs = tom_config["configs"]
                    for config in configs:
                        info_dict[f"tom_1_{config.full_name}_loss"] = torch.mean(
                            torch.stack(
                                self.get_tower_stats(f"tom_1_{config.full_name}_loss")
                            )
                        )
                        if config.concept_type == "classification":
                            for stat in [
                                "acc",
                                "f1",
                                "recall",
                                "precision",
                            ]:
                                info_dict[f"{config.full_name}_{stat}"] = torch.mean(
                                    torch.stack(
                                        self.get_tower_stats(
                                            f"{config.full_name}_{stat}"
                                        )
                                    )
                                )
            if "adversary_policy" in trainable_policies:
                for config in concept_configs["attacker"].configs:
                    info_dict[f"{config.full_name}_loss"] = torch.mean(
                        torch.stack(self.get_tower_stats(f"{config.full_name}_loss"))
                    )
                    if config.concept_type == "classification":
                        for stat in [
                            "acc",
                            "f1",
                            "recall",
                            "precision",
                        ]:
                            info_dict[f"{config.full_name}_{stat}"] = torch.mean(
                                torch.stack(
                                    self.get_tower_stats(f"{config.full_name}_{stat}")
                                )
                            )
                for agent_id, tom_config in concept_configs[
                    "attacker"
                ].tom_configs.items():
                    configs = tom_config["configs"]
                    for config in configs:
                        info_dict[f"tom_1_{config.full_name}_loss"] = torch.mean(
                            torch.stack(
                                self.get_tower_stats(f"tom_1_{config.full_name}_loss")
                            )
                        )
                        if config.concept_type == "classification":
                            for stat in [
                                "acc",
                                "f1",
                                "recall",
                                "precision",
                            ]:
                                info_dict[f"{config.full_name}_{stat}"] = torch.mean(
                                    torch.stack(
                                        self.get_tower_stats(
                                            f"{config.full_name}_{stat}"
                                        )
                                    )
                                )

        return convert_to_numpy(info_dict)

    def compute_single_action_(
        self,
        obs: Optional[TensorStructType] = None,
        state: Optional[List[TensorType]] = None,
        *,
        prev_action: Optional[TensorStructType] = None,
        prev_reward: Optional[TensorStructType] = None,
        info: dict = None,
        input_dict: Optional[SampleBatch] = None,
        episode: Optional["Episode"] = None,
        explore: Optional[bool] = None,
        timestep: Optional[int] = None,
        concept_update=None,
        do_update=None,
        concepts_to_update=None,
        # Kwars placeholder for future compatibility.
        **kwargs,
    ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
        """Computes and returns a single (B=1) action value.
        Takes an input dict (usually a SampleBatch) as its main data input.
        This allows for using this method in case a more complex input pattern
        (view requirements) is needed, for example when the Model requires the
        last n observations, the last m actions/rewards, or a combination
        of any of these.
        Alternatively, in case no complex inputs are required, takes a single
        `obs` values (and possibly single state values, prev-action/reward
        values, etc..).
        Args:
            obs: Single observation.
            state: List of RNN state inputs, if any.
            prev_action: Previous action value, if any.
            prev_reward: Previous reward, if any.
            info: Info object, if any.
            input_dict: A SampleBatch or input dict containing the
                single (unbatched) Tensors to compute actions. If given, it'll
                be used instead of `obs`, `state`, `prev_action|reward`, and
                `info`.
            episode: This provides access to all of the internal episode state,
                which may be useful for model-based or multi-agent algorithms.
            explore: Whether to pick an exploitation or
                exploration action
                (default: None -> use self.config["explore"]).
            timestep: The current (sampling) time step.
        Keyword Args:
            kwargs: Forward compatibility placeholder.
        Returns:
            Tuple consisting of the action, the list of RNN state outputs (if
            any), and a dictionary of extra features (if any).
        """
        # Build the input-dict used for the call to
        # `self.compute_actions_from_input_dict()`.
        if input_dict is None:
            input_dict = {SampleBatch.OBS: obs}
            if state is not None:
                for i, s in enumerate(state):
                    input_dict[f"state_in_{i}"] = s
            if prev_action is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = prev_action
            if prev_reward is not None:
                input_dict[SampleBatch.PREV_REWARDS] = prev_reward
            if info is not None:
                input_dict[SampleBatch.INFOS] = info

        # Batch all data in input dict.
        input_dict = tree.map_structure_with_path(
            lambda p, s: (
                s
                if p == "seq_lens"
                else s.unsqueeze(0)
                if torch and isinstance(s, torch.Tensor)
                else np.expand_dims(s, 0)
            ),
            input_dict,
        )

        episodes = None
        if episode is not None:
            episodes = [episode]

        out = self.compute_actions_from_input_dict_(
            input_dict=SampleBatch(input_dict),
            episodes=episodes,
            explore=explore,
            timestep=timestep,
            concept_update=concept_update,
            do_update=do_update,
            concepts_to_update=concepts_to_update,
        )

        # Some policies don't return a tuple, but always just a single action.
        # E.g. ES and ARS.
        if not isinstance(out, tuple):
            single_action = out
            state_out = []
            info = {}
        # Normal case: Policy should return (action, state, info) tuple.
        else:
            batched_action, state_out, info = out
            single_action = unbatch(batched_action)
        assert len(single_action) == 1
        single_action = single_action[0]

        # Return action, internal state(s), infos.
        return (
            single_action,
            [s[0] for s in state_out],
            {k: v[0] for k, v in info.items()},
        )

    def compute_actions_from_input_dict_(
        self,
        input_dict: Dict[str, TensorType],
        explore: bool = None,
        timestep: Optional[int] = None,
        concept_update=None,
        do_update=None,
        concepts_to_update=None,
        **kwargs,
    ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:

        with torch.no_grad():
            # Pass lazy (torch) tensor dict to Model as `input_dict`.
            input_dict = self._lazy_tensor_dict(input_dict)
            input_dict.set_training(True)
            # Pack internal state inputs into (separate) list.
            state_batches = [
                input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
            ]
            # Calculate RNN sequence lengths.
            seq_lens = (
                torch.tensor(
                    [1] * len(input_dict["obs"]),
                    dtype=torch.long,
                    device=input_dict["obs"].device,
                )
                if state_batches
                else None
            )

            input_dict["concept_infos"] = concept_update
            input_dict["do_update"] = do_update
            

            return self._compute_action_helper_(
                input_dict,
                state_batches,
                seq_lens,
                explore,
                timestep,
                concept_update,
                do_update,
                concepts_to_update
            )

    @with_lock
    def _compute_action_helper_(
        self,
        input_dict,
        state_batches,
        seq_lens,
        explore,
        timestep,
        concept_update,
        do_update,
        concepts_to_update,
    ):
        """Shared forward pass logic (w/ and w/o trajectory view API).
        Returns:
            A tuple consisting of a) actions, b) state_out, c) extra_fetches.
        """
        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep
        self._is_recurrent = state_batches is not None and state_batches != []

        # Switch to eval mode.
        if self.model:
            self.model.eval()

        if self.action_sampler_fn or self.action_distribution_fn:
            assert 0, "Not implemented"
        else:
            dist_class = self.dist_class
            self.model.concepts_to_update = concepts_to_update
            dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)

            if not (
                isinstance(dist_class, functools.partial)
                or issubclass(dist_class, TorchDistributionWrapper)
            ):
                raise ValueError(
                    "`dist_class` ({}) not a TorchDistributionWrapper "
                    "subclass! Make sure your `action_distribution_fn` or "
                    "`make_model_and_action_dist` return a correct "
                    "distribution class.".format(dist_class.__name__)
                )
            action_dist = dist_class(dist_inputs, self.model)

            # Get the exploration action from the forward results.
            actions, logp = self.exploration.get_exploration_action(
                action_distribution=action_dist, timestep=timestep, explore=explore
            )

        input_dict[SampleBatch.ACTIONS] = actions

        # Add default and custom fetches.
        extra_fetches = self.extra_action_out(
            input_dict, state_batches, self.model, action_dist
        )

        # Action-dist inputs.
        if dist_inputs is not None:
            extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs

        # Action-logp and action-prob.
        if logp is not None:
            extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp.float())
            extra_fetches[SampleBatch.ACTION_LOGP] = logp
        # Update our global timestep by the batch size.
        self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])

        return convert_to_numpy((actions, state_out, extra_fetches))


# class PPOTorchCustomPolicy(PPOTorchPolicy):
#     """Custom PPO policy for PyTorch."""
#     def __init__(self, observation_space, action_space, config):
#         super().__init__(observation_space, action_space, config)
#     @override(PPOTorchPolicy)
#     def postprocess_trajectory(self,
#                                sample_batch,
#                                other_agent_batches=None,
#                                episode=None):
#         # Do all post-processing always with no_grad().
#         # Not using this here will introduce a memory leak
#         # in torch (issue #6962).
#         # TODO: no_grad still necessary?
#         with torch.no_grad():
#             batch = compute_gae_for_sample_batch(self, sample_batch,
#                                                 other_agent_batches, episode)
#             batch = compute_concepts(batch,
#                             self.config["include_concepts"],
#                             ['relative_orientation','distance_between', 'absolute_position'])
#             return batch
