import numpy as np
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler
import copy
import csv
import os
import sys
torch.set_default_tensor_type('torch.cuda.FloatTensor')



learning_rate = float(sys.argv[1])
t1 = int(sys.argv[2])
q = float(sys.argv[3])
w = int(sys.argv[4])

transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
trainset = datasets.FashionMNIST('Fashion_Mnist_data/', download = True, train = True, transform = transform)
testset = datasets.FashionMNIST('Fashion_Mnist_data/', download = True, train = False, transform = transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 1024, shuffle = True)
testloader = torch.utils.data.DataLoader(testset, batch_size = 1024, shuffle = True)


class Feedforward(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)
        
    def forward(self, x):
        # make sure input tensor is flattened
        x = x.view(x.shape[0], -1)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.log_softmax(self.fc4(x), dim=1)
        
        return x    

criterion = nn.NLLLoss()


def SplitSGD(net, K, t1, q, w, gamma, lr, mom):

    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=mom)
    real_epoch = 0
    l = len(trainloader)
    test_acc = []
    train_loss = []

    for k in range(K):
        for t in range(t1):
            running_loss = 0
            for images, labels in trainloader: 
                images, labels = images.cuda(), labels.cuda()           
                optimizer.zero_grad()
                outputs = net.forward(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                
            real_epoch += 1 
            train_loss.append(running_loss/len(trainloader))
            
            correct = 0
            total = 0
            for images, labels in testloader:
                images, labels = images.cuda(), labels.cuda()
                outputs = net.forward(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum()

            accuracy = 100*correct.item()/total
            test_acc.append(accuracy)

            print("Epoch: {}/{}..".format(real_epoch, K*(t1+1)),
                  "Training loss: {:.3f}..".format(running_loss/len(trainloader)),
                  "Test Accuracy: {:.3f}..".format(accuracy),
                  "lr: {:.3f}".format(lr))


        net1 = copy.deepcopy(net)                                                           
        net1.load_state_dict(net.state_dict())                                        
        optimizer1 = torch.optim.SGD(net1.parameters(), lr = lr, momentum = mom)     
        optimizer1.load_state_dict(optimizer.state_dict())                         

        net2 = copy.deepcopy(net)                                                        
        net2.load_state_dict(net.state_dict())                                       
        optimizer2 = torch.optim.SGD(net2.parameters(), lr = lr, momentum = mom)     
        optimizer2.load_state_dict(optimizer.state_dict())                          

        # Copy the two net so we can get back the parameters    
        init_params1 = copy.deepcopy(net1)
        init_params2 = copy.deepcopy(net2)

        dot_prod = []
        running_loss = 0
        for i, (images, labels) in enumerate(trainloader):
            images, labels = images.cuda(), labels.cuda()
            if i%2 == 0:
                optimizer1.zero_grad()
                outputs = net1.forward(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer1.step()
                running_loss += loss.item()
            if i%2 == 1:
                optimizer2.zero_grad()
                outputs = net2.forward(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer2.step()
                running_loss += loss.item()
            if i%int(l/w) == int(l/w)-1: 
                fin_params1 = net1.state_dict()
                fin_params2 = net2.state_dict()

                for param_tensor in dict(net1.named_parameters()).keys():
                    p1 = fin_params1[param_tensor] - init_params1.state_dict()[param_tensor]
                    p2 = fin_params2[param_tensor] - init_params2.state_dict()[param_tensor]
                    dot_prod.append(torch.sum(p1*p2))

                init_params1 = copy.deepcopy(net1)
                init_params2 = copy.deepcopy(net2)


        stationarity = (sum([dot_prod[i] < 0 for i in range(len(dot_prod))]) >= q*len(dot_prod))
        if stationarity:
            lr = lr*gamma

        net = copy.deepcopy(net1) 
        beta = 0.5 
        params1 = net1.state_dict()
        params2 = net2.state_dict()
        for name1 in params1.keys():
            if name1 in params2.keys():
                params2[name1].data.copy_(beta*params1[name1].data + (1-beta)*params2[name1].data)

        net.load_state_dict(params2)
        optimizer = torch.optim.SGD(net.parameters(), lr = lr, momentum = mom)

        real_epoch += 1
        
        correct = 0
        total = 0
        
        for images, labels in testloader:
            images, labels = images.cuda(), labels.cuda()
            outputs = net.forward(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()

        accuracy = 100*correct.item()/total
        test_acc.append(accuracy)
        train_loss.append(running_loss/len(trainloader))
        
        print("D -> Epoch: {}/{}..".format(real_epoch, K*(t1+1)),
             "Training loss: {:.3f}..".format(running_loss/len(trainloader)),
             "Test Accuracy: {:.3f}..".format(accuracy),
             "Stationarity: {}..".format(bool(stationarity)),
             "Negative dot products: {}/{}..".format(sum([dot_prod[i] < 0 for i in range(len(dot_prod))]), 
                                                     len(dot_prod)),
             "lr: {:.4f}".format(lr))

        
    ret = {}
    ret['train_loss'] = train_loss
    ret['test_acc'] = test_acc
    return(ret)


    
net_sp = Feedforward()
sp = SplitSGD(net=net_sp, K=20, t1=t1, q=q, w=w, gamma=0.5, lr=learning_rate, mom=0.9)
splitsgd_test = sp['test_acc']
splitsgd_train = sp['train_loss']






outfile_test_accuracy_splitsgd = 'outputs/acc_splitsgd_FF_fashion_lr' + str(learning_rate) + '.csv'
with open(outfile_test_accuracy_splitsgd, 'w') as f:
    writer = csv.writer(f)
    writer.writerow(splitsgd_test) 

outfile_train_loss_splitsgd = 'outputs/loss_splitsgd_FF_fashion_lr' + str(learning_rate) + '.csv'
with open(outfile_train_loss_splitsgd, 'w') as f:
    writer = csv.writer(f)
    writer.writerow(splitsgd_train) 



