import torch
from torch import nn
from torch.nn import Linear

from aggregation import p_norm_aggregation_w_weights
from res_layer_via_matmul_molecules import ResolventConvLayerViaMatMul_molecules


class ResConvModel_via_matmul(torch.nn.Module):
    def __init__(self,
                 input_dimension: int,  K_minus, zero_order, p:int, hidden_channel_list = [128],
                 dropout=False, bias = True):
        super(ResConvModel_via_matmul, self).__init__()
        self.dropout = dropout
        self.graph_conv_layers = nn.ModuleList()
        self.p = p

        dimensions = [input_dimension] + list(hidden_channel_list)
        for c_in, c_out in zip(dimensions[:-1], dimensions[1:]):
            self.graph_conv_layers.append(ResolventConvLayerViaMatMul_molecules(
                c_in, c_out, K_minus=K_minus, zero_order=zero_order, 
                bias=bias))

        self.lin =  Linear(dimensions[-1], 1)
        
    def forward(self, x, Z, edge_index, edge_attr, batch, output_hidden_features=False):
        x_per_layer = [x]
        for convlayer in self.graph_conv_layers:
            x = convlayer(x, edge_index, edge_attr)
            x = x.relu()
            x_per_layer.append(x)
        x_pre_lin = p_norm_aggregation_w_weights(x=x, weights=Z, batch=batch, p=self.p) 
        x_out = self.lin(x_pre_lin)

        if output_hidden_features:
            return x_out, x_pre_lin
        else:
            return x_out


