import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import global_add_pool, global_mean_pool
from models.GraphHD.gnn import GNN, GNN_graphpred


class ViewLearner_K(torch.nn.Module):
    def __init__(self, cf, level):
        super(ViewLearner_K, self).__init__()

        self.encoder = GNN(num_layer=cf.n_layer, emb_dim=cf.n_hidden, JK='last', drop_ratio=cf.dropout, gnn_type='gin')
        self.hidden_dim = cf.n_hidden
        self.num_layers = cf.n_layer
        self.aug_mode = cf.aug_mode
        self.level = level
        self.n_protos = [int(x) for x in cf.n_protos.split('_')]
        self.pretrain_pool = global_mean_pool
        self.device = cf.compute_dev

        if cf.JK == 'concat':
            gnn_out_dim = self.hidden_dim * self.num_layers
        else:
            gnn_out_dim = self.hidden_dim
        mlp_dim = gnn_out_dim

        if self.aug_mode == 'node':
            self.unit_prob_mlp = Sequential(
                Linear(gnn_out_dim, mlp_dim),
                ReLU(),
                Linear(mlp_dim, 1))
        elif self.aug_mode == 'edge':
            self.unit_prob_mlp = Sequential(
                Linear(gnn_out_dim*2, mlp_dim),
                ReLU(),
                Linear(mlp_dim, 1))

        self.init_emb()

    def init_emb(self):
        for m in self.modules():
            if isinstance(m, Linear):
                torch.nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)

    def forward(self, x, edge_index, edge_attr):

        x = self.encoder(x, edge_index, edge_attr)

        if self.aug_mode == 'edge':
            src, dst = edge_index[0], edge_index[1]
            emb_src = x[src]
            emb_dst = x[dst]
            unit_emb = torch.cat([emb_src, emb_dst], 1)
        else:
            unit_emb = x

        unit_logit = self.unit_prob_mlp(unit_emb)
        unit_logit = self.edge_weight_cal(unit_logit)

        return unit_logit

    def edge_weight_cal(self, unit_prob):

        assert unit_prob is not None
        temperature = 1.0
        bias = 0.0 + 0.0001
        eps = (bias - (1 - bias)) * torch.rand(unit_prob.size()) + (1 - bias)
        gate_inputs = torch.log(eps) - torch.log(1 - eps)
        gate_inputs = gate_inputs.to(self.device)
        gate_inputs = (gate_inputs + unit_prob) / temperature
        inv_weight = torch.sigmoid(gate_inputs).squeeze()  ## whether need to detach

        return inv_weight

