import warnings
from typing import Any, Dict, Optional, Union

import torch
from torch import Tensor

from torch_geometric.explain.algorithm import ExplainerAlgorithm
from .explanation import (
    Explanation,
    GenerativeExplanation,
    HeteroExplanation,
)
from torch_geometric.explain.algorithm.utils import (
    clear_masks,
    set_hetero_masks,
    set_masks,
)
from explain.config import (
    ExplainerConfig,
    ExplanationType,
    MaskType,
    ModelConfig,
    ModelMode,
    ModelReturnType,
    ThresholdConfig,
)
from torch_geometric.typing import EdgeType, NodeType


class Explainer:
    r"""An explainer class for instance-level explanations of Graph Neural
    Networks.

    Args:
        model (torch.nn.Module): The model to explain.
        algorithm (ExplainerAlgorithm): The explanation algorithm.
        explanation_type (ExplanationType or str): The type of explanation to
            compute. The possible values are:

                - :obj:`"model"`: Explains the model prediction.

                - :obj:`"phenomenon"`: Explains the phenomenon that the model
                  is trying to predict.

            In practice, this means that the explanation algorithm will either
            compute their losses with respect to the model output
            (:obj:`"model"`) or the target output (:obj:`"phenomenon"`).
        model_config (ModelConfig): The model configuration.
            See :class:`~torch_geometric.explain.config.ModelConfig` for
            available options. (default: :obj:`None`)
        node_mask_type (MaskType or str, optional): The type of mask to apply
            on nodes. The possible values are (default: :obj:`None`):

                - :obj:`None`: Will not apply any mask on nodes.

                - :obj:`"object"`: Will mask each node.

                - :obj:`"common_attributes"`: Will mask each feature.

                - :obj:`"attributes"`: Will mask each feature across all nodes.

        edge_mask_type (MaskType or str, optional): The type of mask to apply
            on edges. Has the sample possible values as :obj:`node_mask_type`.
            (default: :obj:`None`)
        threshold_config (ThresholdConfig, optional): The threshold
            configuration.
            See :class:`~torch_geometric.explain.config.ThresholdConfig` for
            available options. (default: :obj:`None`)
    """
    def __init__(
        self,
        model: torch.nn.Module,
        algorithm: ExplainerAlgorithm,
        explanation_type: Union[ExplanationType, str],
        model_config: Union[ModelConfig, Dict[str, Any]],
        node_mask_type: Optional[Union[MaskType, str]] = None,
        edge_mask_type: Optional[Union[MaskType, str]] = None,
        threshold_config: Optional[ThresholdConfig] = None,
    ):
        explainer_config = ExplainerConfig(
            explanation_type=explanation_type,
            node_mask_type=node_mask_type,
            edge_mask_type=edge_mask_type,
        )

        self.model = model
        self.algorithm = algorithm

        self.explanation_type = explainer_config.explanation_type
        self.model_config = ModelConfig.cast(model_config)
        self.node_mask_type = explainer_config.node_mask_type
        self.edge_mask_type = explainer_config.edge_mask_type
        self.threshold_config = ThresholdConfig.cast(threshold_config)

        self.algorithm.connect(explainer_config, self.model_config)

    @torch.no_grad()
    def get_prediction(self, *args, **kwargs) -> Tensor:
        r"""Returns the prediction of the model on the input graph.

        If the model mode is :obj:`"regression"`, the prediction is returned as
        a scalar value.
        If the model mode is :obj:`"multiclass_classification"` or
        :obj:`"binary_classification"`, the prediction is returned as the
        predicted class label.

        Args:
            *args: Arguments passed to the model.
            **kwargs (optional): Additional keyword arguments passed to the
                model.
        """
        training = self.model.training
        self.model.eval()

        with torch.no_grad():
            out = self.model(*args, **kwargs)

        self.model.train(training)

        return out

    def get_masked_prediction(
        self,
        x: Union[Tensor, Dict[NodeType, Tensor]],
        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
        node_mask: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None,
        edge_mask: Optional[Union[Tensor, Dict[EdgeType, Tensor]]] = None,
        **kwargs,
    ) -> Tensor:
        r"""Returns the prediction of the model on the input graph with node
        and edge masks applied.
        """
        if isinstance(x, Tensor) and node_mask is not None:
            x = node_mask * x
        elif isinstance(x, dict) and node_mask is not None:
            x = {key: value * node_mask[key] for key, value in x.items()}

        if isinstance(edge_mask, Tensor):
            set_masks(self.model, edge_mask, edge_index, apply_sigmoid=False)
        elif isinstance(edge_mask, dict):
            set_hetero_masks(self.model, edge_mask, edge_index,
                             apply_sigmoid=False)

        out = self.get_prediction(x, edge_index, **kwargs)
        clear_masks(self.model)

        return out

    def __call__(
        self,
        x: Optional[Union[Tensor, Dict[NodeType, Tensor]]],
        edge_index: Optional[Union[Tensor, Dict[EdgeType, Tensor]]],
        *,
        target: Optional[Tensor] = None,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ) -> Union[Explanation, HeteroExplanation, GenerativeExplanation]:
        r"""Computes the explanation of the GNN for the given inputs and
        target.

        .. note::

            If you get an error message like "Trying to backward through the
            graph a second time", make sure that the target you provided
            was computed with :meth:`torch.no_grad`.

        Args:
            x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The input
                node features of a homogeneous or heterogeneous graph.
            edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The
                input edge indices of a homogeneous or heterogeneous graph.
            target (torch.Tensor): The target of the model.
                If the explanation type is :obj:`"phenomenon"`, the target has
                to be provided.
                If the explanation type is :obj:`"model"`, the target should be
                set to :obj:`None` and will get automatically inferred. For
                classification tasks, the target needs to contain the class
                labels. (default: :obj:`None`)
            index (Union[int, Tensor], optional): The indices in the
                first-dimension of the model output to explain.
                Can be a single index or a tensor of indices.
                If set to :obj:`None`, all model outputs will be explained.
                (default: :obj:`None`)
            **kwargs: additional arguments to pass to the GNN.
        """
        # Choose the `target` depending on the explanation type:
        prediction: Optional[Tensor] = None
        if self.explanation_type == ExplanationType.phenomenon:
            if target is None:
                raise ValueError(
                    f"The 'target' has to be provided for the explanation "
                    f"type '{self.explanation_type.value}'")
        elif self.explanation_type == ExplanationType.model:
            if target is not None:
                warnings.warn(
                    f"The 'target' should not be provided for the explanation "
                    f"type '{self.explanation_type.value}'")
            prediction = self.get_prediction(x, edge_index, **kwargs)
            target = self.get_target(prediction)
        elif self.explanation_type == ExplanationType.generative:
            pass

        if isinstance(index, int):
            index = torch.tensor([index])

        training = self.model.training
        self.model.eval()

        explanation = self.algorithm(
            self.model,
            x,
            edge_index,
            target=target,
            index=index,
            **kwargs,
        )

        self.model.train(training)

        if isinstance(explanation, GenerativeExplanation):
            return explanation

        # Add explainer objectives to the `Explanation` object:
        explanation._model_config = self.model_config
        explanation.prediction = prediction
        explanation.target = target
        explanation.index = index

        # Add model inputs to the `Explanation` object:
        if isinstance(explanation, Explanation):
            explanation._model_args = list(kwargs.keys())
            explanation.x = x
            explanation.edge_index = edge_index

            for key, arg in kwargs.items():  # Add remaining `kwargs`:
                explanation[key] = arg

        elif isinstance(explanation, HeteroExplanation):
            # TODO Add `explanation._model_args`

            assert isinstance(x, dict)
            explanation.set_value_dict('x', x)

            assert isinstance(edge_index, dict)
            explanation.set_value_dict('edge_index', edge_index)

            for key, arg in kwargs.items():  # Add remaining `kwargs`:
                if isinstance(arg, dict):
                    # Keyword arguments are likely named `{attr_name}_dict`
                    # while we only want to assign the `{attr_name}` to the
                    # `HeteroExplanation` object:
                    key = key[:-5] if key.endswith('_dict') else key
                    explanation.set_value_dict(key, arg)
                else:
                    explanation[key] = arg

        explanation.validate_masks()
        return explanation.threshold(self.threshold_config)

    def get_target(self, prediction: Tensor) -> Tensor:
        r"""Returns the target of the model from a given prediction.

        If the model mode is of type :obj:`"regression"`, the prediction is
        returned as it is.
        If the model mode is of type :obj:`"multiclass_classification"` or
        :obj:`"binary_classification"`, the prediction is returned as the
        predicted class label.
        """
        if self.model_config.mode == ModelMode.binary_classification:
            # TODO: Allow customization of the thresholds used below.
            if self.model_config.return_type == ModelReturnType.raw:
                return (prediction > 0).long().view(-1)
            if self.model_config.return_type == ModelReturnType.probs:
                return (prediction > 0.5).long().view(-1)
            assert False

        if self.model_config.mode == ModelMode.multiclass_classification:
            return prediction.argmax(dim=-1)

        return prediction


from abc import abstractmethod
from typing import Dict, Optional, Union

import torch
from torch import Tensor
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from .explanation import GenerativeExplanation, ExplanationSetSampler


class XGNNExplainer(ExplainerAlgorithm):
    r"""The XGNN-Explainer model from the `"XGNN: Towards Model-Level
    Explanations of Graph Neural Networks" <https://arxiv.org/abs/2006.02587>`_
    paper for training a graph generator so that the generated graph patterns
    maximize a certain prediction of the model.

    XGNN-Explainer interprets GNN models by training a graph generator, which
    creates graph patterns that maximize the predictive outcome for a specific
    class. This approach provides a model-level explanation, offering insights
    into what input patterns lead to certain predictions by GNNs.

    .. note::

        For an example of using :class:`XGNNExplainer`, see
        `examples/explain/xgnn_explainer.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/explain/xgnn_explainer.py>`_.

    The explainer trains a generative model, iterating over steps to
    progressively build a graph that maximizes the class score of the target
    class. The generative model can be customized and is a key component of the
    explanation process.

    Args:
        epochs (int, optional): The number of epochs for training the
        generative model.
            (default: :obj:`100`)
        lr (float, optional): The learning rate for training the
        generative model.
            (default: :obj:`0.01`)
        **kwargs (optional): Additional hyper-parameters to configure the
        training
            process of the generative model.
    """
    def __init__(self, epochs: int = 100, lr: float = 0.01, **kwargs):
        super().__init__()
        self.epochs = epochs
        self.lr = lr
        self.config = kwargs

    def forward(
        self,
        model: torch.nn.Module,
        x: Union[Tensor, Dict[str, Tensor]],
        edge_index: Union[Tensor, Dict[str, Tensor]],
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ) -> GenerativeExplanation:
        r"""Computes the generative explanation for each class.

        Args:
            model (torch.nn.Module): The model to explain.
            x (Union[Tensor, Dict[str, Tensor]]): The input node features.
            edge_index (Union[Tensor, Dict[str, Tensor]]): The edge indices.
            target (Tensor): The target tensor for the explanation.
            index (Optional[Union[int, Tensor]], optional): The index of the
            node or graph to explain, the index does not matter for this
            explainer, and is only here for the sake of integration with
            other classes.
            **kwargs: Additional keyword arguments.

        Returns:
            Union[Explanation, HeteroExplanation]: The explanation result.
        """
        # Validate 'for_class' argument
        for_class = kwargs.pop('for_class', None)
        if for_class is None:
            raise ValueError("The 'for_class' argument must be provided")

        if isinstance(x, dict) or isinstance(edge_index, dict):
            raise ValueError(f"Heterogeneous graphs not yet supported in "
                             f"'{self.__class__.__name__}'")

        if index is not None:
            raise ValueError(
                f"Index not supported in '{self.__class__.__name__}'")

        generative_model_t = self.train_generative_model(
            model, for_class=for_class, **kwargs)

        return GenerativeExplanation(model=model,
                                     explanation_set=generative_model_t)

    @abstractmethod
    def train_generative_model(self, model, for_class,
                               **kwargs) -> ExplanationSetSampler:
        r"""Abstract method to train the generative model.
        Must be implemented in subclasses.

        Args:
            model: The model to explain.
            for_class: The class for which the explanation is generated.
            **kwargs: Additional keyword arguments.
        """
        raise NotImplementedError(
            "The method train_generative_model must be implemented",
            "in subclasses")

    def supports(self) -> bool:
        return True