#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch 
import torch.nn as nn
import torch.nn.functional as F

class mpsn_layer(nn.Module):
    def __init__(self, F_in, F_out, l1l, l1u, sigma_update, agg, sigma):
        super(mpsn_layer, self).__init__()
        self.F_in = F_in
        self.F_out = F_out 
#        if agg == 'sum':
        l1l[l1l!=0]=1
        l1u[l1u!=0]=1  

        self.L1l = l1l
        self.L1u = l1u        

        if agg == 'mean':
            self.L1l = F.normalize(self.L1l,p=1,eps=1e-3 )
            self.L1u = F.normalize(self.L1u,p=1,eps=1e-3 )      
 
        self.sigma_update = sigma_update # nonlinearity in message update
        self.sigma = sigma # final layer nonliearity

        self.mlp_1d_1 = nn.Linear(F_in*2,F_out)
        self.mlp_1d_2 = nn.Linear(F_out,F_out)
        self.mlp_1u_1 = nn.Linear(F_in*2,F_out)
        self.mlp_1u_2 = nn.Linear(F_out,F_out)
        self.mlp_1_update = nn.Linear(F_out*2,F_out)
        self.mlp_2_update = nn.Linear(F_out,F_out)
        print("created MPSN layers")

    def forward(self,x_in):
        x1 = x_in 

        '''order 1'''

        m1d = torch.cat([x1,self.L1l@x1],1)
        m1d = self.mlp_1d_2(self.sigma_update(self.mlp_1d_1(m1d)))

        m1u = torch.cat([x1,self.L1u@x1],1)
        m1u = self.mlp_1u_2(self.sigma_update(self.mlp_1u_1(m1u)))

        y1 = self.mlp_2_update(self.sigma_update(self.mlp_1_update(torch.cat([m1d,m1u],1))))

        return y1