from typing import List, Optional, Union

import torch
import torch.nn.functional as F
from torch import Tensor

EOS = 1e-10


def subgraph(subset: Union[Tensor, List[int]], edge_index: Tensor, num_nodes: Optional[int] = None):

    device = edge_index.device

    node_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)
    node_mask[subset] = 1

    node_idx = torch.zeros(num_nodes, dtype=torch.long, device=device)
    node_idx[subset] = torch.arange(subset.size(0), device=device)

    edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
    edge_index = edge_index[:, edge_mask]
    edge_index = node_idx[edge_index]
    return edge_index


def normalize(adj, mode):
    inv_sqrt_degree = 1. / (torch.sqrt(adj.sum(dim=1, keepdim=False)) + EOS)
    return inv_sqrt_degree[:, None] * adj * inv_sqrt_degree[None, :]


def apply_non_linearity(tensor, non_linearity, i):
    if non_linearity == 'elu':
        return F.elu(tensor * i - i) + 1
    elif non_linearity == 'relu':
        return F.relu(tensor)
    elif non_linearity == 'none':
        return tensor
    else:
        raise NameError('We dont support the non-linearity yet')


def cal_similarity_graph(node_embeddings):
    similarity_graph = torch.mm(node_embeddings, node_embeddings.t())
    return similarity_graph


def top_k(raw_graph, K):
    values, indices = raw_graph.topk(k=int(K), dim=-1)
    assert torch.max(indices) < raw_graph.shape[1]
    mask = torch.zeros(raw_graph.shape).cuda()
    mask[torch.arange(raw_graph.shape[0]).view(-1, 1), indices] = 1.

    mask.requires_grad = False
    sparse_graph = raw_graph * mask
    return sparse_graph


def symmetrize(adj):
    return (adj + adj.T) / 2


def sim(z1, z2):
    z1 = F.normalize(z1)
    z2 = F.normalize(z2)
    return torch.mm(z1, z2.t())


def semi_loss(z1, z2, tau):
    f = lambda x: torch.exp(x / tau)
    refl_sim = f(sim(z1, z1))
    between_sim = f(sim(z1, z2))
    return -torch.log(between_sim.diag() / (refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag()))


def contrastive_loss(h1, h2, tau):
    cl = semi_loss(h1, h2, tau)
    return cl.mean()
