import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
import scipy.sparse as sp 
from utils import *

class Edge_Evaluator(nn.Module):
    def __init__(self, in_dim, sparse, device, hidden_dim=128, temperature=1.0, bias=0.0 + 0.0001):
        super(Edge_Evaluator, self).__init__()

        self.sparse = sparse

        self.mlp = Linear(in_dim, hidden_dim)
        self.edge_mlp = Linear(hidden_dim * 2, 1)

        self.temperature = temperature
        self.bias = bias
        self.device = device


    def get_edge_weight(self, embeddings, edges):
        s1 = self.edge_mlp(torch.cat((embeddings[edges[0]], embeddings[edges[1]]), dim=1)).flatten()
        s2 = self.edge_mlp(torch.cat((embeddings[edges[1]], embeddings[edges[0]]), dim=1)).flatten()
        return (s1 + s2) / 2


    def gumbel_sampling(self, edges_weights_raw):
        eps = (self.bias - (1 - self.bias)) * torch.rand(edges_weights_raw.size()) + (1 - self.bias)
        gate_inputs = torch.log(eps) - torch.log(1 - eps)
        gate_inputs = gate_inputs.to(self.device)
        gate_inputs = (gate_inputs + edges_weights_raw) / self.temperature
        return torch.sigmoid(gate_inputs).squeeze()



    def forward(self, x, edges):
        nnodes = x.shape[0]
        embeddings = F.relu(self.mlp(x))
        edges_weights_raw = self.get_edge_weight(embeddings, edges)
        weights = self.gumbel_sampling(edges_weights_raw)
        if not self.sparse:
            W = torch.zeros(nnodes, nnodes).to(self.device)
            W[edges[0], edges[1]] = weights
        else:
            W = dgl.graph((edges[0], edges[1]), num_nodes=nnodes, device=self.device)
            W.edata['w'] = weights

    
        return W
