#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
sys.path.append(".")
sys.path.append("..")

import numpy as np
import torch
import torchmetrics
import torch.nn as nn
from model.mpsn_layer import mpsn_layer

class mpsn(nn.Module):
    def __init__(self, F_in, F_intermediate, F_out, l1l, l1u, sigma_update,agg, sigma, model_name):
        super(mpsn, self).__init__()
        self.num_features = [F_in] + [F_intermediate[l] for l in range(len(F_intermediate))] + [F_out] # number of features vector e.g., [1 5 5 5 1]
        self.num_layers = len(self.num_features) 
        self.l1l = l1l
        self.l1u = l1u

        self.sigma_update=sigma_update
        self.agg = agg
        self.sigma = sigma 
        nn_layer = []

        for l in range(self.num_layers-1):
            hyperparameters = {"F_in":self.num_features[l],"F_out":self.num_features[l+1], "l1l":self.l1l, "l1u":self.l1u, "sigma_update":self.sigma_update, "agg":self.agg,"sigma":self.sigma}
            nn_layer.extend([mpsn_layer(**hyperparameters)]) 
        
        self.simplicial_nn = nn.Sequential(*nn_layer)

    def forward(self,x):
        return self.simplicial_nn(x)#.view(-1,1).T