import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from dgl.nn.pytorch import GraphConv, GATConv, SGConv


class TGCL(nn.Module):
    def __init__(self, d_hid, size_neg=500, alpha=0):
        super(TGCL, self).__init__()
        self.encoder = nn.Sequential(nn.Linear(d_hid, 200), nn.BatchNorm1d(200), nn.ReLU(inplace=True))
        self.hid = d_hid
        self.sampling_size = size_neg
        self.alpha = alpha
        self.t = 100
        self.tau = 5

    def forward(self, emb, adj):
        if self.alpha == 0:
            return 0
        emb = self.encoder(emb)
        loss = 0
        selected_nodes = torch.randperm(emb.shape[0])[:self.sampling_size]
        k = selected_nodes.shape[0]
        selected_adj = adj.to_dense()[selected_nodes, :]
        dist = self.hamming_distance_by_matrix(selected_adj)
        selected_nodes_pairs = selected_adj[:, selected_nodes].unsqueeze(2).expand(k, k, emb.shape[1])
        pos_sample = emb[selected_nodes, :]
        emb = emb[selected_nodes, :].unsqueeze(0).expand(k, k, emb.shape[1])
        neg_sample = (1 - selected_nodes_pairs) * emb
        max_degree = dist.max() + 1
        pos_weight = 1 - dist / max_degree
        neg_weight = 1 + dist / self.tau
        pos_dis = self.exp_cosine_similarity_2d_matrix(pos_sample, pos_sample, self.t) * pos_weight
        neg_dis = self.exp_cosine_similarity(pos_sample, neg_sample, self.t) * neg_weight
        denominator = neg_dis.sum(2) + pos_dis
        loss += torch.mean(torch.log(denominator / (pos_dis * k)), dim=1).sum(0)
        return loss / self.sampling_size

    def hamming_distance_by_matrix(self, labels):
        return torch.matmul(labels, (1 - labels).T) + torch.matmul(1 - labels, labels.T)

    def exp_cosine_similarity(self, x1, x2, eps=1e-15, temperature=1):
        w1 = x1.norm(p=2, dim=1, keepdim=True)
        w2 = x2.norm(p=2, dim=2, keepdim=True)
        return torch.exp(torch.matmul(x1, x2.permute(0, 2, 1)) / ((w1 * w2.permute(0, 2, 1)).clamp(min=eps) * temperature))


    def exp_cosine_similarity_2d_matrix(self, x1, x2, eps=1e-15, temperature=1):
        w1 = x1.norm(p=2, dim=1, keepdim=True)
        w2 = x2.norm(p=2, dim=1, keepdim=True)
        return torch.exp(torch.matmul(x1, x2.T) / ((w1 * w2.T).clamp(min=eps) * temperature))

    def euclidean_dist(self, x, y):
        n = x.size(0)
        m = y.size(0)
        d = x.size(1)
        assert d == y.size(1)
        x = x.unsqueeze(1).expand(n, m, d)
        y = y.unsqueeze(0).expand(n, m, d)
        return torch.pow(x - y, 2).sum(2)


class GAT(nn.Module):
    def __init__(self, n_feat, d_hid, n_class, dropout, n_layers=1, size_neg=500, gpu=0, alpha=0, use_resnet=True):
        super(GAT, self).__init__()
        self.heads = 4
        self.hid = d_hid
        self.input_layers = GATConv(n_feat, d_hid, self.heads)
        hidden_layers = GATConv(self.heads * d_hid, d_hid, self.heads)
        assert n_layers != 1, 'The number of hidden layer must be greater than 1'
        self.hid_layers = self.clones(hidden_layers, n_layers-1)
        self.tgcl = TGCL(self.heads * d_hid, size_neg, alpha)
        self.output_layer = nn.Linear(d_hid * self.heads, n_class)
        self.dropout = dropout
        self.use_cuda = gpu >= 0 and torch.cuda.is_available()
        self.sampling_size = size_neg
        self.alpha = alpha
        self.use_resnet = use_resnet
        self.dropout = nn.Dropout(p=dropout)
        self.dropout_rate = dropout
        self.t = 100
        self.tau = 5

    def clones(self, module, n):
        return nn.ModuleList([deepcopy(module) for _ in range(n)])

    def forward(self, graph, train_idx):
        h = F.elu(self.input_layers(graph, graph.feats).view(-1, self.hid * self.heads))
        temp_h = 0
        for count, layer in enumerate(self.hid_layers):
            if count % 2 == 0 and self.use_resnet:
                h = layer(graph, h).view(-1, self.hid * self.heads)
                h = F.elu(h) + temp_h
                temp_h = h
            else:
                h = layer(graph, h).view(-1, self.hid * self.heads)
                h = F.elu(h)
            count += 1
        h = h.view(-1, self.hid * self.heads)
        loss = self.tgcl(h, graph.adj)
        h = self.dropout(h)
        prediction = F.log_softmax(self.output_layer(h[train_idx]), dim=1)
        clf_loss = F.nll_loss(prediction, graph.labels[train_idx])
        return self.alpha * loss, clf_loss, prediction


class SGC(nn.Module):
    def __init__(self, n_feat, d_hid, n_class, dropout, n_layers=1, size_neg=500, gpu=0, alpha=0, use_resnet=True):
        super(SGC, self).__init__()
        self.input_layers = SGConv(n_feat, d_hid, bias=True, k=2)
        hidden_layers = self.build_hidden_layer(d_hid)
        self.hid_layers = self.clones(hidden_layers, n_layers - 1)
        self.output_layer = nn.Linear(d_hid, n_class)
        self.use_cuda = gpu >= 0 and torch.cuda.is_available()
        self.sampling_size = size_neg
        self.alpha = alpha
        self.tgcl = TGCL(d_hid, size_neg, alpha)
        self.dropout = nn.Dropout(p=dropout)
        self.dropout_rate = dropout
        self.t = 100
        self.tau = 5
        self.use_resnet = use_resnet

    def clones(self, module, n):
        return nn.ModuleList([deepcopy(module) for _ in range(n)])

    def build_hidden_layer(self, d_hid):
        return SGConv(d_hid, d_hid, bias=True, k=2)

    def forward(self, graph, train_idx):
        h = self.input_layers(graph, graph.feats)
        temp_h = 0
        for count, layer in enumerate(self.hid_layers):
            if count % 2 == 0 and self.use_resnet:
                h = layer(graph, h) + temp_h
                temp_h = h
            else:
                h = layer(graph, h)
            count += 1
        loss = self.tgcl(h, graph.adj)
        h = self.dropout(h)
        prediction = F.log_softmax(self.output_layer(h[train_idx]), dim=1)
        clf_loss = F.nll_loss(prediction, graph.labels[train_idx])
        return self.alpha * loss, clf_loss, prediction


class GCN(nn.Module):
    def __init__(self, n_feat, d_hid, n_class, dropout, n_layers=1, size_neg=500, gpu=0, alpha=0.1, use_resnet=True):
        super(GCN, self).__init__()
        self.input_layers = GraphConv(n_feat, d_hid,  bias=True, activation=F.relu)
        hidden_layers = self.build_hidden_layer(d_hid)
        self.hid_layers = self.clones(hidden_layers, n_layers-1)
        self.output_layer = nn.Linear(d_hid, n_class)
        self.use_cuda = gpu >= 0 and torch.cuda.is_available()
        self.sampling_size = size_neg
        self.alpha = alpha
        self.tgcl = TGCL(d_hid, size_neg, alpha)
        self.dropout = nn.Dropout(p=dropout)
        self.dropout_rate = dropout
        self.t = 100
        self.tau = 5
        self.use_resnet = use_resnet

    def clones(self, module, n):
        return nn.ModuleList([deepcopy(module) for _ in range(n)])

    def build_hidden_layer(self, d_hid):
        return GraphConv(d_hid, d_hid, bias=True, activation=F.relu)

    def forward(self, graph, train_idx):
        h = self.input_layers(graph, graph.feats)
        temp_h = 0
        for count, layer in enumerate(self.hid_layers):
            if count % 2 == 0 and self.use_resnet:
                h = layer(graph, h) + temp_h
                temp_h = h
            else:
                h = layer(graph, h)
            count += 1
        loss = self.tgcl(h, graph.adj)
        h = self.dropout(h)
        prediction = F.log_softmax(self.output_layer(h[train_idx]), dim=1)
        clf_loss = F.nll_loss(prediction, graph.labels[train_idx])
        return self.alpha * loss, clf_loss, prediction
