import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler


import copy
import sys
import numpy as np

import csv
import os
torch.set_default_tensor_type('torch.cuda.FloatTensor')


###########################
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])


dir_path = '/path/to/working/dir'
sys_lr = float(sys.argv[1].split('=')[1])
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root=dir_path + '/data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root=dir_path + '/data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False)
                                         
valid_size = 0.0001
num_train = len(trainset)
ind_train = list(range(num_train))
split_train = int(np.floor(valid_size*num_train))
train_idx = ind_train[split_train:]
train_sampler = SubsetRandomSampler(train_idx)

num_test = len(testset)
ind_test = list(range(num_test))
split_test = int(np.floor(valid_size*num_test))
test_idx = ind_test[split_test:]
test_sampler = SubsetRandomSampler(test_idx)

                                         
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')                                         


################################                                         
def SplitSGD(net, K, t1, q, gamma, lr, mom, w):

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=mom, weight_decay=5e-4)

    accuracy = []
    real_epoch = 0
    l = len(trainloader)
    
    for k in range(K):
        for t in range(t1):
            for inputs, labels in trainloader:
                inputs, labels = inputs.cuda(), labels.cuda()
                optimizer.zero_grad()
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            real_epoch += 1

            # Test error on new data    
            correct = 0
            total = 0
            with torch.no_grad():
                for images, labels in testloader:
                    images, labels = images.cuda(), labels.cuda()
                    outputs = net(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

            accuracy.append(100*correct/total)

            print('epoch: %d,  learning rate: %.5f, test accuracy: %.2f' 
                  % (real_epoch, optimizer.param_groups[0]['lr'], 100*correct / total))


        # Split and create new nets and optimizers
        net1 = ResNet18()                                                               
        net1.load_state_dict(net.state_dict())                                        
        optimizer1 = torch.optim.SGD(net1.parameters(), lr = lr, momentum = mom, weight_decay=5e-4)
        optimizer1.load_state_dict(optimizer.state_dict())                         
        criterion1 = nn.CrossEntropyLoss()

        net2 = ResNet18()                                                              
        net2.load_state_dict(net.state_dict())                                       
        optimizer2 = torch.optim.SGD(net2.parameters(), lr = lr, momentum = mom, weight_decay=5e-4)
        optimizer2.load_state_dict(optimizer.state_dict())                          
        criterion2 = nn.CrossEntropyLoss()        

        # Copy the two net so we can get back the parameters    
        init_params1 = copy.deepcopy(net1)
        init_params2 = copy.deepcopy(net2)

        # Run the two threads separately, each with W windows (1 epoch each window)
        dot_prod = np.array([])
        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.cuda(), labels.cuda()
            if i%2 == 0:
                optimizer1.zero_grad()
                outputs = net1(inputs)
                loss = criterion1(outputs, labels)
                loss.backward()
                optimizer1.step()
            if i%2 == 1:
                optimizer2.zero_grad()
                outputs = net2(inputs)
                loss = criterion2(outputs, labels)
                loss.backward()
                optimizer2.step()
            if i%int(l/w) == int(l/w)-1:                  # We are using w windows dor each thread of the diagnostic
                fin_params1 = net1.state_dict()
                fin_params2 = net2.state_dict()

                for param_tensor in dict(net1.named_parameters()).keys():
                    p1 = fin_params1[param_tensor] - init_params1.state_dict()[param_tensor]
                    p2 = fin_params2[param_tensor] - init_params2.state_dict()[param_tensor]
                    dot_prod = np.append(dot_prod, torch.sum(p1*p2).item())

                init_params1 = copy.deepcopy(net1)
                init_params2 = copy.deepcopy(net2)


        stationarity = (sum(dot_prod < 0) >= q*len(dot_prod))
        if stationarity:
            lr = lr*gamma

        net = ResNet18() 
        beta = 0.5 
        params1 = net1.state_dict()
        params2 = net2.state_dict()
        for name1 in params1.keys():
            if name1 in params2.keys():
                params2[name1].data.copy_(beta*params1[name1].data + (1-beta)*params2[name1].data)

        net.load_state_dict(params2)
        optimizer = torch.optim.SGD(net.parameters(), lr = lr, momentum = mom, weight_decay=5e-4)
        criterion = nn.CrossEntropyLoss()


        real_epoch += 1
        correct = 0
        total = 0
        for inputs, labels in testloader:
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy.append(100*correct/total)

        print('D -> epoch: %d, test accuracy: %.2f, stationarity: %s, negative dot products: %i out of %i, learning rate: %.5f' %
                (real_epoch, accuracy[-1], bool(stationarity), sum(dot_prod < 0), len(dot_prod), lr))
        
    return(accuracy)                                         
                                         
net_sp1 = ResNet18()
acc_sp1 = SplitSGD(net_sp1, K=70, t1=4, q=0.25, gamma=0.5, lr=sys_lr, mom=0.9, w=4)

outfile_test_accuracy_splitsgd = dir_path + '/results/resnet_splitsgd_' + str(sys_lr) + '.csv'


with open(outfile_test_accuracy_splitsgd, 'w') as f:
    writer = csv.writer(f)
    writer.writerow(acc_sp1) 
                                      
