import torch
from torch_geometric.datasets import MD17
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SchNet
from torch.optim import Adam

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the MD17 dataset (ethanol as an example molecule)
datadir = '/data/NFS/potato/username/MD17'
dataset = MD17(root=datadir, name="revised aspirin")

# Split into training, validation, and test sets
split = 1000
train_dataset = dataset[:int(0.8 * split)]
val_dataset = dataset[int(0.8 * split):int(0.9 * split)]
test_dataset = dataset[int(0.9 * split):]

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

# Initialize SchNet model
model = SchNet(hidden_channels=128, num_filters=128, num_interactions=6, cutoff=5.0, num_gaussians=50)
model = model.to(device)

# Optimizer
optimizer = Adam(model.parameters(), lr=1e-3)

# Loss function
loss_fn = torch.nn.MSELoss()

# Training loop
def train():
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        # Forward pass
        pred = model(batch.z, batch.pos, batch.batch)
        # Compute loss
        loss = loss_fn(pred, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

# Validation loop
def validate():
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            pred = model(batch.z, batch.pos, batch.batch)
            loss = loss_fn(pred, batch.y)
            total_loss += loss.item()
    return total_loss / len(val_loader)

# Training
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train_loss = train()
    val_loss = validate()
    print(f"Epoch {epoch:02d}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

# Test the model
def test():
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            pred = model(batch.z, batch.pos, batch.batch)
            loss = loss_fn(pred, batch.y)
            total_loss += loss.item()
    return total_loss / len(test_loader)

test_loss = test()
print(f"Test Loss: {test_loss:.4f}")
