import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np
from pygsp import graphs, filters, reduction
import math
from sklearn.metrics import f1_score

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g.to(device)
        # equation (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False).to(device)

    def sigma(self, x):
        return 1. / (1 + torch.exp(-1 * x)).to(device)

    def edge_attention(self, edges):
        beta = 1 / 4
        gamma = -0.1
        zeta = 1.1
        # edge UDF for equation (2)
        edge_data = edges.data['loga']
        #        print("drawing multivariate random variable")
        u = torch.Tensor((np.random.uniform(0, 1, size=edge_data.size())).astype(float)).to(device)
        f_logalpha_u = self.sigma((torch.log(u) - torch.log(1 - u) + edges.data['loga']) / beta) * (
                    zeta - gamma) + gamma

        z = torch.clamp(torch.clamp(f_logalpha_u, min=0), max=1).to(device)
        return {'e': z}

    def edge_attention_eval(self, edges):
        beta = 1 / 4
        gamma = -0.1
        zeta = 1.1
        edge_data = edges.data['loga']
        f_logalpha_u = self.sigma((edges.data['loga']) / beta) * (zeta - gamma) + gamma

        z = torch.clamp(torch.clamp(f_logalpha_u, min=0), max=1).to(device)
        return {'e': z}

    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        return {'z': edges.src['z'].to(device), 'e': edges.data['e'].to(device)}

    def reduce_func(self, nodes):
        # reduce UDF for equation (3) & (4)
        # equation (3)
        test = nodes.mailbox['e'].to(device)
        # a= F.softmax(torch.log(nodes.mailbox['e']) , dim=1)
        a = F.normalize(nodes.mailbox['e'], p=1, dim=1).to(device)
        # equation (4)
        h = torch.sum(a * nodes.mailbox['z'], dim=1).to(device)
        return {'h': h.to(device)}

    def forward(self, h, logalpha, train_flag=1):
        # equation (1)
        z = self.fc(h.to(device)).to(device)
        self.g.ndata['z'] = z.to(device)
        self.g.edata['loga'] = logalpha.to(device)
        # self.g.ndata['z'] = z
        # equation (2)
        if train_flag:
            self.g.apply_edges(self.edge_attention)
        else:
            self.g.apply_edges(self.edge_attention_eval)
        # equation (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')


# where :math:`K` is the number of heads. You can use
# concatenation for intermediary layers and average for the final layer.
#
# Use the above defined single-head ``GATLayer`` as the building block
# for the ``MultiHeadGATLayer`` below:

class MultiHeadsparseGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, logalpha, merge='cat'):
        super(MultiHeadsparseGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        self.logalpha = logalpha
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h, train_flag=1):
        head_outs = [attn_head(h.to(device), self.logalpha.to(device), train_flag) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))


###########################################################################
# Put everything together
# ^^^^^^^^^^^^^^^^^^^^^^^
#
# Now, you can define a two-layer GAT model.

class sparseGAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(sparseGAT, self).__init__()
        Ne = int(g.number_of_edges())
        self.droprate_init = 0.5
        self.logalpha = Parameter(torch.Tensor(Ne, 1)).to(device)
        self.logalpha.data.normal_(math.log(1 - self.droprate_init) - math.log(self.droprate_init), 1e-2).to(device)
        self.layer1 = MultiHeadsparseGATLayer(g, in_dim, hidden_dim, num_heads, self.logalpha).to(device)
        # Be aware that the input dimension is hidden_dim*num_heads since
        # multiple head outputs are concatenated together. Also, only
        # one attention head in the output layer.
        self.layer2 = MultiHeadsparseGATLayer(g, hidden_dim * num_heads, out_dim, num_heads, self.logalpha).to(device)

    def forward(self, h, train_flag=1):
        h = self.layer1(h, train_flag)
        h = F.elu(h)
        h = self.layer2(h, train_flag)
        return h


# def evaluate(model, g, features, labels, mask):
#    model.eval().to(device)
#    with torch.no_grad():
#        logits = model(features, train_flag=0)
#        logits = logits[mask]
#        labels = labels[mask].to(device)
#        _, indices = torch.max(logits, dim=1)
#        correct = torch.sum(indices.to(device) == labels)
#        return correct.item() * 1.0 / len(labels)

def evaluate(model, g, features, labels, mask):
    model.eval().to(device)
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask].to(device)
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices.to(device) == labels)
        score = f1_score(indices.data.cpu().numpy(), labels.data.cpu().numpy(), average='micro')
        #        return correct.item() * 1.0 / len(labels)
        return score
