import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler


import copy
import sys
import numpy as np

import csv
import os
torch.set_default_tensor_type('torch.cuda.FloatTensor')


###########################
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


dir_path = '/path/to/working/dir'
sys_lr = float(sys.argv[1].split('=')[1])
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root=dir_path + '/data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root=dir_path + '/data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False)
                                         
valid_size = 0.0001
num_train = len(trainset)
ind_train = list(range(num_train))
split_train = int(np.floor(valid_size*num_train))
train_idx = ind_train[split_train:]
train_sampler = SubsetRandomSampler(train_idx)

num_test = len(testset)
ind_test = list(range(num_test))
split_test = int(np.floor(valid_size*num_test))
test_idx = ind_test[split_test:]
test_sampler = SubsetRandomSampler(test_idx)

                                         
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')                                         


################################                                         
def SplitSGD(net, K, t1, q, gamma, lr, mom, w):

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=mom, weight_decay=5e-4)

    accuracy = []
    real_epoch = 0
    l = len(trainloader)
    
    for k in range(K):
        for t in range(t1):
            for inputs, labels in trainloader:
                inputs, labels = inputs.cuda(), labels.cuda()
                optimizer.zero_grad()
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            real_epoch += 1

            # Test error on new data    
            correct = 0
            total = 0
            with torch.no_grad():
                for images, labels in testloader:
                    images, labels = images.cuda(), labels.cuda()
                    outputs = net(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

            accuracy.append(100*correct/total)

            print('epoch: %d,  learning rate: %.5f, test accuracy: %.2f' 
                  % (real_epoch, optimizer.param_groups[0]['lr'], 100*correct / total))


        # Split and create new nets and optimizers
        net1 = VGG('VGG19')
        net1.load_state_dict(net.state_dict())                                        
        optimizer1 = torch.optim.SGD(net1.parameters(), lr = lr, momentum = mom, weight_decay=5e-4)
        optimizer1.load_state_dict(optimizer.state_dict())                         
        criterion1 = nn.CrossEntropyLoss()

        net2 = VGG('VGG19')
        net2.load_state_dict(net.state_dict())                                       
        optimizer2 = torch.optim.SGD(net2.parameters(), lr = lr, momentum = mom, weight_decay=5e-4)
        optimizer2.load_state_dict(optimizer.state_dict())                          
        criterion2 = nn.CrossEntropyLoss()        

        # Copy the two net so we can get back the parameters    
        init_params1 = copy.deepcopy(net1)
        init_params2 = copy.deepcopy(net2)

        # Run the two threads separately, each with W windows (1 epoch each window)
        dot_prod = np.array([])
        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.cuda(), labels.cuda()
            if i%2 == 0:
                optimizer1.zero_grad()
                outputs = net1(inputs)
                loss = criterion1(outputs, labels)
                loss.backward()
                optimizer1.step()
            if i%2 == 1:
                optimizer2.zero_grad()
                outputs = net2(inputs)
                loss = criterion2(outputs, labels)
                loss.backward()
                optimizer2.step()
            if i%int(l/w) == int(l/w)-1:                  # We are using w windows dor each thread of the diagnostic
                fin_params1 = net1.state_dict()
                fin_params2 = net2.state_dict()

                for param_tensor in dict(net1.named_parameters()).keys():
                    p1 = fin_params1[param_tensor] - init_params1.state_dict()[param_tensor]
                    p2 = fin_params2[param_tensor] - init_params2.state_dict()[param_tensor]
                    dot_prod = np.append(dot_prod, torch.sum(p1*p2).item())

                init_params1 = copy.deepcopy(net1)
                init_params2 = copy.deepcopy(net2)


        stationarity = (sum(dot_prod < 0) >= q*len(dot_prod))
        if stationarity:
            lr = lr*gamma

        net = VGG('VGG19')
        beta = 0.5 
        params1 = net1.state_dict()
        params2 = net2.state_dict()
        for name1 in params1.keys():
            if name1 in params2.keys():
                params2[name1].data.copy_(beta*params1[name1].data + (1-beta)*params2[name1].data)

        net.load_state_dict(params2)
        optimizer = torch.optim.SGD(net.parameters(), lr = lr, momentum = mom, weight_decay=5e-4)
        criterion = nn.CrossEntropyLoss()


        real_epoch += 1
        correct = 0
        total = 0
        for inputs, labels in testloader:
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy.append(100*correct/total)

        print('D -> epoch: %d, test accuracy: %.2f, stationarity: %s, negative dot products: %i out of %i, learning rate: %.5f' %
                (real_epoch, accuracy[-1], bool(stationarity), sum(dot_prod < 0), len(dot_prod), lr))
        
    return(accuracy)                                         
                                         
net_sp1 = VGG('VGG19')
acc_sp1 = SplitSGD(net_sp1, K=70, t1=4, q=0.25, gamma=0.5, lr=sys_lr, mom=0.9, w=4)

outfile_test_accuracy_splitsgd = dir_path + '/results/vgg_splitsgd_' + str(sys_lr) + '.csv'


with open(outfile_test_accuracy_splitsgd, 'w') as f:
    writer = csv.writer(f)
    writer.writerow(acc_sp1) 
                                      
