import argparse
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import Subset, DataLoader
import torch.optim as optim
from derivatives import derivatives
import matplotlib.pyplot as plt
from vgg import VGG11, VGG13, VGG16, VGG19
from resnet_kuang import ResNet18, ResNet34
import numpy as np

derivs_size = 2000
loss_size = 5000


num_trials = 5

batchsizes = [64, 128, 256 ,512, 1024]
lrs = [0.1, 0.01, 0.001]

criterion = nn.CrossEntropyLoss()

def wd(net):
    return sum([torch.linalg.vector_norm(p)**2 for p in net.parameters()])

def train_loss_acc(net, loader, acc = False):
    loss = 0.
    total = 0
    correct = 0
    with torch.no_grad():
        for inputs, targets in loader:

            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = net(inputs)

            if acc:
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

            loss += loss_size * criterion(outputs, targets)
        if acc:
            return (1./50000)*loss.cpu().numpy(), correct / total
        else:
            return (1./50000)*loss.cpu().numpy()

def test_loss_acc(net, loader, acc = False):
    loss = 0.
    total = 0
    correct = 0
    with torch.no_grad():
        for inputs, targets in loader:

            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = net(inputs)

            if acc:
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

            loss += loss_size * criterion(outputs, targets)
    if acc:
        return (1./10000)*loss.cpu().numpy(), correct / total
    else:
        return (1./10000)*loss.cpu().numpy()



device = 'cuda:0'

hessvals = np.zeros((num_trials, len(batchsizes), len(lrs)))
jacvals1train = np.zeros((num_trials, len(batchsizes), len(lrs)))
jacvals2train = np.zeros((num_trials, len(batchsizes), len(lrs)))
jacvals1eval = np.zeros((num_trials, len(batchsizes), len(lrs)))
jacvals2eval = np.zeros((num_trials, len(batchsizes), len(lrs)))
testacc = np.zeros((num_trials, len(batchsizes), len(lrs)))
trainacc = np.zeros((num_trials, len(batchsizes), len(lrs)))
finlosstrain = np.zeros((num_trials, len(batchsizes), len(lrs)))
finlosstest = np.zeros(((num_trials, len(batchsizes), len(lrs))))

def main(model, weight_decay, dataset):

    if dataset == 'cifar10':
        data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

        trainset = datasets.CIFAR10('./data', train = True, transform = data_transform)
        testset = datasets.CIFAR10('./data', train = False, transform = data_transform)

        trainloader = DataLoader(trainset, loss_size, num_workers = 4)
        testloader = DataLoader(testset, loss_size, num_workers = 4)

        num_classes = 10

    elif dataset == 'cifar100':
        data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5071, 0.4866, 0.4409), (0.2673, 0.2564, 0.2762))])

        trainset = datasets.CIFAR100('./data', train = True, transform = data_transform)
        testset = datasets.CIFAR100('./data', train = False, transform = data_transform)

        trainloader = DataLoader(trainset, loss_size, num_workers = 4)
        testloader = DataLoader(testset, loss_size, num_workers = 4)

        num_classes = 100

    os.makedirs(f'./batch_lr/{dataset}/{model}/{weight_decay:g}/', exist_ok=True)

    for i in range(num_trials):

        if model == 'vgg11':
            net = VGG11(num_classes).to(device)
        elif model == 'vgg13':
            net = VGG13(num_classes).to(device)
        elif model == 'vgg19':
            net = VGG19(num_classes).to(device)
        elif model == 'resnet18':
            net = ResNet18(num_classes).to(device)
        elif model == 'resnet34':
            net = ResNet34(num_classes).to(device)

        torch.save(net.state_dict(), f'./batch_lr/{dataset}/{model}/{weight_decay:g}/params_{i}.pt')
        
        subset = Subset(trainset, torch.randint(0, 50000, (derivs_size,)))
        subloader = DataLoader(subset, derivs_size)

        for b in range(len(batchsizes)):

            for lr in range(len(lrs)):

                loader = DataLoader(trainset, batchsizes[b], shuffle = True, drop_last = True)
                
                net.load_state_dict(torch.load(f'./batch_lr/{dataset}/{model}/{weight_decay:g}/params_{i}.pt'))

                optimiser = torch.optim.SGD(net.parameters(), lrs[lr])

                net.train()

                losses = []
                lossavg = 2.
                acc = 0.
                e = 0
                j = 0

                if dataset == 'cifar10':
                    threshold = 1e-2
                elif dataset == 'cifar100':
                    threshold = 2e-2

                while lossavg >threshold:

                    print('Trial: ', i, 'Epoch: ', e, 'Batchsize:, ', batchsizes[b], 'Learning rate: ', lrs[lr])

                    total = 0
                    correct = 0

                    for inputs, targets in loader:

                        if j % 100 == 0 and acc >= 0.95:
                            losses.append(train_loss_acc(net, trainloader))
                            lossavg = np.array(losses[-10:]).mean()

                        net.train()
                        optimiser.zero_grad()
                        inputs = inputs.to(device)
                        targets = targets.to(device)
                        
                        outputs = net(inputs)
                        _, predicted = torch.max(outputs, 1)
                        total += targets.size(0)
                        correct += (predicted == targets).sum().item()

                        loss = criterion(outputs, targets) + weight_decay * wd(net)
                        loss.backward()
                        optimiser.step()

                        j += 1

                    e += 1
                    if j >= 61000:
                        break

                    acc = correct / total
                    print('Train acc: ', acc)

                tr_loss_acc = train_loss_acc(net, trainloader, acc = True)
                te_loss_acc = test_loss_acc(net, testloader, acc = True)

                finlosstrain[i, b, lr] = tr_loss_acc[0]
                trainacc[i, b, lr] = tr_loss_acc[1]

                finlosstest[i, b, lr] = te_loss_acc[0]
                testacc[i, b, lr] = te_loss_acc[1]

                # compute derivative quantities:

                d = derivatives(net, criterion, [derivs_size, 3, 32, 32], [derivs_size, num_classes], device)

                for inputs, targets in subloader:
                    inputs = inputs.to(device)
                    targets = targets.to(device)

                data = (inputs, targets)

                d.update(data)

                net.train()

                hessvals[i, b, lr] = d.power('H')
                jacvals1train[i, b, lr] = d.power('jac1train')
                jacvals2train[i, b, lr] = d.power('jac2train')
                jacvals1eval[i, b, lr] = d.power('jac1eval')
                jacvals2eval[i, b, lr] = d.power('jac2eval')

    np.save(f'./batch_lr/{dataset}/{model}/{weight_decay:g}/hess.npy', hessvals)
    np.save(f'./batch_lr/{dataset}/{model}/{weight_decay:g}/jac1train.npy', jacvals1train)
    np.save(f'./batch_lr/{dataset}/{model}/{weight_decay:g}/jac2train.npy', jacvals2train)
    np.save(f'./batch_lr/{dataset}/{model}/{weight_decay:g}/jac1eval.npy', jacvals1eval)
    np.save(f'./batch_lr/{dataset}/{model}/{weight_decay:g}/jac2eval.npy', jacvals2eval)
    np.save(f'./batch_lr/{dataset}/{model}/{weight_decay:g}/trainloss.npy', finlosstrain)
    np.save(f'./batch_lr/{dataset}/{model}/{weight_decay:g}/testloss.npy', finlosstest)
    np.save(f'./batch_lr/{dataset}/{model}/{weight_decay:g}/trainacc.npy', trainacc)
    np.save(f'./batch_lr/{dataset}/{model}/{weight_decay:g}/testacc.npy', testacc)
                
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train using gradient descent.")
    parser.add_argument('model', type = str, choices = ['vgg11', 'vgg13', 'vgg19', 'resnet18', 'resnet34'])
    parser.add_argument('weight_decay', type = float)
    parser.add_argument('dataset', type = str, choices = ['cifar10', 'cifar100'])
    args = parser.parse_args()

    main(model = args.model, weight_decay=args.weight_decay, dataset = args.dataset)




                


        
# loader = DataLoader(trainset, 128, shuffle = True)

# size = 1000

# device = 'cuda:0'

# torch.manual_seed(0)
# net = VGG11().to(device)
# torch.save(net.state_dict(), 'params.pt')

# scalars = [0.0, 0.5, 0.75]
# # scalars = [1.0, 0.5, 0.25]

# # criterion = nn.CrossEntropyLoss()
# # criterion = nn.MSELoss()

# # d = derivatives(net, criterion, [size, 3, 32, 32], [size, 100], device)

# epochs = 1
# optimizer = optim.SGD(net.parameters(), lr = 0.1)
# # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 15, gamma = 0.1)


# for sc in scalars:

#     lossvals = []
#     hessvals = []
#     HmGNvals = []
#     ntkvals = []
#     GNvals = []
#     jac1train = []
#     jac2train = []
#     jac1eval = []
#     jac2eval = []

#     criterion = nn.CrossEntropyLoss(label_smoothing = sc)

#     d = derivatives(net, criterion, [size, 3, 32, 32], [size, 10], device)

#     net.load_state_dict(torch.load('params.pt'))

#     for e in range(epochs):

#         # net.load_state_dict(torch.load('params.pt'))

#         indices = torch.randint(0, 50000, size = [size])
#         # indices = [i for i in range(size)]
#         subset = Subset(trainset, indices)
#         subloader = DataLoader(subset, size)

#         for data, labels in subloader:
#             data = data.to(device)
#             labels = labels.to(device)
#             # onehotlabels = sc * F.one_hot(labels).type(torch.float32)

#         # data = (data, onehotlabels)
#         data = (data, labels)

#         if e == 0:
#             for i in range(100):
#                 net(data[0])
        
#         d.update(data)

#         print('Epoch: {}'.format(e))
#         correct = 0
#         total = 0

#         step = 0

#         for data in loader:

#             if step % 5 == 0:

#                 hessvals.append(d.power('H'))
#                 # ntkvals.append((1./size)*d.power('NTK'))
#                 # GNvals.append(d.power('GN'))
#                 # # HmGNvals.append(d.power('H-GN'))
#                 # # jac1train.append(d.power('jac1train'))
#                 # # jac2train.append(d.power('jac2train'))
#                 # jac1eval.append(d.power('jac1eval'))
#                 # jac2eval.append(d.power('jac2eval'))

#             net.train()

#             optimizer.zero_grad()

#             inputs, labels = data[0].to(device), data[1].to(device)
#             outputs = net(inputs)

#             _, predicted = torch.max(outputs, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#             loss = criterion(outputs, labels)
#             lossvals.append(loss.detach().cpu().numpy())
#             loss.backward()
#             optimizer.step()
#             step += 1

#         # scheduler.step()
#         print('Training accuracy: {}'.format(correct / total))

#     np.save(f'./vgg/smoothing/loss_{sc:g}.npy', np.array(lossvals))
#     np.save(f'./vgg/smoothing/hess_{sc:g}.npy', np.array(hessvals))
#     np.save(f'./vgg/smoothing/GN_{sc:g}.npy', np.array(GNvals))
#     np.save(f'./vgg/smoothing/HmGN_{sc:g}.npy', np.array(HmGNvals))
#     np.save(f'./vgg/smoothing/ntk_{sc:g}.npy', np.array(ntkvals))
#     np.save(f'./vgg/smoothing/jac1train_{sc:g}.npy', np.array(jac1train))
#     np.save(f'./vgg/smoothing/jac2train_{sc:g}.npy', np.array(jac2train))
#     np.save(f'./vgg/smoothing/jac1eval_{sc:g}.npy', np.array(jac1eval))
#     np.save(f'./vgg/smoothing/jac2eval_{sc:g}.npy', np.array(jac2eval))

