import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

import copy

import numpy as np

import csv
import sys
import os
torch.set_default_tensor_type('torch.cuda.FloatTensor')

###########################


cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


dir_path = '/path/to/working/dir'
sys_lr = float(sys.argv[1].split('=')[1])

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root=dir_path + '/data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root=dir_path + '/data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')                                         

################################


def SGD(net_c, ep, lr, mom):
    
    criterion_c = nn.CrossEntropyLoss()
    optimizer_c = optim.SGD(net_c.parameters(), lr=lr, momentum=mom, weight_decay=5e-4)
    training_loss_c = []
    accuracy_c = []

    for epoch in range(ep):  # loop over the dataset multiple times

        for inputs, labels in trainloader:
            inputs, labels = inputs.cuda(), labels.cuda()
            optimizer_c.zero_grad()
            outputs = net_c(inputs)
            loss = criterion_c(outputs, labels)
            loss.backward()
            optimizer_c.step()

        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.cuda(), labels.cuda()
                outputs = net_c(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy_c.append(100*correct/total)

        print('epoch: %d, lr: %f, test accuracy: %f' %
                (epoch + 1, optimizer_c.param_groups[0]['lr'], 100*correct / total))
    
    return(accuracy_c)                                         
                                         

net_sgd1 = VGG('VGG19')
acc_sgd1 = SGD(net_sgd1, ep=350, lr=sys_lr, mom=0.9)

outfile_test_accuracy_sgd = dir_path + '/results/vgg_sgd_' + str(sys_lr) + '.csv'

with open(outfile_test_accuracy_sgd, 'w') as f:
    writer = csv.writer(f)
    writer.writerow(acc_sgd1) 
