import copy
import numpy as np
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler
from GradientVariance import GradientVariance
from BootstrapLoss import BootstrapLoss
from Trace import LimitNeighborLossScaledTrace

def CVSGD(ini_model, x_train, y_train, x_test, y_test, d, N_train, learningrate, eps, bs, eva_bs, Replacement, seed, K, KBS, NumCan, ComputeGV, ComputeBL):
#def SGD(learningrate=LearningRate, eps=epochs, bs=train_bs, seed=ManualSeed):
    torch.manual_seed(seed)
    model = copy.deepcopy(ini_model)
    ites = int(eps * N_train / bs)
    ModelTraj = torch.zeros((ites, sum(p.numel() for p in model.parameters())))
    TrainLosses = np.zeros(ites)
    TestLosses = np.zeros(ites)
    GradientVariances = np.zeros(ites)
    BootstrapLosses = np.zeros(ites)
    Products = np.zeros(ites)
    Frobeniuses = np.zeros(ites)
    HessianTraces = np.zeros(ites)
    CovarianceTraces = np.zeros(ites)

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learningrate)

    #Converting the training inputs and labels to Variable
    if torch.cuda.is_available():
        inputs = Variable(torch.from_numpy(x_train).cuda())
        labels = Variable(torch.from_numpy(y_train).cuda())
    else:
        inputs = Variable(torch.from_numpy(x_train))
        labels = Variable(torch.from_numpy(y_train))
    dataset = TensorDataset(inputs, labels)
    train_dataset = dataset
    # print(inputs)
    #print(dataset[0])

    # Converting the test inputs and labels to Variable
    if torch.cuda.is_available():
        test_inputs = Variable(torch.from_numpy(x_test).cuda())
        test_labels = Variable(torch.from_numpy(y_test).cuda())
    else:
        test_inputs = Variable(torch.from_numpy(x_test))
        test_labels = Variable(torch.from_numpy(y_test))
    #Create DataLoader
    if Replacement:
        sampler = RandomSampler(train_dataset, replacement=True, num_samples=N_train)
        train_loader = DataLoader(dataset=train_dataset, batch_size=NumCan * bs, sampler=sampler)
    else:
        train_loader = DataLoader(dataset=train_dataset, batch_size=NumCan * bs, shuffle=True)
    evaluate_loader = DataLoader(dataset=train_dataset, batch_size=eva_bs, shuffle=True)
    i = 0
    for epoch in range(eps * NumCan):
        for x, y in train_loader:
            Inputs = [[] for k in range(NumCan)]
            Labels = [[] for k in range(NumCan)]
            #x1, y1 = x[0:bs], y[0:bs]
            #x2, y2 = x[bs:2*bs], y[bs:2*bs]
            for j in range(NumCan):
                Inputs[j], Labels[j] = x[j*bs:(j+1)*bs], y[j*bs:(j+1)*bs]
            #print('x1',x1)
            #print('x2',x2)
            #Clear gradient buffers because we don't want any gradient from previous epoch to carry forward, dont want to cummulate gradients
            optimizer.zero_grad()
            #print("i={}".format(i))
            #i+=1
            '''
            #Compute the gradient for mini-batch 1
            y_pred1 = model(x1)
            loss1 = criterion(y_pred1, y1)
            loss1.backward()
            Gradient1 = torch.tensor([])
            for p in model.parameters():
                if p.requires_grad:
                    Gradient1 = torch.cat((Gradient1, p.grad.view(-1, 1)), 0)
            optimizer.zero_grad()
            # Compute the gradient for mini-batch 2
            y_pred2 = model(x2)
            loss2 = criterion(y_pred2, y2)
            loss2.backward()
            Gradient2 = torch.tensor([])
            for p in model.parameters():
                if p.requires_grad:
                    Gradient2 = torch.cat((Gradient2, p.grad.view(-1, 1)), 0)
            optimizer.zero_grad()
            '''
            # Compute the full batch gradient
            FullGradient = torch.tensor([])
            outputs_pred = model(inputs)
            loss = criterion(outputs_pred, labels)
            loss.backward()
            for p in model.parameters():
                if p.requires_grad:
                    FullGradient = torch.cat((FullGradient, p.grad.view(-1, 1)), 0)
            optimizer.zero_grad()
            # Compute the gradient for each minibatch
            Gradients = [torch.tensor([]) for k in range(NumCan)]
            for j in range(NumCan):
                y_pred = model(Inputs[j])
                loss = criterion(y_pred, Labels[j])
                loss.backward()
                for p in model.parameters():
                    if p.requires_grad:
                        Gradients[j] = torch.cat((Gradients[j], p.grad.view(-1, 1)), 0)
                optimizer.zero_grad()
            #Compute the gradient for the eva mini-batch
            x_eva, y_eva = iter(evaluate_loader).next()
            y_pred_eva = model(x_eva)
            loss_eva = criterion(y_pred_eva, y_eva)
            loss_eva.backward()
            Gradient_eva = torch.tensor([])
            for p in model.parameters():
                if p.requires_grad:
                    Gradient_eva = torch.cat((Gradient_eva, p.grad.view(-1, 1)), 0)
            optimizer.zero_grad()
            '''
            # Compare the gradient inner products
            if Gradient1.T @ Gradient_eva > Gradient2.T @ Gradient_eva:
                y_pred1 = model(x1)
                loss = criterion(y_pred1, y1)
            else:
                y_pred2 = model(x2)
                loss = criterion(y_pred2, y2)
            '''
            # Compare the gradient inner products
            InnerProducts = torch.zeros(NumCan)
            for j in range(NumCan):
                InnerProducts[j] = Gradients[j].T @ Gradient_eva
                #InnerProducts[j] = Gradients[j].T @ (N_train * FullGradient - bs * Gradients[j])
            Choice = torch.argmax(InnerProducts)
            y_pred = model(Inputs[Choice])
            loss = criterion(y_pred, Labels[Choice])
            #Update parameters
            loss.backward()
            optimizer.step()

            # Save the model trajectory
            Parameter = torch.tensor([])
            for p in model.parameters():
                if p.requires_grad:
                    Parameter = torch.cat((Parameter, p.view(1, -1)), 1)
            ModelTraj[i] = Parameter
            #ModelTraj[epoch + 1] = model.linear.weight.clone()
            #elif Model == 'ReLUNet':
            #    ModelTraj[epoch + 1] = model.layers[0].weight.clone()
            train_loss = criterion(model(inputs), labels).item()
            optimizer.zero_grad()
            TrainLosses[i] = train_loss
            test_loss = criterion(model(test_inputs), test_labels).item()
            optimizer.zero_grad()
            TestLosses[i] = test_loss
            if i % K == 0 and ComputeGV == True:
                GV, _ = GradientVariance(model, inputs, labels, N_train)
                GradientVariances[i:i+K] = GV
                Product, Frobenius, HessianTrace, CovarianceTrace = LimitNeighborLossScaledTrace(test_model=model, inputs=x_train, labels=y_train, d=d,
                                                       N_train=N_train, radiuses=[10e-12])
                Products[i:i + K] = Product
                Frobeniuses[i:i + K] = Frobenius
                HessianTraces[i:i + K] = HessianTrace
                CovarianceTraces[i:i + K] = CovarianceTrace
            if i % KBS == 0 and ComputeBL == True:
                BL = BootstrapLoss(model, inputs, labels, N_train, 1)
                BootstrapLosses[i:i+KBS] = BL
            i += 1
            #print(i)


    return ModelTraj, TrainLosses, TestLosses, GradientVariances, BootstrapLosses, Products, Frobeniuses, HessianTraces, CovarianceTraces