from copy import copy
from typing import Union

import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils import degree

from greatx.nn.models.surrogate import Surrogate
from greatx.utils import remove_edges


class UniversalDefense(torch.nn.Module):
    r"""Base class for graph universal defense
    from the `"Graph Universal Adversarial Defense"
    <https://arxiv.org/abs/2204.09803>`_ paper (arXiv'22)

    """
    def __init__(self, device: str = "cpu"):
        super().__init__()
        self.device = torch.device(device)
        self._anchors = None

    def forward(self, data: Data, target_nodes: Union[int, Tensor],
                k: int = 50, symmetric: bool = True) -> Data:
        """Return the defended graph with defensive perturbation
        performed on.

        Parameters
        ----------
        data : a graph represented as PyG-like data instance
            the graph where the defensive perturbation performed on
        target_nodes : Union[int, Tensor]
            the target nodes where the defensive perturbation performed on
        k : int
            the number of anchor nodes in the defensive perturbation,
            by default 50
        symmetric : bool
            Determine whether the resulting graph is forcibly symmetric,
            by default True

        Returns
        -------
        Data: PyG-like data
            the defended graph with defensive perturbation
            performed on the target nodes
        """
        data = copy(data)
        data.edge_index = remove_edges(data.edge_index,
                                       self.removed_edges(target_nodes, k),
                                       symmetric=symmetric)
        return data

    def removed_edges(self, target_nodes: Union[int, Tensor],
                      k: int = 50) -> Tensor:
        """Return edges to remove with the defensive perturbation performed on
        on the target nodes

        Parameters
        ----------
        target_nodes : Union[int, Tensor]
            the target nodes where the defensive perturbation
            performed on
        k : int
            the number of anchor nodes in the defensive perturbation,
            by default 50

        Returns
        -------
        Tensor, shape [2, k]
            the edges to remove with the defensive perturbation performed on
            on the target nodes
        """
        row = torch.as_tensor(target_nodes, device=self.device).view(-1)
        col = self.anchors(k)
        row, col = row.repeat_interleave(k), col.repeat(row.size(0))

        return torch.stack([row, col], dim=0)

    def anchors(self, k: int = 50) -> Tensor:
        """Return the top-k anchor nodes

        Parameters
        ----------
        k : int, optional
            the number of anchor nodes in the defensive perturbation,
            by default 50

        Returns
        -------
        Tensor
            the top-k anchor nodes
        """
        assert k > 0
        return self._anchors[:k].to(self.device)

    def patch(self, k=50) -> Tensor:
        """Return the universal patch of the defensive perturbation

        Parameters
        ----------
        k : int, optional
            the number of anchor nodes in the defensive perturbation,
            by default 50

        Returns
        -------
        Tensor
            the 0-1 (boolean) universal patch where 1
            denotes the edges to be removed.
        """
        _patch = torch.zeros(self.num_nodes, dtype=torch.bool,
                             device=self.device)
        _patch[self.anchors(k=k)] = True
        return _patch


class GUARD(UniversalDefense, Surrogate):
    r"""Implementation of Graph Universal Adversarial
    Defense (GUARD) from the `"Graph Universal
    Adversarial Defense"
    <https://arxiv.org/abs/2204.09803>`_ paper (arXiv'22)

    Parameters
    ----------
    data : Data
        the PyG-like input data
    alpha : float, optional
        the scale factor for node degree, by default 2
    batch_size : int, optional
        the batch size for computing node influence, by default 512
    device : str, optional
        the device where the method running on, by default "cpu"

    Example
    -------
    .. code-block:: python

        surrogate = GCN(num_features, num_classes, bias=False, acts=None)
        surrogate_trainer = Trainer(surrogate, device=device)
        ckp = ModelCheckpoint('guard.pth', monitor='val_acc')
        trainer.fit(data, mask=(splits.train_nodes,
                    splits.val_nodes), callbacks=[ckp])
        trainer.evaluate(data, splits.test_nodes)

        guard = GUARD(data, device=device)
        guard.setup_surrogate(surrogate, data.y[splits.train_nodes])
        target_node = 1
        perturbed_data = ... # Other PyG-like Data
        guard(perturbed_data, target_node, k=50)
    """
    def __init__(self, data: Data, alpha: float = 2, batch_size: int = 512,
                 device: str = "cpu"):
        super().__init__(device=device)
        self.data = data
        self.alpha = alpha
        self.batch_size = batch_size
        self.influence_score = None
        self.deg = degree(data.edge_index[0], num_nodes=data.num_nodes,
                          dtype=torch.float)

    @torch.no_grad()
    def setup_surrogate(self, surrogate: torch.nn.Module,
                        victim_labels: Tensor) -> "GUARD":
        from greatx.nn.models.supervised import GCN, SGC

        Surrogate.setup_surrogate(self, surrogate=surrogate, freeze=True,
                                  required=(SGC, GCN))
        W = None
        for para in self.surrogate.parameters():
            if para.ndim == 1:
                continue
            if W is None:
                W = para.detach()
            else:
                W = para.detach() @ W

        W = self.data.x.to(self.device) @ W.t()
        d = self.deg.clamp(min=1).to(self.device)

        loader = DataLoader(victim_labels, pin_memory=False,
                            batch_size=self.batch_size, shuffle=False)

        w_max = W.max(1).values
        incluence = 0.
        for y in loader:
            incluence += W[:, y].sum(1)
        incluence = (w_max - incluence / victim_labels.size(0)) / \
            d.pow(self.alpha)  # node importance
        self._anchors = torch.argsort(incluence, descending=True)
        self.influence_score = incluence
        return self


class DegreeGUARD(UniversalDefense):
    r"""Implementation of Graph Universal Defense
    based on node degrees from the `"Graph Universal
    Adversarial Defense"
    <https://arxiv.org/abs/2204.09803>`_ paper (arXiv'22)

    Parameters
    ----------
    data : Data
        the PyG-like input data
    descending : bool, optional
        whether the degree of chosen nodes are in descending order,
        by default False
    device : str, optional
        the device where the method running on, by default "cpu"

    Example
    -------
    .. code-block:: python

        data = ... # PyG-like Data
        guard = DegreeGUARD(data))
        target_node = 1
        perturbed_data = ... # Other PyG-like Data
        guard(perturbed_data, target_node, k=50)
    """
    def __init__(self, data: Data, descending: bool = False,
                 device: str = "cpu"):
        super().__init__(device=device)
        deg = degree(data.edge_index[0], num_nodes=data.num_nodes,
                     dtype=torch.float)
        self._anchors = torch.argsort(deg, descending=descending)


class RandomGUARD(UniversalDefense):
    r"""Implementation of Graph Universal Defense
    based on random choices from the `"Graph Universal
    Adversarial Defense"
    <https://arxiv.org/abs/2204.09803>`_ paper (arXiv'22)

    Parameters
    ----------
    data : Data
        the PyG-like input data
    device : str, optional
        the device where the method running on, by default "cpu"

    Example
    -------
    .. code-block:: python

        data = ... # PyG-like Data
        guard = RandomGUARD(data)
        target_node = 1
        perturbed_data = ... # Other PyG-like Data
        guard(perturbed_data, target_node, k=50)
    """
    def __init__(self, data: Data, device: str = "cpu"):
        super().__init__(device=device)
        self._anchors = torch.randperm(data.num_nodes, device=self.device)
