import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TransformerConv, GATConv, GIN

# Baseline 1 GCN
class GNN_Model(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=4, type='baselines'):
        super(GNN_Model, self).__init__()

        self.num_layers = num_layers
        self.type = type
        if type == 'gin':
            self.model = GIN(input_dim, hidden_dim, num_layers, output_dim)
        else:
            self.layers = nn.ModuleList()
            for i in range(num_layers):
                in_channels = input_dim if i == 0 else hidden_dim
                out_channels = hidden_dim if i < num_layers - 1 else output_dim
                if type in ['baselines', 'deepopf']:
                    conv = GCNConv(in_channels, out_channels)
                elif type == 'gat':
                    conv = GATConv(in_channels, out_channels)
                else:
                    raise NotImplementedError
                self.layers.append(conv)

    def forward(self, x, edge_index):
        if self.type == 'gin':
            x = self.model(x, edge_index)
        else:
            for i, layer in enumerate(self.layers):
                x_res = x
                x = F.gelu(layer(x, edge_index))
                if i > 0 and i < self.num_layers - 1:
                    x = x + x_res
        return x

    # save model
    def save(self, model_path):
        # save model weights
        torch.save(self.state_dict(), model_path)
        print(f"Model saved at {model_path}")

    # load model
    def load(self, model_path):
        # load model weights
        self.load_state_dict(torch.load(model_path, weights_only=True))
        print(f"Model loaded from {model_path}")