import numpy as np
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler
import copy
import csv
import os
import sys
torch.set_default_tensor_type('torch.cuda.FloatTensor')



learning_rate = float(sys.argv[1])
tot_epochs = int(sys.argv[2])

transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
trainset = datasets.FashionMNIST('MNIST_data/', download = True, train = True, transform = transform)
testset = datasets.FashionMNIST('MNIST_data/', download = True, train = False, transform = transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 64, shuffle = True)
testloader = torch.utils.data.DataLoader(testset, batch_size = 64, shuffle = True)


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(7*7*32, 10)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

criterion = nn.CrossEntropyLoss()


def Adam(net, ep, lr):
    
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    test_acc = []
    train_loss = []
    
    for epoch in range(ep):
        running_loss = 0
        for images, labels in trainloader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = net.forward(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        correct = 0
        total = 0
        for images, labels in testloader:
            images, labels = images.cuda(), labels.cuda()
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()

        accuracy = 100*correct.item()/total
        test_acc.append(accuracy)
        train_loss.append(running_loss/len(trainloader))
        
        print("Epoch: {}/{}..".format(epoch+1, ep),
          "Training loss: {:.3f}..".format(running_loss/len(trainloader)),
          "Test Accuracy: {:.3f}".format(accuracy))
    
    ret = {}
    ret['train_loss'] = train_loss
    ret['test_acc'] = test_acc
    return(ret)


    
net_ad = CNN()
ad = Adam(net_ad, ep = tot_epochs, lr = learning_rate)
adam_test = ad['test_acc']
adam_train = ad['train_loss']




outfile_test_accuracy_adam = 'outputs/acc_adam_' + str(learning_rate) + '.csv'
with open(outfile_test_accuracy_adam, 'w') as f:
    writer = csv.writer(f)
    writer.writerow(adam_test) 

outfile_train_loss_adam = 'outputs/loss_adam_' + str(learning_rate) + '.csv'
with open(outfile_train_loss_adam, 'w') as f:
    writer = csv.writer(f)
    writer.writerow(adam_train) 


