import warnings
import numpy as np
import random
import torch
import torch.nn.functional as F
from torch_geometric.utils import to_scipy_sparse_matrix

import dgl
import os

warnings.filterwarnings("ignore")
EOS = 1e-10
    
def setup_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)   
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)  
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True  
    torch.backends.cudnn.enabled = True
    torch.manual_seed(seed)
    dgl.seed(seed)
    dgl.random.seed(seed)
    

def split_batch(init_list, batch_size):
    groups = zip(*(iter(init_list),) * batch_size)
    end_list = [list(i) for i in groups]
    count = len(init_list) % batch_size
    end_list.append(init_list[-count:]) if count != 0 else end_list
    return end_list


def gaussian_sim(X, Y, gamma=0.5):
    X_norm = (X ** 2).sum(dim=1, keepdim=True)  
    Y_norm = (Y ** 2).sum(dim=1, keepdim=True)  
    K = X_norm - 2 * torch.mm(X, Y.T) + Y_norm.T  
    return torch.exp(-gamma * K)


def cos_sim(a, b, eps=1e-8):
    """
    calculate cosine similarity between matrix a and b
    """
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
    b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sim_mt


def edge_index_to_adj(edge_index, device):
    adj = to_scipy_sparse_matrix(edge_index)
    adj = torch.from_numpy(adj.toarray()).to(device)
    return adj


def gumbel_sampling(weights, temperature=1.0, bias=0.0 + 0.0001, device='cuda'): 
    eps = (bias - (1 - bias)) * torch.rand(weights.size()) + (1 - bias) 
    gate_inputs = torch.log(eps) - torch.log(1 - eps) 
    gate_inputs = gate_inputs.to(device)
    gate_inputs = (gate_inputs + weights) / temperature
    return torch.sigmoid(gate_inputs).squeeze() 

def generate_random_node_pairs(nnodes, nedges, device, backup=300):
    rand_edges = np.random.choice(nnodes, size=(nedges + backup) * 2, replace=True)
    rand_edges = rand_edges.reshape((2, nedges + backup))
    rand_edges = torch.from_numpy(rand_edges)
    rand_edges = rand_edges[:, rand_edges[0,:] != rand_edges[1,:]]
    rand_edges = rand_edges[:, 0: nedges]
    return rand_edges.to(device)

