import math

import torch

from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.nn.functional as F
import torch.nn as nn
        
    

class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, model_type, output_layer = 0):
        super(GraphConvolution, self).__init__()
        self.in_features, self.out_features, self.output_layer, self.model_type = in_features, out_features, output_layer, model_type
        self.low_bn_att, self.high_bn_att, self.mlp_bn_att = nn.BatchNorm1d(out_features), nn.BatchNorm1d(out_features), nn.BatchNorm1d(out_features)
        
        self.low_act, self.high_act, self.mlp_act = nn.ELU(alpha=3),nn.ELU(alpha=3),nn.ELU(alpha=3)
        
        self.att_low, self.att_high, self.att_mlp = 0,0,0
        if torch.cuda.is_available():
            self.weight_low, self.weight_high, self.weight_mlp = Parameter(torch.FloatTensor(in_features, out_features).cuda()), Parameter(torch.FloatTensor(in_features, out_features).cuda()), Parameter(torch.FloatTensor(in_features, out_features).cuda())           
            self.att_vec_low, self.att_vec_high, self.att_vec_mlp = Parameter(torch.FloatTensor(out_features, 1).cuda()), Parameter(torch.FloatTensor(out_features, 1).cuda()), Parameter(torch.FloatTensor(out_features, 1).cuda())
            self.low_param, self.high_param, self.mlp_param = Parameter(torch.FloatTensor(1, 1).cuda()), Parameter(torch.FloatTensor(1, 1).cuda()), Parameter(torch.FloatTensor(1, 1).cuda())
            self.att_vec = Parameter(torch.FloatTensor(3, 3).cuda())

        else:
            self.weight_low, self.weight_high, self.weight_mlp = Parameter(torch.FloatTensor(in_features, out_features)), Parameter(torch.FloatTensor(in_features, out_features)), Parameter(torch.FloatTensor(in_features, out_features))           
            self.att_vec_low, self.att_vec_high, self.att_vec_mlp = Parameter(torch.FloatTensor(out_features, 1)), Parameter(torch.FloatTensor(out_features, 1)), Parameter(torch.FloatTensor(out_features, 1))
            self.low_param, self.high_param, self.mlp_param = Parameter(torch.FloatTensor(1, 1)), Parameter(torch.FloatTensor(1, 1)), Parameter(torch.FloatTensor(1, 1))
            
            self.att_vec = Parameter(torch.FloatTensor(3, 3))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight_mlp.size(1))
        std_att = 1. / math.sqrt( self.att_vec_mlp.size(1))
        std_att_vec = 1. / math.sqrt( self.att_vec.size(1))
        self.weight_low.data.uniform_(-stdv, stdv)
        self.weight_high.data.uniform_(-stdv, stdv)
        self.att_vec_high.data.uniform_(-std_att, std_att)
        self.att_vec_low.data.uniform_(-std_att, std_att)

        self.weight_mlp.data.uniform_(-stdv, stdv)
        self.att_vec_mlp.data.uniform_(-std_att, std_att)
        
        self.low_bn_att.reset_parameters()
        self.high_bn_att.reset_parameters()
        self.mlp_bn_att.reset_parameters()
        self.att_vec.data.uniform_(-std_att_vec, std_att_vec)
        self.low_param.data=torch.FloatTensor([1.0])
        self.high_param.data=torch.FloatTensor([1.0])
        self.mlp_param.data=torch.FloatTensor([1.0])
    def param_relu(self, input, param):
        return torch.max(input,torch.FloatTensor([0.0]).cuda())  + param * (torch.exp(torch.min(input,torch.FloatTensor([0.0]).cuda()))-1)
        
    def attention(self, output_low, output_high, output_mlp):
        low_norm = (torch.norm(output_low,dim=1).detach())[:,None] + 1e-16
        high_norm = (torch.norm(output_high,dim=1).detach())[:,None]+ 1e-16
        mlp_norm = (torch.norm(output_mlp,dim=1).detach())[:,None] + 1e-16
        
        
        att_low = F.elu(torch.mm(output_low.detach()/low_norm, self.att_vec_low),alpha = 5)
        att_high = F.elu(torch.mm(output_high.detach()/high_norm, self.att_vec_high),alpha = 5)
        att_mlp = F.elu(torch.mm(output_mlp.detach()/mlp_norm, self.att_vec_mlp),alpha = 5)
        
        if self.model_type == 'mfsgc':
            T = 1/(low_norm + high_norm + mlp_norm )
            att = torch.softmax((torch.cat([att_low  ,att_high , att_mlp ],1)/T ),1)
            return att[:,0][:,None],att[:,1][:,None],att[:,2][:,None]
        elif self.model_type == 'mfgcn':
            return torch.sigmoid(att_low),torch.sigmoid(att_high),torch.sigmoid(att_mlp)
  
    def attention_BN(self, output_low, output_high, output_mlp):
        low = self.low_bn_att(output_low.detach())
        high  = self.high_bn_att(output_high.detach())
        mlp = self.mlp_bn_att(output_mlp.detach())
        T=3
        
        
        att_low =self.param_relu(torch.mm(output_low, self.att_vec_low), self.low_param)
        att_high = self.param_relu(torch.mm(output_high, self.att_vec_high), self.high_param)
        att_mlp = self.param_relu(torch.mm(output_mlp, self.att_vec_mlp), self.mlp_param)
        
        if self.model_type == 'mfsgc':
            T = 1/(low_norm + high_norm + mlp_norm )
            att = torch.softmax((torch.cat([att_low  ,att_high , att_mlp ],1)/T ),1)
            return att[:,0][:,None],att[:,1][:,None],att[:,2][:,None]
        elif self.model_type == 'mfgcn':
            att = torch.softmax(torch.mm(torch.sigmoid(torch.cat([att_low ,att_high ,att_mlp ],1)), self.att_vec)/T,1)
            return att[:,0][:,None],att[:,1][:,None],att[:,2][:,None]
 
    def forward(self, input, adj_low, adj_high):
        output = 0
        if self.model_type == 'mlp':
            output_mlp = F.relu(torch.mm(input, self.weight_mlp)) # +self.bias_mlp
            return output_mlp
        elif self.model_type == 'sgc' or self.model_type == 'gcn':
            output_low = torch.mm(adj_low, torch.mm(input, self.weight_low))
            return output_low
        elif self.model_type == 'mfgcn':
            output_low = torch.spmm(adj_low, F.relu(torch.mm(input, self.weight_low)))
            output_high = torch.spmm(adj_high, F.relu(torch.mm(input, self.weight_high)))
            output_mlp = F.relu(torch.mm(input, self.weight_mlp))
            
            self.att_low, self.att_high, self.att_mlp = self.attention((output_low), (output_high), (output_mlp))
            return 3*(self.att_low*output_low + self.att_high*output_high + self.att_mlp*output_mlp)
        elif self.model_type == 'mfsgc':
            
            output_low = torch.spmm(adj_low, torch.mm(input, self.weight_low))
            output_high = torch.spmm(adj_high,  torch.mm(input, self.weight_high))
            output_mlp = torch.mm(input, self.weight_mlp)
            
            self.att_low, self.att_high, self.att_mlp = self.attention(F.relu(output_low), F.relu(output_high), F.relu(output_mlp)) 
            return  3*(self.att_low*output_low + self.att_mlp*output_mlp)

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
