import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.pool.topk_pool import topk
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp, global_add_pool as gsp

from torch_geometric.nn import GCNConv
from torch_geometric.nn.conv.gcn_conv import gcn_norm

from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_geometric.utils import softmax, degree, to_dense_batch
from torch_geometric.nn.conv import MessagePassing, GCNConv
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import softmax

import math
from math import ceil

def samplek(x, batch, num_samples=3):
    num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
    batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()

    max_num_nodes = max(max_num_nodes, num_samples)

    cum_num_nodes = torch.cat(
        [num_nodes.new_zeros(1),
            num_nodes.cumsum(dim=0)[:-1]], dim=0)

    index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
    index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)

    dense_x = x.new_full((batch_size * max_num_nodes,), 1e-15) # x is probability
    dense_x[index] = x
    dense_x = dense_x.view(batch_size, max_num_nodes)

    perm = torch.multinomial(dense_x, num_samples)

    perm = perm + cum_num_nodes.view(-1, 1)
    perm = perm.view(-1)

    k = num_nodes.new_full((num_nodes.size(0), ), num_samples)
    k = torch.min(k, num_nodes)

    mask = [torch.arange(k[i], dtype=torch.long, device=x.device) + i * num_samples for i in range(batch_size)]
    mask = torch.cat(mask, dim=0)
    perm = perm[mask] # Sparse Implementation
    return perm

class GraphRepresentation(nn.Module):
    def __init__(self, args):
        super(GraphRepresentation, self).__init__()
        self.args = args
        self.num_node_features = args.hidden_size
        self.num_edge_features = 128
        self.nhid = args.hidden_size
        self.enhid = args.hidden_size

    ### Dual Hypergraph Transformation (DHT)
    def DHT(self, edge_index, batch, add_loops=True):
        num_edge = edge_index.size(1)
        device = edge_index.device

        ### Transform edge list of the original graph to hyperedge list of the dual hypergraph
        edge_to_node_index = torch.arange(0,num_edge,1, device=device).repeat_interleave(2).view(1,-1)
        hyperedge_index = edge_index.T.reshape(1,-1)
        hyperedge_index = torch.cat([edge_to_node_index, hyperedge_index], dim=0).long() 

        ### Transform batch of nodes to batch of edges
        edge_batch = hyperedge_index[1,:].reshape(-1,2)[:,0]
        edge_batch = torch.index_select(batch, 0, edge_batch)

        ### Add self-loops to each node in the dual hypergraph
        if add_loops and hyperedge_index[1].numel() != 0:
            bincount =  hyperedge_index[1].bincount()
            mask = bincount[hyperedge_index[1]]!=1
            max_edge = hyperedge_index[1].max()
            loops = torch.cat([torch.arange(0,num_edge,1,device=device).view(1,-1), 
                                torch.arange(max_edge+1,max_edge+num_edge+1,1,device=device).view(1,-1)], 
                                dim=0)

            hyperedge_index = torch.cat([hyperedge_index[:,mask], loops], dim=1)

        return hyperedge_index, edge_batch

    def get_scoreconvs(self):
        convs = nn.ModuleList()
        conv = HypergraphConv(self.enhid, 1)
        convs.append(conv)
        return convs

class Model_HyperDrop(GraphRepresentation):
    def __init__(self, args, entity_embed, num_convs):
        super(Model_HyperDrop, self).__init__(args)
        self.num_convs = num_convs

        self.ehsz = args.hidden_size # Entity hidden size
        self.rhsz = 128 # Relation hidden size

        self.entity_embed = torch.nn.Embedding(
            entity_embed.shape[0],
            entity_embed.shape[1], 
            padding_idx=0, 
        )
        self.relation_embed = torch.nn.Embedding(
            len(args.label_map),
            128,
        )
        self.hyperconvs = self.get_convs(conv_type='Hyper')
        self.graphconvs = self.get_convs(conv_type="GCN")

        self.z_linear = nn.Sequential(
            nn.Linear(self.rhsz + 2 * self.ehsz, args.hidden_size),
            nn.ReLU(),
            nn.Linear(args.hidden_size, args.hidden_size),
        )
        self.x_linear = nn.Sequential(
            nn.Linear(args.hidden_size, args.hidden_size),
            nn.ReLU(),
            nn.Linear(args.hidden_size, args.hidden_size),
        )
        self.num_facts = args.num_facts # The number of facts in single z
        self.use_sigmoid_score = args.use_sigmoid_score
        self.args = args

    def forward(self, x, nodes, edge_index, edge_attr, batch, embeddings, k=1):
        ### Edge feature initialization
        if edge_attr is None:
            edge_attr = torch.ones((edge_index.size(1), 1), device=edge_index.device)

        edge_attr = self.relation_embed(edge_attr)

        x = F.relu(self.graphconvs[0](x, edge_index)) # 1-hop node encoding (GCN)

        hyperedge_index, edge_batch = self.DHT(edge_index, batch)
        # EHGNN
        # To distinguish the same relation attached to different node
        edge_attr = F.relu(self.hyperconvs[0](edge_attr, hyperedge_index)) # 1-hop edge encoding (EHGNN)

        # Construct z
        head_embeds = x[edge_index[0]]
        tail_embeds = x[edge_index[1]]
        z = self.z_linear(torch.cat([edge_attr, head_embeds, tail_embeds], dim=-1))

        embeddings = embeddings[edge_batch] # Sentence Embedding (x)
        embeddings = self.x_linear(embeddings)

        edge_scores = torch.bmm(z.unsqueeze(1), embeddings.unsqueeze(2)).view(-1) # f(x,z)
        if self.use_sigmoid_score:
            edge_probs = torch.sigmoid(edge_scores)
        else:
            edge_probs = softmax(edge_scores, edge_batch)

        perm_samples = []
        probs_samples = []
        scores_samples = []
        batch_list = list(set(edge_batch.tolist())) # To return existing batch (there is batch where no graph exists)
        
        # If there are no edges, then return empty list of samples
        if hyperedge_index[1].numel() == 0:
            return perm_samples, edge_batch, batch_list, probs_samples, scores_samples, edge_scores

        if k == 1:
            perm = topk(edge_probs, self.num_facts, edge_batch)
            perm_samples.append(perm)
            probs_samples.append(edge_probs[perm])
            scores_samples.append(edge_scores[perm])
        else:
            for _ in range(k):
                perm = samplek(edge_probs, edge_batch, self.num_facts) # Sample 3 edges for single target
                perm_samples.append(perm)
                probs_samples.append(edge_probs[perm])
                scores_samples.append(edge_scores[perm])

        return perm_samples, edge_batch, batch_list, probs_samples, scores_samples, edge_scores

    def get_convs(self, conv_type='GCN'):
        convs = nn.ModuleList()
        for i in range(self.num_convs):
            if conv_type == 'GCN':
                if i == 0 :
                    conv = GCNConv(self.num_node_features, self.nhid)
                else:
                    conv = GCNConv(self.nhid, self.nhid)
            elif conv_type == 'Hyper':
                if i == 0 :
                    conv = HypergraphConv(self.num_edge_features, self.rhsz)
                else:
                    conv = HypergraphConv(self.rhsz, self.rhsz)
            else:
                raise ValueError("Conv Name <{}> is Unknown".format(conv_type))
            convs.append(conv)
        return convs

""" Belows are GNN implementations """
### Hypergraph convolution for message passing on Dual Hypergraph 
class HypergraphConv(MessagePassing):

    def __init__(self, in_channels, out_channels, use_attention=False, heads=1,
                 concat=True, negative_slope=0.2, dropout=0, bias=True,
                 **kwargs):
        kwargs.setdefault('aggr', 'add')
        super(HypergraphConv, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_attention = use_attention

        if self.use_attention:
            self.heads = heads
            self.concat = concat
            self.negative_slope = negative_slope
            self.dropout = dropout
            self.weight = Parameter(
                torch.Tensor(in_channels, heads * out_channels))
            self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels))
        else:
            self.heads = 1
            self.concat = True
            self.weight = Parameter(torch.Tensor(in_channels, out_channels))

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        if self.use_attention:
            glorot(self.att)
        zeros(self.bias)

    def message(self, x_j, edge_index_i, norm, alpha):
        out = norm[edge_index_i].view(-1, 1, 1) * x_j.view(-1, self.heads, self.out_channels)

        if alpha is not None:
            out = alpha.view(-1, self.heads, 1) * out
        return out

    def forward(self, x, hyperedge_index, hyperedge_weight=None):
        
        x = torch.matmul(x, self.weight)
        alpha = None

        if self.use_attention:
            x = x.view(-1, self.heads, self.out_channels)
            x_i, x_j = x[hyperedge_index[0]], x[hyperedge_index[1]]
            alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
            alpha = F.leaky_relu(alpha, self.negative_slope)
            alpha = softmax(alpha, hyperedge_index[0], num_nodes=x.size(0))
            alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        if hyperedge_weight is None:
            D = degree(hyperedge_index[0], x.size(0), x.dtype)
        else:
            D = scatter_add(hyperedge_weight[hyperedge_index[1]],
                            hyperedge_index[0], dim=0, dim_size=x.size(0))
        D = 1.0 / D
        D[D == float("inf")] = 0

        if hyperedge_index.numel() == 0:
            num_edges = 0
        else:
            num_edges = hyperedge_index[1].max().item() + 1
        B = 1.0 / degree(hyperedge_index[1], num_edges, x.dtype)
        B[B == float("inf")] = 0
        if hyperedge_weight is not None:
            B = B * hyperedge_weight

        num_nodes = x.size(0)
        dif = max([num_nodes, num_edges]) - num_nodes # get size of padding
        x_help = F.pad(x, (0,0,0, dif)) # create dif many nodes

        self.flow = 'source_to_target'
        out = self.propagate(hyperedge_index, x=x_help, norm=B, alpha=alpha)
        self.flow = 'target_to_source'
        out = self.propagate(hyperedge_index, x=out, norm=D, alpha=alpha)

        out = out[:num_nodes] # prune back to original x.size()

        if self.concat is True:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out = out + self.bias

        return out

    def __repr__(self):
        return "{}({}, {})".format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)
