import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TransformerConv

# Baseline 1 GCN
class Canos_Model(nn.Module):
    def __init__(self, input_dim, hidden_dim, edge_dim, output_dim, num_heads=4, num_layers=4):
        super(Canos_Model, self).__init__()
        self.encoder = TransformerConv(input_dim, hidden_dim // num_heads, heads=num_heads, edge_dim=edge_dim)
        self.processer1 = GCNConv(hidden_dim, hidden_dim)
        # self.norm1 = nn.LayerNorm(hidden_dim)
        self.processer2 = GCNConv(hidden_dim, hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.decoder = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, edge_attr):
        x = self.encoder(x, edge_index, edge_attr)
        x_res = x
        x = F.relu(self.processer1(x, edge_index))
        x = x + x_res
        x_res = x
        x = F.relu(self.norm2(self.processer2(x, edge_index)))
        x = x + x_res
        return self.decoder(x)

    # save model
    def save(self, model_path):
        # save model weights
        torch.save(self.state_dict(), model_path)
        print(f"Model saved at {model_path}")

    # load model
    def load(self, model_path):
        # load model weights
        self.load_state_dict(torch.load(model_path, weights_only=True))
        print(f"Model loaded from {model_path}")