import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

import copy

import numpy as np

import csv
import sys
import os
torch.set_default_tensor_type('torch.cuda.FloatTensor')

###########################


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])


dir_path = '/path/to/working/dir'

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root=dir_path + '/data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root=dir_path + '/data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')                                         

################################


def SGD(net_c, ep, mom):
    
    criterion_c = nn.CrossEntropyLoss()
    optimizer_c = optim.SGD(net_c.parameters(), lr=0.1, momentum=mom, weight_decay=5e-4)
    training_loss_c = []
    accuracy_c = []

    for epoch in range(ep):  # loop over the dataset multiple times
        if epoch == 150:
            for g in optimizer_c.param_groups:
                g['lr'] = 0.01
        elif epoch == 250:
            for g in optimizer_c.param_groups:
                g['lr'] = 0.001
        for g in optimizer_c.param_groups:
            print('optimization lr', g['lr'])
            break
        for inputs, labels in trainloader:
            inputs, labels = inputs.cuda(), labels.cuda()
            optimizer_c.zero_grad()
            outputs = net_c(inputs)
            loss = criterion_c(outputs, labels)
            loss.backward()
            optimizer_c.step()

        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.cuda(), labels.cuda()
                outputs = net_c(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy_c.append(100*correct/total)

        print('epoch: %d, lr: %f, test accuracy: %f' %
                (epoch + 1, optimizer_c.param_groups[0]['lr'], 100*correct / total))
    
    return(accuracy_c)                                         
                                         
                                         
                                         
net_sgd1 = ResNet18()
acc_sgd1 = SGD(net_sgd1, ep=350, mom=0.9)

outfile_test_accuracy_sgd = dir_path + '/results/resnet_sgd_3stage.csv'

with open(outfile_test_accuracy_sgd, 'w') as f:
    writer = csv.writer(f)
    writer.writerow(acc_sgd1) 
