import copy
import numpy as np
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler
from GradientVariance import GradientVariance
from BootstrapLoss import BootstrapLoss
from Trace import LimitNeighborLossScaledTrace, TakeGradient
import time


def SGDwReg(ini_model, x_train, y_train, x_test, y_test, d, N_train, learningrate, eps, bs, Replacement, seed, K, KBS, ComputeGV, ComputeBL, lambda1, lambda2):
    #lambda1 = 10
    #lambda2 = 30
    starting_time = time.time()
    torch.manual_seed(seed)
    model = copy.deepcopy(ini_model)
    ites = int(eps * N_train / bs)
    dim = sum(p.numel() for p in model.parameters())
    ModelTraj = torch.zeros((ites, dim))
    TrainLosses = np.zeros(ites)
    TestLosses = np.zeros(ites)
    GradientVariances = np.zeros(ites)
    GradientNorms = np.zeros(ites)
    BootstrapLosses = np.zeros(ites)
    Products = np.zeros(ites)
    Frobeniuses = np.zeros(ites)
    HessianTraces = np.zeros(ites)
    CovarianceTraces = np.zeros(ites)
    AccProducts = np.zeros(ites)
    Hessian = torch.zeros((dim, dim))

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learningrate)

    #Converting the training inputs and labels to Variable
    if False:#torch.cuda.is_available():
        inputs = Variable(torch.from_numpy(x_train).cuda())
        labels = Variable(torch.from_numpy(y_train).cuda())
    else:
        inputs = Variable(torch.from_numpy(x_train))
        labels = Variable(torch.from_numpy(y_train))
    dataset = TensorDataset(inputs, labels)
    train_dataset = dataset
    # print(inputs)
    #print(dataset[0])

    # Converting the test inputs and labels to Variable
    if False:#torch.cuda.is_available():
        test_inputs = Variable(torch.from_numpy(x_test).cuda())
        test_labels = Variable(torch.from_numpy(y_test).cuda())
    else:
        test_inputs = Variable(torch.from_numpy(x_test))
        test_labels = Variable(torch.from_numpy(y_test))
    #Create DataLoader
    if Replacement:
        sampler = RandomSampler(train_dataset, replacement=True, num_samples=N_train)
        train_loader = DataLoader(dataset=train_dataset, batch_size=bs, sampler=sampler)
    else:
        train_loader = DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)

    # Initialize the accumulated covariance matrix
    AccCovariance = torch.zeros(dim, dim)
    i = 0
    for epoch in range(eps):
        for x, y in train_loader:
            #Clear gradient buffers because we don't want any gradient from previous epoch to carry forward, dont want to cummulate gradients
            optimizer.zero_grad()
            #print("i={}".format(i))
            #i+=1

            #Get loss for the predicted output
            loss_diff = criterion(model(x), y) - criterion(model(inputs), labels)
            # Take the gradient of the original loss
            #loss_diff.backward()

            #Add the additional regularization to the loss
            '''
            #single_grad = TakeGradient(x, y, model, criterion)
            #full_grad = TakeGradient(inputs, labels, model, criterion)
            single_grad = torch.tensor([])
            for p in model.parameters():
                if p.requires_grad:
                    # print(torch.autograd.grad(loss, p))
                    single_grad = torch.cat(
                        (single_grad, torch.autograd.grad(criterion(model(x),y), p, retain_graph=True)[0].view(-1, 1)), 0)
            full_grad = torch.tensor([])
            for p in model.parameters():
                if p.requires_grad:
                    # print(torch.autograd.grad(loss, p))
                    full_grad = torch.cat(
                        (full_grad, torch.autograd.grad(criterion(model(inputs), labels), p, retain_graph=True)[0].view(-1, 1)),
                        0)
            #print('diff{}'.format((single_grad - full_grad).shape))
            grad_diff = single_grad - full_grad
            regularization = (grad_diff @ grad_diff.T)[0,0]
            regularization = Variable(regularization, requires_grad=True)
            loss = regularization
            '''
            loss_with_reg = criterion(model(x), y)
            grads = torch.autograd.grad(loss_diff, model.parameters(), create_graph=True, only_inputs=True)
            #print('grads is {}'.format(grads))
            for grad in grads:
                loss_with_reg += grad.pow(2).sum() * learningrate / 4 * lambda1
            # Get loss for each sample in the train set

            single_loader = DataLoader(dataset=dataset, batch_size=1)
            for x_sample, y_sample in single_loader:
                loss_sample = criterion(model(x_sample), y_sample)  # - criterion(model(inputs), labels)
                grads = torch.autograd.grad(loss_sample, model.parameters(), create_graph=True, only_inputs=True)
                for grad in grads:
                    loss_with_reg += grad.pow(2).sum() * learningrate / (4 * N_train) * lambda2
                optimizer.zero_grad()

            '''
            for p in model.parameters():
                if p.requires_grad:
                    gradient = torch.autograd.grad(loss_diff, p, retain_graph=True)[0]
                    print('gradient is {}'.format(gradient))
                    loss_with_reg += gradient.norm(dim=1) ** 2#(gradient @ gradient.T)[0, 0]
            '''
            #print(loss)
            #Get gradients w.r.t to parameters
            optimizer.zero_grad()
            loss_with_reg.backward()

            #Update parameters
            optimizer.step()

            # Save the model trajectory
            Parameter = torch.tensor([])
            for p in model.parameters():
                if p.requires_grad:
                    Parameter = torch.cat((Parameter, p.view(1, -1)), 1)
            ModelTraj[i] = Parameter
            #ModelTraj[epoch + 1] = model.linear.weight.clone()
            #elif Model == 'ReLUNet':
            #    ModelTraj[epoch + 1] = model.layers[0].weight.clone()
            train_loss = criterion(model(inputs), labels).item()
            optimizer.zero_grad()
            TrainLosses[i] = train_loss
            test_loss = criterion(model(test_inputs), test_labels).item()
            optimizer.zero_grad()
            TestLosses[i] = test_loss
            if i % K == 0 and ComputeGV == True:
                GV, GN = GradientVariance(model, inputs, labels, N_train)
                GradientVariances[i:i+K] = GV
                GradientNorms[i:i+K] = GN
                Product, Frobenius, HessianTrace, CovarianceTrace, Hessian, NewCovariance = LimitNeighborLossScaledTrace(test_model=model, inputs=x_train, labels=y_train, d=d,
                                             N_train=N_train, radiuses=[10e-12])
                Products[i:i+K] = Product
                Frobeniuses[i:i + K] = Frobenius
                HessianTraces[i:i + K] = HessianTrace
                CovarianceTraces[i:i + K] = CovarianceTrace
                AccCovariance += NewCovariance
                AccProducts[i:i + K] = torch.trace(torch.mm(Hessian, AccCovariance))
            if i % KBS == 0 and ComputeBL == True:
                BL = BootstrapLoss(model, inputs, labels, N_train, 1)
                BootstrapLosses[i:i+KBS] = BL
            i += 1

    print("SGDwReg AccCovariance Trace is {}".format(torch.trace(AccCovariance)))
    print('SGDwReg runtime is {}'.format(time.time()-starting_time))
    return ModelTraj, TrainLosses, TestLosses, GradientVariances, GradientNorms, BootstrapLosses, Products, Frobeniuses, HessianTraces, CovarianceTraces, AccProducts, Hessian
