import copy
from typing import Union, List, Optional

import torch
from torch import Tensor
from torch_geometric.utils.num_nodes import maybe_num_nodes

def _subgraph(subset: Union[Tensor, List[int]], edge_index: Tensor,
             edge_attr: Optional[Tensor] = None, relabel_nodes: bool = False,
             num_nodes: Optional[int] = None, return_edge_mask: bool = False):
    r"""Returns the induced subgraph of :obj:`(edge_index, edge_attr)`
    containing the nodes in :obj:`subset`.

    Args:
        subset (LongTensor, BoolTensor or [int]): The nodes to keep.
        edge_index (LongTensor): The edge indices.
        edge_attr (Tensor, optional): Edge weights or multi-dimensional
            edge features. (default: :obj:`None`)
        relabel_nodes (bool, optional): If set to :obj:`True`, the resulting
            :obj:`edge_index` will be relabeled to hold consecutive indices
            starting from zero. (default: :obj:`False`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
        return_edge_mask (bool, optional): If set to :obj:`True`, will return
            the edge mask to filter out additional edge features.
            (default: :obj:`False`)

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """

    device = edge_index.device

    if isinstance(subset, (list, tuple)):
        subset = torch.tensor(subset, dtype=torch.long, device=device)

    if subset.dtype == torch.bool or subset.dtype == torch.uint8:
        node_mask = subset
        num_nodes = node_mask.size(0)

        if relabel_nodes:
            node_idx = torch.zeros(node_mask.size(0), dtype=torch.long,
                                   device=device)
            node_idx[subset] = torch.arange(subset.sum().item(), device=device)
    else:
        num_nodes = maybe_num_nodes(edge_index, num_nodes)
        node_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)
        node_mask[subset] = 1

        if relabel_nodes:
            node_idx = torch.zeros(num_nodes, dtype=torch.long, device=device)
            node_idx[subset] = torch.arange(subset.size(0), device=device)

    edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
    edge_index = edge_index[:, edge_mask]
    edge_attr = edge_attr[edge_mask] if edge_attr is not None else None

    if relabel_nodes:
        edge_index = node_idx[edge_index]

    if return_edge_mask:
        return edge_index, edge_attr, edge_mask
    else:
        return edge_index, edge_attr

def subgraph(data, subset: Tensor):
    r"""Returns the induced subgraph given by the node indices
    :obj:`subset`.

    Args:
        subset (LongTensor or BoolTensor): The nodes to keep.
    """

    out = _subgraph(subset, data.edge_index, relabel_nodes=True,
                   num_nodes=data.num_nodes, return_edge_mask=True)
    edge_index, _, edge_mask = out

    if subset.dtype == torch.bool:
        num_nodes = int(subset.sum())
    else:
        num_nodes = subset.size(0)

    data = copy.copy(data)

    for key, value in data:
        if key == 'edge_index':
            data.edge_index = edge_index
        elif key == 'num_nodes':
            data.num_nodes = num_nodes
        elif key == 'y':
            data[key] = value[subset]
        elif isinstance(value, Tensor):
            if data.is_node_attr(key):
                data[key] = value[subset]
            elif data.is_edge_attr(key):
                data[key] = value[edge_mask]

    return data