import torch
import torchvision
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch import optim
from torch.autograd import Variable
#from sklearn.metrics import accuracy_score, classification_report
import numpy as np
from Trace import LossScaledTrace
#from Batch_Trace import LossScaledTrace
#from ExpandedWeightsTrace import LossScaledTrace
#from Accel_Trace import LossScaledTrace
import time
import copy

folder_num = '11'
def SGDwReg(model, train_data, test_data, lambda1, train_size = 6400, bs = 32, eps = 50, learning_rate=0.01, decay_rate=1.0, weight_decay=0.0, starting_epoch = 0, B=50, seed=1):
    # N_train is the size of the whole MNIST train set, and train_size is the size of the selected train dataset.
    N_train = 60000
    N_test = 10000
    start_time = time.time()
    torch.manual_seed(seed)
    num_epochs = eps
    learning_rate = learning_rate
    #learning_rate = 0.1
    train_bs = bs
    eval_bs = 100
    test_bs = 100

    cnn = copy.deepcopy(model)

    if torch.cuda.is_available():
        print("Working on GPU")
    else:
        print("Working on CPU")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Define the batch size
    eval_data = copy.deepcopy(train_data)
    loaders = {
        'train': DataLoader(train_data,
                                             batch_size=train_bs,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True),
        'eval': DataLoader(eval_data,
                                             batch_size=eval_bs,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True),
        'test': DataLoader(test_data,
                                            batch_size=test_bs,
                                            shuffle=True,
                                            num_workers=1,
                                            pin_memory=True),
        'full': DataLoader(train_data,
                            batch_size=N_train,
                            shuffle=True,
                            num_workers=8,
                            pin_memory=True),
    }

    #for full_images, full_labels in loaders['full']:
    #    full_images = full_images.to(device)
    #    full_labels = full_labels.to(device)

    ProductTraces = []
    Frobeniuses = []
    HessianTraces = []
    Epochs = []
    TrainLosses = []
    TestEpochs = []
    TestLosses = []
    TestAccuracies = []

    '''
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv1 = nn.Sequential(
                nn.Conv2d(
                    in_channels=1,
                    out_channels=16,
                    kernel_size=5,
                    stride=1,
                    padding=2,
                ),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(16, 32, 5, 1, 2),
                nn.ReLU(),
                nn.MaxPool2d(2),
            )
            # fully connected layer, output 10 classes
            self.out = nn.Linear(32 * 7 * 7, 10)
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
            x = x.view(x.size(0), -1)
            output = self.out(x)
            return output, x    # return x for visualization

    cnn = CNN()
    cnn = cnn.to(device)
    cnn = nn.DataParallel(cnn)
    print(cnn)
    '''
    loss_func = nn.CrossEntropyLoss()
    print(loss_func)

    #optimizer = optim.Adam(cnn.parameters(), lr = learning_rate)
    optimizer = optim.SGD(cnn.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, decay_rate)
    print(optimizer)

    def test():
        # Test the model
        cnn.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            accuracy = 0
            for images, labels in loaders['test']:
                images = images.to(device)
                labels = labels.to(device)
                #test_output, last_layer = cnn(images)
                test_output = cnn(images)[0]
                pred_y = torch.max(test_output, 1)[1].data.squeeze()
                accuracy += (pred_y == labels).sum().item() / float(labels.size(0))
            print('SGDwReg Test Accuracy of the model on the 10000 test images: %.4f' % (accuracy/test_bs))
        cnn.train

    def train(num_epochs, cnn, loaders):
        #cnn.train()

        # Compute the number of parameters in the model
        total_params = sum(p.numel() for p in cnn.parameters())
        print(total_params)
        total_params = 0
        for p in cnn.parameters():
            if p.requires_grad:
                total_params += p.numel()
        print((total_params))
        #print("The parameters are {}".format(cnn.parameters()))
        #print("The number of parameters is {}".format(total_params))
        # Train the model
        total_step = len(loaders['train'])

        for ep in range(num_epochs):
            epoch = ep + (1 + starting_epoch)
            for i, (images, labels) in enumerate(loaders['train']):
                # gives batch data, normalize x when iterate train_loader
                b_x = images  # batch x
                b_y = labels  # batch y
                b_x = b_x.to(device)
                b_y = b_y.to(device)


                #output = cnn(b_x)[0]
                #output = cnn(b_x)
                #loss = loss_func(output, b_y)

                # Get loss for the predicted output
                loss_diff = loss_func(cnn(b_x)[0], b_y)# - loss_func(cnn(full_images)[0], full_labels)

                # clear gradients for this training step
                optimizer.zero_grad()
                loss_with_reg = loss_func(cnn(b_x)[0], b_y)
                grads = torch.autograd.grad(loss_diff, cnn.parameters(), create_graph=True, only_inputs=True)#, allow_unused=True)
                #print('grads is {}'.format(grads))
                #print('grads shape is {}'.format(grads.shape))
                for grad in grads:
                    #print('para is {}'.format(grad.shape))
                    loss_with_reg += grad.pow(2).sum() * learning_rate / 4 * lambda1
                # backpropagation, compute gradients
                loss_with_reg.backward()
                # Print the norm of the gradient
                if i == 1:
                    norm = 0
                    for para in cnn.parameters():
                        #print('para is {}'.format(para.shape))
                        if para.requires_grad:
                            norm += torch.norm(para.grad)
                    print('The norm of the gradient is {}'.format(norm))
                # clip gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), 20)
                # apply gradients
                optimizer.step()

                if (i + 1) % 100 == 0:
                    print('Epoch [{}/{}], Step [{}/{}], Loss: {:.14f}'.format(epoch, num_epochs, i + 1, total_step, loss_with_reg.item()))
                    # Print the loss-scaled gradient variance
            '''
            if ((epoch + 1) % 5 == 0) and (epoch > 1000):
                # print(epoch + 1)
                modelcopy = copy.deepcopy(cnn)
                #modelcopy = cnn
                # print(LimitNeighborBV(test_model=modelcopy, inputs=x_train, labels=y_train, d=d,
                #                                N_train=N_train, radiuses=[10e-7]))
                # print(LimitNeighborLossScaledBV(test_model=modelcopy, inputs=x_train, labels=y_train, d=d,
                # N_train=N_train, radiuses=[10e-7]))

                ProductTrace, Frobenius, HessianTrace = LossScaledTrace(test_model=modelcopy,
                                                                        train_data=train_data,
                                                                        d=total_params,
                                                                        train_size=train_size,
                                                                        B=B)
                ProductTraces.append(ProductTrace)
                Frobeniuses.append(Frobenius)
                HessianTraces.append(HessianTrace)
                Epochs.append(epoch + 1)
            '''
            if epoch % 1 == 0:
                # Compute the train loss
                modelcopy = copy.deepcopy(cnn)
                #modelcopy = modelcopy.to(device)
                #modelcopy = nn.DataParallel(modelcopy)
                modelcopy.eval()
                with torch.no_grad():
                    train_loss = 0
                    for j, (images, labels) in enumerate(loaders['eval']):
                        images = images.to(device)
                        labels = labels.to(device)
                        outputs = modelcopy(images)[0]
                        #outputs = modelcopy(images)
                        train_loss += loss_func(outputs, labels).item()
                    train_loss /= len(loaders['eval'])
                TrainLosses.append(train_loss)
            if epoch % 1 == 0:
            #if (epoch + 1) % 10 == 0:
                # Compute the test loss
                cnn.eval()
                with torch.no_grad():
                    test_loss = 0
                    accuracy = 0
                    for images, labels in loaders['test']:
                        images = images.to(device)
                        labels = labels.to(device)
                        test_outputs = cnn(images)[0]
                        #print('testoutputs.is {}'.format(test_outputs))
                        #test_outputs = cnn(images)
                        test_loss += loss_func(test_outputs, labels).item()
                        pred_y = torch.max(test_outputs, 1)[1].data.squeeze()
                        accuracy += (pred_y == labels).sum().item() / float(labels.size(0))
                    test_loss /= len(loaders['test'])
                    accuracy /= len(loaders['test'])
                    TestEpochs.append(epoch)
                    TestLosses.append(test_loss)
                    TestAccuracies.append(accuracy)
                    print('Epoch [{}/{}], Test Loss: {:.8f}'.format(epoch, num_epochs, test_loss))
                cnn.train()
            if epoch % 1 == 0:
                test()
            # Save the model
            if epoch % 10 == 0:
                PATH = 'Saved01/SGDwReg{}/epoch{}.pt'.format(folder_num, epoch)
                torch.save(cnn.state_dict(), PATH)
            scheduler.step()
    train(num_epochs, cnn, loaders)


    test()
    print(f'SGDwReg RunTime: {time.time() - start_time:.2f}')

    return Epochs, ProductTraces, Frobeniuses, HessianTraces, TrainLosses, TestLosses, TestEpochs, TestAccuracies