import numpy as np
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler
import copy
import csv
import os
import sys
torch.set_default_tensor_type('torch.cuda.FloatTensor')



learning_rate = float(sys.argv[1])
tot_epochs = int(sys.argv[2])

transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
trainset = datasets.FashionMNIST('Fashion_Mnist_data/', download = True, train = True, transform = transform)
testset = datasets.FashionMNIST('Fashion_Mnist_data/', download = True, train = False, transform = transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 1024, shuffle = True)
testloader = torch.utils.data.DataLoader(testset, batch_size = 1024, shuffle = True)


class Feedforward(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)
        
    def forward(self, x):
        # make sure input tensor is flattened
        x = x.view(x.shape[0], -1)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.log_softmax(self.fc4(x), dim=1)
        
        return x    

   
criterion = nn.NLLLoss()


def SGD(net, ep, lr, mom):
    
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum = mom)

    test_acc = []
    train_loss = []
    
    for epoch in range(ep):
        running_loss = 0
        for images, labels in trainloader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = net.forward(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        correct = 0
        total = 0
        for images, labels in testloader:
            images, labels = images.cuda(), labels.cuda()
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()

        accuracy = 100*correct.item()/total
        test_acc.append(accuracy)
        train_loss.append(running_loss/len(trainloader))
        
        print("Epoch: {}/{}..".format(epoch+1, ep),
          "Training loss: {:.3f}..".format(running_loss/len(trainloader)),
          "Test Accuracy: {:.3f}".format(accuracy))
    
    ret = {}
    ret['train_loss'] = train_loss
    ret['test_acc'] = test_acc
    return(ret)


    
net_s = Feedforward()
s = SGD(net_s, ep = tot_epochs, lr = learning_rate, mom = 0.9)
sgd_test = s['test_acc']
sgd_train = s['train_loss']






outfile_test_accuracy_sgd = 'outputs/acc_sgd_FF_fashion_lr' + str(learning_rate) + '.csv'
with open(outfile_test_accuracy_sgd, 'w') as f:
    writer = csv.writer(f)
    writer.writerow(sgd_test) 

outfile_train_loss_sgd = 'outputs/loss_sgd_FF_fashion_lr' + str(learning_rate) + '.csv'
with open(outfile_train_loss_sgd, 'w') as f:
    writer = csv.writer(f)
    writer.writerow(sgd_train) 



    
