from torch_geometric.nn import GATConv, GCNConv, GCN2Conv, Linear
from torch.nn.functional import mse_loss, relu
from layers.GCN import BaseModel


class Model(BaseModel):
    def __init__(self, configs):
        self.in_channels = configs.seq_len
        self.out_channels = configs.pred_len
        self.hidden_channels = configs.d_model
        self.num_hidden = configs.e_layers
        self.param_sharing = False
        self.edge_index = configs.edge_index
        self.edge_weights = configs.edge_attr

        layer_gen = lambda: GATConv(self.hidden_channels, self.hidden_channels, add_self_loops=False)
        super().__init__(self.in_channels, self.out_channels, self.hidden_channels, self.num_hidden,
                         self.param_sharing, layer_gen, self.edge_weights)

    def apply_layer(self, layer, x, x_0, edge_index, edge_weights):
        if edge_weights.dim() == 1:
            edge_index = edge_index[:, edge_weights != 0]
        B, N, D = x.shape
        x_reshape = x.view(-1, D)
        out = relu(layer(x_reshape, edge_index, edge_weights))
        out = out.view(B, N, -1)
        return x + out