#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch 
import torch.nn as nn
import torch.nn.functional as F

class mpsn_layer(nn.Module):
    def __init__(self, F_in, F_out, b1, b2, l0, l1l, l1u, l2, sigma_update, agg, sigma):
        super(mpsn_layer, self).__init__()
        self.F_in = F_in
        self.F_out = F_out 
#        if agg == 'sum':
        b1[b1!=0]=1
        b2[b2!=0]=1
        l0[l0!=0]=1
        l1l[l1l!=0]=1
        l1u[l1u!=0]=1  
        l2[l2!=0]=1
        self.B1 = b1 
        self.B2 = b2
        self.B1t = b1.T
        self.B2t = b2.T
        self.L0 = l0
        self.L1l = l1l
        self.L1u = l1u        
        self.L2 = l2
        if agg == 'mean':
            self.B1 = F.normalize(self.B1,p=1,eps=1e-3 )
            self.B2 = F.normalize(self.B2,p=1,eps=1e-3 )
            self.B1t = F.normalize(self.B1t,p=1,eps=1e-3 )
            self.B2t = F.normalize(self.B2t,p=1,eps=1e-3 )
            self.L0 = F.normalize(self.L0,p=1,eps=1e-3 )
            self.L1l = F.normalize(self.L1l,p=1,eps=1e-3 )
            self.L1u = F.normalize(self.L1u,p=1,eps=1e-3 )      
            self.L2 = F.normalize(self.L2,p=1,eps=1e-3 )
 
        self.sigma_update = sigma_update # nonlinearity in message update
        self.sigma = sigma # final layer nonliearity

        self.mlp_0p_1 = nn.Linear(F_in*2,F_out)
        self.mlp_0p_2 = nn.Linear(F_out,F_out)
        self.mlp_0_1 = nn.Linear(F_in*2,F_out)
        self.mlp_0_2 = nn.Linear(F_out,F_out)
        self.mlp_0_update = nn.Linear(F_out*2,F_out)

        self.mlp_1n_1 = nn.Linear(F_in*2,F_out)
        self.mlp_1n_2 = nn.Linear(F_out,F_out)
        self.mlp_1d_1 = nn.Linear(F_in*2,F_out)
        self.mlp_1d_2 = nn.Linear(F_out,F_out)
        self.mlp_1u_1 = nn.Linear(F_in*2,F_out)
        self.mlp_1u_2 = nn.Linear(F_out,F_out)
        self.mlp_1p_1 = nn.Linear(F_in*2,F_out)
        self.mlp_1p_2 = nn.Linear(F_out,F_out)
        self.mlp_1_update = nn.Linear(F_out*4,F_out)

        self.mlp_2n_1 = nn.Linear(F_in*2,F_out)
        self.mlp_2n_2 = nn.Linear(F_out,F_out)
        self.mlp_2_1 = nn.Linear(F_in*2,F_out)
        self.mlp_2_2 = nn.Linear(F_out,F_out)
        self.mlp_2_update = nn.Linear(F_out*2,F_out)
        print("created MPSN layers")

    def forward(self,x_in):
        x0,x1,x2 = x_in 
        '''order 0'''
        # print(x0)
        m0p = torch.cat([x0,self.B1@x1],1)# the operation B1@x1 is actually aggregation: sum
        # print(m0p)
        m0p = self.mlp_0p_2(self.sigma_update(self.mlp_0p_1(m0p)))# two layer mlp
        # print(m0p)
        m0 = torch.cat([x0,self.L0@x0],1)
        m0 = self.mlp_0_2(self.sigma_update(self.mlp_0_1(m0)))

        y0 = self.sigma_update(self.mlp_0_update(torch.cat([m0,m0p],1)))

        '''order 1'''
        m1n = torch.cat([x1,self.B1t@x0],1)
        m1n = self.mlp_1n_2(self.sigma_update(self.mlp_1n_1(m1n)))

        m1d = torch.cat([x1,self.L1l@x1],1)
        m1d = self.mlp_1d_2(self.sigma_update(self.mlp_1d_1(m1d)))

        m1u = torch.cat([x1,self.L1u@x1],1)
        m1u = self.mlp_1u_2(self.sigma_update(self.mlp_1u_1(m1u)))

        m1p = torch.cat([x1,self.B2@x2],1)
        m1p = self.mlp_1p_2(self.sigma_update(self.mlp_1p_1(m1p)))

        y1 = self.sigma_update(self.mlp_1_update(torch.cat([m1n,m1d,m1u,m1p],1)))


        '''order 2'''
        m2n = torch.cat([x2,self.B2t@x1],1)
        m2n = self.mlp_2n_2(self.sigma_update(self.mlp_2n_1(m2n)))

        m2 = torch.cat([x2,self.L2@x2],1)
        m2 = self.mlp_2_2(self.sigma_update(self.mlp_2_1(m2)))

        y2 = self.sigma_update(self.mlp_2_update(torch.cat([m2n,m2],1)))

        return y0,y1,y2


class mpsn_layer_sc_order_1(nn.Module):
    def __init__(self, F_in, F_out, b1, l0, l1l, sigma_update, agg, sigma):
        super(mpsn_layer_sc_order_1, self).__init__()
        self.F_in = F_in
        self.F_out = F_out 
#        if agg == 'sum':
        b1[b1!=0]=1
        l0[l0!=0]=1
        l1l[l1l!=0]=1
        self.B1 = b1 
        self.B1t = b1.T
        self.L0 = l0
        self.L1l = l1l
        if agg == 'mean':
            self.B1 = F.normalize(self.B1,p=1,eps=1e-3 )
            self.B1t = F.normalize(self.B1t,p=1,eps=1e-3 )
            self.L0 = F.normalize(self.L0,p=1,eps=1e-3 )
            self.L1l = F.normalize(self.L1l,p=1,eps=1e-3 )

        self.sigma_update = sigma_update # nonlinearity in message update
        self.sigma = sigma # final layer nonliearity

        self.mlp_0p_1 = nn.Linear(F_in*2,F_out)
        self.mlp_0p_2 = nn.Linear(F_out,F_out)
        self.mlp_0_1 = nn.Linear(F_in*2,F_out)
        self.mlp_0_2 = nn.Linear(F_out,F_out)
        self.mlp_0_update = nn.Linear(F_out*2,F_out)

        self.mlp_1n_1 = nn.Linear(F_in*2,F_out)
        self.mlp_1n_2 = nn.Linear(F_out,F_out)
        self.mlp_1d_1 = nn.Linear(F_in*2,F_out)
        self.mlp_1d_2 = nn.Linear(F_out,F_out)
        self.mlp_1_update = nn.Linear(F_out*2,F_out)
        print("created MPSN layers")

    def forward(self,x_in):
        x0,x1 = x_in 
        '''order 0'''
        # print(x0)
        m0p = torch.cat([x0,self.B1@x1],1)# the operation B1@x1 is actually aggregation: sum
        # print(m0p)
        m0p = self.mlp_0p_2(self.sigma_update(self.mlp_0p_1(m0p)))# two layer mlp
        # print(m0p)
        m0 = torch.cat([x0,self.L0@x0],1)
        m0 = self.mlp_0_2(self.sigma_update(self.mlp_0_1(m0)))

        y0 = self.sigma_update(self.mlp_0_update(torch.cat([m0,m0p],1)))

        '''order 1'''
        m1n = torch.cat([x1,self.B1t@x0],1)
        m1n = self.mlp_1n_2(self.sigma_update(self.mlp_1n_1(m1n)))

        m1d = torch.cat([x1,self.L1l@x1],1)
        m1d = self.mlp_1d_2(self.sigma_update(self.mlp_1d_1(m1d)))

        y1 = self.sigma_update(self.mlp_1_update(torch.cat([m1n,m1d],1)))

        return y0,y1

        