import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torch.nn as nn

import pickle

from purchase100 import Purchase100, Purchase_Net


with open('YOUR_SAVED_INDICES_PURMUTATION', 'rb') as fp:
    perm = pickle.load(fp)

dataset = Purchase100(perm=perm)
testset = Subset(dataset, indices=list(range(50000, 60000)))

for ind in range(5):
    trainset = Subset(dataset, indices=list(range(ind*10000,(ind+1)*10000)))

    trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
    testloader = DataLoader(testset, batch_size=64, shuffle=True, num_workers=2)

    net = Purchase_Net()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-07)

    for epoch in range(200):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print('[%d, %5d] loss: %.3f' %
                        (epoch + 1, i + 1, running_loss / 100))
                running_loss = 0.0

    torch.save(net.state_dict(), "YOUR_SAVE_PATH")