from torch_geometric.nn import GATConv, GCNConv, GCN2Conv, Linear
from torch.nn.functional import mse_loss, relu
from layers.GCN import BaseModel


class Model(BaseModel):
    def __init__(self, configs):
        self.in_channels = configs.seq_len
        self.out_channels = configs.pred_len
        self.hidden_channels = configs.d_model
        self.num_hidden = configs.e_layers
        self.param_sharing = False
        self.edge_index = configs.edge_index
        self.edge_weights = configs.edge_attr

        layer_gen = lambda: GCNConv(self.hidden_channels, self.hidden_channels, add_self_loops=False)
        super().__init__(self.in_channels, self.out_channels, self.hidden_channels, self.num_hidden,
                         self.param_sharing, layer_gen, self.edge_weights)

    def apply_layer(self, layer, x, x_0, edge_index, edge_weights):
        return relu(layer(x, edge_index, edge_weights))