import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from dataset import load_data
from model import GraphProp
import numpy as np
from sklearn.preprocessing import StandardScaler


model = GraphProp()

device = torch.device('cuda')
model = model.to(device)


data_loader = load_data()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

def min_max_normalize(tensor):
    tensor_min = tensor.min(dim=0, keepdim=True)[0]  # Minimum value per feature
    tensor_max = tensor.max(dim=0, keepdim=True)[0]  # Maximum value per feature
    return (tensor - tensor_min) / (tensor_max - tensor_min)

epoches = 50

best_loss = []


for epoch in range(epoches):
    epoch_loss = 0
    for batch_idx, (data, target) in enumerate(data_loader):
        data = data.to(device)

        target = min_max_normalize(target)
        
        target = target.to(device)
        output = model(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += float(loss.item())

    
    print("loss",epoch, float(epoch_loss/len(data_loader)))

torch.save(model, './model.pt')