import torch

# Check available device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Download & verify the MNIST dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

train_data = datasets.MNIST(
    root = 'data',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
)

test_data = datasets.MNIST(
    root = 'data', 
    train = False, 
    transform = ToTensor()
)

print(train_data)
print(test_data)
print(train_data.data.size())
print(test_data.data.size())

import matplotlib.pyplot as plt

plt.imshow(train_data.data[0], cmap='gray')
plt.title('%i' % train_data.targets[0])
plt.savefig('train_example.pdf')

# Prepare dataset for training
from torch.utils.data import DataLoader

loaders = {
    'train' : DataLoader(train_data, batch_size=100, shuffle=True, num_workers=1),
    'test'  : DataLoader(test_data, batch_size=100, shuffle=True, num_workers=1)
}

# Define the model
from model import CNN

cnn = CNN()

# Loss function
import torch.nn as nn

loss_func = nn.CrossEntropyLoss()

# Optimization method
from torch import optim

optimizer = optim.Adam(cnn.parameters(), lr=0.01)

# Training procedure
from torch.autograd import Variable
num_epochs = 10

def train(num_epochs, cnn, loaders):
    
    cnn.train()
        
    # Train the model
    total_step = len(loaders['train'])
        
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(loaders['train']):
            
            # gives batch data, normalize x when iterate train_loader
            b_x = Variable(images)   # batch x
            b_y = Variable(labels)   # batch y
            output = cnn(b_x)[0]               
            loss = loss_func(output, b_y)
            
            # clear gradients for this training step   
            optimizer.zero_grad()           
            
            # backpropagation, compute gradients 
            loss.backward()                
            # apply gradients             
            optimizer.step()                
            
            if (i+1) % 100 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
            
            pass
        
        pass
        
    pass

train(num_epochs, cnn, loaders)

# Save the trained model
torch.save(cnn.state_dict(), 'model_trained.pth')
