import copy
import numpy as np
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler
from GradientVariance import GradientVariance
from BootstrapLoss import BootstrapLoss
from Trace import LimitNeighborLossScaledTrace
import time


def SGD(ini_model, x_train, y_train, x_test, y_test, d, N_train, learningrate, eps, bs, Replacement, seed, K, KBS, ComputeGV, ComputeBL):
#def SGD(learningrate=LearningRate, eps=epochs, bs=train_bs, seed=ManualSeed):
    starting_time = time.time()
    torch.manual_seed(seed)
    model = copy.deepcopy(ini_model)
    ites = int(eps * N_train / bs)
    dim = sum(p.numel() for p in model.parameters())
    ModelTraj = torch.zeros((ites, dim))
    TrainLosses = np.zeros(ites)
    TestLosses = np.zeros(ites)
    GradientVariances = np.zeros(ites)
    GradientNorms = np.zeros(ites)
    BootstrapLosses = np.zeros(ites)
    Products = np.zeros(ites)
    Frobeniuses = np.zeros(ites)
    HessianTraces = np.zeros(ites)
    CovarianceTraces = np.zeros(ites)
    AccProducts = np.zeros(ites)
    Hessian = torch.zeros((dim, dim))

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learningrate)

    #Converting the training inputs and labels to Variable
    if False:#torch.cuda.is_available():
        inputs = Variable(torch.from_numpy(x_train).cuda())
        labels = Variable(torch.from_numpy(y_train).cuda())
    else:
        inputs = Variable(torch.from_numpy(x_train))
        labels = Variable(torch.from_numpy(y_train))
    dataset = TensorDataset(inputs, labels)
    train_dataset = dataset
    # print(inputs)
    #print(dataset[0])

    # Converting the test inputs and labels to Variable
    if False:#torch.cuda.is_available():
        test_inputs = Variable(torch.from_numpy(x_test).cuda())
        test_labels = Variable(torch.from_numpy(y_test).cuda())
    else:
        test_inputs = Variable(torch.from_numpy(x_test))
        test_labels = Variable(torch.from_numpy(y_test))
    #Create DataLoader
    if Replacement:
        sampler = RandomSampler(train_dataset, replacement=True, num_samples=N_train)
        train_loader = DataLoader(dataset=train_dataset, batch_size=bs, sampler=sampler)
    else:
        train_loader = DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)

    # Initialize the accumulated covariance matrix
    AccCovariance = torch.zeros(dim, dim)
    i = 0
    for epoch in range(eps):
        for x, y in train_loader:
            #Clear gradient buffers because we don't want any gradient from previous epoch to carry forward, dont want to cummulate gradients
            optimizer.zero_grad()
            #print("i={}".format(i))
            #i+=1
            #Get output from the model, given the inputs
            y_pred = model(x)

            #Get loss for the predicted output
            loss = criterion(y_pred, y)
            #print(loss)
            #Get gradients w.r.t to parameters
            loss.backward()

            #Update parameters
            optimizer.step()

            # Save the model trajectory
            Parameter = torch.tensor([])
            for p in model.parameters():
                if p.requires_grad:
                    Parameter = torch.cat((Parameter, p.view(1, -1)), 1)
            ModelTraj[i] = Parameter
            #ModelTraj[epoch + 1] = model.linear.weight.clone()
            #elif Model == 'ReLUNet':
            #    ModelTraj[epoch + 1] = model.layers[0].weight.clone()
            train_loss = criterion(model(inputs), labels).item()
            optimizer.zero_grad()
            TrainLosses[i] = train_loss
            test_loss = criterion(model(test_inputs), test_labels).item()
            optimizer.zero_grad()
            TestLosses[i] = test_loss
            if i % K == 0 and ComputeGV == True:
                GV, GN = GradientVariance(model, inputs, labels, N_train)
                GradientVariances[i:i+K] = GV
                GradientNorms[i:i+K] = GN
                Product, Frobenius, HessianTrace, CovarianceTrace, Hessian, NewCovariance = LimitNeighborLossScaledTrace(test_model=model, inputs=x_train, labels=y_train, d=d,
                                             N_train=N_train, radiuses=[10e-12])
                Products[i:i+K] = Product
                Frobeniuses[i:i + K] = Frobenius
                HessianTraces[i:i + K] = HessianTrace
                CovarianceTraces[i:i + K] = CovarianceTrace
                AccCovariance += NewCovariance
                AccProducts[i:i + K] = torch.trace(torch.mm(Hessian, AccCovariance))
            if i % KBS == 0 and ComputeBL == True:
                BL = BootstrapLoss(model, inputs, labels, N_train, 1)
                BootstrapLosses[i:i+KBS] = BL
            i += 1


    print("SGD AccCovariance is {}".format(AccCovariance))
    print("SGD AccCovariance Trace is {}".format(torch.trace(AccCovariance)))
    print('SGD runtime is {}'.format(time.time()-starting_time))
    return ModelTraj, TrainLosses, TestLosses, GradientVariances, GradientNorms, BootstrapLosses, Products, Frobeniuses, HessianTraces, CovarianceTraces, AccProducts, Hessian