import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from dataset2 import load_data2
from model2 import Specific
import numpy as np
from sklearn.preprocessing import StandardScaler
from torch_geometric import TUDataset

data_loader = load_data2()
input_size = data_loader.input_dim()
target_num = 3

model = Specific(input_size, target_num)
model = model.to(device)

device = torch.device('cuda')

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

epoches = 1000

for epoch in range(epoches):
    epoch_loss = 0
    for batch_idx, (data, target) in enumerate(data_loader):
        data = data.to(device)
        target = target.to(device)

        output = model(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += float(loss.item())
    print("loss", epoch, float(epoch_loss/len(data_loader)))
    
torch.save(model, './model2.pt')