import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F

import numpy as np


# GCN basic operation
class GraphConv(nn.Module):
    def __init__(self, input_dim, output_dim, add_self=False, normalize_embedding=False,
                 dropout=0.0, bias=True, gpu=True, att=False, edge_dim=1):
        super(GraphConv, self).__init__()
        self.att = att
        self.add_self = add_self
        self.dropout = dropout
        if dropout > 0.001:
            self.dropout_layer = nn.Dropout(p=dropout)
        self.normalize_embedding = normalize_embedding
        self.input_dim = input_dim
        self.output_dim = output_dim
        if not gpu:
            self.weight = nn.Parameter(torch.FloatTensor(edge_dim, input_dim, output_dim))
            self.weight2 = nn.Parameter(torch.FloatTensor(edge_dim, output_dim, output_dim))
            if add_self:
                self.self_weight = nn.Parameter(torch.FloatTensor(edge_dim, input_dim, output_dim))
            if att:
                self.att_weight = nn.Parameter(torch.FloatTensor(input_dim, input_dim))
        else:
            self.weight = nn.Parameter(torch.FloatTensor(edge_dim, input_dim, output_dim).cuda())
            self.weight2 = nn.Parameter(torch.FloatTensor(edge_dim, output_dim, output_dim))
            if add_self:
                self.self_weight = nn.Parameter(torch.FloatTensor(edge_dim, input_dim, output_dim).cuda())
            if att:
                self.att_weight = nn.Parameter(torch.FloatTensor(input_dim, input_dim).cuda())
        if bias:
            if not gpu:
                self.bias = nn.Parameter(torch.FloatTensor(output_dim))
                self.bias2 = nn.Parameter(torch.FloatTensor(output_dim))
            else:
                self.bias = nn.Parameter(torch.FloatTensor(output_dim).cuda())
                self.bias2 = nn.Parameter(torch.FloatTensor(output_dim))
        else:
            self.bias = None
            self.bias2 = None

            # self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, adj):
        if self.dropout > 0.001:
            x = self.dropout_layer(x)
        # deg = torch.sum(adj, -1, keepdim=True)
        if self.att:
            x_att = torch.matmul(x, self.att_weight)
            att = x_att @ x_att.permute(0, 2, 1)
            # att = self.softmax(att)
            adj = adj * att

        x = x.unsqueeze(1) if len(x.shape) == 3 else x
        y = torch.matmul(adj, x)

        if y.dim() < 4:  # expand to batch, edge_type, num_nodes, embedding_dim
            y = torch.unsqueeze(y, 1)
        y = torch.matmul(y, self.weight)

        if self.add_self:
            self_emb = torch.matmul(x, self.self_weight)
            y += self_emb

        if self.bias is not None:
            y = y + self.bias

        # Second MLP layer
        y = nn.functional.relu(y)

        y = torch.matmul(y, self.weight2)

        if self.add_self:
            self_emb = torch.matmul(x, self.self_weight)
            y += self_emb

        if self.bias2 is not None:
            y = y + self.bias2

        # sum over edge types
        y = torch.sum(y, dim=1)

        #if self.normalize_embedding:
        #    y = F.normalize(y, p=2, dim=2)
            # print(y[0][0])
        return y, adj


class GcnEncoderGraph(nn.Module):
    def __init__(self, input_dim, hidden_dim, embedding_dim, label_dim, num_layers,
                 pred_hidden_dims=[], concat=True, bn=True, dropout=0.0, add_self=False, args=None):
        super(GcnEncoderGraph, self).__init__()
        self.args = args
        self.concat = concat
        add_self = add_self
        self.bn = bn
        self.num_layers = num_layers
        self.num_aggs = 1

        self.bias = True
        self.gpu = args.gpu
        if args.method == 'att':
            self.att = True
        else:
            self.att = False
        if args is not None:
            self.bias = args.bias

        self.conv_first, self.conv_block, self.conv_last = self.build_conv_layers(
            input_dim, hidden_dim, embedding_dim, num_layers,
            add_self, normalize=False, dropout=dropout, edge_dim=args.edge_dim)
        self.act = nn.ReLU()
        self.label_dim = label_dim

        if concat:
            self.pred_input_dim = hidden_dim * (num_layers - 1) + embedding_dim
        else:
            self.pred_input_dim = embedding_dim
        self.pred_model = self.build_pred_layers(self.pred_input_dim, pred_hidden_dims,
                                                 label_dim, num_aggs=self.num_aggs)
        #self.pred_model = self.build_test_pred_layers(self.pred_input_dim, label_dim)

        for m in self.modules():
            if isinstance(m, GraphConv):
                init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu'))
                init.xavier_uniform_(m.weight2.data, gain=nn.init.calculate_gain('relu'))

                if m.att:
                    init.xavier_uniform_(m.att_weight.data, gain=nn.init.calculate_gain('relu'))
                if m.add_self:
                    init.xavier_uniform_(m.self_weight.data, gain=nn.init.calculate_gain('relu'))
                if m.bias is not None:
                    init.constant_(m.bias.data, 0.0)
                    init.constant_(m.bias2.data, 0.0)

    def build_conv_layers(self, input_dim, hidden_dim, embedding_dim, num_layers, add_self,
                          normalize=False, dropout=0.0, edge_dim=1):
        conv_first = GraphConv(input_dim=input_dim, output_dim=hidden_dim, add_self=add_self,
                               normalize_embedding=normalize, bias=self.bias, gpu=self.gpu, att=self.att,
                               edge_dim=edge_dim)
        conv_block = nn.ModuleList(
            [GraphConv(input_dim=hidden_dim, output_dim=hidden_dim, add_self=add_self,
                       normalize_embedding=normalize, dropout=dropout, bias=self.bias, gpu=self.gpu,
                       edge_dim=edge_dim, att=self.att)
             for _ in range(num_layers - 2)])
        conv_last = GraphConv(input_dim=hidden_dim, output_dim=embedding_dim, add_self=add_self,
                              normalize_embedding=normalize, bias=self.bias, gpu=self.gpu, att=self.att,
                              edge_dim=edge_dim)
        return conv_first, conv_block, conv_last

    def build_pred_layers(self, pred_input_dim, pred_hidden_dims, label_dim, num_aggs=1):
        pred_input_dim = pred_input_dim * num_aggs
        if len(pred_hidden_dims) == 0:
            print(pred_input_dim)
            print(label_dim)
            pred_model = nn.Linear(pred_input_dim, label_dim)
        else:
            pred_layers = []
            for pred_dim in pred_hidden_dims:
                pred_layers.append(nn.Linear(pred_input_dim, pred_dim))
                pred_layers.append(self.act)
                pred_input_dim = pred_dim

            pred_layers.append(nn.Linear(pred_dim, label_dim))
            pred_model = nn.Sequential(*pred_layers)
        return pred_model

    def build_test_pred_layers(self, pred_input_dim, label_dim):
        print("Using test")
        mat = torch.FloatTensor(pred_input_dim // 2, pred_input_dim)
        for i in range(pred_input_dim // 2):
            mat[i, i] = 1
            mat[i, pred_input_dim // 2 + i] = -1
        def pred_model(x):
            x = x[0].t()
            intermediate = mat @ x
            intermediate = 10 * torch.abs(torch.sum(intermediate, dim=0))
            logit = torch.pow(intermediate + 0.45, 10) * torch.reciprocal(torch.pow(intermediate + 0.45, 10) + 0.5)
            pred = torch.stack((logit, 1 - logit))
            return torch.unsqueeze(pred.t(), 0)
        return pred_model

    def construct_mask(self, max_nodes, batch_num_nodes):
        ''' For each num_nodes in batch_num_nodes, the first num_nodes entries of the
        corresponding column are 1's, and the rest are 0's (to be masked out).
        Dimension of mask: [batch_size x max_nodes x 1]
        '''
        # masks
        packed_masks = [torch.ones(int(num)) for num in batch_num_nodes]
        batch_size = len(batch_num_nodes)
        out_tensor = torch.zeros(batch_size, max_nodes)
        for i, mask in enumerate(packed_masks):
            out_tensor[i, :batch_num_nodes[i]] = mask
        return out_tensor.unsqueeze(2).cuda()

    def apply_bn(self, x):
        ''' Batch normalization of 3D tensor x
        '''
        bn_module = nn.BatchNorm1d(x.size()[1])
        if self.gpu:
            bn_module = bn_module.cuda()
        return bn_module(x)

    def gcn_forward(self, x, adj, conv_first, conv_block, conv_last, embedding_mask=None):

        ''' Perform forward prop with graph convolution.
        Returns:
            Embedding matrix with dimension [batch_size x num_nodes x embedding]
            The embedding dim is self.pred_input_dim
        '''
        x, _ = conv_first(x, adj)
        x = self.act(x)
        if self.bn:
            x = self.apply_bn(x)
        x_all = [x]
        # out_all = []
        # out, _ = torch.max(x, dim=1)
        # out_all.append(out)
        for i in range(len(conv_block)):
            x, _ = conv_block[i](x, adj)
            x = self.act(x)
            if self.bn:
                x = self.apply_bn(x)
            x_all.append(x)

        x, adj_att = conv_last(x, adj)
        x_all.append(x)
        # x_tensor: [batch_size x num_nodes x embedding]
        x_tensor = torch.cat(x_all, dim=2)
        if embedding_mask is not None:
            x_tensor = x_tensor * embedding_mask
        self.embedding_tensor = x_tensor
        return x_tensor, adj_att

    def forward(self, x, adj, batch_num_nodes=None, **kwargs):
        # mask
        max_num_nodes = adj.size()[1]
        if batch_num_nodes is not None:
            self.embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes)
        else:
            self.embedding_mask = None

        # conv
        x, adj_att = self.conv_first(x, adj)
        x = self.act(x)
        if self.bn:
            x = self.apply_bn(x)
        out_all = []
        out, _ = torch.max(x, dim=1)
        out_all.append(out)
        for i in range(self.num_layers - 2):
            x, adj_att = self.conv_block[i](x, adj)
            x = self.act(x)
            if self.bn:
                x = self.apply_bn(x)
            out, _ = torch.max(x, dim=1)
            out_all.append(out)
            if self.num_aggs == 2:
                out = torch.sum(x, dim=1)
                out_all.append(out)
        x, adj_att = self.conv_last(x, adj)
        # x = self.act(x)
        out, _ = torch.max(x, dim=1)
        out_all.append(out)
        if self.num_aggs == 2:
            out = torch.sum(x, dim=1)
            out_all.append(out)
        if self.concat:
            output = torch.cat(out_all, dim=1)
        else:
            output = out
        self.embedding_tensor = output
        ypred = self.pred_model(output)
        # print(output.size())
        return ypred, adj

    def loss(self, pred, label, type='softmax'):
        # softmax + CE
        if type == 'softmax':
            return F.cross_entropy(pred, label, size_average=True)
        elif type == 'margin':
            batch_size = pred.size()[0]
            label_onehot = torch.zeros(batch_size, self.label_dim).long().cuda()
            label_onehot.scatter_(1, label.view(-1, 1), 1)
            return torch.nn.MultiLabelMarginLoss()(pred, label_onehot)

            # return F.binary_cross_entropy(F.sigmoid(pred[:,0]), label.float())

class GcnEncoderMatching(GcnEncoderGraph):
    def __init__(self, input_dim, hidden_dim, embedding_dim, label_dim, num_layers,
                 pred_hidden_dims=[], concat=True, bn=True, dropout=0.0, args=None):
        super(GcnEncoderMatching, self).__init__(input_dim, hidden_dim, embedding_dim, label_dim,
                                                 num_layers, pred_hidden_dims, concat, bn, dropout, args=args)
        
        self.celoss = nn.CrossEntropyLoss()

        # pred input dimension for individual graph
        pred_input_dim_graph = self.pred_input_dim
        # baseline: concatenation of query and data to determine whether it's a subgraph
        self.pred_input_dim *= 2
        self.pred_model = self.build_pred_layers(self.pred_input_dim, pred_hidden_dims,
                                                 label_dim, num_aggs=self.num_aggs)
        #self.pred_model = self.build_test_pred_layers(self.pred_input_dim, label_dim)
        self.order_embeddings = args.order_embeddings

        if args.ntn:
            self.ntn_layer = NeuralTensorNetwork(pred_input_dim_graph, pred_input_dim_graph, self.pred_input_dim)

    def forward(self, x, adj, batch_num_nodes=None, **kwargs):
        # mask
        max_num_nodes = adj.size()[1]
        if batch_num_nodes is not None:
            embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes)
        else:
            embedding_mask = None

        self.embedding_tensor, adj_att = self.gcn_forward(x, adj,
                                                          self.conv_first, self.conv_block, self.conv_last,
                                                          embedding_mask)
        return self.embedding_tensor, adj_att
    
    def pred_order(self, embedding_data, embedding_query):
        '''
        Prediction function for order embeddings
        '''
        DIM_RATIO = 0.1
        _, batch_size, emb_size = embedding_data.shape 
        subtract = embedding_query - embedding_data
        # 1 if violating the order constraint
        indicator = (subtract > 0).type(torch.FloatTensor)
        if self.gpu:
            predictions = (torch.sum(indicator, dim = 2) < int(DIM_RATIO * emb_size)).view(-1,1).type(torch.cuda.FloatTensor)
        else:
            predictions = (torch.sum(indicator, dim = 2) < int(DIM_RATIO * emb_size)).view(-1,1).type(torch.FloatTensor)
        scores = 1 - torch.sum(indicator, dim=2) / emb_size
        return predictions, scores

    def loss(self, embedding_data, embedding_query, label):
        if self.order_embeddings:
            margin = 0.2
            _, batch_size, emb_size = embedding_data.shape

            # inference time prediction
            predictions, scores = self.pred_order(embedding_data, embedding_query)

            if self.gpu:
                pos_filter = label.view(-1,1).repeat(1, emb_size).view(1, batch_size, emb_size).type(torch.cuda.FloatTensor)
            else:
                pos_filter = label.view(-1,1).repeat(1, emb_size).view(1, batch_size, emb_size).type(torch.FloatTensor)
            neg_filter = 1 - pos_filter
            # Query should be less than search neighborhood embedding if query is a subgraph
            diff = embedding_query - embedding_data
            penalty = torch.max(diff, torch.zeros_like(diff)) ** 2
            pos_loss = torch.sum(torch.mul(penalty, pos_filter))
            margin_penalty = margin - penalty
            margin_penalty = torch.max(margin_penalty, torch.zeros_like(margin_penalty))
            neg_loss = torch.sum(torch.mul(margin_penalty, neg_filter))
            return predictions, scores, pos_loss + neg_loss

        else:
            # [num_embeddings x embedding_dim]
            pred = torch.cat((embedding_data, embedding_query), 2)

            # neural tensor network
            if self.args.ntn:
                pred = self.ntn_layer(pred)
            pred = self.pred_model(pred)
            # score is the softmax value for class 1 (is subgraph)
            scores = F.softmax(pred, dim=2)[:, :, 1]

            loss = self.celoss(pred.permute(1, 2, 0), label) if label is not None else None

            return pred, scores, loss

class NeuralTensorNetwork(nn.Module):
    def __init__(self, input_dim1, input_dim2, output_dim, bias=True):
        super(NeuralTensorNetwork, self).__init__()
        self.input_dim1 = input_dim1 # this is d1
        self.input_dim2 = input_dim2 # this is d2
        self.output_dim = output_dim # this is k

        self.bilinear1 = nn.Bilinear(input_dim1, input_dim2, output_dim, bias=bias)
        self.fc1 = nn.Linear(input_dim1 + input_dim2, output_dim, bias=bias)


    def forward(self, x):
        x1, x2 = x[:, :, :self.input_dim1], x[:, :, self.input_dim1:]
        res = self.bilinear1(x1, x2) + self.fc1(torch.cat((x1, x2), -1))
        return res
