import warnings
from copy import copy
from functools import lru_cache
from typing import Optional

import numpy as np
import torch
from torch import Tensor
from torch_geometric.data import Data

from greatx.attack.attacker import Attacker
from greatx.utils import BunchDict, add_edges, remove_edges


class FlipAttacker(Attacker):
    """Adversarial attacker for graph data by flipping edges.

    Parameters
    ----------
    data : Data
        PyG-like data denoting the input graph
    device : str, optional
        the device of the attack running on, by default "cpu"
    seed : Optional[int], optional
        the random seed for reproducing the attack, by default None
    name : Optional[str], optional
        name of the attacker, if None, it would be
        :obj:`__class__.__name__`, by default None
    kwargs : additional arguments of :class:`greatx.attack.Attacker`,

    Raises
    ------
    TypeError
        unexpected keyword argument in :obj:`kwargs`

    Note
    ----
    :class:`greatx.attack.FlipAttacker` is a base class
    for graph modification attacks (GMA).
    """
    def reset(self) -> "FlipAttacker":
        """Reset attacker. This method must be called before attack."""
        super().reset()
        self.data.cache_clear()
        self._removed_edges = {}
        self._added_edges = {}
        self._removed_feats = {}
        self._added_feats = {}
        self.degree = self._degree.clone()
        return self

    def remove_edge(self, u: int, v: int, it: Optional[int] = None):
        """Remove an edge from the graph.

        Parameters
        ----------
        u : int
            The source node of the edge
        v : int
            The destination node of the edge
        it : Optional[int], optional
             The iteration that indicates the order of
             the edge being removed, by default None
        """

        if not self._allow_singleton:
            is_singleton_u = self.degree[u] <= 1
            is_singleton_v = self.degree[v] <= 1

            if is_singleton_u or is_singleton_v:
                warnings.warn(
                    f"You are trying to remove an edge ({u}-{v}) "
                    "that would result in singleton nodes. "
                    "If the behavior is not intended, "
                    "please make sure you have set "
                    "`attacker.set_allow_singleton(False)` "
                    "or check your algorithm.", UserWarning)

        self._removed_edges[(u, v)] = it
        self.degree[u] -= 1
        self.degree[v] -= 1

    def add_edge(self, u: int, v: int, it: Optional[int] = None):
        """Add one edge to the graph.

        Parameters
        ----------
        u : int
            The source node of the edge
        v : int
            The destination node of the edge
        it : Optional[int], optional
             The iteration that indicates the order of
             the edge being added, by default None
        """
        self._added_edges[(u, v)] = it
        self.degree[u] += 1
        self.degree[v] += 1

    def removed_edges(self) -> Optional[Tensor]:
        """Get all the edges to be removed.
        """
        edges = self._removed_edges
        if edges is None or len(edges) == 0:
            return None

        if torch.is_tensor(edges):
            return edges.to(self.device)

        if isinstance(edges, dict):
            edges = list(edges.keys())

        removed = torch.tensor(
            np.asarray(edges, dtype="int64").T, device=self.device)
        return removed

    def added_edges(self) -> Optional[Tensor]:
        """Get all the edges to be added."""
        edges = self._added_edges
        if edges is None or len(edges) == 0:
            return None

        if torch.is_tensor(edges):
            return edges.to(self.device)

        if isinstance(edges, dict):
            edges = list(edges.keys())

        return torch.tensor(
            np.asarray(edges, dtype="int64").T, device=self.device)

    def edge_flips(self, frac: float = 1.0) -> BunchDict:
        """Get all the edges to be flipped, including edges
        to be added and removed.

        Parameters
        ----------
        frac : float, optional
            the fraction of edge perturbations, i.e.,
            how many perturbed features are used to
            construct the perturbed graph.
            by default 1.0

        Example
        -------
        >>> # Get the edge flips
        >>> attacker.edge_flips()

        >>> # Get the edge flips, with
        >>> # specifying feat_ratio
        >>> attacker.edge_flips(frac=0.5)
        """
        assert 0 <= frac <= 1
        added = self.added_edges()
        if added is not None:
            added = added[:, :round(added.size(1) * frac)]

        removed = self.removed_edges()
        if removed is not None:
            removed = removed[:, :round(removed.size(1) * frac)]

        _all = cat(added, removed, dim=1)
        return BunchDict(added=added, removed=removed, all=_all)

    def remove_feat(self, u: int, v: int, it: Optional[int] = None):
        """Remove the feature in a dimension `v` form a node `u`.
        That is, set a dimension of the specific node to zero.

        Parameters
        ----------
        u : int
            the node whose features are to be removed
        v : int
            the dimension of the feature to be removed
        it : Optional[int], optional
            The iteration that indicates the order
            of the features being removed, by default None
        """

        self._removed_feats[(u, v)] = it

    def add_feat(self, u: int, v: int, it: Optional[int] = None):
        """Remove the feature in a dimension `v` form a node `u`.
        That is, set a dimension of the specific node to one.

        Parameters
        ----------
        u : int
            the node whose features are to be added
        v : int
            the dimension of the feature to be added
        it : Optional[int], optional
            The iteration that indicates the order
            of the features being added, by default None
        """
        self._added_feats[(u, v)] = it

    def removed_feats(self) -> Optional[Tensor]:
        """Get all the features to be removed."""
        feats = self._removed_feats
        if feats is None or len(feats) == 0:
            return None

        if isinstance(feats, dict):
            feats = list(feats.keys())

        if torch.is_tensor(feats):
            return feats.to(self.device)

        return torch.tensor(
            np.asarray(feats, dtype="int64").T, device=self.device)

    def added_feats(self) -> Optional[Tensor]:
        """Get all the features to be added."""
        feats = self._added_feats
        if feats is None or len(feats) == 0:
            return None

        if isinstance(feats, dict):
            feats = list(feats.keys())

        if torch.is_tensor(feats):
            return feats.to(self.device)

        return torch.tensor(
            np.asarray(feats, dtype="int64").T, device=self.device)

    def feat_flips(self, frac: float = 1.0) -> BunchDict:
        """Get all the features to be flipped, including features
        to be added and removed.

        Parameters
        ----------
        frac : float, optional
            the fraction of feature perturbations, i.e.,
            how many perturbed features are used to
            construct the perturbed graph.
            by default 1.0

        Example
        -------
        >>> # Get the feature flips
        >>> attacker.feat_flips()

        >>> # Get the feature flips, with
        >>> # specifying feat_ratio
        >>> attacker.feat_flips(frac=0.5)
        """
        assert 0 <= frac <= 1

        added = self.added_feats()
        if added is not None:
            added = added[:, :round(added.size(1) * frac)]

        removed = self.removed_feats()
        if removed is not None:
            removed = removed[:, :round(removed.size(1) * frac)]

        _all = cat(added, removed, dim=1)
        return BunchDict(added=added, removed=removed, all=_all)

    @lru_cache(maxsize=1)
    def data(
        self,
        edge_ratio: float = 1.0,
        feat_ratio: float = 1.0,
        coalesce: bool = True,
        symmetric: bool = True,
    ) -> Data:
        """Get the attacked graph denoted by
        PyG-like data instance. Note that this method
        uses LRU cache for efficiency, the computation is
        only excuted at the first call if the input parameters
        were the same.

        Parameters
        ----------
        edge_ratio : float, optional
            the fraction of edge perturbations, i.e.,
            how many perturbed edges are used to
            construct the perturbed graph.
            by default 1.0
        feat_ratio : float, optional
            the fraction of feature perturbations, i.e.,
            how many perturbed features are used to
            construct the perturbed graph.
            by default 1.0
        coalesce : bool, optional
            whether to coalesce the output edges.
        symmetric : bool, optional
            whether the output graph is symmetric, by default True

        Example
        -------
        >>> # Get the perturbed graph, including
        >>> # edge flips and feature flips
        >>> attacker.data()

        >>> # Get the perturbed graph, with
        >>> # specifying edge_ratio
        >>> attacker.data(edge_ratio=0.5)

        >>> # Get the perturbed graph, with
        >>> # specifying feat_ratio
        >>> attacker.data(feat_ratio=0.5)

        Returns
        -------
        Data
            the attacked graph denoted by PyG-like data instance
        """

        data = copy(self.ori_data)
        edge_index = data.edge_index
        edge_weight = data.edge_weight
        assert edge_weight is None, 'weighted graph is not supported now.'

        edge_flips = self.edge_flips(frac=edge_ratio)
        removed = edge_flips['removed']

        if removed is not None:
            edge_index = remove_edges(edge_index, removed, symmetric=symmetric)

        added = edge_flips['added']
        if added is not None:
            edge_index = add_edges(edge_index, added, symmetric=symmetric,
                                   coalesce=coalesce)

        data.edge_index = edge_index

        if edge_weight is not None:
            data.edge_weight = edge_weight

        if self.feature_attack:
            feat = self.feat.detach().clone()
            feat_flips = self.feat_flips(frac=feat_ratio)
            removed = feat_flips['removed']
            if removed is not None:
                feat[removed[0], removed[1]] = 0.

            added = feat_flips['added']
            if added is not None:
                feat[added[0], added[1]] = 1.
            data.x = feat

        return data

    def set_allow_singleton(self, state: bool):
        """Set whether the attacked graph allow singleton node, i.e.,
        zero degree nodes.

        Parameters
        ----------
        state : bool
            the flag to set

        Example
        -------
        >>> attacker.set_allow_singleton(True)
        """

        self._allow_singleton = state

    def is_singleton_edge(self, u: int, v: int) -> bool:
        """Check if the edge is an singleton edge that, if removed,
        would result in a singleton node in the graph.

        Parameters
        -----------
        u : int
            The source node of the edge
        v : int
            The destination node of the edge

        Return
        ------
        bool: `True` if the edge is an singleton edge, otherwise `False`.

        Note
        ----
        Please make sure the edge is the one being removed.
        """
        threshold = 1
        # threshold = 2 if the graph has selfloop before
        # otherwise threshold = 1
        if not self._allow_singleton and (self.degree[u] <= threshold
                                          or self.degree[v] <= threshold):
            return True
        return False

    def is_legal_edge(self, u: int, v: int) -> bool:
        """Check whether the edge (u,v) is legal.

        An edge (u,v) is legal if u!=v and edge (u,v) is
        not selected before.

        Parameters
        -----------
        u : int
            The source node of the edge
        v : int
            The destination node of the edge

        Returns
        -------
        bool: :obj:`True` if the u!=v and edge (u,v), (v,u) is not selected,
        otherwise :obj:`False`.
        """
        _removed_edges = self._removed_edges
        _added_edges = self._added_edges

        return all((u != v, (u, v) not in _removed_edges, (v, u)
                    not in _removed_edges, (u, v) not in _added_edges, (v, u)
                    not in _added_edges))


def cat(a, b, dim=1):
    if a is None:
        return b
    if b is None:
        return a

    return torch.cat([a, b], dim=dim)
